Pytorch transforms.RandomRotation() 在 Google Colab 上不起作用

Posted

技术标签:

【中文标题】Pytorch transforms.RandomRotation() 在 Google Colab 上不起作用【英文标题】:Pytorch transforms.RandomRotation() does not work on Google Colab 【发布时间】:2020-05-29 01:02:16 【问题描述】:

通常我在计算机上进行字母和数字识别,我想将我的项目移动到 Colab,但不幸的是出现了错误(您可以在下面看到错误)。 经过一些调试,我发现哪一行给了我错误。

transforms.RandomRotation(degrees=(90, -90))

下面我编写了简单的抽象代码来显示此错误。此代码在 colab 上不起作用,但在我自己的计算机环境中运行良好。问题可能与我计算机上的 1.3.1 版本的 pytorch 库的不同版本有关而 colab 使用的是 1.4.0 版本。

import torch
import torchvision
from torchvision import datasets, transforms
import matplotlib.pyplot as plt   
    transformOpt = transforms.Compose([
            transforms.RandomRotation(degrees=(90, -90)),
            transforms.ToTensor()
        ])

    train_set = datasets.MNIST(
        root='', train=True, transform=transformOpt, download=True)
    test_set = datasets.MNIST(
        root='', train=False, transform=transformOpt, download=True)


    train_loader = torch.utils.data.DataLoader(
        dataset=train_set,
        batch_size=100,
        shuffle=True)
    test_loader = torch.utils.data.DataLoader(
        dataset=test_set,
        batch_size=100,
        shuffle=False)

    images, labels = next(iter(train_loader))
    plt.imshow(images[0].view(28, 28), cmap="gray")
    plt.show()

我在 Google Colab 上执行上述示例代码时遇到的完整错误。

TypeError                                 Traceback (most recent call last)

<ipython-input-1-8409db422154> in <module>()
     24     shuffle=False)
     25 
---> 26 images, labels = next(iter(train_loader))
     27 plt.imshow(images[0].view(28, 28), cmap="gray")
     28 plt.show()

10 frames

/usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py in __next__(self)
    343 
    344     def __next__(self):
--> 345         data = self._next_data()
    346         self._num_yielded += 1
    347         if self._dataset_kind == _DatasetKind.Iterable and \

/usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py in _next_data(self)
    383     def _next_data(self):
    384         index = self._next_index()  # may raise StopIteration
--> 385         data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
    386         if self._pin_memory:
    387             data = _utils.pin_memory.pin_memory(data)

/usr/local/lib/python3.6/dist-packages/torch/utils/data/_utils/fetch.py in fetch(self, possibly_batched_index)
     42     def fetch(self, possibly_batched_index):
     43         if self.auto_collation:
---> 44             data = [self.dataset[idx] for idx in possibly_batched_index]
     45         else:
     46             data = self.dataset[possibly_batched_index]

/usr/local/lib/python3.6/dist-packages/torch/utils/data/_utils/fetch.py in <listcomp>(.0)
     42     def fetch(self, possibly_batched_index):
     43         if self.auto_collation:
---> 44             data = [self.dataset[idx] for idx in possibly_batched_index]
     45         else:
     46             data = self.dataset[possibly_batched_index]

/usr/local/lib/python3.6/dist-packages/torchvision/datasets/mnist.py in __getitem__(self, index)
     95 
     96         if self.transform is not None:
---> 97             img = self.transform(img)
     98 
     99         if self.target_transform is not None:

/usr/local/lib/python3.6/dist-packages/torchvision/transforms/transforms.py in __call__(self, img)
     68     def __call__(self, img):
     69         for t in self.transforms:
---> 70             img = t(img)
     71         return img
     72 

/usr/local/lib/python3.6/dist-packages/torchvision/transforms/transforms.py in __call__(self, img)    1001         angle = self.get_params(self.degrees)    1002 
-> 1003         return F.rotate(img, angle, self.resample, self.expand, self.center, self.fill)    1004     1005     def
__repr__(self):

/usr/local/lib/python3.6/dist-packages/torchvision/transforms/functional.py in rotate(img, angle, resample, expand, center, fill)
    727         fill = tuple([fill] * 3)
    728 
--> 729     return img.rotate(angle, resample, expand, center, fillcolor=fill)
    730 
    731 

/usr/local/lib/python3.6/dist-packages/PIL/Image.py in rotate(self, angle, resample, expand, center, translate, fillcolor)    2003         w, h = nw, nh    2004 
-> 2005         return self.transform((w, h), AFFINE, matrix, resample, fillcolor=fillcolor)    2006     2007     def save(self,    fp, format=None, **params):

/usr/local/lib/python3.6/dist-packages/PIL/Image.py in transform(self, size, method, data, resample, fill, fillcolor)    2297             raise ValueError("missing method data")    2298 
-> 2299         im = new(self.mode, size, fillcolor)    2300         if method == MESH:    2301             # list of quads

/usr/local/lib/python3.6/dist-packages/PIL/Image.py in new(mode, size, color)    2503         im.palette = ImagePalette.ImagePalette()    2504         color = im.palette.getcolor(color)
-> 2505     return im._new(core.fill(mode, size, color))    2506     2507 

TypeError: function takes exactly 1 argument (3 given)

【问题讨论】:

您能粘贴完整的错误跟踪吗? @kHarshit 谢谢,我添加了完整的错误跟踪。顺便说一句,您可以将此示例代码粘贴到您自己的 colab 中并自己查看错误。我发现错误来自我提到的那一行,但我不知道如何解决它。 【参考方案1】:

你完全正确。 torchvision 0.5 在fill 参数中的RandomRotation() 中有一个错误,可能是由于Pillow 版本不兼容。此issue 现已修复 (PR#1760),将在下一个版本中解决。

暂时,您将fill=(0,) 添加到RandomRotation 转换以修复它。

transforms.RandomRotation(degrees=(90, -90), fill=(0,))

【讨论】:

嗨 kHarshit,您使用的是哪个版本的 PIL...我有 '5.0.0' 并且通过 fill=(0,) 并不能解决我的问题...谢谢 @LuisCandanedo 您现在可以升级到 torchvision v0.6.0,或查看 github 问题页面。我不确定使用的是什么版本的 PIL。 我只想说这个错误在 2020 年仍然是一个问题,但至少您在此处列出的解决方法仍然有效。

以上是关于Pytorch transforms.RandomRotation() 在 Google Colab 上不起作用的主要内容,如果未能解决你的问题,请参考以下文章

pytorch 中的常用矩阵操作

Pytorch Note1 Pytorch介绍

pytorch_geometric + MinkowskiEngine

1. PyTorch是什么?

1. PyTorch是什么?

对比学习:《深度学习之Pytorch》《PyTorch深度学习实战》+代码