torch加载参数

Posted junzhaoliang

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了torch加载参数相关的知识,希望对你有一定的参考价值。

 1 from torch.utils.data import DataLoader
 2 from torchvision import datasets
 3 from PIL import Image as img
 4 
 5 dataPath = ./data/imgs/
 6 
 7 dataset = datasets.ImageFolder(./data/, loader=img.open)
 8 dataloader = DataLoader(dataset, batch_size=32, shuffle=True, num_workers=0)
 9 
10 # 方式一
11 for epoch in range(100):
12     for i, (img, _)in enumerate(dataloader):
13         # do training
14 
15 # 方式二
16 
17 def data_gen(data_loader):
18     while True:
19         for (images, _) in enumerate(data_loader):
20             yield images
21 
22 gen_img = data_gen(dataloader)
23 
24 for iter in range(100):
25     imgs = gen_img.__next__()
26     # do training

 

以上是关于torch加载参数的主要内容,如果未能解决你的问题,请参考以下文章

在android中动态创建选项卡并使用传入的参数加载片段

PyTorch保存和加载模型

加载训练的模型参数并继续训练

Torch.load()使用方式

python如何导入pth模型

每天讲解一点PyTorch 15model.load_state_dict torch.load torch.save