Caffe源码精读笔记(一)之caffe.cpp

        博客正式进入caffe源码精读系列,博主会在阅读源码的过程中将自己认为重要的部分记录下来,整理成笔记。由于博主之前粗略浏览过一遍源码,对caffe的架构和caffe中使用的库有了了解,这些笔记介绍整体工作流程和框架的同时,还将偏重于源码的细节。第一篇博客主要由程序入口开始介绍caffe训练的流程。

        博主博客中的部分内容参考了 caffe caffe.cpp程序入口分析 这篇博客,特此链接以示感谢!

        caffe.cpp中运用宏注册了四个函数,根据传入的参数确定调用train(),test(),device_query(),time()中的一个,它们分别独立完成各自的功能,具体细节请链接至caffe caffe.cpp程序入口分析 ,博主就不搬运了。博主关系的是train()函数,下面就来看看train函数的核心代码。

        下面的代码中省略了解析参数,设置工作gpu等等的代码,我们真正关注的训练部分就是从这里开始的。

  shared_ptr<caffe::Solver<float> >
      solver(caffe::SolverRegistry<float>::CreateSolver(solver_param));

  solver->SetActionFunction(signal_handler.GetActionFunction());

  if (FLAGS_snapshot.size()) {
    LOG(INFO) << "Resuming from " << FLAGS_snapshot;
    solver->Restore(FLAGS_snapshot.c_str());
  } else if (FLAGS_weights.size()) {
    CopyLayers(solver.get(), FLAGS_weights);
  }

  if (gpus.size() > 1) {
    caffe::P2PSync<float> sync(solver, NULL, solver->param());
    sync.run(gpus);
  } else {
    LOG(INFO) << "Starting Optimization";
    solver->Solve();
  }
  LOG(INFO) << "Optimization Done.";
  return 0;
}
RegisterBrewFunction(train);

        贯穿Caffe的四个类Blob,Layer,Net,Solver,这几个类的复杂性从低到高,我们自顶向下浏览源码,会依次接触Solver、Net、Layer、Blob这四个类,上面的代码就是从solver开始的,于是我们直接切入Solver类中:

1. SolverAction

        用户通过solver类可以执行一些操作,除了正常的训练和检测之外,还提供暂停和快照功能(下面这部分代码在solver.hpp中,却不属于Solver类)

/**
  * @brief Enumeration of actions that a client of the Solver may request by
  * implementing the Solver's action request function, which a
  * a client may optionally provide in order to request early termination
  * or saving a snapshot without exiting. In the executable caffe, this
  * mechanism is used to allow the snapshot to be saved when stopping
  * execution with a SIGINT (Ctrl-C).
  */
  namespace SolverAction {
    enum Enum {
      NONE = 0,  // Take no special action.
      STOP = 1,  // Stop training. snapshot_after_train controls whether a
                 // snapshot is created.
      SNAPSHOT = 2  // Take a snapshot, and keep training.
    };
  }

/**
 * @brief Type of a function that returns a Solver Action enumeration.
 */
typedef boost::function<SolverAction::Enum()> ActionCallback;

2.Solver类的构造与初始化

        这样我们就可以看懂了,程序是直接从protobuf配置文件中读取solver参数,然后送入solver类的构造函数中进行初始化的。这个传入的配置文件protobuf就是我们自己定义的网络结构参数,至于网络结构参数的意义,参见caffe.proto文件,里面有详细的注释,都进来的参数都被填入caffe.proto中了

        Caffe中最重要的文件就是caffe.proto,这么说一点也不为过,所有的配置参数都是流向caffe.proto,如果你已经对google的protobuf有一定了解,但是还不知道网络结构参数的具体含义,那么这里有一篇介绍caffe.proto的文档

        特别提醒,下方源码中的关键字explicit是为了防止编译器做隐式类型转换:)

/**
 * @brief An interface for classes that perform optimization on Net%s.
 *
 * Requires implementation of ApplyUpdate to compute a parameter update
 * given the current state of the Net parameters.
 */
template <typename Dtype>
class Solver {
 public:
  explicit Solver(const SolverParameter& param,
      const Solver* root_solver = NULL);
  explicit Solver(const string& param_file, const Solver* root_solver = NULL);
  void Init(const SolverParameter& param);
  void InitTrainNet();
  void InitTestNets();

3.Solver操作的回调函数

        我们最初给出的caffe.cpp中solver类的第二个操作,就是设置执行Action的回调函数,包括停止和快照(就是存储当前整个网络)

  // Client of the Solver optionally may call this in order to set the function
  // that the solver uses to see what action it should take (e.g. snapshot or
  // exit training early).
  void SetActionFunction(ActionCallback func);
  SolverAction::Enum GetRequestedAction();

 4.Solver类的核心函数

        Solve,注意这不是构造函数,这是训练过程中求解网络参数的函数,自然也是Solver类中最重要的函数。这节笔记主要讲caffe.cpp,Solver类的细节参见下节博客(抱歉,为了思路连贯先卖个关子),其中Step函数会被Solve函数调用,执行单步操作

        于是我们就了解caffe.cpp中的train函数真正开始网络的训练是在什么时候了,快回头看看源码吧,是不是在最后,一切设置妥当之后才开始的。

  // The main entry of the solver function. In default, iter will be zero. Pass
  // in a non-zero iter number to resume training for a pre-trained net.
  virtual void Solve(const char* resume_file = NULL);
  inline void Solve(const string resume_file) { Solve(resume_file.c_str()); }
  void Step(int iters);

5. 网络状态的数据存储

        这两个函数,一个用来保存当前的网络状态,另一个用来保存当前整个网络学习到的参数(快照功能)

  // The Restore method simply dispatches to one of the
  // RestoreSolverStateFrom___ protected methods. You should implement these
  // methods to restore the state from the appropriate snapshot type.
  void Restore(const char* resume_file);
  // The Solver::Snapshot function implements the basic snapshotting utility
  // that stores the learned net. You should implement the SnapshotSolverState()
  // function that produces a SolverState protocol buffer that needs to be
  // written to disk together with the learned net.
  void Snapshot();

        solver类中的public函数基本介绍完了,其中的一些protected函数(内部函数)都是供这些函数调用的底层,我们会在介绍solver.cpp的章节进行细致讲解。

        现在我们再回头看代码,caffe.cpp的训练函数做了哪些操作也就大致清楚了,我们可能会疑惑,其中具体的操作是怎样完成的呢,一开始就介绍过,自顶向下Solver、Net、Layer、Blob来看,不得不感叹Caffe庞大而精致的设计,同时也不得不承认,我们只有一步一步走下来才能读懂caffe宏大的设计,将来修改代码为我所用。

        Ok, See You Next Chapter!

发表评论