在Tensorflow中限制多类分类中的输出类

Posted

技术标签:

【中文标题】在Tensorflow中限制多类分类中的输出类【英文标题】:Restricting output classes in multi-class classification in Tensorflow 【发布时间】:2016-11-29 01:17:41 【问题描述】:

我正在构建一个双向 LSTM 来进行多类句子分类。 我总共有 13 个类可供选择,我将 LSTM 网络的输出乘以一个维度为 [2*num_hidden_unit,num_classes] 的矩阵,然后应用 softmax 来获得句子落入 13 个类中的一个的概率。

因此,如果我们将output[-1] 视为网络输出:

W_output = tf.Variable(tf.truncated_normal([2*num_hidden_unit,num_classes])) result = tf.matmul(output[-1],W_output) + bias

我得到了我的[1, 13] 矩阵(假设我暂时没有处理批次)。

现在,我还知道给定句子肯定不属于给定类别,我想限制给定句子考虑的类别数量。例如,假设对于一个给定的句子,我知道它只能分为 6 个类,所以输出应该是一个维度矩阵 [1,6]

我正在考虑的一个选项是在result 矩阵上放置一个掩码,在该矩阵中,我将对应于我想要保留的类的行乘以 1,我想要丢弃的行乘以 0,通过这种方式我只会丢失一些信息而不是重定向它。

有人知道在这种情况下该怎么办吗?

【问题讨论】:

【参考方案1】:

我认为你最好的选择是,正如你所描述的,使用加权交叉熵损失函数,其中“不可能的类”的权重为 0,而其他可能的类的权重为 1。 Tensorflow 具有加权交叉熵损失函数。

另一种有趣但可能不太有效的方法是提供您现在掌握的任何信息,说明您的句子在某个时间点(可能接近尾声)可以/不能落入网络中。

【讨论】:

以上是关于在Tensorflow中限制多类分类中的输出类的主要内容,如果未能解决你的问题,请参考以下文章

tensorflow 增强树分类器多类

为多类分类python tensorflow keras设置偏差

使用深度学习防止在多类分类中过度拟合特定类

基于图像和文本特征的 TensorFlow 训练模型,具有多类输出

多类图像分类中如何获取权重图

支持向量机多类分类方法及特点