使用faster rcnn 跑vot2015的数据集

Posted 贵有恒何必三更眠五更起,最无益只怕一日曝十日寒。

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了使用faster rcnn 跑vot2015的数据集相关的知识,希望对你有一定的参考价值。

本周老师给的任务:

一是将VOT15数据集(世华已传到服务器上)上每个序列的第1,11,21,31,41帧分别运行Faster R-CNN检测器并保存在图片上显示的检测结果;

二是将这5帧的ground truth bounding box作为proposal得到其对应的检测器分类结果(比如网络要检测20类物体,那包括背景就是得到21类对应的检测分数值),并将每个序列的检测结果分别存成一个文本文档。

 注意,使用代码的时候,可能会有路径错误,还可能是,我贴上的代码,博客园的网站给在某些语句后加了 <br> ,调错的时候细看!!我在后台竟然看不到<br>,但是浏览的时候却有!!

第一个问题已经解决,现在整理一下思路。

先将py faster rcnn 装好之后,测试运行dome.py能成功展示之后,再进行接下来的工作。

我的想法是,

(1)将vot2015数据集上的所有数据的分类统计出来(就是把vot2015下的子文件夹的名称统计出来,方便之后操作),这里直接用了( http://www.cnblogs.com/flyhigh1860/p/3896111.html )的源码进行修改

 

#!/usr/bin/python
# -*- coding:utf8 -*-

import os
allFileNum = 0

def printPath(level, path):
    global allFileNum
    \'\'\'\'\' 
    打印一个目录下的所有文件夹和文件 
    \'\'\'
    # 所有文件夹,第一个字段是次目录的级别
    dirList = []
    # 所有文件
    fileList = []
    # 返回一个列表,其中包含在目录条目的名称(google翻译)
    files = os.listdir(path)
    # 先添加目录级别
    dirList.append(str(level))
    for f in files:
        if (os.path.isdir(path + \'/\' + f)):
            # 排除隐藏文件夹。因为隐藏文件夹过多
            if (f[0] == \'.\'):
                pass
            else:
                # 添加非隐藏文件夹
                dirList.append(f)
        if (os.path.isfile(path + \'/\' + f)):
            # 添加文件
            fileList.append(f)
            # 当一个标志使用,文件夹列表第一个级别不打印
    i_dl = 0
  #得到的文件夹名保存在 save_file.txt 中,使用python的追加操作 ‘a’ save_file = open(\'/home/user/Downloads/save_file.txt\',\'a\') for dl in dirList: if (i_dl == 0): i_dl = i_dl + 1 else: # 打印至控制台,不是第一个的目录 print \'-\' * (int(dirList[0])), dl
       #将文件名写入save_file.txt中 save_file.write(dl) save_file.write(\'\\n\') # 打印目录下的所有文件夹和文件,目录级别+1 #printPath((int(dirList[0]) + 1), path + \'/\' + dl) for fl in fileList: # 打印文件 print \'-\' * (int(dirList[0])), fl # 随便计算一下有多少个文件 allFileNum = allFileNum + 1 if __name__ == \'__main__\': printPath(1, \'/home/user/Downloads/vot2015\') print \'总文件数 =\', allFileNum

 这里再给出save_file.txt 文件内容

soldier
butterfly
hand
car2
sheep
birds1
motocross1
marching
book
road
graduate
fish3
fernando
bag
wiper
gymnastics2
leaves
ball1
birds2
crossing
soccer1
godfather
nature
racing
traffic
pedestrian2
handball2
ball2
gymnastics1
singer2
singer1
dinosaur
gymnastics3
bolt1
gymnastics4
pedestrian1
helicopter
singer3
matrix
octopus
iceskater1
fish4
sphere
car1
motocross2
girl
fish1
bolt2
basketball
blanket
bmx
shaking
tiger
handball1
rabbit
fish2
tunnel
glove
iceskater2
soccer2

 

 

 (2)从save_file.txt 中将分来读取出来,保存再一个list中,之后将这段代码加到 demo.py 中使用(参考了  http://www.cnblogs.com/xuxn/archive/2011/07/27/read-a-file-with-python.html     和    http://www.cnblogs.com/mxh1099/p/5680001.html)

l = []

file = open(\'/home/user/Downloads/save_file.txt\')

while 1:
    line = file.readline()
    if line != \'\\n\':
        print line.replace("\\n", "")
     #在list中 加入去掉换行符的文件名 l.append(line.replace("\\n","")) if not line: break print l

 

 (3)需要将文件名和要遍历的每个文件夹下的文件名配合,同样,这段代码之后会用在demo.py 中

lfile = []

file = open(\'/home/user/Downloads/save_file.txt\')

while 1:
    line = file.readline()
    if line != \'\\n\':
        lfile.append(line.replace("\\n", ""))
    if not line:
        break
im_names =[\'00000023.jpg\',\'00000011.jpg\',\'00000001.jpg\']
    # im_names = [\'00000001.jpg\', \'000000011.jpg\', \'00000021.jpg\',
    #             \'00000031.jpg\', \'00000041.jpg\']

for litme in lfile :
    for im_name in im_names:
        im_path = str(litme) + \'/\' + str(im_name)
        print \'~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\'
        #print \'Demo for data/demo/{}\'.format(im_name)
        print im_path

 (4)可以对文件遍历之后,需要将生成的图片结果保存下来,参考了《演示如何实现Matplotlib绘图并保存图像但不显示图形的方法》(http://blog.csdn.net/rumswell/article/details/7342479) 和Python创建目录文件夹 (http://www.cnblogs.com/monsteryang/p/6574550.html)

 

最后附上我修改之后的demo.py

#!/usr/bin/env python

# --------------------------------------------------------
# Faster R-CNN
# Copyright (c) 2015 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# Written by Ross Girshick
# --------------------------------------------------------

"""
Demo script showing detections in sample images.

See README.md for installation instructions before running.
"""

import _init_paths
from fast_rcnn.config import cfg
from fast_rcnn.test import im_detect
from fast_rcnn.nms_wrapper import nms
from utils.timer import Timer
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import scipy.io as sio
import caffe, os, sys, cv2
import argparse

#add
matplotlib.use(\'Agg\')

CLASSES = (\'__background__\',
           \'aeroplane\', \'bicycle\', \'bird\', \'boat\',
           \'bottle\', \'bus\', \'car\', \'cat\', \'chair\',
           \'cow\', \'diningtable\', \'dog\', \'horse\',
           \'motorbike\', \'person\', \'pottedplant\',
           \'sheep\', \'sofa\', \'train\', \'tvmonitor\')

NETS = {\'vgg16\': (\'VGG16\',
                  \'VGG16_faster_rcnn_final.caffemodel\'),
        \'zf\': (\'ZF\',
                  \'ZF_faster_rcnn_final.caffemodel\')}

#add
def mkdir(path):
    import os

    path = path.strip()
    path = path.rstrip("\\\\")

    isExists = os.path.exists(path)
    if not isExists:
        os.makedirs(path)
        print path + \'ok\'
        return True
    else:

        print path + \'failed!\'
        return False

def vis_detections(image_name, im, class_name, dets, thresh=0.5):
    """Draw detected bounding boxes."""
    inds = np.where(dets[:, -1] >= thresh)[0]
    if len(inds) == 0:
        return

    im = im[:, :, (2, 1, 0)]
    fig, ax = plt.subplots(figsize=(12, 12))
    ax.imshow(im, aspect=\'equal\')
    for i in inds:
        bbox = dets[i, :4]
        score = dets[i, -1]

        ax.add_patch(
            plt.Rectangle((bbox[0], bbox[1]),
                          bbox[2] - bbox[0],
                          bbox[3] - bbox[1], fill=False,
                          edgecolor=\'red\', linewidth=3.5)
            )
        ax.text(bbox[0], bbox[1] - 2,
                \'{:s} {:.3f}\'.format(class_name, score),
                bbox=dict(facecolor=\'blue\', alpha=0.5),
                fontsize=14, color=\'white\')

    ax.set_title((\'{} detections with \'
                  \'p({} | box) >= {:.1f}\').format(class_name, class_name,
                                                  thresh),
                  fontsize=14)
    plt.axis(\'off\')
    plt.tight_layout()
    plt.draw()
    #add
    ll = []
    ll = str(image_name).split(\'/\')
    print ll[0]

    mkdir(\'/home/user/tmp/\' + str(ll[0]))
    plt.savefig(\'/home/user/tmp/\' + str(image_name))

def demo(net, image_name):
    """Detect object classes in an image using pre-computed object proposals."""

    # Load the demo image
    im_file = os.path.join(cfg.DATA_DIR, \'demo\',\'vot2015\', image_name)
    print("%s", im_file)
    im = cv2.imread(im_file)

    # Detect all object classes and regress object bounds
    timer = Timer()
    timer.tic()
    #add try except
    try:
        scores, boxes = im_detect(net, im)
        timer.toc()
        print (\'Detection took {:.3f}s for \'
               \'{:d} object proposals\').format(timer.total_time, boxes.shape[0])
        # Visualize detections for each class
        CONF_THRESH = 0.8
        NMS_THRESH = 0.3
        for cls_ind, cls in enumerate(CLASSES[1:]):
            cls_ind += 1 # because we skipped background
            cls_boxes = boxes[:, 4*cls_ind:4*(cls_ind + 1)]
            cls_scores = scores[:, cls_ind]
            dets = np.hstack((cls_boxes,
                              cls_scores[:, np.newaxis])).astype(np.float32)
            keep = nms(dets, NMS_THRESH)
            dets = dets[keep, :]
            vis_detections(image_name,im, cls, dets, thresh=CONF_THRESH)
    except Exception:
        print \'Error\'

def parse_args():
    """Parse input arguments."""
    parser = argparse.ArgumentParser(description=\'Faster R-CNN demo\')
    parser.add_argument(\'--gpu\', dest=\'gpu_id\', help=\'GPU device id to use [0]\',
                        default=0, type=int)
    parser.add_argument(\'--cpu\', dest=\'cpu_mode\',
                        help=\'Use CPU mode (overrides --gpu)\',
                        action=\'store_true\')
    parser.add_argument(\'--net\', dest=\'demo_net\', help=\'Network to use [vgg16]\',
                        choices=NETS.keys(), default=\'vgg16\')

    args = parser.parse_args()

    return args

if __name__ == \'__main__\':
    cfg.TEST.HAS_RPN = True  # Use RPN for proposals

    args = parse_args()

    prototxt = os.path.join(cfg.MODELS_DIR, NETS[args.demo_net][0],
                            \'faster_rcnn_alt_opt\', \'faster_rcnn_test.pt\')
    caffemodel = os.path.join(cfg.DATA_DIR, \'faster_rcnn_models\',
                              NETS[args.demo_net][1])

    if not os.path.isfile(caffemodel):
        raise IOError((\'{:s} not found.\\nDid you run ./data/script/\'
                       \'fetch_faster_rcnn_models.sh?\').format(caffemodel))

    if args.cpu_mode:
        caffe.set_mode_cpu()
    else:
        caffe.set_mode_gpu()
        caffe.set_device(args.gpu_id)
        cfg.GPU_ID = args.gpu_id
    net = caffe.Net(prototxt, caffemodel, caffe.TEST)

    print \'\\n\\nLoaded network {:s}\'.format(caffemodel)

    # Warmup on a dummy image
    im = 128 * np.ones((300, 500, 3), dtype=np.uint8)
    for i in xrange(2):
        _, _= im_detect(net, im)

    # im_names = [\'000456.jpg\', \'000542.jpg\', \'001150.jpg\',
    #             \'001763.jpg\', \'004545.jpg\',\'00000023.jpg\',\'00000011.jpg\',\'00000001.jpg\']

    # edit
    lfile = []

    file = open(\'/home/user/Downloads/save_file.txt\')

    while 1:
        line = file.readline()
        if line != \'\\n\':
            lfile.append(line.replace("\\n", ""))
        if not line:
            break

    print lfile

    im_names = [\'00000001.jpg\', \'00000011.jpg\', \'00000021.jpg\',
                \'00000031.jpg\', \'00000041.jpg\']

    for litme in lfile :
        for im_name in im_names:
            im_path = str(litme) + \'/\' + str(im_name)
            print \'~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\'
            print \'Demo for data/demo/{}\'.format(im_name)
            try:
                demo(net, im_path)
            except Exception:
                print \'ERROR\'
    #plt.show()

 

第二个问题先看着,没想法

 在图片上显示每个IOU大于0.5的proposal对应的最高检测值的类别、分数和回归后的框,在文本文档里则保存每个proposal对应的21个类别的检测分数和回归后的边界框坐标。

对于每个类别,总会生成300个proposals,

所以,在每个proposal,都会有4个坐标

对于每个proposal,都会有一个类别值。

因为要生成每个proposal对应的21个类别的分数,就需要将分数先保存起来,再输出

还要记录回归后的边间框。

 

对于图片,显示每个IOU大于0.5的proposal对应的最高检测值的类别、分数和回归后的框。

也是先要将最高检测分数对应的类别和回归框记录下来。

以上是关于使用faster rcnn 跑vot2015的数据集的主要内容,如果未能解决你的问题,请参考以下文章

折腾faster-rcnn(三)--训练篇

Ubuntu下跑通py-faster-rcnn详解demo运作流程

tensorflow faster rann

faster-rcnn训练自己数据+测试

[图像算法]-Faster RCNN详解

如何在faster-rcnn上训练自己的数据集