#pragma once #include #include #include #include #include #include "paddle/fluid/platform/place.h" #include "paddle/fluid/train/custom_trainer/feed/common/yaml_helper.h" #include "paddle/fluid/train/custom_trainer/feed/common/pslib_warpper.h" #include "paddle/fluid/train/custom_trainer/feed/common/runtime_environment.h" namespace paddle { namespace custom_trainer { namespace feed { class PSlib; class Process; class Dataset; class FileSystem; class EpochAccessor; const uint32_t SecondsPerMin = 60; const uint32_t SecondsPerHour = 3600; const uint32_t SecondsPerDay = 24 * 3600; enum class ModelSaveWay { ModelSaveTrainCheckpoint = 0, ModelSaveInferenceDelta = 1, ModelSaveInferenceBase = 2, ModelSaveTrainCheckpointBase = 3, }; enum class TrainerStatus { Training = 0, // 训练状态 Saving = 1 // 模型存储状态 }; const uint32_t SignCacheMaxValueNum = 13; struct SignCacheData { SignCacheData() { memset(cache_value, 0, sizeof(float) * SignCacheMaxValueNum); } uint32_t idx; float cache_value[SignCacheMaxValueNum]; }; class SignCacheDict { public: inline int32_t sign2index(uint64_t sign) { auto itr = _sign2data_map.find(sign); if (itr == _sign2data_map.end()) { return -1; } return itr->second.idx; } inline uint64_t index2sign(int32_t index) { if (index >= _sign_list.size()) { return 0; } return _sign_list[index]; } inline void reserve(uint32_t size) { _sign_list.reserve(size); _sign2data_map.reserve(size); } inline void clear() { _sign_list.clear(); _sign2data_map.clear(); } inline void append(uint64_t sign) { if (_sign2data_map.find(sign) != _sign2data_map.end()) { return; } SignCacheData data; data.idx = _sign_list.size(); _sign_list.push_back(sign); _sign2data_map.emplace(sign, std::move(data)); } inline SignCacheData* data(uint64_t sign) { tsl::bhopscotch_pg_map::iterator itr = _sign2data_map.find(sign); if (itr == _sign2data_map.end()) { return nullptr; } return const_cast(&(itr->second)); } private: std::vector _sign_list; tsl::bhopscotch_pg_map _sign2data_map; }; class TrainerContext { public: TrainerContext() { trainer_status.resize(2, 0); } inline paddle::ps::PSClient* ps_client() { return pslib->ps_client(); } inline bool is_status(TrainerStatus status) { auto status_idx = static_cast(status); return trainer_status[status_idx] > 0; } // 非线程安全, 其实TrainerContext所有成员的线程安全性 取决于 成员本身的线程安全性 inline void set_status(TrainerStatus status, bool on) { auto status_idx = static_cast(status); trainer_status[status_idx] = on ? 1 : 0; } std::vector trainer_status; YAML::Node trainer_config; paddle::platform::CPUPlace cpu_place; std::shared_ptr pslib; std::stringstream monitor_ssm; //记录monitor信息 std::shared_ptr dataset; //训练样本 std::shared_ptr file_system; //文件操作辅助类 std::shared_ptr epoch_accessor; //训练轮次控制 std::shared_ptr environment; //运行环境 std::vector> process_list; //训练流程 std::shared_ptr cache_dict; //大模型cache词典 }; class ContextStatusGurad { public: ContextStatusGurad(TrainerContext* context, TrainerStatus status) : _context(context), _status(status) { _context->set_status(_status, true); } virtual ~ContextStatusGurad() { _context->set_status(_status, false); } private: TrainerStatus _status; TrainerContext* _context = nullptr; }; } // namespace feed } // namespace custom_trainer } // namespace paddle