learner_process.h 1.2 KB
Newer Older
X
xiexionghang 已提交
1 2 3 4 5 6
/*
 *Author: xiexionghang
 *Train样本
 */
#pragma once
#include "paddle/fluid/train/custom_trainer/feed/process/process.h"
X
xiexionghang 已提交
7
#include "paddle/fluid/train/custom_trainer/feed/executor/multi_thread_executor.h"
X
xiexionghang 已提交
8 9 10 11 12 13 14 15 16 17 18 19 20

namespace paddle {
namespace custom_trainer {
namespace feed {
class LearnerProcess : public Process {
public:
    LearnerProcess() {}
    virtual ~LearnerProcess() {}
    
    virtual int run();
    virtual int initialize(std::shared_ptr<TrainerContext> context_ptr);

protected:
X
xiexionghang 已提交
21 22
// 加载所有模型
virtual int load_model(uint64_t epoch_id);
X
xiexionghang 已提交
23 24
// 同步保存所有模型, is_force_dump:不判断dump条件,强制dump出模型
virtual int wait_save_model(uint64_t epoch_id, ModelSaveWay way, bool is_force_dump = false);
X
xiexionghang 已提交
25
virtual int update_cache_model(uint64_t epoch_id, ModelSaveWay way);
X
xiexionghang 已提交
26 27

private:
X
xiexionghang 已提交
28 29 30 31
    bool _is_dump_cache_model;          // 是否进行cache dump
    uint32_t _cache_sign_max_num = 0;   // cache sign最大个数
    std::string _cache_load_converter;  // cache加载的前置转换脚本
    bool _startup_dump_inference_base;  // 启动立即dump base
X
xiexionghang 已提交
32
    std::vector<std::shared_ptr<MultiThreadExecutor>> _executors;
X
xiexionghang 已提交
33 34 35 36 37
};

}  // namespace feed
}  // namespace custom_trainer
}  // namespace paddle