如何收集未知的第一(批)维度的张量?
Posted
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了如何收集未知的第一(批)维度的张量?相关的知识,希望对你有一定的参考价值。
我有一个形状(?, 3, 2, 5)
的张量。我想提供成对的索引来从张量的第一维和第二维中进行选择,其形状为(3, 2)
。
如果我提供4对这样的对,我希望得到的形状是(?, 4, 5)
。我认为这就是batch_gather
的用途:在第一个(批量)维度上“广播”收集指数。但这不是它正在做的事情:
import tensorflow as tf
data = tf.placeholder(tf.float32, (None, 3, 2, 5))
indices = tf.constant([
[2, 1],
[2, 0],
[1, 1],
[0, 1]
], tf.int32)
tf.batch_gather(data, indices)
这导致<tf.Tensor 'Reshape_3:0' shape=(4, 2, 2, 5) dtype=float32>
而不是我期待的形状。
如果没有明确索引批次(具有未知大小),我该怎么做我想要的?
我想避免使用transpose
和Python循环,我认为这很有效。这是设置:
import numpy as np
import tensorflow as tf
shape = None, 3, 2, 5
data = tf.placeholder(tf.int32, shape)
idxs_list = [
[2, 1],
[2, 0],
[1, 1],
[0, 1]
]
idxs = tf.constant(idxs_list, tf.int32)
这允许我们收集结果:
batch_size, num_idxs, num_channels = tf.shape(data)[0], tf.shape(idxs)[0], shape[-1]
batch_idxs = tf.math.floordiv(tf.range(0, batch_size * num_idxs), num_idxs)[:, None]
nd_idxs = tf.concat([batch_idxs, tf.tile(idxs, (batch_size, 1))], axis=1)
gathered = tf.reshape(tf.gather_nd(data, nd_idxs), (batch_size, num_idxs, num_channels))
当我们运行批量大小的4
时,我们得到一个形状为(4, 4, 5)
的结果,即(batch_size, num_idxs, num_channels)
。
vals_shape = 4, *shape[1:]
vals = np.arange(int(np.prod(vals_shape))).reshape(vals_shape)
with tf.Session() as sess:
result = gathered.eval(feed_dict={data: vals})
与numpy
索引的关系:
x, y = zip(*idxs_list)
assert np.array_equal(result, vals[:, x, y])
基本上,gather_nd
想要第一维中的批量索引,并且对于每个索引对必须重复一次(即,如果有4个索引对,则为[0, 0, 0, 0, 1, 1, 1, 1, 2, ...]
)。
由于似乎没有tf.repeat
,我使用了range
和floordiv
,然后qazxs使用所需的(x,y)索引(它们本身平铺concat
次)的批次索引。
使用batch_size
,tf.batch_gather
形状的主要尺寸应与tensor
张量形状的前导尺寸相匹配。
indice
你最想要的是使用import tensorflow as tf
data = tf.placeholder(tf.float32, (2, 3, 2, 5))
print(data.shape) // (2, 3, 2, 5)
# shape of indices, [2, 3]
indices = tf.constant([
[1, 1, 1],
[0, 0, 1]
])
print(tf.batch_gather(data, indices).shape) # (2, 3, 2, 5)
# if shape of indice was (2, 3, 1) the output would be 2, 3, 1, 5
如下
tf.gather_nd
以上是关于如何收集未知的第一(批)维度的张量?的主要内容,如果未能解决你的问题,请参考以下文章