smooth_L1_loss_layer.cu解读 caffe源码初认识
Posted 去做点事情
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了smooth_L1_loss_layer.cu解读 caffe源码初认识相关的知识,希望对你有一定的参考价值。
这是smooth_L1_loss_layer.cu的前向传播部分
#include "caffe/fast_rcnn_layers.hpp" namespace caffe { template <typename Dtype> __global__ void SmoothL1Forward(const int n, const Dtype* in, Dtype* out) { // f(x) = 0.5 * x^2 if |x| < 1 // |x| - 0.5 otherwise CUDA_KERNEL_LOOP(index, n) { Dtype val = in[index]; Dtype abs_val = abs(val); if (abs_val < 1) { out[index] = 0.5 * val * val; } else { out[index] = abs_val - 0.5; } } } template <typename Dtype> void SmoothL1LossLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom, const vector<Blob<Dtype>*>& top) { int count = bottom[0]->count(); caffe_gpu_sub( count, bottom[0]->gpu_data(), bottom[1]->gpu_data(), diff_.mutable_gpu_data()); // d := b0 - b1 if (has_weights_) { caffe_gpu_mul( count, bottom[2]->gpu_data(), diff_.gpu_data(), diff_.mutable_gpu_data()); // d := w * (b0 - b1) } SmoothL1Forward<Dtype><<<CAFFE_GET_BLOCKS(count), CAFFE_CUDA_NUM_THREADS>>>( count, diff_.gpu_data(), errors_.mutable_gpu_data()); CUDA_POST_KERNEL_CHECK; Dtype loss; caffe_gpu_asum(count, errors_.gpu_data(), &loss); top[0]->mutable_cpu_data()[0] = loss / bottom[0]->num(); }
blob的主要变量:
shared_ptr<SyncedMemory> data_; shared_ptr<SyncedMemory> diff_; vector<int> shape_; int count_; int capacity_;
blob只是一个基本的数据结构,因此内部的变量相对较少,首先是data_
指针,指针类型是shared_ptr,属于boost库的一个智能指针,这一部分主要用来申请内存存储data,data主要是正向传播的时候用的。同理,diff_
主要用来存储偏差,shape_
都是存储Blob的形状,count
表示Blob中的元素个数,也就是个数*通道数*高度*宽度
,capacity
表示当前的元素个数,因为Blob可能会reshape。
blob类里面有重载很多个count()
函数,主要还是为了统计blob的容量(volume),或者是某一片(slice),从某个axis到具体某个axis的shape乘积。
inline int count(int start_axis, int end_axis)
int count = bottom[0]->count(); count()没带参数,计算的是bottom[0]这个输入blob所有的元素个数。这里就是计算一个迭代期的所有图片的所有通道的所有坐标点形成的blob数据结构元素的个数。
top[0]->mutable_cpu_data()[0] = loss / bottom[0]->num(); num()是计算一个迭代期参与的所有图片的个数。这里就是求一个迭代期所有几张图片的平均loss。
caffe_gpu_asum(count, errors_.gpu_data(), &loss); caffe_gpu_asum是对
以上是关于smooth_L1_loss_layer.cu解读 caffe源码初认识的主要内容,如果未能解决你的问题,请参考以下文章
Java中系统时间的获取_currentTimeMillis()函数应用解读