faster-rcnn 之训练数据是如何准备的:imdb和roidb的产生

Posted sloanqin

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了faster-rcnn 之训练数据是如何准备的:imdb和roidb的产生相关的知识,希望对你有一定的参考价值。

【说明】:欢迎加入:faster-rcnn 交流群 238138700,关于imdb和roidb的生成都是在函数train_rpn的中,所以我们从这个函数开始,逐个跟进看如何执行得到我们需要的imdb和roidb:


def train_rpn(queue=None, imdb_name=None, init_model=None, solver=None,
              max_iters=None, cfg=None):
    """Train a Region Proposal Network in a separate training process.
    """

    # Not using any proposals, just ground-truth boxes
    cfg.TRAIN.HAS_RPN = True
    cfg.TRAIN.BBOX_REG = False  # applies only to Fast R-CNN bbox regression
    cfg.TRAIN.PROPOSAL_METHOD = 'gt'
    cfg.TRAIN.IMS_PER_BATCH = 1
    print 'Init model: '.format(init_model)
    print('Using config:')
    pprint.pprint(cfg)

    import caffe
    _init_caffe(cfg)

    roidb, imdb = get_roidb(imdb_name) # 调用函数,返回训练数据
    print 'roidb len: '.format(len(roidb))
    output_dir = get_output_dir(imdb)
    print 'Output will be saved to `:s`'.format(output_dir)

    model_paths = train_net(solver, roidb, output_dir, #传入数据roidb,供训练
                            pretrained_model=init_model,
                            max_iters=max_iters)
    # Cleanup all but the final model
    for i in model_paths[:-1]:
        os.remove(i)
    rpn_model_path = model_paths[-1]
    # Send final model path through the multiprocessing queue
    queue.put('model_path': rpn_model_path)

所以,进入get_roidb函数:

def get_roidb(imdb_name, rpn_file=None):
    imdb = get_imdb(imdb_name) # 调用该函数,返回imdb
    print 'Loaded dataset `:s` for training'.format(imdb.name)
    imdb.set_proposal_method(cfg.TRAIN.PROPOSAL_METHOD)
    print 'Set proposal method: :s'.format(cfg.TRAIN.PROPOSAL_METHOD)
    if rpn_file is not None:
        imdb.config['rpn_file'] = rpn_file
    roidb = get_training_roidb(imdb) #利用imdb,产生roi_db
    return roidb, imdb

所以我们要先看imdb是如何产生的,然后看如何借助imdb产生roidb

def get_imdb(name):
    """Get an imdb (image database) by name."""
    if not __sets.has_key(name):
        raise KeyError('Unknown dataset: '.format(name))
    return __sets[name]()
从上面可见,get_imdb这个函数的实现原理:_sets是一个字典,字典的key是数据集的名称,字典的value是一个lambda表达式(即一个函数指针),
__sets[name]()
这句话实际上是调用函数,返回数据集imdb,下面看这个函数:
for year in ['2007', '2012']:
    for split in ['train', 'val', 'trainval', 'test']:
        name = 'voc__'.format(year, split)
        __sets[name] = (lambda split=split, year=year: pascal_voc(split, year))
所以可以看到,执行的实际上是pascal_voc函数,参数是split 和 year(ps:在train_vpn函数中,name是voc_2007_trainval,所以这里对应的split和year分别是trainval和2007);
很明显,pascal_voc是一个类,这是调用了该类的构造函数,返回的也是该类的一个实例,所以这下我们清楚了imdb实际上就是pascal_voc的一个实例;

那么我们来看这个类的构造函数是如何的,以及输入的图片数据在里面是如何组织的:

该类的构造函数如下:基本上就是设置了imdb的一些属性,比如图片的路径,图片名称的索引,并没有把真实的图片数据放进来

