如何以与 np.where 相同的方式使用 Tensorflow.where?

Posted

技术标签:

【中文标题】如何以与 np.where 相同的方式使用 Tensorflow.where?【英文标题】:How To Use Tensorflow.where in the same way as np.where? 【发布时间】:2021-10-13 12:05:28 【问题描述】:

我正在尝试制作一个计算 MSE 的自定义损失函数,但忽略所有真实值低于某个阈值(接近 0)的点。我可以通过以下方式使用 numpy 数组来实现这一点。

import numpy as np

a = np.random.normal(size=(4,4))
b = np.random.normal(size=(4,4))
temp_a = a[np.where(a>0.5)] # Your threshold condition
temp_b = b[np.where(a>0.5)]
mse = mean_squared_error(temp_a, temp_b)

但我不知道如何使用 keras 后端来做到这一点。我的自定义损失函数不起作用,因为 numpy 无法对张量进行操作。

def customMSE(y_true, y_pred):
    '''
    Correct predictions of 0 do not affect performance.
    '''
    y_true_ = y_true[tf.where(y_true>0.1)] # Your threshold condition
    y_pred_ = y_pred[tf.where(y_true>0.1)]
    mse = K.mean(K.square(y_pred_ - y_true_), axis=1)
    return mse

但是当我这样做时,我会返回错误

ValueError: Shape must be rank 1 but is rank 3 for 'node customMSE/strided_slice = StridedSlice[Index=DT_INT64, T=DT_FLOAT, begin_mask=0, ellipsis_mask=0, end_mask=0, new_axis_mask=0, shrink_axis_mask=1](cond_2/Identity_1, customMSE/strided_slice/stack, customMSE/strided_slice/stack_1, customMSE/strided_slice/Cast)' with input shapes: [?,?,?,?], [1,?,4], [1,?,4], [1].```

【问题讨论】:

Loss 函数将在图形模式下执行,numpy 函数在那里不可用。请改用tf.where (import tensorflow as tf)。 哦。在第一次调用 tf.where 时,我返回一个值错误 Shape must be rank 1 but is rank 3。不知道该怎么做。它与y_true[tf.where(y_true>01.)]@Kaveh 有关 我已经用 tf.where 完全替换了 np.where。那么我是否必须重塑输入张量,使用 tf 为 1D? @Kaveh 你想在自定义损失函数中做什么? @Kaveh 我想计算 MSE,但仅适用于真相不是 0 或接近 0 的预测。我想忽略这些。 【参考方案1】:

您可以在损失函数中使用tf.where 代替np.where

如果你想丢失对应的ground truth值低于阈值的预测,你可以编写如下自定义函数:

def my_loss_threshold(threshold):
    def my_loss(y_true,y_pred):
        # keep predictions pixels where their corresponding y_true is above a threshold
        y_pred = tf.gather_nd(y_pred, tf.where(y_true>=threshold))
        # keep image pixels where they're above a threshold
        y_true = tf.gather_nd(y_true, tf.where(y_true>=threshold))
        # compute MSE between filtered pixels
        loss = tf.square(y_true-y_pred)
        # return mean of losses
        return tf.reduce_mean(loss)
    return my_loss 

model.compile(loss=my_loss_threshold(threshold=0.1), optimizer="adam")

我将损失函数包装到另一个函数中,因此您可以将阈值作为超参数传递给模型编译。

【讨论】:

并不是我们要设置loss为0,我们根本不想去统计。它不应该进入损失的总和,它应该被简单地忽略。如果图像上有 100 个像素,并且有 10 个非零像素,则应仅对这 10 个像素求和。这样做吗? 如果你设置loss = tf.where(y_true&lt;0.1)就可以了。是否只有当 y_true 和 y_pred loss = tf.where(y_true<0.1 and y_pred<0.1) 抛出了一个错误。 @mmont 我编辑了我的答案,在某种程度上你可以忽略预测的丢失,它们对应的y_true 低于阈值。 这看起来是对的,但我不在我的电脑前。但是,如果我想根据两个条件过滤掉:1. y_true @mmont 如果你想得到一个准确的答案,你应该提供一个关于你想要什么的准确问题。在收到回复的任何时候更改问题都不是一个好主意。明确你想要做什么,然后提出你的问题。

以上是关于如何以与 np.where 相同的方式使用 Tensorflow.where?的主要内容,如果未能解决你的问题,请参考以下文章

如何以与 iPhoto 相同的方式订购 ALAssetsGroup

如何以与“TABLE name;”相同的方式格式化 postgres 结果做?

如何以与 VBScript 相同的方式解析命令行

如何以与 WhatsApp 相同的方式实施 APNS?

pandas比较两个dataframe特定数据列的数值是否相同并给出差值:使用np.where函数

如何以与 toad 相同的方式在 PL/SQL 中导出数据库对象?