PyTorch基础(15)-- torch.flatten()方法
Posted 奋斗丶
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了PyTorch基础(15)-- torch.flatten()方法相关的知识,希望对你有一定的参考价值。
前言
最近在复现论文中一个块的时候需要使用到torch.flatten()这个方法,这个方法其实很简单,但其中有一些细节可能需要注意,且有个关键点很容易忘记,故在此记录以备查阅。
方法解析
flatten的中文含义为“扁平化”,具体怎么理解呢?我们可以尝试这么理解,假设你的数据为1维数据,那么这个数据天然就已经扁平化了,如果是2维数据,那么扁平化就是将2维数据变为1维数据,如果是3维数据,那么就要根据你自己所选择的“扁平化程度”来进行操作,假设需要全部扁平化,那么就直接将3维数据变为1维数据,如果只需要部分扁平化,那么有一维的数据不会进行扁平操作,具体看下面的案例分析。
可以看到,torch.flatten()方法有三个参数,分别:
- input tensor:该方法的输入
- start_dim:开始flatten的维度
- end_dim:结束flatten的维度
案例解析
- 导包
import numpy as np
import torch
- 案例1 – 全部扁平化
x = np.arange(27)
x = np.reshape(x, (3,3,3))
x = torch.from_numpy(x)
print('before flatten', x)
x = torch.flatten(x) # 默认扁平化程度为最高
print('after flatten', x)
- 案例2 – 部分扁平化
x = np.arange(27)
x = np.reshape(x, (3, 3, 3))
x = torch.from_numpy(x)
print('before flatten', x)
x = torch.flatten(x, start_dim=0, end_dim=1)
print('after flatten', x)
- 案例3 – 部分扁平化
x = np.arange(27)
x = np.reshape(x, (3, 3, 3))
x = torch.from_numpy(x)
print('before flatten', x)
print(x.shape)
x = torch.flatten(x, start_dim=1, end_dim=2)
print('after flatten', x)
以上是关于PyTorch基础(15)-- torch.flatten()方法的主要内容,如果未能解决你的问题,请参考以下文章
PyTorch基础(15)-- torch.flatten()方法
PyTorch基础(15)-- torch.flatten()方法
PyTorch基础(15)-- torch.flatten()方法
PyTorch基础(15)-- torch.flatten()方法