class pascal_voc(imdb):
    def __init__(self, image_set, year, devkit_path=None):
        imdb.__init__(self, 'voc_' + year + '_' + image_set)
        self._year = year # 设置年,2007
        self._image_set = image_set # trainval
        self._devkit_path = self._get_default_path() if devkit_path is None \\
                            else devkit_path # 数据集的路径'/home/sloan/py-faster-rcnn-master/data/VOCdevkit2007'
        self._data_path = os.path.join(self._devkit_path, 'VOC' + self._year) # '/home/sloan/py-faster-rcnn-master/data/VOCdevkit2007/VOC2007'
        self._classes = ('__background__', # always index 0
                         'aeroplane', 'bicycle', 'bird', 'boat',
                         'bottle', 'bus', 'car', 'cat', 'chair',
                         'cow', 'diningtable', 'dog', 'horse',
                         'motorbike', 'person', 'pottedplant',
                         'sheep', 'sofa', 'train', 'tvmonitor') # 21个类别
        self._class_to_ind = dict(zip(self.classes, xrange(self.num_classes))) #给每个类别赋予一个对应的整数
        self._image_ext = '.jpg' # 图片的扩展名
        self._image_index = self._load_image_set_index() # 把所有图片的名称加载,放在list中,便于索引读取图片
        # Default to roidb handler
        self._roidb_handler = self.selective_search_roidb
        self._salt = str(uuid.uuid4())
        self._comp_id = 'comp4'

        # PASCAL specific config options
        self.config = 'cleanup'     : True,
                       'use_salt'    : True,
                       'use_diff'    : False,
                       'matlab_eval' : False,
                       'rpn_file'    : None,
                       'min_size'    : 2
        # 这两句就是检查前面的路径是否存在合法了,否则后面无法运行
        assert os.path.exists(self._devkit_path), \\
                'VOCdevkit path does not exist: '.format(self._devkit_path)
        assert os.path.exists(self._data_path), \\
                'Path does not exist: '.format(self._data_path)

那么有了imdb之后,roidb又有什么不同呢?为什么实际输入train_rpn的数据是roidb呢?

前面我们已经得到了imdb,但是imdb的成员roidb还是空白,啥都没有,那么roidb是如何生成的,其中又包含了哪些信息呢?

imdb.set_proposal_method(cfg.TRAIN.PROPOSAL_METHOD)
上面调用的函数,为imdb添加了roidb的数据,我们看看如何添加的,见下面这个函数:

    def set_proposal_method(self, method):
        method = eval('self.' + method + '_roidb')
        self.roidb_handler = method
这里method传入的是一个str:gt,所以method=eval('self.gt_roidb')

那么关键就是eval函数做了什么操作???分析这个函数分析roidb中每个元素的具体hany
有了roidb后,后面的get_training_roidb(imdb)完成什么功能:将roidb中的元素由5011个,通过水平对称变成10022个;将index这个list的元素相应的也翻一番;

我们看看这个函数:

<span style="font-family: Arial, Helvetica, sans-serif; background-color: rgb(255, 255, 255);">函数如下:这个函数首先对imdb中涉及到的图像做了一个水平镜像,使得trainval中的5011张图片,变成了10022张图片;然后调用函数prepare_roidb函数准备数据(ps:我觉得作者这些函数的层层嵌套,又没做多大事情,实在是让结构不那么美观)</span>

def get_training_roidb(imdb):
    """Returns a roidb (Region of Interest database) for use in training."""
    if cfg.TRAIN.USE_FLIPPED:
        print 'Appending horizontally-flipped training examples...'
        imdb.append_flipped_images()
        print 'done'

    print 'Preparing training data...'
    rdl_roidb.prepare_roidb(imdb)
    print 'done'
首先我们看看append_flipped_images函数:可以发现,roidb是imdb的一个成员变量,roidb是一个list(每个元素对应一张图片),list中的元素是一个字典,字典中存放了5个key,分别是boxes信息,每个box的class信息,是否是flipped的标志位,重叠信息gt_overlaps,以及seg_areas;分析该函数可知,将box的值按照水平对称,原先roidb中只有5011个元素,经过水平对称后通过append增加到5011*2=10022个;

    def append_flipped_images(self):
        num_images = self.num_images
        widths = self._get_widths()
        for i in xrange(num_images):
            boxes = self.roidb[i]['boxes'].copy()
            oldx1 = boxes[:, 0].copy()
            oldx2 = boxes[:, 2].copy()
            boxes[:, 0] = widths[i] - oldx2 - 1
            boxes[:, 2] = widths[i] - oldx1 - 1 # 新框的xmin和xmax都要更新
            assert (boxes[:, 2] >= boxes[:, 0]).all()
            entry = 'boxes' : boxes,
                     'gt_overlaps' : self.roidb[i]['gt_overlaps'],
                     'gt_classes' : self.roidb[i]['gt_classes'],
                     'flipped' : True
            self.roidb.append(entry) # 把这个新的框添加到roidb中
        self._image_index = self._image_index * 2 #将索引的list 复制拼接
