learner_process.h 804 字节
Newer Older
X
xiexionghang 已提交
1 2 3 4 5 6
/*
 *Author: xiexionghang
 *Train样本
 */
#pragma once
#include "paddle/fluid/train/custom_trainer/feed/process/process.h"
X
xiexionghang 已提交
7
#include "paddle/fluid/train/custom_trainer/feed/executor/multi_thread_executor.h"
X
xiexionghang 已提交
8 9 10 11 12 13 14 15 16 17 18 19 20

namespace paddle {
namespace custom_trainer {
namespace feed {
class LearnerProcess : public Process {
public:
    LearnerProcess() {}
    virtual ~LearnerProcess() {}
    
    virtual int run();
    virtual int initialize(std::shared_ptr<TrainerContext> context_ptr);

protected:
X
xiexionghang 已提交
21 22
// 加载所有模型
virtual int load_model(uint64_t epoch_id);
X
xiexionghang 已提交
23
// 同步保存所有模型
X
xiexionghang 已提交
24
virtual int wait_save_model(uint64_t epoch_id, ModelSaveWay way);
X
xiexionghang 已提交
25 26

private:
X
xiexionghang 已提交
27
    std::vector<std::shared_ptr<MultiThreadExecutor>> _executors;
X
xiexionghang 已提交
28 29 30 31 32
};

}  // namespace feed
}  // namespace custom_trainer
}  // namespace paddle