learner_process.h 951 字节
Newer Older
X
xiexionghang 已提交
1 2 3 4 5 6
/*
 *Author: xiexionghang
 *Train样本
 */
#pragma once
#include "paddle/fluid/train/custom_trainer/feed/process/process.h"
X
xiexionghang 已提交
7
#include "paddle/fluid/train/custom_trainer/feed/executor/multi_thread_executor.h"
X
xiexionghang 已提交
8 9 10 11 12 13 14 15 16 17 18 19 20

namespace paddle {
namespace custom_trainer {
namespace feed {
class LearnerProcess : public Process {
public:
    LearnerProcess() {}
    virtual ~LearnerProcess() {}
    
    virtual int run();
    virtual int initialize(std::shared_ptr<TrainerContext> context_ptr);

protected:
X
xiexionghang 已提交
21 22
// 加载所有模型
virtual int load_model(uint64_t epoch_id);
X
xiexionghang 已提交
23 24
// 同步保存所有模型, is_force_dump:不判断dump条件,强制dump出模型
virtual int wait_save_model(uint64_t epoch_id, ModelSaveWay way, bool is_force_dump = false);
X
xiexionghang 已提交
25 26

private:
X
xiexionghang 已提交
27
    bool _startup_dump_inference_base;  //启动立即dump base
X
xiexionghang 已提交
28
    std::vector<std::shared_ptr<MultiThreadExecutor>> _executors;
X
xiexionghang 已提交
29 30 31 32 33
};

}  // namespace feed
}  // namespace custom_trainer
}  // namespace paddle