2. PyTorch Tutorial 1
Posted Shannnon_sun
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了2. PyTorch Tutorial 1相关的知识,希望对你有一定的参考价值。
文章目录
PyTorch Tutorial 1
An machine learning framework
1. Dataset&DataLoader
from torch.utils.data import Dataset,DataLoader
class MyDataset(Dataset):
def _init_(self,file):
def _getitem_(self,index):
def _len_(self):
2. Tensors
-
x.shape()
-
创建Tensor
-
运算:x.transpose(0,1) 转置 x.squeeze(0) torch.cat([x,y,z],dim=0) 在某一维数上连接
-
Type:
-
Device:
x=x.to(‘cpu’) x=x.to(‘gpu’)
3. 定义模型:
import torch.nn as nn
class Mymodole(nn.Module):
def _init_(self):
super(Mymodel,self)._init_()
self.net=nn.sequential(
nn.Linear(10,32)
nn.Sigmoid()
nn.Linear(32,1)
)
4. 定义Loss Function
在torch.nn
nn.MSELoss()
nn.CrossEntropyLoss()
5. 优化
在torch.optim
torch.optim.SGD(model.parameters(),lr,momentum=0)
6. Neural Network Training Setup
dataset = MyDataset(file)
tr_set = DataLoader(dataset,16,shuffle=True) #dataset 放入DataLoader
model = MyModel().to(device) #模型放到(cpu/cude)
criterion = nn.MSELoss() #定义损失函数
optimizer = torch.optim.SGD(model.parameters(),0.1) #优化器
7. Neural Network Training Loop
for epoch in range(n_epoches):
model.train() #训练
for x,y in tr_set:
optimizer.zero_grad()
x,y=x.to(device),y.to(device)
pred=model(x) #前向传播计算输出
loss = criterion(pred,y) #计算Loss
loss.backward() #反向传播计算梯度
optimizer.step() #更新参数
8. Neural Network Validation Loop
model.eval() #模型调到测试模式
total_loss = 0
for x,y in dv_set:
x,y = x.to(device),y.to(device)
with torch.no_grad(): #disable gradient calculation
pred = model(x)
loss = ceiterion(pred,y)
total_loss +=loss.cpu().item()*len(x)
avg_loss = total_loss/len(dv_set.dataset)
9. Neural Network Testing Loop
为什么把梯度的计算关掉?
- 计算快一点
- 确保不会更新模型
10. Save/Load Trained Models
- Save
torch.save(model.state_dict(),path)
-
Load
ckpt = torch.load(path) model.load_state_dict(ckpt)
以上是关于2. PyTorch Tutorial 1的主要内容,如果未能解决你的问题,请参考以下文章
PyTorch-11 进行神经风格迁移neural style tutorial
pytorch例子学习——TRANSFER LEARNING TUTORIAL
PyTorch in Action: A Step by Step Tutorial
PyTorch in Action: A Step by Step Tutorial
PyTorch 计算机视觉的迁移学习教程代码详解 (TRANSFER LEARNING FOR COMPUTER VISION TUTORIAL )