Pytorch transforms.RandomRotation() 在 Google Colab 上不起作用
Posted
技术标签:
【中文标题】Pytorch transforms.RandomRotation() 在 Google Colab 上不起作用【英文标题】:Pytorch transforms.RandomRotation() does not work on Google Colab 【发布时间】:2020-05-29 01:02:16 【问题描述】:通常我在计算机上进行字母和数字识别,我想将我的项目移动到 Colab,但不幸的是出现了错误(您可以在下面看到错误)。 经过一些调试,我发现哪一行给了我错误。
transforms.RandomRotation(degrees=(90, -90))
下面我编写了简单的抽象代码来显示此错误。此代码在 colab 上不起作用,但在我自己的计算机环境中运行良好。问题可能与我计算机上的 1.3.1 版本的 pytorch 库的不同版本有关而 colab 使用的是 1.4.0 版本。
import torch
import torchvision
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
transformOpt = transforms.Compose([
transforms.RandomRotation(degrees=(90, -90)),
transforms.ToTensor()
])
train_set = datasets.MNIST(
root='', train=True, transform=transformOpt, download=True)
test_set = datasets.MNIST(
root='', train=False, transform=transformOpt, download=True)
train_loader = torch.utils.data.DataLoader(
dataset=train_set,
batch_size=100,
shuffle=True)
test_loader = torch.utils.data.DataLoader(
dataset=test_set,
batch_size=100,
shuffle=False)
images, labels = next(iter(train_loader))
plt.imshow(images[0].view(28, 28), cmap="gray")
plt.show()
我在 Google Colab 上执行上述示例代码时遇到的完整错误。
TypeError Traceback (most recent call last)
<ipython-input-1-8409db422154> in <module>()
24 shuffle=False)
25
---> 26 images, labels = next(iter(train_loader))
27 plt.imshow(images[0].view(28, 28), cmap="gray")
28 plt.show()
10 frames
/usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py in __next__(self)
343
344 def __next__(self):
--> 345 data = self._next_data()
346 self._num_yielded += 1
347 if self._dataset_kind == _DatasetKind.Iterable and \
/usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py in _next_data(self)
383 def _next_data(self):
384 index = self._next_index() # may raise StopIteration
--> 385 data = self._dataset_fetcher.fetch(index) # may raise StopIteration
386 if self._pin_memory:
387 data = _utils.pin_memory.pin_memory(data)
/usr/local/lib/python3.6/dist-packages/torch/utils/data/_utils/fetch.py in fetch(self, possibly_batched_index)
42 def fetch(self, possibly_batched_index):
43 if self.auto_collation:
---> 44 data = [self.dataset[idx] for idx in possibly_batched_index]
45 else:
46 data = self.dataset[possibly_batched_index]
/usr/local/lib/python3.6/dist-packages/torch/utils/data/_utils/fetch.py in <listcomp>(.0)
42 def fetch(self, possibly_batched_index):
43 if self.auto_collation:
---> 44 data = [self.dataset[idx] for idx in possibly_batched_index]
45 else:
46 data = self.dataset[possibly_batched_index]
/usr/local/lib/python3.6/dist-packages/torchvision/datasets/mnist.py in __getitem__(self, index)
95
96 if self.transform is not None:
---> 97 img = self.transform(img)
98
99 if self.target_transform is not None:
/usr/local/lib/python3.6/dist-packages/torchvision/transforms/transforms.py in __call__(self, img)
68 def __call__(self, img):
69 for t in self.transforms:
---> 70 img = t(img)
71 return img
72
/usr/local/lib/python3.6/dist-packages/torchvision/transforms/transforms.py in __call__(self, img) 1001 angle = self.get_params(self.degrees) 1002
-> 1003 return F.rotate(img, angle, self.resample, self.expand, self.center, self.fill) 1004 1005 def
__repr__(self):
/usr/local/lib/python3.6/dist-packages/torchvision/transforms/functional.py in rotate(img, angle, resample, expand, center, fill)
727 fill = tuple([fill] * 3)
728
--> 729 return img.rotate(angle, resample, expand, center, fillcolor=fill)
730
731
/usr/local/lib/python3.6/dist-packages/PIL/Image.py in rotate(self, angle, resample, expand, center, translate, fillcolor) 2003 w, h = nw, nh 2004
-> 2005 return self.transform((w, h), AFFINE, matrix, resample, fillcolor=fillcolor) 2006 2007 def save(self, fp, format=None, **params):
/usr/local/lib/python3.6/dist-packages/PIL/Image.py in transform(self, size, method, data, resample, fill, fillcolor) 2297 raise ValueError("missing method data") 2298
-> 2299 im = new(self.mode, size, fillcolor) 2300 if method == MESH: 2301 # list of quads
/usr/local/lib/python3.6/dist-packages/PIL/Image.py in new(mode, size, color) 2503 im.palette = ImagePalette.ImagePalette() 2504 color = im.palette.getcolor(color)
-> 2505 return im._new(core.fill(mode, size, color)) 2506 2507
TypeError: function takes exactly 1 argument (3 given)
【问题讨论】:
您能粘贴完整的错误跟踪吗? @kHarshit 谢谢,我添加了完整的错误跟踪。顺便说一句,您可以将此示例代码粘贴到您自己的 colab 中并自己查看错误。我发现错误来自我提到的那一行,但我不知道如何解决它。 【参考方案1】:你完全正确。 torchvision 0.5 在fill
参数中的RandomRotation()
中有一个错误,可能是由于Pillow 版本不兼容。此issue 现已修复 (PR#1760),将在下一个版本中解决。
暂时,您将fill=(0,)
添加到RandomRotation
转换以修复它。
transforms.RandomRotation(degrees=(90, -90), fill=(0,))
【讨论】:
嗨 kHarshit,您使用的是哪个版本的 PIL...我有 '5.0.0' 并且通过 fill=(0,) 并不能解决我的问题...谢谢 @LuisCandanedo 您现在可以升级到 torchvision v0.6.0,或查看 github 问题页面。我不确定使用的是什么版本的 PIL。 我只想说这个错误在 2020 年仍然是一个问题,但至少您在此处列出的解决方法仍然有效。以上是关于Pytorch transforms.RandomRotation() 在 Google Colab 上不起作用的主要内容,如果未能解决你的问题,请参考以下文章