AI基础之训练网络
Posted wangyueyyy
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了AI基础之训练网络相关的知识,希望对你有一定的参考价值。
在以前的分享中,我们已经构建好了全连接网络和卷积神经网络,接下来就是训练网络了。不多说了,直接上代码。
from MyNet import CNNet
import torch.nn as nn
import torch
from torchvision.datasets import CIFAR10
from torch.utils.data import DataLoader
import torchvision.transforms as trans
import matplotlib.pyplot as plt
class Trainer:
def __init__(self):
self.net = CNNet().cuda() # 信息论 概率论 # GPU CPU 矩阵运算的方式效率最高 GPU
self.loss_func = nn.MSELoss()
self.optimier = torch.optim.Adam(self.net.parameters())
self.dataset = self.get_dataset()
"""
AI forward:
z = w.T*x + b
A = Relu(z)
loss = 1/m*∑(A - Y)^2
"""
def get_dataset(self):
transform = trans.Compose([
trans.ToTensor(),
trans.Normalize((0.4914, 0.4822, 0.4465),(0.2023, 0.1994, 0.2010))
])
trainData = CIFAR10(root="datasets/", train=True,download=False,transform=transform)
testData = CIFAR10(root="datasets/", train=False,download=False,transform=transform)
return trainData,testData
def load_data(self, trainData,testData):
trainloader = DataLoader(dataset=trainData, batch_size=500, shuffle=True)
testloader = DataLoader(dataset=testData, batch_size=500, shuffle=True)
return trainloader, testloader
def train(self):
trainloader,testloader = self.load_data(self.dataset[0],self.dataset[1])
# 训练数据集
for i in range(50): #70%
print("epochs:{}".format(i))
for index, (input,target) in enumerate(trainloader):
input = input.cuda()
target = target.cuda()
output = self.net(input)
target = torch.nn.functional.one_hot(target)
loss = self.loss_func(output, target.float())
losses = []
if index % 10 == 0:
print("{}/{},loss:{}".format(index, len(trainloader), loss.float()))
self.optimier.zero_grad()
loss.backward() # 更新梯度
self.optimier.step()
# 验证数据集
count = 0
for input, target in testloader:
input = input.cuda()
target = target.cuda()
output = self.net(input)
predict = torch.argmax(output,dim=1)
count += (predict == target).sum()
v = count.item() / self.dataset[1].data.shape[0]
print("精度:{}".format( str(v * 100) + "%"))
torch.save(self.net, "models/net.pth")
if __name__ == ‘__main__‘:
t = Trainer()
# print( t.dataset[1].data.shape)
t.train()
# print(torch.cuda.is_available())
# trainData, testData = t.get_dataset()
# data = trainData+testData
# loader = DataLoader(dataset=data, batch_size=60000,shuffle=True)
# data = next(iter(loader))
# input = data[0]
# # shape 60000,3,32,32
# mean = torch.mean(input, dim=(0,2,3))
# std = torch.std(input, dim=(0,2,3))
# print(std)
# target = data[1]
# print(target.size())
以上是关于AI基础之训练网络的主要内容,如果未能解决你的问题,请参考以下文章