pytorch复习
Posted 钟钟终
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了pytorch复习相关的知识,希望对你有一定的参考价值。
CIFAR10 模型结构
import torch
from torch import nn
from torch.nn import Sequential, Conv2d, MaxPool2d, Linear,Flatten
class Model(nn.Module):
def __init__(self):
super(Model,self).__init__()
self.model1=Sequential(
Conv2d(3,32,5,padding=2),
MaxPool2d(2),
Conv2d(32,32,5,padding=2),
MaxPool2d(2),
Conv2d(32,64,5,padding=2),
MaxPool2d(2),
Flatten(),
Linear(1024,64),
Linear(64,10)
)
def forward(self,x):
x=self.model1(x)
return x
model=Model()
input=torch.ones((64,3,32,32))
output=model(input)
print(output.shape)
损失函数和反向传播
import torch
import torchvision
from torch import nn
from torch.nn import Sequential, Conv2d, MaxPool2d, Linear,Flatten,CrossEntropyLoss
from torch.utils.data import DataLoader
# 加载数据集
dataset=torchvision.datasets.CIFAR10("./dataset",train=False,
transform=torchvision.transforms.ToTensor(),
download=True)
dataloader=DataLoader(dataset,batch_size=64)
# 搭建模型
class Model(nn.Module):
def __init__(self):
super(Model,self).__init__()
self.model1=Sequential(
Conv2d(3,32,5,padding=2),
MaxPool2d(2),
Conv2d(32,32,5,padding=2),
MaxPool2d(2),
Conv2d(32,64,5,padding=2),
MaxPool2d(2),
Flatten(),
Linear(1024,64),
Linear(64,10)
)
def forward(self,x):
x=self.model1(x)
return x
model=Model()
loss=CrossEntropyLoss()
optim=torch.optim.SGD(model.parameters(),lr=0.01)
for epoch in range(20):
running_loss=0.0
for data in dataloader:
imgs,targets=data
output=model(imgs)
result_loss=loss(output,targets)
# 上一次训练梯度清零
optim.zero_grad()
# 反向传播
result_loss.backward()
# 优化器调优
optim.step()
running_loss=running_loss+result_loss
print(running_loss)
vgg16模型的修改
import torchvision
from torch import nn
vgg16_false=torchvision.models.vgg16(pretrained=False)
vgg16_true=torchvision.models.vgg16(pretrained=False)
print(vgg16_true)
dataset=torchvision.datasets.CIFAR10("./dataset",train=False,download=True,
transform=torchvision.transforms.ToTensor())
vgg16_true.classifier.add_module("add_linear",nn.Linear(1000,10))
print(vgg16_true)
vgg16_false.classifier[6]=nn.Linear(4096,10)
print(vgg16_false)
模型的保存和加载
import torch
import torchvision
from torch import nn
vgg16=torchvision.models.vgg16(pretrained=False)
# 保存方式1
torch.save(vgg16,"vgg16_method1.pth")
# 保存方式2
torch.save(vgg16.state_dict(),"vgg16_method2.pth")
# 陷阱 保存方式1必须需要能找到模型
class Model(nn.Module):
def __init__(self):
super(Model,self).__init__()
self.conv=nn.Conv2d(3,64,kernel_size=3)
def forward(self,x):
x=self.conv(x)
return x
import torch
import torchvision
from torch import nn
# 方式1 -> 加载模型
class Model(nn.Module):
def __init__(self):
super(Model,self).__init__()
self.conv=nn.Conv2d(3,64,kernel_size=3)
def forward(self,x):
x=self.conv(x)
return x
model=torch.load("vgg16_method1.pth")
# 方式2 -> 加载模型
vgg16=torchvision.models.vgg16(pretrained=False)
vgg16.load_state_dict(torch.load("vgg16_method2.pth"))
完整的模型训练套路
import torch
import torchvision
from torch import nn
from torch.nn import Sequential, Conv2d, MaxPool2d, Linear,Flatten
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from model import *
# 准备数据集
train_data=torchvision.datasets.CIFAR10("./dataset",train=True,download=True,
transform=torchvision.transforms.ToTensor())
test_data=torchvision.datasets.CIFAR10("./dataset",train=False,download=True,
transform=torchvision.transforms.ToTensor())
# 长度
train_data_size=len(train_data)
test_data_size=len(test_data)
print(f"训练数据集长度:train_data_size")
print(f"训练数据集长度:test_data_size")
# 利用dataloader加载数据
train_dataloader=DataLoader(train_data,batch_size=64)
test_dataloader=DataLoader(test_data,batch_size=64)
# 搭建神经网络
class Model(nn.Module):
def __init__(self):
super(Model,self).__init__()
self.model1=Sequential(
Conv2d(3,32,5,stride=1,padding=2),
MaxPool2d(2),
Conv2d(32,32,5,stride=1,padding=2),
MaxPool2d(2),
Conv2d(32,64,5,stride=1,padding=2),
MaxPool2d(2),
Flatten(),
Linear(1024,64),
Linear(64,10)
)
def forward(self,x):
x=self.model1(x)
return x
# 创建网络模型
model=Model()
# 损失函数
loss_fn=nn.CrossEntropyLoss()
# 优化器
learning_rate=0.01
optimizer=torch.optim.SGD(model.parameters(),learning_rate)
# 设置训练网络的一些参数
total_train_step=0
total_test_step=0
epoch=10
# 添加tensorboard
writer=SummaryWriter("/logs_train")
for i in range(epoch):
print(f"-----------第i轮训练开始:-----------")
#训练步骤开始
model.train()
for data in train_dataloader:
imgs,targets=data
outputs=model(imgs)
loss=loss_fn(outputs,targets)
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_train_step+=1
if total_train_step%100==0:
print(f"训练次数: total_train_step, Loss: loss.item()")
writer.add_scalar("train_loss",loss.item(),total_train_step)
# 测试步骤开始
model.eval()
total_test_loss=0
total_accuracy=0
with torch.no_grad():
for data in test_dataloader:
imgs,targets=data
outputs=model(imgs)
loss=loss_fn(outputs,targets)
total_test_loss+=loss.item()
# 正确率
accuracy=(outputs.argmax(1)==targets).sum()
total_accuracy+=accuracy
rate=total_accuracy/test_data_size
print(f"整体测试集上的Loss: total_test_loss")
print(f"整体测试集上的正确率:rate")
writer.add_scalar("test_loss",total_test_loss,total_test_step)
writer.add_scalar("test_accuracy",rate,total_test_step )
total_test_step+=1
torch.save(model,f"model_i")
print("模型已保存!")
writer.close()
验证测试集
from PIL import Image
import torch
import torchvision
from torch import nn
from torch.nn import Sequential, Conv2d, MaxPool2d, Linear,Flatten
from torch.utils.data import DataLoader
image_path="./imgs/img_1.png"
image=Image.open(image_path)
image=image.convert("RGB")
transform=torchvision.transforms.Compose([torchvision.transforms.Resize((32,32)),
torchvision.transforms.ToTensor()])
image=transform(image)
# 搭建神经网络
class Model(nn.Module):
def __init__(self):
super(Model,self).__init__()
self.model1=Sequential(
Conv2d(3,32,5,stride=1,padding=2),
MaxPool2d(2),
Conv2d(32,32,5,stride=1,padding=2),
MaxPool2d(2),
Conv2d(32,64,5,stride=1,padding=2),
MaxPool2d(2),
Flatten(),
Linear(1024,64),
Linear(64,10)
)
def forward(self,x):
x=self.model1(x)
return x
model=torch.load("./model_9.pth",map_location=torch.device("cpu"))
image=torch.reshape(image,(1,3,32,32))
model.eval()
with torch.no_grad():
output=model(image)
print(output)
print(output.argmax(1))
以上是关于pytorch复习的主要内容,如果未能解决你的问题,请参考以下文章