PyTorch使用中需要注意的地方

Posted yezzz

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了PyTorch使用中需要注意的地方相关的知识,希望对你有一定的参考价值。

参考博客:

https://blog.csdn.net/u011276025/article/details/73826562/

 

1. 把Label要转成LongTensor格式

self.y = torch.LongTensor(y)

完整使用代码如下:

技术图片
 1 class ImgDataset(Dataset):
 2     def __init__(self, x, y=None, transform=None):
 3         self.x = x
 4         # label is required to be a LongTensor
 5         self.y = y
 6         if y is not None:
 7             self.y = torch.LongTensor(y)
 8         self.transform = transform
 9     def __len__(self):
10         return len(self.x)
11     def __getitem__(self, index):
12         X = self.x[index]
13         if self.transform is not None:
14             X = self.transform(X)
15         if self.y is not None:
16             Y = self.y[index]
17             return X, Y
18         else:
19             return X
View Code
技术图片
 1 class ImgDataset(Dataset):
 2     def __init__(self, x, y=None, transform=None):
 3         self.x = x
 4         # label is required to be a LongTensor
 5         self.y = y
 6         if y is not None:
 7             self.y = torch.LongTensor(y)
 8         self.transform = transform
 9     def __len__(self):
10         return len(self.x)
11     def __getitem__(self, index):
12         X = self.x[index]
13         if self.transform is not None:
14             X = self.transform(X)
15         if self.y is not None:
16             Y = self.y[index]
17             return X, Y
18         else:
19             return X
View Code

需要保证target类型为torch.cuda.LongTensor,需要在数据读取的迭代其中把target的类型转换为int64位的:target = target.astype(np.int64),这样,输出的target类型为torch.cuda.LongTensor。(或者在使用前使用Tensor.type(torch.LongTensor)进行转换)。

*LongTensor其实就是int64,有符号整型

 

 

2. 做预测时,没有y值,从dataloader中传入给model的直接是data,而不再是data[0]了

model_best.eval()
prediction = []
with torch.no_grad():
    for i, data in enumerate(test_loader):
        #print(data[0].size())
        # 特别要注意的是,这里直接传入data,因为已经没有y值了,所以无需data[0]。
        # 如果传了data[0]反而导致没有传入整个batch,计算错误
        test_pred = model_best(data.cuda())
        test_label = np.argmax(test_pred.cpu().data.numpy(), axis=1)
        for y in test_label:
            prediction.append(y)

 

未完待续。。。

以上是关于PyTorch使用中需要注意的地方的主要内容,如果未能解决你的问题,请参考以下文章

使用Pytorch实现VGG的一般版本

使用移动CMPP2.0协议关于企业代码字段需要注意的地方

关于pytorch中inplace运算需要注意的问题

torchline:让Pytorch使用的更加顺滑

Pytorch中多GPU训练指北

在片段中使用 enableAutoManage()