python 重现原始deeplabv2结果(76.3)

Posted

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了python 重现原始deeplabv2结果(76.3)相关的知识,希望对你有一定的参考价值。

import cv2
import numpy as np
import torch
from torch.autograd import Variable
import os
import torch.nn as nn
from libs.models import DeepLabV2_ResNet101_MSC
from tqdm import tqdm
from libs.utils import scores


model_path = 'data/models/deeplab_resnet101/voc12/deeplabv2_resnet101_VOC2012_trainaug.pth'
im_path = '/data1/jayzjwang/data/datasets/voc/img/ori'
gt_path = '/data1/jayzjwang/data/datasets/voc/gt'
img_list = open('/data1/jayzjwang/data/datasets/voc/list/val.txt').readlines()

model = DeepLabV2_ResNet101_MSC(n_classes=21)
model.load_state_dict(
    torch.load(model_path, map_location=lambda storage, loc: storage)
)
model.eval()
model.cuda()

gts, outputs = [], []

for idx, i in tqdm(
    enumerate(img_list),
    total=len(img_list),
    leave=False,
    dynamic_ncols=True,
):
    img = np.zeros((513,513,3))
    img_temp = cv2.imread(os.path.join(im_path,i[:-1]+'.jpg')).astype(float)
    img_temp[:,:,0] = img_temp[:,:,0] - 104.008
    img_temp[:,:,1] = img_temp[:,:,1] - 116.669
    img_temp[:,:,2] = img_temp[:,:,2] - 122.675
    img[:img_temp.shape[0],:img_temp.shape[1],:] = img_temp

    gt = cv2.imread(os.path.join(gt_path,i[:-1]+'.png'),0)

    with torch.no_grad():
        output = model(Variable(torch.from_numpy(img[np.newaxis, :].transpose(0,3,1,2)).float().cuda()))

    interp = nn.Upsample(size=(513, 513), mode='bilinear', align_corners=True)
    output = interp(output).cpu().data[0].numpy()
    output = output[:,:img_temp.shape[0],:img_temp.shape[1]]
    
    output = output.transpose(1,2,0)
    output = np.argmax(output,axis = 2)
    
    outputs.append(output)
    gts.append(gt)

miou, _ = scores(gts, outputs, n_class=21)
print("miou = {}".format(miou["Mean IoU"]))

以上是关于python 重现原始deeplabv2结果(76.3)的主要内容,如果未能解决你的问题,请参考以下文章

DeepLabV2网络简析

使用 Iris 数据集使用 Python 在 R 中重现 LASSO / Logistic 回归结果

RandomizedSearchCV 和 GridsearchCV 结果不可重现

Jenkins one_at_a_time hash - 试图让 Python 代码重现 JavaScript 代码

使用 Python Pandas 合并距离矩阵结果和原始索引

使用 Python 绘制 PCA 结果,包括带有散点图的原始数据