PyTorch教程 读写文件 #yyds干货盘点#

Posted LolitaAnn

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了PyTorch教程 读写文件 #yyds干货盘点#相关的知识,希望对你有一定的参考价值。

读写文件有什么必要呢?

读写文件其实不是读取数据集。

是当你的训练时要定期存储中间结果,以确保在服务器电源不小心被断掉,或者出现其他情况的时候,损失掉你前几天的计算结果。

这一节要做的就是如何存储权重向量和整个模型。

import torch
from torch import nn
from torch.nn import functional as F

loadsave

对于单个张量,我们可以直接调用loadsave函数分别读写它们。

  • torch.saves

    torch.save(obj, f, pickle_module=<module pickle from .../pickle.py>, pickle_protocol=2)

    参数:

    • obj – 保存对象
    • f - 字符串,文件名
    • pickle_module – 用于pickling元数据和对象的模块
    • pickle_protocol – 指定pickle protocal 可以覆盖默认参数
  • torch.load

    torch.load(f, map_location=None, pickle_module=<module pickle from .../pickle.py>)

    从磁盘文件中读取一个通过torch.save()保存的对象。

    参数:

    • f – 字符串,文件名
    • map_location – 一个函数或字典规定如何remap存储位置
    • pickle_module – 用于unpickling元数据和对象的模块 (必须匹配序列化文件时的pickle_module )
x = torch.arange(4)
torch.save(x, x-file)
x2 = torch.load(x-file)
print(x2)

初始化一个x

将x存储到当前文件夹下并命名为x-file,此时你会发现当前文件夹下边多出来一个同名的文件。
当然打开之后可能不是 0 1 2 3 ,因为编码方式不同,所以不用纠结打开以后看到的是什么。

声明一个x2再从文件中读回来,会发现结果就是tensor([0, 1, 2, 3]),结果没错就可以了。

y = torch.zeros(4)
torch.save([x, y],x-file)
x2, y2 = torch.load(x-file)
print(x2, y2)
>>
(tensor([0, 1, 2, 3]), tensor([0., 0., 0., 0.]))

存储一个张量列表,然后把它们读回内存。

y = torch.zeros(4)
torch.save(y[:2],x-file)
x2, y2 = torch.load(x-file)
(x2, y2)
print(x2, y2)

切片也是可以的。

mydict = {x: x, y: y}
torch.save(mydict, x-file)
mydict2 = torch.load(x-file)
mydict2
>>
{x: tensor([0, 1, 2, 3]), y: tensor([0., 0., 0., 0.])}

存储字典也可以。

加载和保存模型参数

深度学习框架提供了内置函数来保存和加载整个网络。

但是是保存模型的参数而不是保存整个模型。

class MLP(nn.Module):
    def __init__(self):
        super().__init__()
        self.hidden = nn.Linear(20, 256)
        self.output = nn.Linear(256, 10)

    def forward(self, x):
        return self.output(F.relu(self.hidden(x)))

还记得我们手写的多层感知机吗,用这个实现一下子。

net = MLP()
X = torch.randn(size=(2, 20))
Y = net(X)

现在生成一个net,用它计算X,并将其赋值给Y。

torch.save(net.state_dict(), x-file)

将net的参数保存起来。

net_ = MLP()
net_.load_state_dict(torch.load(x-file))
net_.eval()

生成一个net_也是多层感知机,net的参数直接加载文件中的参数。

net_.eval()是将模型的模式改为评价模式。

Y_clone = net_(X)
print(Y_clone == Y)
>>
tensor([[True, True, True, True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True, True, True, True]])

tensor([[True, True, True, True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True, True, True, True]])

将新网络赋值给Y_clone,可以看到Y_clone和Y是相同的。

当然换成pytorch自己的层也是可以的


MLP = nn.Sequential(nn.Linear(20,256),nn.Linear(256,10),nn.ReLU())

def init(m):
    if type(m) == nn.Linear:
        nn.init.normal_(m.weight)
        nn.init.zeros_(m.bias)

net = MLP
X = torch.randn(size=(2, 20))
Y = net(X)

torch.save(net.state_dict(), x-file)

net_ = MLP
net_.load_state_dict(torch.load(x-file))
net_.eval()

Y_clone = net_(X)
Y_clone == Y
>>
tensor([[True, True, True, True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True, True, True, True]])

以上是关于PyTorch教程 读写文件 #yyds干货盘点#的主要内容,如果未能解决你的问题,请参考以下文章

#yyds干货盘点#动力节点王鹤Springboot教程笔记ORM操作MySQL

#yyds干货盘点#数据分析从零开始实战,Pandas读写CSV数据

# yyds干货盘点 # Pandas入门教程

#yyds干货盘点#SpringBoot + MyBatis + MySQL 实现读写分离!

#yyds干货盘点# 滴滴二面:Kafka是如何读写副本消息的?

#yyds干货盘点#windows server 2012 R2磁盘数据组织