然后就是prepare_roidb函数:

def prepare_roidb(imdb):
    """Enrich the imdb's roidb by adding some derived quantities that
    are useful for training. This function precomputes the maximum
    overlap, taken over ground-truth boxes, between each ROI and
    each ground-truth box. The class with maximum overlap is also
    recorded.
    """
    sizes = [PIL.Image.open(imdb.image_path_at(i)).size
             for i in xrange(imdb.num_images)]
    roidb = imdb.roidb
    for i in xrange(len(imdb.image_index)):
        roidb[i]['image'] = imdb.image_path_at(i)
        roidb[i]['width'] = sizes[i][0]
        roidb[i]['height'] = sizes[i][1]
        # need gt_overlaps as a dense array for argmax
        gt_overlaps = roidb[i]['gt_overlaps'].toarray()
        # max overlap with gt over classes (columns)
        max_overlaps = gt_overlaps.max(axis=1)
        # gt class that had the max overlap
        max_classes = gt_overlaps.argmax(axis=1)
        roidb[i]['max_classes'] = max_classes
        roidb[i]['max_overlaps'] = max_overlaps
        # sanity checks
        # max overlap of 0 => class should be zero (background)
        zero_inds = np.where(max_overlaps == 0)[0]
        assert all(max_classes[zero_inds] == 0)
        # max overlap > 0 => class should not be zero (must be a fg class)
        nonzero_inds = np.where(max_overlaps > 0)[0]
        assert all(max_classes[nonzero_inds] != 0)
============================================================================================================================

写到这里,我就想吐槽了,以为数据准备好了么,no,上面只是准备好了roidb的相关信息而已;

我表示这个作者搞的太麻烦了,结构不够扁平化,简单的事情用多个函数绕来绕去,受不了了;

真正的数据处理操作是在

class RoIDataLayer(caffe.Layer): 类的

    def forward(self, bottom, top):函数中开始的,这个类在faster-rcnn-root/lib/roi_data_layer/layer.py文件中

blobs = self._get_next_minibatch()这句话产生了我们需要的数据blobs;这个函数又调用了minibatch.py文件中的def get_minibatch(roidb, num_classes):函数;

然后又调用了def _get_image_blob(roidb, scale_inds):函数;在这个函数中,我们终于发现了cv2.imread函数,也就是最终的读取图片到内存的地方:

def _get_image_blob(roidb, scale_inds):
    """Builds an input blob from the images in the roidb at the specified
    scales.
    """
    num_images = len(roidb)
    processed_ims = []
    im_scales = []
    for i in xrange(num_images):
        im = cv2.imread(roidb[i]['image']) #终于在这里读取图片了
        if roidb[i]['flipped']:
            im = im[:, ::-1, :]
        target_size = cfg.TRAIN.SCALES[scale_inds[i]]
        im, im_scale = prep_im_for_blob(im, cfg.PIXEL_MEANS, target_size,
                                        cfg.TRAIN.MAX_SIZE)
        im_scales.append(im_scale)
        processed_ims.append(im)

    # Create a blob to hold the input images
    blob = im_list_to_blob(processed_ims)

    return blob, im_scales





作者:香蕉麦乐迪-sloanqin-覃元元








以上是关于faster-rcnn 之训练数据是如何准备的:imdb和roidb的产生的主要内容,如果未能解决你的问题,请参考以下文章

faster-rcnn训练自己数据+测试

faster-rcnn 之训练脚本解析:./tools/train_faster_rcnn_alt_opt.py

使用faster-rcnn.pytorch训练自己数据集

Faster-RCNN训练自己的数据集——备忘

如何在faster-rcnn上训练自己的数据集

faster-rcnn 之训练脚本解析:./tools/train_faster_rcnn_alt_opt.py