pytorch使用总结
Posted ayanwan
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了pytorch使用总结相关的知识,希望对你有一定的参考价值。
torch.Tensor - 一个多维数组
autograd.Variable - 改变Tensor并且记录下来操作的历史记录。和Tensor拥有相同的API,以及backward()的一些API。同时包含着和张量相关的梯度。
nn.Module - 神经网络模块。便捷的数据封装,能够将运算移往GPU,还包括一些输入输出的东西。
nn.Parameter - 一种变量,当将任何值赋予Module时自动注册为一个参数。
autograd.Function - 实现了使用自动求导方法的前馈和后馈的定义。每个Variable的操作都会生成至少一个独立的Function节点,与生成了Variable的函数相连之后记录下操作历史。
1、Tensors与numpy之间转换
# 此处演示tensor和numpy数据结构的相互转换
a = torch.ones(5)
b = a.numpy()
# 此处演示当修改numpy数组之后,与之相关联的tensor也会相应的被修改
a.add_(1)
print(a)
print(b)
# 将numpy的Array转换为torch的Tensor
import numpy as np
a = np.ones(5)
b = torch.from_numpy(a)
np.add(a, 1, out=a)
print(a)
print(b)
2、autograd.Variable 这是这个包中最核心的类。 它包装了一个Tensor,并且几乎支持所有的定义在其上的操作。一旦完成了你的运算,你可以调用 .backward()来自动计算出所有的梯度。
可以通过属性 .data 来访问原始的tensor,而关于这一Variable的梯度则集中于 .grad 属性中。
3、torch.nn 只接受小批量的数据
整个torch.nn包只接受那种小批量样本的数据,而非单个样本。 例如,nn.Conv2d能够结构一个四维的TensornSamples x nChannels x Height x Width。如果你拿的是单个样本,使用input.unsqueeze(0)来加一个假维度就可以了。
4、数据读入
通常来讲,当你处理图像,声音,文本,视频时需要使用python中其他独立的包来将他们转换为numpy中的数组,之后再转换为torch.*Tensor。
(1)图像的话,可以用Pillow, OpenCV。
(2)声音处理可以用scipy和librosa。
(3)文本的处理使用原生Python或者Cython以及NLTK和SpaCy都可以。
特别的对于图像,我们有torchvision这个包可用,其中包含了一些现成的数据集如:Imagenet, CIFAR10, MNIST等等。同时还有一些转换图像用的工具。 这非常的方便并且避免了写样板代码。
5、模型的保存与加载
torch.save()实现对网络结构和模型参数的保存。有两种保存方式:一是保存年整个神经网络的的结构信息和模型参数信息,save的对象是网络net;二是只保存神经网络的训练模型参数,save的对象是net.state_dict()。
torch.save(net1, '7-net.pth') # 保存整个神经网络的结构和模型参数
torch.save(net1.state_dict(), '7-net_params.pth') # 只保存神经网络的模型参数
对应上面两种保存方式,重载方式也有两种。对应第一种完整网络结构信息,重载的时候通过torch.load(‘.pth’)直接初始化新的神经网络对象即可。对应第二种只保存模型参数信息,需要首先导入对应的网络,通过net.load_state_dict(torch.load('.pth'))完成模型参数的重载。在网络比较大的时候,第一种方法会花费较多的时间。
# 保存和加载整个模型
torch.save(model_object, 'model.pkl')
model = torch.load('model.pkl')
# 仅保存和加载模型参数(推荐使用)
torch.save(model_object.state_dict(), 'params.pkl')
model_object.load_state_dict(torch.load('params.pkl'))
以上是关于pytorch使用总结的主要内容,如果未能解决你的问题,请参考以下文章
开发 | PyTorch vs. TensorFlow月度使用体验总结
pytorch使用CIFAR10数据集进行图片分类 + pytorch基本入门知识总结
pytorch使用CIFAR10数据集进行图片分类 + pytorch基本入门知识总结