pytorch 之 batch_train
Posted dhname
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了pytorch 之 batch_train相关的知识,希望对你有一定的参考价值。
1 import torch 2 import torch.utils.data as Data 3 4 torch.manual_seed(1) # reproducible 5 6 BATCH_SIZE = 5 7 # BATCH_SIZE = 8 8 9 x = torch.linspace(1, 10, 10) # this is x data (torch tensor) 10 y = torch.linspace(10, 1, 10) # this is y data (torch tensor) 11 12 torch_dataset = Data.TensorDataset(x, y) 13 loader = Data.DataLoader( 14 dataset=torch_dataset, # torch TensorDataset format 15 batch_size=BATCH_SIZE, # mini batch size 16 shuffle=True, # random shuffle for training 17 num_workers=2, # subprocesses for loading data 18 ) 19 20 21 def show_batch(): 22 for epoch in range(3): # train entire dataset 3 times 23 for step, (batch_x, batch_y) in enumerate(loader): # for each training step 24 # train your data... 25 print(‘Epoch: ‘, epoch, ‘| Step: ‘, step, ‘| batch x: ‘, 26 batch_x.numpy(), ‘| batch y: ‘, batch_y.numpy()) 27 28 29 if __name__ == ‘__main__‘: 30 show_batch()
以上是关于pytorch 之 batch_train的主要内容,如果未能解决你的问题,请参考以下文章
PT之Transformer:基于PyTorch框架利用Transformer算法针对IMDB数据集实现情感分类的应用案例代码解析