pytorch使用总结

Posted ayanwan

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了pytorch使用总结相关的知识,希望对你有一定的参考价值。

torch.Tensor - 一个多维数组

autograd.Variable - 改变Tensor并且记录下来操作的历史记录。和Tensor拥有相同的API,以及backward()的一些API。同时包含着和张量相关的梯度。

nn.Module - 神经网络模块。便捷的数据封装,能够将运算移往GPU,还包括一些输入输出的东西。

nn.Parameter - 一种变量,当将任何值赋予Module时自动注册为一个参数。

autograd.Function - 实现了使用自动求导方法的前馈和后馈的定义。每个Variable的操作都会生成至少一个独立的Function节点,与生成了Variable的函数相连之后记录下操作历史。


1、Tensors与numpy之间转换

# 此处演示tensor和numpy数据结构的相互转换
a = torch.ones(5)
b = a.numpy()

# 此处演示当修改numpy数组之后,与之相关联的tensor也会相应的被修改
a.add_(1)
print(a)
print(b)

# 将numpy的Array转换为torch的Tensor
import numpy as np
a = np.ones(5)
b = torch.from_numpy(a)
np.add(a, 1, out=a)
print(a)
print(b)

2、autograd.Variable 这是这个包中最核心的类。 它包装了一个Tensor,并且几乎支持所有的定义在其上的操作。一旦完成了你的运算,你可以调用 .backward()来自动计算出所有的梯度。

可以通过属性 .data 来访问原始的tensor,而关于这一Variable的梯度则集中于 .grad 属性中。


3、torch.nn 只接受小批量的数据

        整个torch.nn包只接受那种小批量样本的数据,而非单个样本。 例如,nn.Conv2d能够结构一个四维的TensornSamples x nChannels x Height x Width。如果你拿的是单个样本,使用input.unsqueeze(0)来加一个假维度就可以了。


4、数据读入

通常来讲,当你处理图像,声音,文本,视频时需要使用python中其他独立的包来将他们转换为numpy中的数组,之后再转换为torch.*Tensor。

(1)图像的话,可以用Pillow, OpenCV。

(2)声音处理可以用scipy和librosa。

(3)文本的处理使用原生Python或者Cython以及NLTK和SpaCy都可以。

特别的对于图像,我们有torchvision这个包可用,其中包含了一些现成的数据集如:Imagenet, CIFAR10, MNIST等等。同时还有一些转换图像用的工具。 这非常的方便并且避免了写样板代码。


5、模型的保存与加载

        torch.save()实现对网络结构和模型参数的保存。有两种保存方式:一是保存年整个神经网络的的结构信息和模型参数信息,save的对象是网络net;二是只保存神经网络的训练模型参数,save的对象是net.state_dict()。

    torch.save(net1, '7-net.pth')                     # 保存整个神经网络的结构和模型参数  
    torch.save(net1.state_dict(), '7-net_params.pth') # 只保存神经网络的模型参数  
        对应上面两种保存方式,重载方式也有两种。对应第一种完整网络结构信息,重载的时候通过torch.load(‘.pth’)直接初始化新的神经网络对象即可。对应第二种只保存模型参数信息,需要首先导入对应的网络,通过net.load_state_dict(torch.load('.pth'))完成模型参数的重载。在网络比较大的时候,第一种方法会花费较多的时间。

# 保存和加载整个模型
torch.save(model_object, 'model.pkl')
model = torch.load('model.pkl')
# 仅保存和加载模型参数(推荐使用)
torch.save(model_object.state_dict(), 'params.pkl')
model_object.load_state_dict(torch.load('params.pkl'))




以上是关于pytorch使用总结的主要内容,如果未能解决你的问题,请参考以下文章

开发 | PyTorch vs. TensorFlow月度使用体验总结

pytorch使用CIFAR10数据集进行图片分类 + pytorch基本入门知识总结

pytorch使用CIFAR10数据集进行图片分类 + pytorch基本入门知识总结

pytorch使用CIFAR10数据集进行图片分类 + pytorch基本入门知识总结

我是土堆 - Pytorch教程 知识点 学习总结笔记

Pytorch自动混合精度(AMP)的使用总结