torch加载参数
Posted junzhaoliang
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了torch加载参数相关的知识,希望对你有一定的参考价值。
1 from torch.utils.data import DataLoader
2 from torchvision import datasets
3 from PIL import Image as img
4
5 dataPath = ‘./data/imgs/‘
6
7 dataset = datasets.ImageFolder(‘./data/‘, loader=img.open)
8 dataloader = DataLoader(dataset, batch_size=32, shuffle=True, num_workers=0)
9
10 # 方式一
11 for epoch in range(100):
12 for i, (img, _)in enumerate(dataloader):
13 # do training
14
15 # 方式二
16
17 def data_gen(data_loader):
18 while True:
19 for (images, _) in enumerate(data_loader):
20 yield images
21
22 gen_img = data_gen(dataloader)
23
24 for iter in range(100):
25 imgs = gen_img.__next__()
26 # do training
以上是关于torch加载参数的主要内容,如果未能解决你的问题,请参考以下文章