关于 Theano 中的 flatten 函数的说明

Posted

技术标签:

【中文标题】关于 Theano 中的 flatten 函数的说明【英文标题】:Clarification about flatten function in Theano 【发布时间】:2016-04-15 01:52:33 【问题描述】:

在 [http://deeplearning.net/tutorial/lenet.html#lenet] 它说:

This will generate a matrix of shape (batch_size, nkerns[1] * 4 * 4),
# or (500, 50 * 4 * 4) = (500, 800) with the default values.
layer2_input = layer1.output.flatten(2)

当我在 numpy 3d 数组上使用 flatten 函数时,我得到一个 1D 数组。但在这里它说我得到了一个矩阵。在 theano 中 flatten(2) 是如何工作的?

numpy 上的一个类似示例生成一维数组:

     a= array([[[ 1,  2,  3],
    [ 4,  5,  6],
    [ 7,  8,  9]],

   [[10, 11, 12],
    [13, 14, 15],
    [16, 17, 18]],

   [[19, 20, 21],
    [22, 23, 24],
    [25, 26, 27]]])

   a.flatten(2)=array([ 1, 10, 19,  4, 13, 22,  7, 16, 25,  2, 11, 20,  5, 14, 23,  8, 17,
   26,  3, 12, 21,  6, 15, 24,  9, 18, 27])

【问题讨论】:

【参考方案1】:

numpy 不支持仅对某些维度进行展平,但 Theano 支持。

所以如果a 是一个numpy 数组,a.flatten(2) 没有任何意义。它运行没有错误,但只是因为 2 作为 order 参数传递,这似乎导致 numpy 坚持使用 C 的默认顺序。

Theano 的flatten 确实支持轴规范。 The documentation 解释了它是如何工作的。

Parameters:
    x (any TensorVariable (or compatible)) – variable to be flattened
    outdim (int) – the number of dimensions in the returned variable

Return type:
    variable with same dtype as x and outdim dimensions

Returns:
    variable with the same shape as x in the leading outdim-1 dimensions,
    but with all remaining dimensions of x collapsed into the last dimension.

例如,如果我们将形状为 (2, 3, 4, 5) 的张量用 flatten(x, outdim=2),那么我们将有相同的 (2-1=1) 前导 尺寸 (2,),其余尺寸已折叠。所以 此示例中的输出将具有形状 (2, 60)。

一个简单的 Theano 演示:

import numpy
import theano
import theano.tensor as tt


def compile():
    x = tt.tensor3()
    return theano.function([x], x.flatten(2))


def main():
    a = numpy.arange(2 * 3 * 4).reshape((2, 3, 4))
    f = compile()
    print a.shape, f(a).shape


main()

打印

(2L, 3L, 4L) (2L, 12L)

【讨论】:

代码出错了。我的你的代码版本(对于像我这样的新手):import numpy as np import theano import theano.tensor as tt from numpy import dtype x = tt.tensor4() f=theano.function([x], x.flatten(2)) a = np.asarray(np.arange(2 * 3 * 4*5).reshape((2, 3, 4, 5)),dtype=theano.config.floatX) print a print f(a) 错误是什么?如果您已经搜索过但无法找到现有答案,最好开始一个新问题。

以上是关于关于 Theano 中的 flatten 函数的说明的主要内容,如果未能解决你的问题,请参考以下文章

如果我们使用索引矩阵,我们是不是需要在 Theano 中使用 flatten 和 reshape ?

BigQuery 中的表函数和 FLATTEN

theano 中的扫描函数,循环神经网络

Theano入门笔记2:scan函数等

Theano中的导数

Groovy集合遍历 ( 集合中有集合元素时调用 flatten 函数拉平集合元素 | 代码示例 )