TinyDnn 源码阅读
Posted
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了TinyDnn 源码阅读相关的知识,希望对你有一定的参考价值。
const vec_t& back_propagation(const vec_t& curr_delta, size_t index) override { auto& ws = this->get_worker_storage(index); conv_layer_worker_specific_storage& cws = conv_layer_worker_storage_[index]; const vec_t& prev_out = *(cws.prev_out_padded_); const activation::function& prev_h = prev_->activation_function(); vec_t* prev_delta = (pad_type_ == padding::same) ? &cws.prev_delta_padded_ : &ws.prev_delta_; vec_t& dW = ws.dW_; vec_t& db = ws.db_; std::fill(prev_delta->begin(), prev_delta->end(), float_t(0)); // propagate delta to previous layer for_i(in_.depth_, [&](int inc) { for (cnn_size_t outc = 0; outc < out_.depth_; outc++) { if (!tbl_.is_connected(outc, inc)) continue; const float_t *pw = &this->W_[weight_.get_index(0, 0, in_.depth_ * outc + inc)]; const float_t *pdelta_src = &curr_delta[out_.get_index(0, 0, outc)]; float_t *pdelta_dst = &(*prev_delta)[in_padded_.get_index(0, 0, inc)]; for (cnn_size_t y = 0; y < out_.height_; y++) { for (cnn_size_t x = 0; x < out_.width_; x++) { const float_t * ppw = pw; const float_t ppdelta_src = pdelta_src[y * out_.width_ + x]; float_t * ppdelta_dst = pdelta_dst + y * h_stride_ * in_padded_.width_ + x * w_stride_; for (cnn_size_t wy = 0; wy < weight_.height_; wy++) { for (cnn_size_t wx = 0; wx < weight_.width_; wx++) { ppdelta_dst[wy * in_padded_.width_ + wx] += *ppw++ * ppdelta_src; } } } } } }); for_i(parallelize_, in_padded_.size(), [&](int i) { (*prev_delta)[i] *= prev_h.df(prev_out[i]); }); // accumulate dw for_i(in_.depth_, [&](int inc) { for (cnn_size_t outc = 0; outc < out_.depth_; outc++) { if (!tbl_.is_connected(outc, inc)) continue; for (cnn_size_t wy = 0; wy < weight_.height_; wy++) { for (cnn_size_t wx = 0; wx < weight_.width_; wx++) { float_t dst = float_t(0); const float_t * prevo = &prev_out[in_padded_.get_index(wx, wy, inc)]; const float_t * delta = &curr_delta[out_.get_index(0, 0, outc)]; for (cnn_size_t y = 0; y < out_.height_; y++) { dst += vectorize::dot(prevo + y * in_padded_.width_, delta + y * out_.width_, out_.width_); } dW[weight_.get_index(wx, wy, in_.depth_ * outc + inc)] += dst; } } } }); // accumulate db if (!db.empty()) { for (cnn_size_t outc = 0; outc < out_.depth_; outc++) { const float_t *delta = &curr_delta[out_.get_index(0, 0, outc)]; db[outc] += std::accumulate(delta, delta + out_.width_ * out_.height_, float_t(0)); } } if (pad_type_ == padding::same) copy_and_unpad_delta(cws.prev_delta_padded_, ws.prev_delta_); CNN_LOG_VECTOR(curr_delta, "[pc]curr_delta"); CNN_LOG_VECTOR(prev_delta_[index], "[pc]prev_delta"); CNN_LOG_VECTOR(dW, "[pc]dW"); CNN_LOG_VECTOR(db, "[pc]db"); return prev_->back_propagation(ws.prev_delta_, index); }
以上是关于TinyDnn 源码阅读的主要内容,如果未能解决你的问题,请参考以下文章