paddlepaddle模型训练和预测之基础API
Posted 修炼之路
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了paddlepaddle模型训练和预测之基础API相关的知识,希望对你有一定的参考价值。
导读
在上篇文章paddlepaddle模型训练和预测之高级API我们已经介绍过了如何通过高级API来搭建网络模型进行模型的训练和预测,这篇文章我们主要来介绍如何通过paddlepaddle的基础API来实现模型的训练,其实高级的API接口也是基于基础的API接口进行封装的。
模型训练
这里我们将上篇文章中的高级API拆分成基础的API进行介绍
- 模型构建
在高级API中,我们是采用Sequential
来构建网络模型的,它能够直接将卷积
、激活函数
、池化
等操作直接进行组合。当,我们需要设计比较复杂的结构时,就需要用到Layer
来实现了
import paddle
from paddle import nn,vision
import numpy as np
class LeNet(nn.Layer):
def __init__(self,num_classes=10):
super(LeNet, self).__init__()
self.num_classes = num_classes
self.features = nn.Sequential(
nn.Conv2D(1,6,3,stride=1,padding=1),
nn.ReLU(),
nn.MaxPool2D(2,2),
nn.Conv2D(6,16,5,stride=1,padding=0),
nn.ReLU(),
nn.MaxPool2D(2,2)
)
if self.num_classes > 0:
self.fc = nn.Sequential(
nn.Linear(400,120),
nn.Linear(120,84),
nn.Linear(84,self.num_classes)
)
def forward(self,inputs):
x = self.features(inputs)
if self.num_classes > 0:
x = paddle.flatten(x,1)
x = self.fc(x)
return x
- 超参设置
#模型参数设置
epochs = 10
batch_size = 64
learning_rate = 0.001
num_classes = 10
#设置输出日志的频率
freq_batch = 900
- 数据加载
在高级API中,我们是直接使用datasets
将它添加到fit
或evaluate
接口中。而,在这里我们还需要一个DataLoader
在模型进行训练和评估的时候,能够通过它从里面取数据。
#输入图片的预处理
transform = vision.Normalize(mean=[127.5],std=[127.5],data_format="CHW")
#加载数据集
train_datasets = vision.datasets.FashionMNIST(mode="train",transform=transform)
test_datasets = vision.datasets.FashionMNIST(mode="test",transform=transform)
#数据加载器
train_dataloader = paddle.io.DataLoader(train_datasets,batch_size=batch_size,shuffle=True)
test_dataloader = paddle.io.DataLoader(test_datasets,batch_size=batch_size,shuffle=False)
- 模型训练和评估
在模型训练一个epoch之后,我们会在测试集上评估模型的准确率,并且只保存best
模型
#构建模型
model = LeNet(num_classes)
#设置梯度下降的优化算法
optim = paddle.optimizer.Adam(learning_rate=learning_rate,parameters=model.parameters())
#设置损失函数
loss_fn = paddle.nn.CrossEntropyLoss()
#用来保存最好的准确率的模型
best_acc = 0
#训练模型
for epoch in range(epochs):
# 将模型切换到训练模式
model.train()
#使用训练集进行模型训练
for batch_id,batch_data in enumerate(train_dataloader):
batch_x,batch_y = batch_data
#获取模型的预测结果
batch_pred_y = model(batch_x)
#计算损失
loss = loss_fn(batch_pred_y,batch_y)
#计算准确率
batch_acc = paddle.metric.accuracy(batch_pred_y,batch_y)
#反向传播
loss.backward()
#更新网络参数
optim.step()
optim.clear_grad()
#打印日志
if batch_id % freq_batch == 0:
print("epoch:,batch id:,loss::.2f,acc::.4f".format(epoch,batch_id+1,
loss.numpy()[0],batch_acc.numpy()[0]))
#将模型切换到预测模型
model.eval()
#用来保存loss和准确率
loss_list = []
acc_list = []
for batch_id,batch_data in enumerate(test_dataloader):
batch_x,batch_y = batch_data
pred_batch_y = model(batch_x)
#计算损失值
batch_loss = loss_fn(pred_batch_y,batch_y)
#计算准确率
batch_acc = paddle.metric.accuracy(pred_batch_y,batch_y)
loss_list.append(batch_loss.numpy()[0])
acc_list.append(batch_acc.numpy()[0])
val_loss = np.mean(loss_list)
val_acc = np.mean(acc_list)
print("validate data,loss::.2f,acc::.4f".format(val_loss,val_acc))
if val_acc > best_acc:
best_acc = val_acc
#保存模型参数
paddle.save(model.state_dict(),"save/best.pdparams")
#保存模型训练中设置的参数
paddle.save(optim.state_dict(),"save/best.pdopt")
epoch:0,batch id:1,loss:2.85,acc:0.0156
epoch:0,batch id:901,loss:0.28,acc:0.9375
validate data,loss:0.40,acc:0.8550
epoch:1,batch id:1,loss:0.30,acc:0.8438
epoch:1,batch id:901,loss:0.40,acc:0.8594
validate data,loss:0.36,acc:0.8690
epoch:2,batch id:1,loss:0.41,acc:0.8125
epoch:2,batch id:901,loss:0.34,acc:0.8438
validate data,loss:0.35,acc:0.8761
epoch:3,batch id:1,loss:0.47,acc:0.7969
epoch:3,batch id:901,loss:0.19,acc:0.9531
validate data,loss:0.33,acc:0.8825
epoch:4,batch id:1,loss:0.18,acc:0.9688
epoch:4,batch id:901,loss:0.17,acc:0.9531
validate data,loss:0.32,acc:0.8835
epoch:5,batch id:1,loss:0.41,acc:0.8750
epoch:5,batch id:901,loss:0.19,acc:0.9375
validate data,loss:0.31,acc:0.8885
epoch:6,batch id:1,loss:0.20,acc:0.9375
epoch:6,batch id:901,loss:0.24,acc:0.9062
validate data,loss:0.32,acc:0.8849
epoch:7,batch id:1,loss:0.23,acc:0.8594
epoch:7,batch id:901,loss:0.25,acc:0.8750
validate data,loss:0.32,acc:0.8882
epoch:8,batch id:1,loss:0.23,acc:0.8906
epoch:8,batch id:901,loss:0.35,acc:0.8906
validate data,loss:0.32,acc:0.8886
epoch:9,batch id:1,loss:0.17,acc:0.9375
epoch:9,batch id:901,loss:0.24,acc:0.8750
validate data,loss:0.31,acc:0.8883
- 模型加载和预测
from matplotlib import pyplot as plt
#加载模型参数
model_params = paddle.load("save/best.pdparams")
#加载模型参数
model.set_state_dict(model_params)
#显示预测的结果
plt.figure(figsize=(8,8))
#设置行和列的数量
row_num = 4
col_num = 4
for i in range(16):
img,true_label = test_datasets[i]
img_tensor = paddle.to_tensor([img])
output = model(img_tensor)
pred_label = output.argmax().numpy()
plt.subplot(row_num,col_num,i+1)
img = img[0]
plt.imshow(img)
plt.title("ture:\\n predict:".format(true_label[0],pred_label[0]))
#隐藏x轴和y轴
plt.xticks([])
plt.yticks([])
plt.show()
以上是关于paddlepaddle模型训练和预测之基础API的主要内容,如果未能解决你的问题,请参考以下文章