为啥有时 CNN 模型只能预测所有其他类别中的一类?
Posted
技术标签:
【中文标题】为啥有时 CNN 模型只能预测所有其他类别中的一类?【英文标题】:Why do sometimes CNN models predict just one class out of all others?为什么有时 CNN 模型只能预测所有其他类别中的一类? 【发布时间】:2021-04-22 01:38:29 【问题描述】:我对深度学习领域比较陌生,所以请不要像 Reddit 那样刻薄!这似乎是一个普遍的问题,所以我不会在这里提供我的代码,因为它似乎没有必要(如果是,这里是 colab 的链接)
关于数据的一点:可以找到原始数据here。它是 82 GB 原始数据集的缩小版本。
一旦我在这方面训练了我的 CNN,它每次都预测“无糖尿病视网膜病变”(无 DR),准确率达到 73%。这是因为大量的 No DR 图像还是其他原因?我不知道!我预测的 5 个类是 ["Mild", "Moderate", "No DR", "Proliferative DR", "Severe"]
。
这可能只是糟糕的代码,希望你们能帮忙
【问题讨论】:
您可以在训练中删除一些 No DR 图像,并测试您的理论! 我确实尝试过这样做!虽然准确率下降到 35%,不到一半,但该模型至少可以预测其他类别,而且我不介意在 6 层 CNN 上的准确率如此之低;) 【参考方案1】:正如 Ivan 已经指出的那样,您有一个班级不平衡问题。这可以通过以下方式解决:
在线硬负挖掘:在计算损失后的每次迭代中,您可以对批次中属于“无DR”类的所有元素进行排序,只保留最差的k
。然后你只使用这些更差的 k 来估计梯度,然后丢弃所有其余的。
参见,例如:Abhinav Shrivastava、Abhinav Gupta 和 Ross Girshick Training Region-based Object Detectors with Online Hard Example Mining (CVPR 2016)
Focal loss: 对“vanilla”交叉熵损失的修改可用于解决类别不平衡问题。
相关帖子this 和this。
【讨论】:
哇,谢谢,我以前没听说过这些 @Divith 我想如果你用谷歌搜索“类不平衡”,你也会遇到其他选项。此列表绝不是全面的。 感谢您提供链接,很遗憾我不能投票两次!【参考方案2】:我正要评论:
更严格的方法是开始衡量您的数据集平衡:您拥有每个类别的多少张图片?这可能会回答您的问题。
但我忍不住看了你给的链接。 Kaggle 已经为您提供了数据集的概览:
快速计算:25,812 / 35,126 * 100 = 73%
。这很有趣,你说你的准确度是74%
。您的模型正在一个不平衡的数据集上学习,第一类被过度表示,25k/35k
是巨大的。我的假设是你的模型一直在预测第一类,这意味着平均而言你最终会得到74%
的准确度。
你应该做的是平衡你的数据集。例如,只允许来自第一类的35,126 - 25,810 = 9,316
示例在一个时期出现。更好的是,平衡所有类的数据集,使每个类在每个时期只出现 n 次。
【讨论】:
我试过这样做,我想我也可以在其他类中制作图像的旋转/翻转副本来扩展数据集,你认为这些可行吗? 这将被视为数据增强,因此我认为它不会弥补这一差距。增强图像并不完全是 new 图像。从这个意义上说,我认为第一堂课仍然会因为其多样化的内容而胜过其他班级。以上是关于为啥有时 CNN 模型只能预测所有其他类别中的一类?的主要内容,如果未能解决你的问题,请参考以下文章
分类预测 | MATLAB实现CNN-GRU-Attention多输入分类预测