08-pytorch(批数据训练)

Posted liu247

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了08-pytorch(批数据训练)相关的知识,希望对你有一定的参考价值。

import torch
import torch.utils.data as Data
BATCH_SIZE = 5
x = torch.linspace(1,10,10)
y = torch.linspace(10,1,10)

这个是打乱数据,然后在 依次的慢慢的按步伐的取出,当不足够的时候,就吧剩下的取出来(自适应)

# 格式下x,y
torch_dataset = Data.TensorDataset(x, y)
loader = Data.DataLoader(
    # 数据
    dataset=torch_dataset,
    # 尺寸
    batch_size = BATCH_SIZE,
    # 是否事先打乱数据
    shuffle = True,
    # 采用的线程数目
    num_workers=2,
)
for epoch in range(3):
    # 第一个是次数,第二个是值
    for step,(batch_x,batch_y) in enumerate(loader):
        # training...
        print('Epoch:',epoch,'|Step',step,'|Batch x:',batch_x.numpy(),
             '|batch y:',batch_y.numpy())
Epoch: 0 |Step 0 |Batch x: [ 8. 10.  7.  9.  1.] |batch y: [ 3.  1.  4.  2. 10.]
Epoch: 0 |Step 1 |Batch x: [3. 4. 5. 6. 2.] |batch y: [8. 7. 6. 5. 9.]
Epoch: 1 |Step 0 |Batch x: [3. 9. 5. 6. 7.] |batch y: [8. 2. 6. 5. 4.]
Epoch: 1 |Step 1 |Batch x: [10.  4.  2.  1.  8.] |batch y: [ 1.  7.  9. 10.  3.]
Epoch: 2 |Step 0 |Batch x: [10.  5.  7.  1.  6.] |batch y: [ 1.  6.  4. 10.  5.]
Epoch: 2 |Step 1 |Batch x: [4. 3. 2. 9. 8.] |batch y: [7. 8. 9. 2. 3.]

以上是关于08-pytorch(批数据训练)的主要内容,如果未能解决你的问题,请参考以下文章

Pytorch中如何使用DataLoader对数据集进行批训练

MAE实现及预训练可视化 (CIFAR-Pytorch)

MAE实现及预训练可视化 (CIFAR-Pytorch)

MAE实现及预训练可视化 (CIFAR-Pytorch)

PyTorch学习批训练

pytorch学习-5:批训练+Optimizer 优化器