如何在 TensorFlow 中选择 2D 张量的某些列?

Posted

技术标签:

【中文标题】如何在 TensorFlow 中选择 2D 张量的某些列?【英文标题】:How do I select certain columns of a 2D tensor in TensorFlow? 【发布时间】:2016-10-06 20:32:06 【问题描述】:

在this issue 中进行广义切片时,实现二维张量(矩阵)的运算收集列的最佳方法是什么?例如,对于张量t

1 2 3 4
5 6 7 8 

和索引 [1,3],我想得到:

2 4
6 8

相当于 numpy t[:, [1,3]]

【问题讨论】:

【参考方案1】:

同时gather 方法有一个axis 参数。

import tensorflow as tf
params = tf.constant([[1,2,3],[4,5,6]])
indices = [0,2]
op = tf.gather(params, indices, axis=1)

产生输出

[[1 3]
 [4 6]]

【讨论】:

【参考方案2】:

有一个名为tf.nn.embedding_lookup(params, ind) 的函数检索params 张量的

为了实现你想要的,我们可以先转置你想要从中选择某些列的张量t。然后查找tf.transpose(t) 的行(t 的列)。选择后,我们将结果转回。

import tensorflow as tf


t = tf.constant([[1, 2, 3], 
                 [4, 5, 6]])
ind = tf.constant([0, 2])

result = tf.transpose(tf.nn.embedding_lookup(tf.transpose(t), ind))

with tf.Session() as sess:
    print(sess.run(result))

【讨论】:

如果你想转置,为什么不直接使用gather呢?我虽然在 TF 中转置很昂贵。【参考方案3】:

到目前为止,我通过展平输入并使用 gather 创建了一个解决方法:

def gather_cols(params, indices, name=None):
    """Gather columns of a 2D tensor.

    Args:
        params: A 2D tensor.
        indices: A 1D tensor. Must be one of the following types: ``int32``, ``int64``.
        name: A name for the operation (optional).

    Returns:
        A 2D Tensor. Has the same type as ``params``.
    """
    with tf.op_scope([params, indices], name, "gather_cols") as scope:
        # Check input
        params = tf.convert_to_tensor(params, name="params")
        indices = tf.convert_to_tensor(indices, name="indices")
        try:
            params.get_shape().assert_has_rank(2)
        except ValueError:
            raise ValueError('\'params\' must be 2D.')
        try:
            indices.get_shape().assert_has_rank(1)
        except ValueError:
            raise ValueError('\'indices\' must be 1D.')

        # Define op
        p_shape = tf.shape(params)
        p_flat = tf.reshape(params, [-1])
        i_flat = tf.reshape(tf.reshape(tf.range(0, p_shape[0]) * p_shape[1],
                                       [-1, 1]) + indices, [-1])
        return tf.reshape(tf.gather(p_flat, i_flat),
                          [p_shape[0], -1])

为:

params = tf.constant([[1, 2, 3],
                      [4, 5, 6]])
indices = [0, 2]
op = gather_cols(params, indices)

产生预期的输出:

[[1 3]
 [4 6]]

【讨论】:

以上是关于如何在 TensorFlow 中选择 2D 张量的某些列?的主要内容,如果未能解决你的问题,请参考以下文章

在 Tensorflow.js 中获取张量中项目的值

如何使用 TensorFlow 连接两个具有不同形状的张量?

3d次2d矩阵Tensorflow

TensorFlow.js:那两个张量相等吗?

深度学习TensorFlow面试题:什么是TensorFlow?你对张量了解多少?TensorFlow有什么优势?TensorFlow比PyTorch有什么不同?该如何选择?

Tensorflow - 为图像张量中的每个像素查找最大3个相邻像素