trainer_context.h 1.3 KB
Newer Older
X
xiexionghang 已提交
1 2 3 4 5 6
#pragma once
#include <string>
#include <memory>
#include <vector>
#include <yaml-cpp/yaml.h>
#include "paddle/fluid/platform/place.h"
X
xiexionghang 已提交
7
#include "paddle/fluid/train/custom_trainer/feed/common/pslib_warpper.h"
X
xiexionghang 已提交
8 9 10 11 12 13 14 15
#include "paddle/fluid/train/custom_trainer/feed/common/runtime_environment.h"


namespace paddle {
namespace custom_trainer {
namespace feed {

class Process;
X
xiexionghang 已提交
16 17
class Dataset;
class FileSystem;
X
xiexionghang 已提交
18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35
class EpochAccessor;

enum class ModelSaveWay {
    ModelSaveTrainCheckpoint = 0,
    ModelSaveInferenceDelta = 1,
    ModelSaveInferenceBase = 2
};

class TableMeta {
public:
    TableMeta() {}
    ~TableMeta() {}
    int table_id() {
        return _id;
    }
private:
    int _id;
};
X
xiexionghang 已提交
36 37 38 39 40

class TrainerContext {
public:
YAML::Node trainer_config;
paddle::platform::CPUPlace cpu_place;
X
xiexionghang 已提交
41

X
xiexionghang 已提交
42
std::shared_ptr<PSlib> pslib;
X
xiexionghang 已提交
43 44 45 46 47 48
std::shared_ptr<Dataset> dataset;                          //训练样本
std::shared_ptr<FileSystem> file_system;                   //文件操作辅助类
std::vector<TableMeta> params_table_list;                  //参数表
std::shared_ptr<EpochAccessor> epoch_accessor;             //训练轮次控制
std::shared_ptr<RuntimeEnvironment> environment;           //运行环境
std::vector<std::shared_ptr<Process>> process_list;        //训练流程
X
xiexionghang 已提交
49 50 51 52 53
};

}  // namespace feed
}  // namespace custom_trainer
}  // namespace paddle