PyTorch使用中需要注意的地方
Posted yezzz
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了PyTorch使用中需要注意的地方相关的知识,希望对你有一定的参考价值。
参考博客:
https://blog.csdn.net/u011276025/article/details/73826562/
1. 把Label要转成LongTensor格式
self.y = torch.LongTensor(y)
完整使用代码如下:
1 class ImgDataset(Dataset): 2 def __init__(self, x, y=None, transform=None): 3 self.x = x 4 # label is required to be a LongTensor 5 self.y = y 6 if y is not None: 7 self.y = torch.LongTensor(y) 8 self.transform = transform 9 def __len__(self): 10 return len(self.x) 11 def __getitem__(self, index): 12 X = self.x[index] 13 if self.transform is not None: 14 X = self.transform(X) 15 if self.y is not None: 16 Y = self.y[index] 17 return X, Y 18 else: 19 return X
1 class ImgDataset(Dataset): 2 def __init__(self, x, y=None, transform=None): 3 self.x = x 4 # label is required to be a LongTensor 5 self.y = y 6 if y is not None: 7 self.y = torch.LongTensor(y) 8 self.transform = transform 9 def __len__(self): 10 return len(self.x) 11 def __getitem__(self, index): 12 X = self.x[index] 13 if self.transform is not None: 14 X = self.transform(X) 15 if self.y is not None: 16 Y = self.y[index] 17 return X, Y 18 else: 19 return X
需要保证target类型为torch.cuda.LongTensor,需要在数据读取的迭代其中把target的类型转换为int64位的:target = target.astype(np.int64),这样,输出的target类型为torch.cuda.LongTensor。(或者在使用前使用Tensor.type(torch.LongTensor)
进行转换)。
*LongTensor其实就是int64,有符号整型
2. 做预测时,没有y值,从dataloader中传入给model的直接是data,而不再是data[0]了
model_best.eval() prediction = [] with torch.no_grad(): for i, data in enumerate(test_loader): #print(data[0].size()) # 特别要注意的是,这里直接传入data,因为已经没有y值了,所以无需data[0]。 # 如果传了data[0]反而导致没有传入整个batch,计算错误 test_pred = model_best(data.cuda()) test_label = np.argmax(test_pred.cpu().data.numpy(), axis=1) for y in test_label: prediction.append(y)
未完待续。。。
以上是关于PyTorch使用中需要注意的地方的主要内容,如果未能解决你的问题,请参考以下文章