PyTorch教程 读写文件 #yyds干货盘点#
Posted LolitaAnn
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了PyTorch教程 读写文件 #yyds干货盘点#相关的知识,希望对你有一定的参考价值。
读写文件有什么必要呢?
读写文件其实不是读取数据集。
是当你的训练时要定期存储中间结果,以确保在服务器电源不小心被断掉,或者出现其他情况的时候,损失掉你前几天的计算结果。
这一节要做的就是如何存储权重向量和整个模型。
import torch
from torch import nn
from torch.nn import functional as F
load
和save
对于单个张量,我们可以直接调用load
和save
函数分别读写它们。
-
torch.saves
torch.save(obj, f, pickle_module=<module pickle from .../pickle.py>, pickle_protocol=2)
参数:
- obj – 保存对象
- f - 字符串,文件名
- pickle_module – 用于pickling元数据和对象的模块
- pickle_protocol – 指定pickle protocal 可以覆盖默认参数
-
torch.load
torch.load(f, map_location=None, pickle_module=<module pickle from .../pickle.py>)
从磁盘文件中读取一个通过
torch.save()
保存的对象。参数:
- f – 字符串,文件名
- map_location – 一个函数或字典规定如何remap存储位置
- pickle_module – 用于unpickling元数据和对象的模块 (必须匹配序列化文件时的pickle_module )
x = torch.arange(4)
torch.save(x, x-file)
x2 = torch.load(x-file)
print(x2)
初始化一个x
将x存储到当前文件夹下并命名为x-file
,此时你会发现当前文件夹下边多出来一个同名的文件。
当然打开之后可能不是 0 1 2 3 ,因为编码方式不同,所以不用纠结打开以后看到的是什么。
声明一个x2再从文件中读回来,会发现结果就是tensor([0, 1, 2, 3]),结果没错就可以了。
y = torch.zeros(4)
torch.save([x, y],x-file)
x2, y2 = torch.load(x-file)
print(x2, y2)
>>
(tensor([0, 1, 2, 3]), tensor([0., 0., 0., 0.]))
存储一个张量列表,然后把它们读回内存。
y = torch.zeros(4)
torch.save(y[:2],x-file)
x2, y2 = torch.load(x-file)
(x2, y2)
print(x2, y2)
切片也是可以的。
mydict = {x: x, y: y}
torch.save(mydict, x-file)
mydict2 = torch.load(x-file)
mydict2
>>
{x: tensor([0, 1, 2, 3]), y: tensor([0., 0., 0., 0.])}
存储字典也可以。
加载和保存模型参数
深度学习框架提供了内置函数来保存和加载整个网络。
但是是保存模型的参数而不是保存整个模型。
class MLP(nn.Module):
def __init__(self):
super().__init__()
self.hidden = nn.Linear(20, 256)
self.output = nn.Linear(256, 10)
def forward(self, x):
return self.output(F.relu(self.hidden(x)))
还记得我们手写的多层感知机吗,用这个实现一下子。
net = MLP()
X = torch.randn(size=(2, 20))
Y = net(X)
现在生成一个net,用它计算X,并将其赋值给Y。
torch.save(net.state_dict(), x-file)
将net的参数保存起来。
net_ = MLP()
net_.load_state_dict(torch.load(x-file))
net_.eval()
生成一个net_也是多层感知机,net的参数直接加载文件中的参数。
net_.eval()
是将模型的模式改为评价模式。
Y_clone = net_(X)
print(Y_clone == Y)
>>
tensor([[True, True, True, True, True, True, True, True, True, True],
[True, True, True, True, True, True, True, True, True, True]])
tensor([[True, True, True, True, True, True, True, True, True, True],
[True, True, True, True, True, True, True, True, True, True]])
将新网络赋值给Y_clone,可以看到Y_clone和Y是相同的。
当然换成pytorch自己的层也是可以的
MLP = nn.Sequential(nn.Linear(20,256),nn.Linear(256,10),nn.ReLU())
def init(m):
if type(m) == nn.Linear:
nn.init.normal_(m.weight)
nn.init.zeros_(m.bias)
net = MLP
X = torch.randn(size=(2, 20))
Y = net(X)
torch.save(net.state_dict(), x-file)
net_ = MLP
net_.load_state_dict(torch.load(x-file))
net_.eval()
Y_clone = net_(X)
Y_clone == Y
>>
tensor([[True, True, True, True, True, True, True, True, True, True],
[True, True, True, True, True, True, True, True, True, True]])
以上是关于PyTorch教程 读写文件 #yyds干货盘点#的主要内容,如果未能解决你的问题,请参考以下文章
#yyds干货盘点#动力节点王鹤Springboot教程笔记ORM操作MySQL
#yyds干货盘点#数据分析从零开始实战,Pandas读写CSV数据
#yyds干货盘点#SpringBoot + MyBatis + MySQL 实现读写分离!