Pytorch之数据处理
Posted betterthanever_victor
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了Pytorch之数据处理相关的知识,希望对你有一定的参考价值。
使用TensorDataset和DataLoader来简化
from torch.utils.data import TensorDataset
from torch.utils.data import DataLoader
?
train_ds = TensorDataset(x_train, y_train)
train_dl = DataLoader(train_ds, batch_size=bs, shuffle=True)
?
valid_ds = TensorDataset(x_valid, y_valid)
valid_dl = DataLoader(valid_ds, batch_size=bs * 2)
def get_data(train_ds, valid_ds, bs):
return (
DataLoader(train_ds, batch_size=bs, shuffle=True),
DataLoader(valid_ds, batch_size=bs * 2),
)
- 一般在训练模型时加上model.train(),这样会正常使用Batch Normalization和 Dropout
- 测试的时候一般选择model.eval(),这样就不会使用Batch Normalization和 Dropout
import numpy as np
?
def fit(steps, model, loss_func, opt, train_dl, valid_dl):
for step in range(steps):
model.train()
for xb, yb in train_dl:
loss_batch(model, loss_func, xb, yb, opt)
?
model.eval()
with torch.no_grad():
losses, nums = zip(
*[loss_batch(model, loss_func, xb, yb) for xb, yb in valid_dl]
)
val_loss = np.sum(np.multiply(losses, nums)) / np.sum(nums)
print(‘当前step:‘+str(step), ‘验证集损失:‘+str(val_loss))
from torch import optim
def get_model():
model = Mnist_NN()
return model, optim.SGD(model.parameters(), lr=0.001)
def loss_batch(model, loss_func, xb, yb, opt=None):
loss = loss_func(model(xb), yb)
?
if opt is not None:
loss.backward()
opt.step()
opt.zero_grad()
?
return loss.item(), len(xb)
三行搞定!
train_dl, valid_dl = get_data(train_ds, valid_ds, bs)
model, opt = get_model()
fit(25, model, loss_func, opt, train_dl, valid_dl)
?
以上是关于Pytorch之数据处理的主要内容,如果未能解决你的问题,请参考以下文章
PyTorch 之 强大的 hub 模块和搭建神经网络进行气温预测
19.初识Pytorch之完整的模型套路-整理后的代码 Complete model routine - compiled code