2.3 Tensor类型
Posted 王小小小草
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了2.3 Tensor类型相关的知识,希望对你有一定的参考价值。
欢迎订阅本专栏:《PyTorch深度学习实践》
订阅地址:https://blog.csdn.net/sinat_33761963/category_9720080.html
- 第二章:认识Tensor的类型、创建、存储、api等,打好Tensor的基础,是进行PyTorch深度学习实践的重中之重的基础。
- 第三章:学习PyTorch如何读入各种外部数据
- 第四章:利用PyTorch从头到尾创建、训练、评估一个模型,理解与熟悉PyTorch实现模型的每个步骤,用到的模块与方法。
- 第五章:学习如何利用PyTorch提供的3种方法去创建各种模型结构。
- 第六章:利用PyTorch实现简单与经典的模型全过程:简单二分类、手写字体识别、词向量的实现、自编码器实现。
- 第七章利用PyTorch实现复杂模型:翻译机(nlp领域)、生成对抗网络(GAN)、强化学习(RL)、风格迁移(cv领域)。
- 第八章:PyTorch的其他高级用法:模型在不同框架之间的迁移、可视化、多个GPU并行计算。
类型列表
知道了创建Tensor的各种方法,现在来看看Tensor有什么数据类型,下表是官网中给出的信息,在CPU和GPU上各有9种类型。这些类型是特地和NumPy的参数名称一致的,以方便大家认知。
在tensor的类型,我们常常会用到以下这些操作:
(1)创建Tensor时用参数指明数据类型
import torch
double_points = torch.ones((10, 2), dtype=torch.double)
short_points = torch.tensor([[1,2],[3,4]], dtype=torch.short)
(2)获取tensor的数据类型
short_points.dtype
torch.int16
(3)转换tensor的数据类型
# (1)直接在tensor后面接.dtype()进行转换
double_points = torch.zeros(10,2).double()
# (2)使用to进行转换
double_points = torch.zeros(10,2).to(torch.double)
# (3)使用type()进行转换
double_points = torch.zeros(10,2).type(torch.short)
(4)设置/获取默认Tensor类型
# 指定
torch.set_default_tensor_type(torch.double)
# 获取
torch.get_default_tensor_type()
以上是关于2.3 Tensor类型的主要内容,如果未能解决你的问题,请参考以下文章