trainer_context.h 3.8 KB
Newer Older
X
xiexionghang 已提交
1 2 3 4
#pragma once
#include <string>
#include <memory>
#include <vector>
5
#include <sstream>
X
xiexionghang 已提交
6
#include <tsl/bhopscotch_map.h>
X
xiexionghang 已提交
7
#include "paddle/fluid/platform/place.h"
X
xiexionghang 已提交
8
#include "paddle/fluid/train/custom_trainer/feed/common/yaml_helper.h"
X
xiexionghang 已提交
9
#include "paddle/fluid/train/custom_trainer/feed/common/pslib_warpper.h"
X
xiexionghang 已提交
10 11 12 13 14 15 16
#include "paddle/fluid/train/custom_trainer/feed/common/runtime_environment.h"


namespace paddle {
namespace custom_trainer {
namespace feed {

X
xiexionghang 已提交
17
class PSlib;
X
xiexionghang 已提交
18
class Process;
X
xiexionghang 已提交
19 20
class Dataset;
class FileSystem;
X
xiexionghang 已提交
21 22
class EpochAccessor;

X
xiexionghang 已提交
23 24 25 26
const uint32_t SecondsPerMin = 60;
const uint32_t SecondsPerHour = 3600;
const uint32_t SecondsPerDay = 24 * 3600;

X
xiexionghang 已提交
27 28 29
enum class ModelSaveWay {
    ModelSaveTrainCheckpoint = 0,
    ModelSaveInferenceDelta = 1,
L
linan17 已提交
30 31
    ModelSaveInferenceBase = 2,
    ModelSaveTrainCheckpointBase = 3,
X
xiexionghang 已提交
32 33
};

X
xiexionghang 已提交
34 35 36 37 38
enum class TrainerStatus {
    Training  = 0,  // 训练状态
    Saving    = 1  // 模型存储状态
};

X
xiexionghang 已提交
39 40 41 42 43 44 45 46
const uint32_t SignCacheMaxValueNum = 13;
struct SignCacheData {
    SignCacheData() {
        memset(cache_value, 0, sizeof(float) * SignCacheMaxValueNum);
    }
    uint32_t idx;
    float cache_value[SignCacheMaxValueNum];
};
W
wangyihong01 已提交
47
class SignCacheDict {
X
xiexionghang 已提交
48
public:
X
xiexionghang 已提交
49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81
    inline int32_t sign2index(uint64_t sign) {
        auto itr = _sign2data_map.find(sign);
        if (itr == _sign2data_map.end()) {  
            return -1;
        } 
        return itr->second.idx;
    }

    inline uint64_t index2sign(int32_t index) {
        if (index >= _sign_list.size()) {
            return 0;
        } 
        return _sign_list[index];
    }

    inline void reserve(uint32_t size) {
        _sign_list.reserve(size);
        _sign2data_map.reserve(size);
    }

    inline void clear() {
        _sign_list.clear();
        _sign2data_map.clear();
    }

    inline void append(uint64_t sign) {
        if (_sign2data_map.find(sign) != _sign2data_map.end()) {
            return;
        }
        SignCacheData data;
        data.idx = _sign_list.size();
        _sign_list.push_back(sign);
        _sign2data_map.emplace(sign, std::move(data));
W
wangyihong01 已提交
82 83
    }

X
xiexionghang 已提交
84 85 86 87 88 89
    inline SignCacheData* data(uint64_t sign) {
        tsl::bhopscotch_pg_map<uint64_t, SignCacheData>::iterator itr = _sign2data_map.find(sign);
        if (itr == _sign2data_map.end()) {
            return nullptr;
        }
        return const_cast<SignCacheData*>(&(itr->second));
X
xiexionghang 已提交
90
    }
X
xiexionghang 已提交
91 92 93
private:
    std::vector<uint64_t> _sign_list;
    tsl::bhopscotch_pg_map<uint64_t, SignCacheData> _sign2data_map;
X
xiexionghang 已提交
94
};
X
xiexionghang 已提交
95 96 97

class TrainerContext {
public:
X
xiexionghang 已提交
98 99 100
    inline paddle::ps::PSClient* ps_client() {
        return pslib->ps_client();
    }
X
xiexionghang 已提交
101 102 103 104 105 106 107 108 109
    inline bool is_status(TrainerStatus status) {
        auto bit_idx = static_cast<uint32_t>(status);
        return ((trainer_status >> bit_idx) & 1) > 0;
    }
    // 非线程安全, 其实TrainerContext所有成员的线程安全性 取决于 成员本身的线程安全性
    inline void set_status(TrainerStatus status, bool on) {
        auto bit_idx = static_cast<uint32_t>(status);
        trainer_status = trainer_status & (1L << bit_idx);
    }
X
xiexionghang 已提交
110

X
xiexionghang 已提交
111
    uint32_t trainer_status;      // trainer当前,由于可同时处于多种状态,这里分bit存储状态
X
xiexionghang 已提交
112 113 114 115
    YAML::Node trainer_config;
    paddle::platform::CPUPlace cpu_place;

    std::shared_ptr<PSlib> pslib;
116
    std::stringstream monitor_ssm;                             //记录monitor信息
X
xiexionghang 已提交
117 118 119 120 121 122
    std::shared_ptr<Dataset> dataset;                          //训练样本
    std::shared_ptr<FileSystem> file_system;                   //文件操作辅助类
    std::shared_ptr<EpochAccessor> epoch_accessor;             //训练轮次控制
    std::shared_ptr<RuntimeEnvironment> environment;           //运行环境
    std::vector<std::shared_ptr<Process>> process_list;        //训练流程
    std::shared_ptr<SignCacheDict> cache_dict;                 //大模型cache词典
X
xiexionghang 已提交
123 124 125 126 127
};

}  // namespace feed
}  // namespace custom_trainer
}  // namespace paddle