08-pytorch(批数据训练)
Posted liu247
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了08-pytorch(批数据训练)相关的知识,希望对你有一定的参考价值。
import torch
import torch.utils.data as Data
BATCH_SIZE = 5
x = torch.linspace(1,10,10)
y = torch.linspace(10,1,10)
这个是打乱数据,然后在 依次的慢慢的按步伐的取出,当不足够的时候,就吧剩下的取出来(自适应)
# 格式下x,y
torch_dataset = Data.TensorDataset(x, y)
loader = Data.DataLoader(
# 数据
dataset=torch_dataset,
# 尺寸
batch_size = BATCH_SIZE,
# 是否事先打乱数据
shuffle = True,
# 采用的线程数目
num_workers=2,
)
for epoch in range(3):
# 第一个是次数,第二个是值
for step,(batch_x,batch_y) in enumerate(loader):
# training...
print('Epoch:',epoch,'|Step',step,'|Batch x:',batch_x.numpy(),
'|batch y:',batch_y.numpy())
Epoch: 0 |Step 0 |Batch x: [ 8. 10. 7. 9. 1.] |batch y: [ 3. 1. 4. 2. 10.]
Epoch: 0 |Step 1 |Batch x: [3. 4. 5. 6. 2.] |batch y: [8. 7. 6. 5. 9.]
Epoch: 1 |Step 0 |Batch x: [3. 9. 5. 6. 7.] |batch y: [8. 2. 6. 5. 4.]
Epoch: 1 |Step 1 |Batch x: [10. 4. 2. 1. 8.] |batch y: [ 1. 7. 9. 10. 3.]
Epoch: 2 |Step 0 |Batch x: [10. 5. 7. 1. 6.] |batch y: [ 1. 6. 4. 10. 5.]
Epoch: 2 |Step 1 |Batch x: [4. 3. 2. 9. 8.] |batch y: [7. 8. 9. 2. 3.]
以上是关于08-pytorch(批数据训练)的主要内容,如果未能解决你的问题,请参考以下文章