python test_custom_transform.py

Posted

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了python test_custom_transform.py相关的知识,希望对你有一定的参考价值。

import torchvision.datasets as datasets
import torchvision.transforms as transforms
import torch
import torchvision
import matplotlib.pyplot as plt
import numpy as np

from custom_transforms import NRandomCrop


def imshow(img):
    img = img / 2 + 0.5
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.show()


mean, sd = (0.5, 0.5, 0.5), (0.5, 0.5, 0.5)
transform = transforms.Compose(
    [NRandomCrop(size=32, n=5, padding=4),
     transforms.Lambda(
         lambda crops: torch.stack([transforms.Normalize(mean, sd)(transforms.ToTensor()(crop)) for crop in crops])),
     ]
)


train_data = datasets.CIFAR10(root='./data',
                                  train=True,
                                  download=True,
                                  transform=transform)
train_loader = torch.utils.data.DataLoader(train_data,
                                           batch_size=1,
                                           shuffle=True,
                                           num_workers=4)
classes = ('plane', 'car', 'bird', 'cat',
               'deer', 'dog', 'frog', 'horse', 'ship', 'truck')


dataiter = iter(train_loader)
images, labels = dataiter.next()

# show images
imshow(torchvision.utils.make_grid(images.squeeze(0)))
# print labels
print(' '.join('%5s' % classes[labels[j]] for j in range(0)))

以上是关于python test_custom_transform.py的主要内容,如果未能解决你的问题,请参考以下文章

Python代写,Python作业代写,代写Python,代做Python

Python开发

Python,python,python

Python 介绍

Python学习之认识python

python初识