TensorFlow Object Detection API 中用于平衡数据的类权重
Posted
技术标签:
【中文标题】TensorFlow Object Detection API 中用于平衡数据的类权重【英文标题】:Class weights for balancing data in TensorFlow Object Detection API 【发布时间】:2019-01-22 14:01:33 【问题描述】:我正在使用Open Images Dataset 上的TensorFlow object detection API 微调SSD 对象检测器。我的训练数据包含不平衡的类,例如
-
顶部(5K 图像)
连衣裙(50K 图像)
等等……
我想在分类损失中添加类权重以提高性能。我怎么做?配置文件的以下部分似乎相关:
loss
classification_loss
weighted_sigmoid
localization_loss
weighted_smooth_l1
...
classification_weight: 1.0
localization_weight: 1.0
如何更改配置文件以添加每个类别的分类损失权重?如果不是通过配置文件,那么推荐的方法是什么?
【问题讨论】:
【参考方案1】:API 要求直接在 注释文件 中为每个对象 (bbox) 分配一个权重。由于这个要求,使用类权重的解决方案似乎是:
1) 如果您有自定义数据集,您可以修改每个对象 (bbox) 的注释以将权重字段包含为“对象/权重”。
2) 如果您不想修改注释,您可以重新创建仅 tf_records 文件以包含 bbox 的权重。
3)修改API的代码(在我看来相当棘手)
我决定选择#2,所以我将代码放在这里为具有两个类(“top”, "dress") 与权重 (1.0, 0.1) 给定的 xml 注释文件夹为:
import os
import io
import glob
import hashlib
import pandas as pd
import xml.etree.ElementTree as ET
import tensorflow as tf
import random
from PIL import Image
from object_detection.utils import dataset_util
# Define the class names and their weight
class_names = ['top', 'dress', ...]
class_weights = [1.0, 0.1, ...]
def create_example(xml_file):
tree = ET.parse(xml_file)
root = tree.getroot()
image_name = root.find('filename').text
image_path = root.find('path').text
file_name = image_name.encode('utf8')
size=root.find('size')
width = int(size[0].text)
height = int(size[1].text)
xmin = []
ymin = []
xmax = []
ymax = []
classes = []
classes_text = []
truncated = []
poses = []
difficult_obj = []
weights = [] # Important line
for member in root.findall('object'):
xmin.append(float(member[4][0].text) / width)
ymin.append(float(member[4][1].text) / height)
xmax.append(float(member[4][2].text) / width)
ymax.append(float(member[4][3].text) / height)
difficult_obj.append(0)
class_name = member[0].text
class_id = class_names.index(class_name)
weights.append(class_weights[class_id])
if class_name == 'top':
classes_text.append('top'.encode('utf8'))
classes.append(1)
elif class_name == 'dress':
classes_text.append('dress'.encode('utf8'))
classes.append(2)
else:
print('E: class not recognized!')
truncated.append(0)
poses.append('Unspecified'.encode('utf8'))
full_path = image_path
with tf.gfile.GFile(full_path, 'rb') as fid:
encoded_jpg = fid.read()
encoded_jpg_io = io.BytesIO(encoded_jpg)
image = Image.open(encoded_jpg_io)
if image.format != 'JPEG':
raise ValueError('Image format not JPEG')
key = hashlib.sha256(encoded_jpg).hexdigest()
#create TFRecord Example
example = tf.train.Example(features=tf.train.Features(feature=
'image/height': dataset_util.int64_feature(height),
'image/width': dataset_util.int64_feature(width),
'image/filename': dataset_util.bytes_feature(file_name),
'image/source_id': dataset_util.bytes_feature(file_name),
'image/key/sha256': dataset_util.bytes_feature(key.encode('utf8')),
'image/encoded': dataset_util.bytes_feature(encoded_jpg),
'image/format': dataset_util.bytes_feature('jpeg'.encode('utf8')),
'image/object/bbox/xmin': dataset_util.float_list_feature(xmin),
'image/object/bbox/xmax': dataset_util.float_list_feature(xmax),
'image/object/bbox/ymin': dataset_util.float_list_feature(ymin),
'image/object/bbox/ymax': dataset_util.float_list_feature(ymax),
'image/object/class/text': dataset_util.bytes_list_feature(classes_text),
'image/object/class/label': dataset_util.int64_list_feature(classes),
'image/object/difficult': dataset_util.int64_list_feature(difficult_obj),
'image/object/truncated': dataset_util.int64_list_feature(truncated),
'image/object/view': dataset_util.bytes_list_feature(poses),
'image/object/weight': dataset_util.float_list_feature(weights) # Important line
))
return example
def main(_):
weighted_tf_records_output = 'name_of_records_file.record' # output file
annotations_path = '/path/to/annotations/folder/*.xml' # input annotations
writer_train = tf.python_io.TFRecordWriter(weighted_tf_records_output)
filename_list=tf.train.match_filenames_once(annotations_path)
init = (tf.global_variables_initializer(), tf.local_variables_initializer())
sess=tf.Session()
sess.run(init)
list = sess.run(filename_list)
random.shuffle(list)
for xml_file in list:
print('-> Processing '.format(xml_file))
example = create_example(xml_file)
writer_train.write(example.SerializeToString())
writer_train.close()
print('-> Successfully converted dataset to TFRecord.')
if __name__ == '__main__':
tf.app.run()
如果您有其他类型的注释,代码将非常相似,但不幸的是,这个代码不起作用。
【讨论】:
这很有趣,我会考虑修改注释。我也在研究上采样/下采样少数/多数类以平衡数据。谢谢! 无论如何我们可以在 labelimg github.com/tzutalin/labelImg 中做第 1 点吗?【参考方案2】:对象检测 API 损失定义在:https://github.com/tensorflow/models/blob/master/research/object_detection/core/losses.py
特别是,已经实现了以下损失类:
分类损失:
-
WeightedSigmoidClassificationLoss
SigmoidFocalClassificationLoss
WeightedSoftmaxClassificationLoss
WeightedSoftmaxClassificationAgainstLogitsLoss
BootstrappedSigmoidClassificationLoss
本地化损失:
-
WeightedL2LocalizationLoss
WeightedSmoothL1LocalizationLoss
加权IOULocalizationLoss
权重参数用于平衡锚点(先前的框),大小为[batch_size, num_anchors]
,除了硬负挖掘。或者,focal loss 向下权衡分类良好的示例并专注于困难示例。
主要类别不平衡是由于与极少数正面示例(具有对象类的边界框)相比,更多负面示例(没有感兴趣对象的边界框)。这似乎是为什么正样本中的类不平衡(即正类标签的不均匀分布)没有作为对象检测损失的一部分来实现的原因。
【讨论】:
谢谢瓦迪姆。您的意思是说,如果我们使用模型/研究中提供的框架,对 TFRecord 示例设置权重将无济于事吗?在我对 TF 记录示例中的权重进行测试时,我没有看到任何改进。以上是关于TensorFlow Object Detection API 中用于平衡数据的类权重的主要内容,如果未能解决你的问题,请参考以下文章
如何安装 TensorFlow 2 和 object_detection 模块?
TensorFlow Object Detection API
TensorFlow object_detection 使用
TensorFlow object detection API