trainer_context.h 2.5 KB
Newer Older
X
xiexionghang 已提交
1 2 3 4
#pragma once
#include <string>
#include <memory>
#include <vector>
5
#include <sstream>
X
xiexionghang 已提交
6
#include "paddle/fluid/platform/place.h"
X
xiexionghang 已提交
7
#include "paddle/fluid/train/custom_trainer/feed/common/yaml_helper.h"
X
xiexionghang 已提交
8
#include "paddle/fluid/train/custom_trainer/feed/common/pslib_warpper.h"
X
xiexionghang 已提交
9 10 11 12 13 14 15
#include "paddle/fluid/train/custom_trainer/feed/common/runtime_environment.h"


namespace paddle {
namespace custom_trainer {
namespace feed {

X
xiexionghang 已提交
16
class PSlib;
X
xiexionghang 已提交
17
class Process;
X
xiexionghang 已提交
18 19
class Dataset;
class FileSystem;
X
xiexionghang 已提交
20 21
class EpochAccessor;

X
xiexionghang 已提交
22 23 24 25
const uint32_t SecondsPerMin = 60;
const uint32_t SecondsPerHour = 3600;
const uint32_t SecondsPerDay = 24 * 3600;

X
xiexionghang 已提交
26 27 28
enum class ModelSaveWay {
    ModelSaveTrainCheckpoint = 0,
    ModelSaveInferenceDelta = 1,
L
linan17 已提交
29 30
    ModelSaveInferenceBase = 2,
    ModelSaveTrainCheckpointBase = 3,
X
xiexionghang 已提交
31 32
};

X
xiexionghang 已提交
33 34 35 36 37
enum class TrainerStatus {
    Training  = 0,  // 训练状态
    Saving    = 1  // 模型存储状态
};

W
wangyihong01 已提交
38
class SignCacheDict {
X
xiexionghang 已提交
39
public:
W
wangyihong01 已提交
40 41 42 43 44 45
    int32_t sign2index(uint64_t sign) {
        return -1;
    }

    uint64_t index2sign(int32_t index) {
        return 0;
X
xiexionghang 已提交
46 47
    }
};
X
xiexionghang 已提交
48 49 50

class TrainerContext {
public:
X
xiexionghang 已提交
51 52 53
    inline paddle::ps::PSClient* ps_client() {
        return pslib->ps_client();
    }
X
xiexionghang 已提交
54 55 56 57 58 59 60 61 62
    inline bool is_status(TrainerStatus status) {
        auto bit_idx = static_cast<uint32_t>(status);
        return ((trainer_status >> bit_idx) & 1) > 0;
    }
    // 非线程安全, 其实TrainerContext所有成员的线程安全性 取决于 成员本身的线程安全性
    inline void set_status(TrainerStatus status, bool on) {
        auto bit_idx = static_cast<uint32_t>(status);
        trainer_status = trainer_status & (1L << bit_idx);
    }
X
xiexionghang 已提交
63

X
xiexionghang 已提交
64
    uint32_t trainer_status;      // trainer当前,由于可同时处于多种状态,这里分bit存储状态
X
xiexionghang 已提交
65 66 67 68
    YAML::Node trainer_config;
    paddle::platform::CPUPlace cpu_place;

    std::shared_ptr<PSlib> pslib;
69
    std::stringstream monitor_ssm;                             //记录monitor信息
X
xiexionghang 已提交
70 71 72 73 74 75
    std::shared_ptr<Dataset> dataset;                          //训练样本
    std::shared_ptr<FileSystem> file_system;                   //文件操作辅助类
    std::shared_ptr<EpochAccessor> epoch_accessor;             //训练轮次控制
    std::shared_ptr<RuntimeEnvironment> environment;           //运行环境
    std::vector<std::shared_ptr<Process>> process_list;        //训练流程
    std::shared_ptr<SignCacheDict> cache_dict;                 //大模型cache词典
X
xiexionghang 已提交
76 77 78 79 80
};

}  // namespace feed
}  // namespace custom_trainer
}  // namespace paddle