PyTorch基础(15)-- torch.flatten()方法

Posted 奋斗丶

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了PyTorch基础(15)-- torch.flatten()方法相关的知识,希望对你有一定的参考价值。

前言

最近在复现论文中一个块的时候需要使用到torch.flatten()这个方法,这个方法其实很简单,但其中有一些细节可能需要注意,且有个关键点很容易忘记,故在此记录以备查阅。

方法解析

flatten的中文含义为“扁平化”,具体怎么理解呢?我们可以尝试这么理解,假设你的数据为1维数据,那么这个数据天然就已经扁平化了,如果是2维数据,那么扁平化就是将2维数据变为1维数据,如果是3维数据,那么就要根据你自己所选择的“扁平化程度”来进行操作,假设需要全部扁平化,那么就直接将3维数据变为1维数据,如果只需要部分扁平化,那么有一维的数据不会进行扁平操作,具体看下面的案例分析。

可以看到,torch.flatten()方法有三个参数,分别:

  • input tensor:该方法的输入
  • start_dim:开始flatten的维度
  • end_dim:结束flatten的维度

案例解析

  • 导包
import numpy as np 
import torch
  • 案例1 – 全部扁平化
x = np.arange(27)
x = np.reshape(x, (3,3,3))
x = torch.from_numpy(x)
print('before flatten', x)
x = torch.flatten(x)  # 默认扁平化程度为最高
print('after flatten', x)

  • 案例2 – 部分扁平化
x = np.arange(27)
x = np.reshape(x, (3, 3, 3))
x = torch.from_numpy(x)
print('before flatten', x)
x = torch.flatten(x, start_dim=0, end_dim=1)
print('after flatten', x)

  • 案例3 – 部分扁平化
x = np.arange(27)
x = np.reshape(x, (3, 3, 3))
x = torch.from_numpy(x)
print('before flatten', x)
print(x.shape)
x = torch.flatten(x, start_dim=1, end_dim=2)
print('after flatten', x)

以上是关于PyTorch基础(15)-- torch.flatten()方法的主要内容,如果未能解决你的问题,请参考以下文章

PyTorch基础(15)-- torch.flatten()方法

PyTorch基础(15)-- torch.flatten()方法

PyTorch基础(15)-- torch.flatten()方法

PyTorch基础(15)-- torch.flatten()方法

PyTorch基础教程15循环神经网络RNN(学不会来打我啊)

Pytorch学习笔记3.深度学习基础