Caffe中im2col的实现解析

Posted jourluohua

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了Caffe中im2col的实现解析相关的知识,希望对你有一定的参考价值。

这里,我是将Caffe中im2col的解析过程直接拉了出来,使用C++进行了输出,方便理解。代码如下:

  1 #include<iostream>
  2 
  3 using namespace std;
  4 
  5 bool is_a_ge_zero_and_a_lt_b(int a,int b)
  6 {
  7     if(a>=0 && a <b) return true;
  8     return false;
  9 }
 10 
 11 void im2col_cpu(const float* data_im, const int channels,
 12     const int height, const int width, const int kernel_h, const int kernel_w,
 13     const int pad_h, const int pad_w,
 14     const int stride_h, const int stride_w,
 15     const int dilation_h, const int dilation_w,
 16     float* data_col) {
 17   const int output_h = (height + 2 * pad_h -
 18     (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1;
 19   const int output_w = (width + 2 * pad_w -
 20     (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1;
 21   const int channel_size = height * width;
 22   for (int channel = channels; channel--; data_im += channel_size) {
 23     for (int kernel_row = 0; kernel_row < kernel_h; kernel_row++) {
 24       for (int kernel_col = 0; kernel_col < kernel_w; kernel_col++) {
 25         int input_row = -pad_h + kernel_row * dilation_h;
 26         for (int output_rows = output_h; output_rows; output_rows--) {
 27           if (!is_a_ge_zero_and_a_lt_b(input_row, height)) {
 28             for (int output_cols = output_w; output_cols; output_cols--) {
 29               *(data_col++) = 0;
 30             }
 31           } else {
 32             int input_col = -pad_w + kernel_col * dilation_w;
 33             for (int output_col = output_w; output_col; output_col--) {
 34               if (is_a_ge_zero_and_a_lt_b(input_col, width)) {
 35                 *(data_col++) = data_im[input_row * width + input_col];
 36               } else {
 37                 *(data_col++) = 0;
 38               }
 39               input_col += stride_w;
 40             }
 41           }
 42           input_row += stride_h;
 43         }
 44       }
 45     }
 46   }
 47 }
 48 
 49 
 50 int main()
 51 {
 52      float* data_im;
 53     int height=5;
 54     int width=5;   
 55     int kernel_h=3;   
 56     int kernel_w=3;
 57     int pad_h=1;   
 58     int pad_w=1;
 59     int stride_h=1;   
 60     int stride_w=1;
 61     int dilation_h=1;   
 62     int dilation_w=1;
 63     float* data_col;
 64     int channels =3;
 65     const int output_h = (height + 2 * pad_h -
 66     (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1;
 67       const int output_w = (width + 2 * pad_w -
 68     (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1;
 69     data_im = new float[channels*height*width];
 70     data_col = new float[channels*output_h*output_w*kernel_h*kernel_w];
 71     
 72     //init input image data
 73     for(int m=0;m<channels;++m)
 74     {
 75       for(int i=0;i<height;++i)
 76       {
 77         for(int j=0;j<width;++j)
 78         {
 79           data_im[m*width*height+i*width+j] = m*width*height+ i*width +j;
 80           cout <<data_im[m*width*height+i*width+j] << ;
 81         }
 82         cout <<endl;
 83       }
 84     }
 85     
 86     im2col_cpu(data_im, channels,
 87      height,width, kernel_h, kernel_w,
 88     pad_h, pad_w,
 89     stride_h, stride_w,
 90     dilation_h, dilation_w,
 91      data_col);
 92     cout <<channels<<endl;
 93     cout <<output_h<<endl;
 94     cout <<output_w<<endl;
 95     cout <<kernel_h<<endl;
 96     cout <<kernel_w<<endl;
 97    // cout <<"error"<<endl;
 98     for(int i=0;i<kernel_w*kernel_h*channels;++i)
 99     {    
100         for(int j=0;j<output_w*output_h;++j)
101         {
102             cout <<data_col[i*output_w*output_h+j]<< ;
103         }
104         cout <<endl;
105     }
106 
107     return 0;
108 }

多通道卷积的图像别人已经给过很多了,大家可以搜到的基本都来自于一篇。这里附上一个我自己的理解过程,和程序的输出是完全一致的

技术分享图片

 

以上是关于Caffe中im2col的实现解析的主要内容,如果未能解决你的问题,请参考以下文章

caffe运行错误: im2col.cu:61] Check failed: error == cudaSuccess (8 vs. 0) invalid device function

caffe源码 卷积层

深度学习框架之Caffe源码解析

使用Caffe进行手写数字识别执行流程解析

[转] caffe视觉层Vision Layers 及参数

如何使用 OpenCV 在 C++ 中实现高效的 im2col 函数?