如何在 TensorFlow 中对批次进行切片并在每个切片上应用操作

Posted

技术标签:

【中文标题】如何在 TensorFlow 中对批次进行切片并在每个切片上应用操作【英文标题】:How to slice a batch and apply an operation on each slice in TensorFlow 【发布时间】:2016-06-05 04:45:16 【问题描述】:

我是 TensorFlow 的初学者,我正在尝试实现一个将批处理作为输入的函数。它必须将此批次切成几份,对它们应用一些操作,然后将它们连接起来以构建一个新的张量以返回。通过阅读,我发现有一些已实现的功能,例如 input_slice_producer 和 batch_join,但我没有使用它们。我在下面附上了我发现的解决方案,但它有点慢,不合适并且无法检测当前的批次大小。有没有人知道更好的方法?

def model(x):

    W_1 = tf.Variable(tf.random_normal([6,1]),name="W_1")
    x_size = x.get_shape().as_list()[0]
    # x is a batch of bigger input of shape [None,6], so I couldn't 
    # get the proper size of the batch when feeding it 
    if x_size == None:
        x_size= batch_size
    #intialize the y_res
    dummy_x = tf.slice(x,[0,0],[1,6])
    result = tf.reduce_sum(tf.mul(dummy_x,W_1))
    y_res = tf.zeros([1], tf.float32)
    y_res = result
    #go throw all slices and concatenate them to get result
    for i in range(1,x_size): 
        dummy_x = tf.slice(x,[i,0],[1,6])
        result = tf.reduce_sum(tf.mul(dummy_x,W_1))
        y_res = tf.concat(0, [y_res, result])

    return y_res

【问题讨论】:

您可以将所有 y_res 累积在一个列表中并一次性创建最终张量,否则您将一遍又一遍地复制数据 谢谢,但我认为将 y_res 设为列表不会有太大变化,因为我将在每次迭代中进行“追加”。这与使用基本上与“追加”做同样事情的函数“concat”没有太大区别。我的挑战是:自动获取批处理的当前大小并在函数内部对计算进行多线程处理 【参考方案1】:

TensorFlow 函数 tf.map_fn(fn, elems) 允许您将函数 (fn) 应用于张量 (elems) 的每个切片。例如,您可以如下表达您的程序:

def model(x):
    W_1 = tf.Variable(tf.random_normal([6, 1]), name="W_1")

    def fn(x_slice):
        return tf.reduce_sum(x_slice, W_1)

    return tf.map_fn(fn, x)

还可以通过使用NumPy broadcasting semantics 的tf.mul() 运算符上的广播和tf.reduce_sum()axis 参数更简洁地实现您的操作。

【讨论】:

以上是关于如何在 TensorFlow 中对批次进行切片并在每个切片上应用操作的主要内容,如果未能解决你的问题,请参考以下文章

如何在 Tensorflow 中对无维度的张量进行切片

如何在 TensorFlow 中处理具有可变长度序列的批次?

如何在 Javascript 中对对象进行切片?

如何在 Bash 中对数组进行切片

如何在 Swift 中对字符串进行切片? [复制]

如何在 Excel VBA 中对数组进行切片?