learner_process.h 1.0 KB
Newer Older
X
xiexionghang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40
/*
 *Author: xiexionghang
 *Train样本
 */
#pragma once
#include "paddle/fluid/train/custom_trainer/feed/process/process.h"
#include "paddle/fluid/train/custom_trainer/feed/executor/executor.h"

namespace paddle {
namespace custom_trainer {
namespace feed {

typedef std::vector<std::shared_ptr<Executor>> MultiExecutor;
class LearnerProcess : public Process {
public:
    LearnerProcess() {}
    virtual ~LearnerProcess() {}
    
    virtual int run();
    virtual int initialize(std::shared_ptr<TrainerContext> context_ptr);

protected:
//同步保存所有模型
virtual int wait_save_model(int epoch_id, ModelSaveWay way);
//异步保存指定模型
virtual std::future<int> save_model(int epoch_id, int table_id, ModelSaveWay way);
//执行指定训练网络
virtual int run_executor(Executor* executor);



private:
    int _executor_num = 0;    //需要执行训练的网络个数
    int _train_thread_num = 1;//并行训练线程数
    std::vector<MultiExecutor> _threads_executor;
};

}  // namespace feed
}  // namespace custom_trainer
}  // namespace paddle