在Tensorflow中限制多类分类中的输出类
Posted
技术标签:
【中文标题】在Tensorflow中限制多类分类中的输出类【英文标题】:Restricting output classes in multi-class classification in Tensorflow 【发布时间】:2016-11-29 01:17:41 【问题描述】:我正在构建一个双向 LSTM 来进行多类句子分类。
我总共有 13 个类可供选择,我将 LSTM 网络的输出乘以一个维度为 [2*num_hidden_unit,num_classes]
的矩阵,然后应用 softmax 来获得句子落入 13 个类中的一个的概率。
因此,如果我们将output[-1]
视为网络输出:
W_output = tf.Variable(tf.truncated_normal([2*num_hidden_unit,num_classes]))
result = tf.matmul(output[-1],W_output) + bias
我得到了我的[1, 13]
矩阵(假设我暂时没有处理批次)。
现在,我还知道给定句子肯定不属于给定类别,我想限制给定句子考虑的类别数量。例如,假设对于一个给定的句子,我知道它只能分为 6 个类,所以输出应该是一个维度矩阵 [1,6]
。
我正在考虑的一个选项是在result
矩阵上放置一个掩码,在该矩阵中,我将对应于我想要保留的类的行乘以 1,我想要丢弃的行乘以 0,通过这种方式我只会丢失一些信息而不是重定向它。
有人知道在这种情况下该怎么办吗?
【问题讨论】:
【参考方案1】:我认为你最好的选择是,正如你所描述的,使用加权交叉熵损失函数,其中“不可能的类”的权重为 0,而其他可能的类的权重为 1。 Tensorflow 具有加权交叉熵损失函数。
另一种有趣但可能不太有效的方法是提供您现在掌握的任何信息,说明您的句子在某个时间点(可能接近尾声)可以/不能落入网络中。
【讨论】:
以上是关于在Tensorflow中限制多类分类中的输出类的主要内容,如果未能解决你的问题,请参考以下文章
为多类分类python tensorflow keras设置偏差