trainer_context.h 4.2 KB
Newer Older
X
xiexionghang 已提交
1 2 3 4
#pragma once
#include <string>
#include <memory>
#include <vector>
5
#include <sstream>
X
xiexionghang 已提交
6
#include <tsl/bhopscotch_map.h>
X
xiexionghang 已提交
7
#include "paddle/fluid/platform/place.h"
X
xiexionghang 已提交
8
#include "paddle/fluid/train/custom_trainer/feed/common/yaml_helper.h"
X
xiexionghang 已提交
9
#include "paddle/fluid/train/custom_trainer/feed/common/pslib_warpper.h"
X
xiexionghang 已提交
10 11 12 13 14 15 16
#include "paddle/fluid/train/custom_trainer/feed/common/runtime_environment.h"


namespace paddle {
namespace custom_trainer {
namespace feed {

X
xiexionghang 已提交
17
class PSlib;
X
xiexionghang 已提交
18
class Process;
X
xiexionghang 已提交
19 20
class Dataset;
class FileSystem;
X
xiexionghang 已提交
21 22
class EpochAccessor;

X
xiexionghang 已提交
23 24 25 26
const uint32_t SecondsPerMin = 60;
const uint32_t SecondsPerHour = 3600;
const uint32_t SecondsPerDay = 24 * 3600;

X
xiexionghang 已提交
27 28 29
enum class ModelSaveWay {
    ModelSaveTrainCheckpoint = 0,
    ModelSaveInferenceDelta = 1,
L
linan17 已提交
30 31
    ModelSaveInferenceBase = 2,
    ModelSaveTrainCheckpointBase = 3,
X
xiexionghang 已提交
32 33
};

X
xiexionghang 已提交
34 35 36 37 38
enum class TrainerStatus {
    Training  = 0,  // 训练状态
    Saving    = 1  // 模型存储状态
};

X
xiexionghang 已提交
39 40 41 42 43 44 45 46
const uint32_t SignCacheMaxValueNum = 13;
struct SignCacheData {
    SignCacheData() {
        memset(cache_value, 0, sizeof(float) * SignCacheMaxValueNum);
    }
    uint32_t idx;
    float cache_value[SignCacheMaxValueNum];
};
W
wangyihong01 已提交
47
class SignCacheDict {
X
xiexionghang 已提交
48
public:
X
xiexionghang 已提交
49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81
    inline int32_t sign2index(uint64_t sign) {
        auto itr = _sign2data_map.find(sign);
        if (itr == _sign2data_map.end()) {  
            return -1;
        } 
        return itr->second.idx;
    }

    inline uint64_t index2sign(int32_t index) {
        if (index >= _sign_list.size()) {
            return 0;
        } 
        return _sign_list[index];
    }

    inline void reserve(uint32_t size) {
        _sign_list.reserve(size);
        _sign2data_map.reserve(size);
    }

    inline void clear() {
        _sign_list.clear();
        _sign2data_map.clear();
    }

    inline void append(uint64_t sign) {
        if (_sign2data_map.find(sign) != _sign2data_map.end()) {
            return;
        }
        SignCacheData data;
        data.idx = _sign_list.size();
        _sign_list.push_back(sign);
        _sign2data_map.emplace(sign, std::move(data));
W
wangyihong01 已提交
82 83
    }

X
xiexionghang 已提交
84 85 86 87 88 89
    inline SignCacheData* data(uint64_t sign) {
        tsl::bhopscotch_pg_map<uint64_t, SignCacheData>::iterator itr = _sign2data_map.find(sign);
        if (itr == _sign2data_map.end()) {
            return nullptr;
        }
        return const_cast<SignCacheData*>(&(itr->second));
X
xiexionghang 已提交
90
    }
X
xiexionghang 已提交
91 92 93
private:
    std::vector<uint64_t> _sign_list;
    tsl::bhopscotch_pg_map<uint64_t, SignCacheData> _sign2data_map;
X
xiexionghang 已提交
94
};
X
xiexionghang 已提交
95 96 97

class TrainerContext {
public:
X
xiexionghang 已提交
98 99 100
    TrainerContext() {
        trainer_status.resize(2, 0);
    }
X
xiexionghang 已提交
101 102 103
    inline paddle::ps::PSClient* ps_client() {
        return pslib->ps_client();
    }
X
xiexionghang 已提交
104
    inline bool is_status(TrainerStatus status) {
X
xiexionghang 已提交
105 106
        auto status_idx = static_cast<uint32_t>(status);
        return trainer_status[status_idx] > 0;
X
xiexionghang 已提交
107 108 109
    }
    // 非线程安全, 其实TrainerContext所有成员的线程安全性 取决于 成员本身的线程安全性
    inline void set_status(TrainerStatus status, bool on) {
X
xiexionghang 已提交
110 111
        auto status_idx = static_cast<uint32_t>(status);
        trainer_status[status_idx] = on ? 1 : 0;
X
xiexionghang 已提交
112
    }
X
xiexionghang 已提交
113

X
xiexionghang 已提交
114
    std::vector<uint32_t> trainer_status;
X
xiexionghang 已提交
115 116 117 118
    YAML::Node trainer_config;
    paddle::platform::CPUPlace cpu_place;

    std::shared_ptr<PSlib> pslib;
119
    std::stringstream monitor_ssm;                             //记录monitor信息
X
xiexionghang 已提交
120 121 122 123 124 125
    std::shared_ptr<Dataset> dataset;                          //训练样本
    std::shared_ptr<FileSystem> file_system;                   //文件操作辅助类
    std::shared_ptr<EpochAccessor> epoch_accessor;             //训练轮次控制
    std::shared_ptr<RuntimeEnvironment> environment;           //运行环境
    std::vector<std::shared_ptr<Process>> process_list;        //训练流程
    std::shared_ptr<SignCacheDict> cache_dict;                 //大模型cache词典
X
xiexionghang 已提交
126 127
};

X
xiexionghang 已提交
128 129 130 131 132 133 134 135 136 137 138 139 140 141
class ContextStatusGurad {
public:
    ContextStatusGurad(TrainerContext* context, TrainerStatus status) :
        _context(context), _status(status) {
        _context->set_status(_status, true);
    }
    virtual ~ContextStatusGurad() {
        _context->set_status(_status, false);
    }
private:
    TrainerStatus _status;
    TrainerContext* _context = nullptr;
};

X
xiexionghang 已提交
142 143 144
}  // namespace feed
}  // namespace custom_trainer
}  // namespace paddle