博主博客中的部分内容参考了 caffe caffe.cpp程序入口分析 这篇博客,特此链接以示感谢!
caffe.cpp中运用宏注册了四个函数,根据传入的参数确定调用train(),test(),device_query(),time()中的一个,它们分别独立完成各自的功能,具体细节请链接至caffe caffe.cpp程序入口分析 ,博主就不搬运了。博主关系的是train()函数,下面就来看看train函数的核心代码。
shared_ptr<caffe::Solver<float> > solver(caffe::SolverRegistry<float>::CreateSolver(solver_param)); solver->SetActionFunction(signal_handler.GetActionFunction()); if (FLAGS_snapshot.size()) { LOG(INFO) << "Resuming from " << FLAGS_snapshot; solver->Restore(FLAGS_snapshot.c_str()); } else if (FLAGS_weights.size()) { CopyLayers(solver.get(), FLAGS_weights); } if (gpus.size() > 1) { caffe::P2PSync<float> sync(solver, NULL, solver->param()); sync.run(gpus); } else { LOG(INFO) << "Starting Optimization"; solver->Solve(); } LOG(INFO) << "Optimization Done."; return 0; } RegisterBrewFunction(train);
1. SolverAction
/** * @brief Enumeration of actions that a client of the Solver may request by * implementing the Solver's action request function, which a * a client may optionally provide in order to request early termination * or saving a snapshot without exiting. In the executable caffe, this * mechanism is used to allow the snapshot to be saved when stopping * execution with a SIGINT (Ctrl-C). */ namespace SolverAction { enum Enum { NONE = 0, // Take no special action. STOP = 1, // Stop training. snapshot_after_train controls whether a // snapshot is created. SNAPSHOT = 2 // Take a snapshot, and keep training. }; } /** * @brief Type of a function that returns a Solver Action enumeration. */ typedef boost::function<SolverAction::Enum()> ActionCallback;
/** * @brief An interface for classes that perform optimization on Net%s. * * Requires implementation of ApplyUpdate to compute a parameter update * given the current state of the Net parameters. */ template <typename Dtype> class Solver { public: explicit Solver(const SolverParameter& param, const Solver* root_solver = NULL); explicit Solver(const string& param_file, const Solver* root_solver = NULL); void Init(const SolverParameter& param); void InitTrainNet(); void InitTestNets();
// Client of the Solver optionally may call this in order to set the function // that the solver uses to see what action it should take (e.g. snapshot or // exit training early). void SetActionFunction(ActionCallback func); SolverAction::Enum GetRequestedAction();
// The main entry of the solver function. In default, iter will be zero. Pass // in a non-zero iter number to resume training for a pre-trained net. virtual void Solve(const char* resume_file = NULL); inline void Solve(const string resume_file) { Solve(resume_file.c_str()); } void Step(int iters);
5. 网络状态的数据存储
// The Restore method simply dispatches to one of the // RestoreSolverStateFrom___ protected methods. You should implement these // methods to restore the state from the appropriate snapshot type. void Restore(const char* resume_file); // The Solver::Snapshot function implements the basic snapshotting utility // that stores the learned net. You should implement the SnapshotSolverState() // function that produces a SolverState protocol buffer that needs to be // written to disk together with the learned net. void Snapshot();
Ok, See You Next Chapter!