张量流:tf.argmax 和切片

Posted

技术标签:

【中文标题】张量流:tf.argmax 和切片【英文标题】:Tensorflow : tf.argmax and slicing 【发布时间】:2018-01-31 20:07:17 【问题描述】:

我想设计这个损失函数:

sum((y[argmax(y_)] - y_[argmax(y_)])²)

我找不到y[argmax(y_)] 的方法。我试过y[k]y[:,k]y[None,k] 这些都不起作用。这是我的代码:

    Na = 3
    x = tf.placeholder(tf.float32, [None, 2])
    W = tf.Variable(tf.zeros([2, Na]))
    b = tf.Variable(tf.zeros([Na]))
    y = tf.nn.relu(tf.matmul(x, W) + b)
    y_ = tf.placeholder(tf.float32, [None, 3])
    k = tf.argmax(y_, 1)
    diff = y[k] - y_[k]
    loss = tf.reduce_sum(tf.square(diff))

还有错误:

  File "/home/ncarrara/phd/code/cython/robotnavigation/ftq/cftq19.py", line 156, in <module>
    diff = y[k] - y_[k]
  File "/home/ncarrara/miniconda3/lib/python2.7/site-packages/tensorflow/python/ops/array_ops.py", line 499, in _SliceHelper
    name=name)
  File "/home/ncarrara/miniconda3/lib/python2.7/site-packages/tensorflow/python/ops/array_ops.py", line 663, in strided_slice
    shrink_axis_mask=shrink_axis_mask)
  File "/home/ncarrara/miniconda3/lib/python2.7/site-packages/tensorflow/python/ops/gen_array_ops.py", line 3515, in strided_slice
    shrink_axis_mask=shrink_axis_mask, name=name)
  File "/home/ncarrara/miniconda3/lib/python2.7/site-packages/tensorflow/python/framework/op_def_library.py", line 767, in apply_op
    op_def=op_def)
  File "/home/ncarrara/miniconda3/lib/python2.7/site-packages/tensorflow/python/framework/ops.py", line 2508, in create_op
    set_shapes_for_outputs(ret)
  File "/home/ncarrara/miniconda3/lib/python2.7/site-packages/tensorflow/python/framework/ops.py", line 1873, in set_shapes_for_outputs
    shapes = shape_func(op)
  File "/home/ncarrara/miniconda3/lib/python2.7/site-packages/tensorflow/python/framework/ops.py", line 1823, in call_with_requiring
    return call_cpp_shape_fn(op, require_shape_fn=True)
  File "/home/ncarrara/miniconda3/lib/python2.7/site-packages/tensorflow/python/framework/common_shapes.py", line 610, in call_cpp_shape_fn
    debug_python_shape_fn, require_shape_fn)
  File "/home/ncarrara/miniconda3/lib/python2.7/site-packages/tensorflow/python/framework/common_shapes.py", line 676, in _call_cpp_shape_fn_impl
    raise ValueError(err.message)
ValueError: Shape must be rank 1 but is rank 2 for 'strided_slice' (op: 'StridedSlice') with input shapes: [?,3], [1,?], [1,?], [1].

【问题讨论】:

【参考方案1】:

这可以使用tf.gather_nd

import tensorflow as tf

Na = 3
x = tf.placeholder(tf.float32, [None, 2])
W = tf.Variable(tf.zeros([2, Na]))
b = tf.Variable(tf.zeros([Na]))
y = tf.nn.relu(tf.matmul(x, W) + b)
y_ = tf.placeholder(tf.float32, [None, 3])
k = tf.argmax(y_, 1)
# Make index tensor with row and column indices
num_examples = tf.cast(tf.shape(x)[0], dtype=k.dtype)
idx = tf.stack([tf.range(num_examples), k], axis=-1)
diff = tf.gather_nd(y, idx) - tf.gather_nd(y_, idx)
loss = tf.reduce_sum(tf.square(diff))

解释:

在这种情况下,tf.gather_nd 的想法是创建一个矩阵(二维张量),其中每一行包含输出中的行和列的索引。例如,如果我有一个矩阵 a 包含:

| 1 2 3 |
| 4 5 6 |
| 7 8 9 |

还有一个矩阵i 包含:

| 1 2 |
| 0 1 |
| 2 2 |
| 1 0 |

那么tf.gather_nd(a, i) 的结果将是包含以下内容的向量(一维张量):

| 6 |
| 2 |
| 9 |
| 4 |

在这种情况下,列索引由tf.argmax 中的k 给出;它会告诉您,对于每一行,哪一列是具有最高值的列。现在您只需要将行索引与其中的每一个一起放置。 k 中的第一个元素是第 0 行中最大值列的索引,下一个元素是第 1 行的索引,依此类推。 num_examples 只是x 中的行数,tf.range(num_examples) 为您提供了一个从 0 到 x 中的行数减 1 的向量(即所有行索引序列)。现在你只需要把它和k 放在一起,这就是tf.stack 所做的,结果idx 就是tf.gather_nd 的参数。

【讨论】:

看起来不错,但现在我不确定是否可以验证您的答案,还是谢谢您! @nicolascarrara 我已经添加了一些解释。

以上是关于张量流:tf.argmax 和切片的主要内容,如果未能解决你的问题,请参考以下文章

由张量流中的索引张量指定的切片二维张量

张量流中张量对象的非连续索引切片(高级索引,如numpy)

根据张量流中给定的序列长度数组对 3D 张量进行切片

TensorFlow:沿轴的张量的最大值

对张量流变量的切片分配

tensorflow数据统计