TinyDnn 源码阅读

Posted

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了TinyDnn 源码阅读相关的知识,希望对你有一定的参考价值。

const vec_t& back_propagation(const vec_t& curr_delta, size_t index) override {
        auto& ws = this->get_worker_storage(index);
        conv_layer_worker_specific_storage& cws = conv_layer_worker_storage_[index];

        const vec_t& prev_out = *(cws.prev_out_padded_);
        const activation::function& prev_h = prev_->activation_function();
        vec_t* prev_delta = (pad_type_ == padding::same) ? &cws.prev_delta_padded_ : &ws.prev_delta_;
        vec_t& dW = ws.dW_;
        vec_t& db = ws.db_;

        std::fill(prev_delta->begin(), prev_delta->end(), float_t(0));

        // propagate delta to previous layer
        for_i(in_.depth_, [&](int inc) {
            for (cnn_size_t outc = 0; outc < out_.depth_; outc++) {
                if (!tbl_.is_connected(outc, inc)) continue;

                const float_t *pw = &this->W_[weight_.get_index(0, 0, in_.depth_ * outc + inc)];
                const float_t *pdelta_src = &curr_delta[out_.get_index(0, 0, outc)];
                float_t *pdelta_dst = &(*prev_delta)[in_padded_.get_index(0, 0, inc)];

                for (cnn_size_t y = 0; y < out_.height_; y++) {
                    for (cnn_size_t x = 0; x < out_.width_; x++) {
                        const float_t * ppw = pw;
                        const float_t ppdelta_src = pdelta_src[y * out_.width_ + x];
                        float_t * ppdelta_dst = pdelta_dst + y * h_stride_ * in_padded_.width_ + x * w_stride_;

                        for (cnn_size_t wy = 0; wy < weight_.height_; wy++) {
                            for (cnn_size_t wx = 0; wx < weight_.width_; wx++) {
                                ppdelta_dst[wy * in_padded_.width_ + wx] += *ppw++ * ppdelta_src;
                            }
                        }
                    }
                }
            }
        });

        for_i(parallelize_, in_padded_.size(), [&](int i) {
            (*prev_delta)[i] *= prev_h.df(prev_out[i]);
        });

        // accumulate dw
        for_i(in_.depth_, [&](int inc) {
            for (cnn_size_t outc = 0; outc < out_.depth_; outc++) {

                if (!tbl_.is_connected(outc, inc)) continue;

                for (cnn_size_t wy = 0; wy < weight_.height_; wy++) {
                    for (cnn_size_t wx = 0; wx < weight_.width_; wx++) {
                        float_t dst = float_t(0);
                        const float_t * prevo = &prev_out[in_padded_.get_index(wx, wy, inc)];
                        const float_t * delta = &curr_delta[out_.get_index(0, 0, outc)];

                        for (cnn_size_t y = 0; y < out_.height_; y++) {
                            dst += vectorize::dot(prevo + y * in_padded_.width_, delta + y * out_.width_, out_.width_);
                        }
                        dW[weight_.get_index(wx, wy, in_.depth_ * outc + inc)] += dst;
                    }
                }
            }
        });

        // accumulate db
        if (!db.empty()) {
            for (cnn_size_t outc = 0; outc < out_.depth_; outc++) {
                const float_t *delta = &curr_delta[out_.get_index(0, 0, outc)];
                db[outc] += std::accumulate(delta, delta + out_.width_ * out_.height_, float_t(0));
            }
        }

        if (pad_type_ == padding::same)
            copy_and_unpad_delta(cws.prev_delta_padded_, ws.prev_delta_);

        CNN_LOG_VECTOR(curr_delta, "[pc]curr_delta");
        CNN_LOG_VECTOR(prev_delta_[index], "[pc]prev_delta");
        CNN_LOG_VECTOR(dW, "[pc]dW");
        CNN_LOG_VECTOR(db, "[pc]db");

        return prev_->back_propagation(ws.prev_delta_, index);
    }

  

以上是关于TinyDnn 源码阅读的主要内容,如果未能解决你的问题,请参考以下文章

Python代码阅读(第26篇):将列表映射成字典

如何进行 Java 代码阅读分析?

Python代码阅读(第41篇):矩阵转置

Python代码阅读(第25篇):将多行字符串拆分成列表

JDK源码阅读之 HashMap

Python代码阅读(第13篇):检测列表中的元素是否都一样