Pytorch之数据处理

Posted betterthanever_victor

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了Pytorch之数据处理相关的知识,希望对你有一定的参考价值。

使用TensorDataset和DataLoader来简化

 
from torch.utils.data import TensorDataset
from torch.utils.data import DataLoader
?
train_ds = TensorDataset(x_train, y_train)
train_dl = DataLoader(train_ds, batch_size=bs, shuffle=True)
?
valid_ds = TensorDataset(x_valid, y_valid)
valid_dl = DataLoader(valid_ds, batch_size=bs * 2)
 
def get_data(train_ds, valid_ds, bs):
    return (
        DataLoader(train_ds, batch_size=bs, shuffle=True),
        DataLoader(valid_ds, batch_size=bs * 2),
    )
 
 
 
  • 一般在训练模型时加上model.train(),这样会正常使用Batch Normalization和 Dropout
  • 测试的时候一般选择model.eval(),这样就不会使用Batch Normalization和 Dropout

 

 

import numpy as np
?
def fit(steps, model, loss_func, opt, train_dl, valid_dl):
    for step in range(steps):
        model.train()
        for xb, yb in train_dl:
            loss_batch(model, loss_func, xb, yb, opt)
?
        model.eval()
        with torch.no_grad():
            losses, nums = zip(
                *[loss_batch(model, loss_func, xb, yb) for xb, yb in valid_dl]
            )
        val_loss = np.sum(np.multiply(losses, nums)) / np.sum(nums)
        print(‘当前step:‘+str(step), ‘验证集损失:‘+str(val_loss))
 
 
from torch import optim
def get_model():
    model = Mnist_NN()
    return model, optim.SGD(model.parameters(), lr=0.001)
 
 
def loss_batch(model, loss_func, xb, yb, opt=None):
    loss = loss_func(model(xb), yb)
?
    if opt is not None:
        loss.backward()
        opt.step()
        opt.zero_grad()
?
    return loss.item(), len(xb)
 
 

 

 

三行搞定!

train_dl, valid_dl = get_data(train_ds, valid_ds, bs)
model, opt = get_model()
fit(25, model, loss_func, opt, train_dl, valid_dl)
 
 
 
 
 
?

 

以上是关于Pytorch之数据处理的主要内容,如果未能解决你的问题,请参考以下文章

PyTorch 之 神经网络 Mnist 分类任务

PyTorch 之 强大的 hub 模块和搭建神经网络进行气温预测

19.初识Pytorch之完整的模型套路-整理后的代码 Complete model routine - compiled code

pytorch 学习笔记之编写 C 扩展,又涨姿势了

基于Pytorch的神经网络之autoencoder

小白学习之pytorch框架之实战Kaggle比赛:房价预测(K折交叉验证*args**kwargs)