caffe Solve函数
Posted Exploring...
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了caffe Solve函数相关的知识,希望对你有一定的参考价值。
下面来看Solver<Dtype>::Solve(const char* resume_file)
solver.cpp
template <typename Dtype> void Solver<Dtype>::Solve(const char* resume_file) { CHECK(Caffe::root_solver()); LOG(INFO) << "Solving " << net_->name(); LOG(INFO) << "Learning Rate Policy: " << param_.lr_policy(); // Initialize to false every time we start solving. requested_early_exit_ = false; if (resume_file) { LOG(INFO) << "Restoring previous solver status from " << resume_file; // 从以前中断的训练状态中恢复训练 Restore(resume_file); } // For a network that is trained by the solver, no bottom or top vecs // should be given, and we will just provide dummy vecs. int start_iter = iter_; // 主要的迭代过程都在这里 Step(param_.max_iter() - iter_); // If we haven\'t already, save a snapshot after optimization, unless // overridden by setting snapshot_after_train := false if (param_.snapshot_after_train() && (!param_.snapshot() || iter_ % param_.snapshot() != 0)) { Snapshot(); } if (requested_early_exit_) { LOG(INFO) << "Optimization stopped early."; return; } // After the optimization is done, run an additional train and test pass to // display the train and test loss/outputs if appropriate (based on the // display and test_interval settings, respectively). Unlike in the rest of // training, for the train net we only run a forward pass as we\'ve already // updated the parameters "max_iter" times -- this final pass is only done to // display the loss, which is computed in the forward pass. if (param_.display() && iter_ % param_.display() == 0) { int average_loss = this->param_.average_loss(); Dtype loss; net_->Forward(&loss); UpdateSmoothedLoss(loss, start_iter, average_loss); LOG(INFO) << "Iteration " << iter_ << ", loss = " << smoothed_loss_; } if (param_.test_interval() && iter_ % param_.test_interval() == 0) { TestAll(); } LOG(INFO) << "Optimization Done."; }
下面先看Solve中的Restore(resume_file)
solver.cpp
template <typename Dtype> void Solver<Dtype>::Restore(const char* state_file) { string state_filename(state_file); if (state_filename.size() >= 3 && state_filename.compare(state_filename.size() - 3, 3, ".h5") == 0) { RestoreSolverStateFromHDF5(state_filename); } else { RestoreSolverStateFromBinaryProto(state_filename); } }
上面的RestoreSolverStateFromHDF5(state_filename)和RestoreSolverStateFromBinaryProto(state_filename)都是虚函数,调用的其实是其派生类的同名方法。例如,若使用SGD求解,SGDSolver类中的RestoreSolverStateFromBinaryProto方法如下
sgd_solver.cpp
template <typename Dtype> void SGDSolver<Dtype>::RestoreSolverStateFromBinaryProto( const string& state_file) { SolverState state; ReadProtoFromBinaryFile(state_file, &state); // 此处获取上次训练中断时的迭代次数 this->iter_ = state.iter(); if (state.has_learned_net()) { NetParameter net_param; ReadNetParamsFromBinaryFileOrDie(state.learned_net().c_str(), &net_param); this->net_->CopyTrainedLayersFrom(net_param); } this->current_step_ = state.current_step(); CHECK_EQ(state.history_size(), history_.size()) << "Incorrect length of history blobs."; LOG(INFO) << "SGDSolver: restoring history"; for (int i = 0; i < history_.size(); ++i) { history_[i]->FromProto(state.history(i)); } }
下面主要分析Solve中的Step(param_.max_iter() - iter_)
solver.cpp
template <typename Dtype> void Solver<Dtype>::Step(int iters) { const int start_iter = iter_; const int stop_iter = iter_ + iters; int average_loss = this->param_.average_loss(); losses_.clear(); smoothed_loss_ = 0; iteration_timer_.Start(); while (iter_ < stop_iter) { // zero-init the params // 将网络中参数的梯度清零 net_->ClearParamDiffs(); if (param_.test_interval() && iter_ % param_.test_interval() == 0 && (iter_ > 0 || param_.test_initialization())) { if (Caffe::root_solver()) { TestAll(); } if (requested_early_exit_) { // Break out of the while loop because stop was requested while testing. break; } } for (int i = 0; i < callbacks_.size(); ++i) { callbacks_[i]->on_start(); } const bool display = param_.display() && iter_ % param_.display() == 0; net_->set_debug_info(display && param_.debug_info()); // accumulate the loss and gradient Dtype loss = 0; // param.iter_size_默认是1,正常情况下,此处其实只进行了以次前向和反向传播 for (int i = 0; i < param_.iter_size(); ++i) { loss += net_->ForwardBackward(); } loss /= param_.iter_size(); // average the loss across iterations for smoothed reporting UpdateSmoothedLoss(loss, start_iter, average_loss); if (display) { float lapse = iteration_timer_.Seconds(); float per_s = (iter_ - iterations_last_) / (lapse ? lapse : 1); LOG_IF(INFO, Caffe::root_solver()) << "Iteration " << iter_ << " (" << per_s << " iter/s, " << lapse << "s/" << param_.display() << " iters), loss = " << smoothed_loss_; iteration_timer_.Start(); iterations_last_ = iter_; const vector<Blob<Dtype>*>& result = net_->output_blobs(); int score_index = 0; for (int j = 0; j < result.size(); ++j) { const Dtype* result_vec = result[j]->cpu_data(); const string& output_name = net_->blob_names()[net_->output_blob_indices()[j]]; const Dtype loss_weight = net_->blob_loss_weights()[net_->output_blob_indices()[j]]; for (int k = 0; k < result[j]->count(); ++k) { ostringstream loss_msg_stream; if (loss_weight) { loss_msg_stream << " (* " << loss_weight << " = " << loss_weight * result_vec[k] << " loss)"; } LOG_IF(INFO, Caffe::root_solver()) << " Train net output #" << score_index++ << ": " << output_name << " = " << result_vec[k] << loss_msg_stream.str(); } } } for (int i = 0; i < callbacks_.size(); ++i) { callbacks_[i]->on_gradients_ready(); } // 网络的参数在此处更新。该函数是一个虚函数,具体由Solver的派生类来实现 ApplyUpdate(); // Increment the internal iter_ counter -- its value should always indicate // the number of times the weights have been updated. // 每次迭代其实是一个batch_size个样本输入网络中,将它们产生的网络参数的梯度加起来作为一次迭代的参数梯度。然后用这个梯度跟据一定的正则化方法、参数更新策略来更新参数 ++iter_; SolverAction::Enum request = GetRequestedAction(); // Save a snapshot if needed. if ((param_.snapshot() && iter_ % param_.snapshot() == 0 && Caffe::root_solver()) || (request == SolverAction::SNAPSHOT)) { Snapshot(); } if (SolverAction::STOP == request) { requested_early_exit_ = true; // Break out of training loop. break; } } }
上面的loss += net_->ForwardBackward()是训练过程的核心。这行代码的功能是取一个batch_size数据,让其在网络中进行一次前向传播,得出损失的均值;再进行一次反向传播,得出网络参数的梯度(该梯度是一个batch_size数据产生的梯度的均值)。详细分析见下一章节。
以上是关于caffe Solve函数的主要内容,如果未能解决你的问题,请参考以下文章