深度学习入门比赛——街景字符识别
Posted wushupei
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了深度学习入门比赛——街景字符识别相关的知识,希望对你有一定的参考价值。
这是比赛的第四阶段,模型的相关训练与验证
选好模型之后,需要建立训练集与验证集进行模型的效果验证,保证模型的预测结果正确符合,以及不过拟合训练与验证主要有以下几种方法:
交叉验证法
交叉验证法的作用就是尝试利用不同的训练集/测试集划分来对模型做多组不同的训练/测试,来应对单词测试结果过于片面以及训练数据不足的问题。交叉验证的做法就是将数据集粗略地分为比较均等不相交的k份,即然后取其中的一份进行测试,另外的k-1份进行训练,然后求得error的平均值作为最终的评价,具体算法流程西瓜书中的插图如下:
主要代码:
train_loader = torch.utils.data.DataLoader(
train_dataset,
batch_size=10,
shuffle=True,
num_workers=10,
)
val_loader = torch.utils.data.DataLoader(
val_dataset,
batch_size=10,
shuffle=False,
num_workers=10,
)
model = SVHN_Model1()
criterion = nn.CrossEntropyLoss (size_average=False)
optimizer = torch.optim.Adam(model.parameters(), 0.001)
best_loss = 1000.0
for epoch in range(20):
print(‘Epoch: ‘, epoch)
train(train_loader, model, criterion, optimizer, epoch)
val_loss = validate(val_loader, model, criterion)
# 记录下验证集精度
if val_loss < best_loss:
best_loss = val_loss
torch.save(model.state_dict(), ‘./model.pt‘)
参考资料:
https://zhuanlan.zhihu.com/p/35394638
以上是关于深度学习入门比赛——街景字符识别的主要内容,如果未能解决你的问题,请参考以下文章