diff --git a/paddle/fluid/train/custom_trainer/feed/accessor/epoch_accessor.cc b/paddle/fluid/train/custom_trainer/feed/accessor/epoch_accessor.cc index 8f7c0436d16acd1a631f0328f48a6ba32031c3c4..a7e98b213fed677aa1516cd7f0675ae65dff8001 100644 --- a/paddle/fluid/train/custom_trainer/feed/accessor/epoch_accessor.cc +++ b/paddle/fluid/train/custom_trainer/feed/accessor/epoch_accessor.cc @@ -11,25 +11,31 @@ namespace feed { void HourlyEpochAccessor::next_epoch() { _current_epoch_id = next_epoch_id(_current_epoch_id); } - std::string HourlyEpochAccessor::text(int epoch_id) { + std::string HourlyEpochAccessor::text(uint64_t epoch_id) { return std::to_string(epoch_id); } - bool HourlyEpochAccessor::data_ready(int epoch_id) { + bool HourlyEpochAccessor::data_ready(uint64_t epoch_id) { return true; } - int HourlyEpochAccessor::next_epoch_id(int epoch_id) { - if (epoch_id <= 0) { + int HourlyEpochAccessor::next_epoch_id(uint64_t epoch_id) { + if (epoch_id == 0) { struct timeval now; gettimeofday(&now, NULL); return now.tv_sec / (24 * 3600) * (24 * 3600); } return epoch_id + 3600; } - bool HourlyEpochAccessor::is_last_epoch(int epoch_id) { + bool HourlyEpochAccessor::is_last_epoch(uint64_t epoch_id) { return ((epoch_id / 3600) % 24) == 23; } - bool HourlyEpochAccessor::need_save_model(int epoch_id, ModelSaveWay save_way) { - if (epoch_id <= 0) { + uint64_t HourlyEpochAccessor::epoch_time_interval() { + return 3600; + } + uint64_t HourlyEpochAccessor::epoch_timestamp(uint64_t epoch_id) { + return epoch_id; + } + bool HourlyEpochAccessor::need_save_model(uint64_t epoch_id, ModelSaveWay save_way) { + if (epoch_id == 0) { return false; } if (save_way == ModelSaveWay::ModelSaveInferenceDelta) { @@ -41,7 +47,7 @@ namespace feed { } return false; } - std::string HourlyEpochAccessor::model_save_path(int epoch_id, ModelSaveWay save_way) { + std::string HourlyEpochAccessor::model_save_path(uint64_t epoch_id, ModelSaveWay save_way) { if (save_way == ModelSaveWay::ModelSaveInferenceDelta) { return _model_root_path + "/xbox/delta-" + std::to_string(epoch_id); } else if (save_way == ModelSaveWay::ModelSaveInferenceBase) { diff --git a/paddle/fluid/train/custom_trainer/feed/accessor/epoch_accessor.h b/paddle/fluid/train/custom_trainer/feed/accessor/epoch_accessor.h index 9f5ff7aea88812f0fefcd829987041c1708208c5..5d47e9801a7795570c19ab2c9be17fd1f070dbae 100644 --- a/paddle/fluid/train/custom_trainer/feed/accessor/epoch_accessor.h +++ b/paddle/fluid/train/custom_trainer/feed/accessor/epoch_accessor.h @@ -13,18 +13,22 @@ public: virtual int initialize(YAML::Node config, std::shared_ptr context_ptr) = 0; - virtual int current_epoch_id() { + virtual uint64_t current_epoch_id() { return _current_epoch_id; } virtual void next_epoch() = 0; - virtual std::string text(int epoch_id) = 0; - virtual bool data_ready(int epoch_id) = 0; - virtual int next_epoch_id(int epoch_id) = 0; - virtual bool is_last_epoch(int epoch_id) = 0; - virtual bool need_save_model(int epoch_id, ModelSaveWay save_way) = 0; - virtual std::string model_save_path(int epoch_id, ModelSaveWay save_way) = 0; + virtual std::string text(uint64_t epoch_id) = 0; + virtual bool data_ready(uint64_t epoch_id) = 0; + virtual int next_epoch_id(uint64_t epoch_id) = 0; + virtual bool is_last_epoch(uint64_t epoch_id) = 0; + //epoch间的数据时间间隔(秒) + virtual uint64_t epoch_time_interval() = 0; + //获取epoch的样本数据时间 + virtual uint64_t epoch_timestamp(uint64_t epoch_id) = 0; + virtual bool need_save_model(uint64_t epoch_id, ModelSaveWay save_way) = 0; + virtual std::string model_save_path(uint64_t epoch_id, ModelSaveWay save_way) = 0; protected: - int _current_epoch_id; + uint64_t _current_epoch_id; }; REGISTER_REGISTERER(EpochAccessor); @@ -35,12 +39,14 @@ public: virtual int initialize(YAML::Node config, std::shared_ptr context_ptr); virtual void next_epoch(); - virtual std::string text(int epoch_id); - virtual bool data_ready(int epoch_id); - virtual int next_epoch_id(int epoch_id); - virtual bool is_last_epoch(int epoch_id); - virtual bool need_save_model(int epoch_id, ModelSaveWay save_way); - virtual std::string model_save_path(int epoch_id, ModelSaveWay save_way); + virtual std::string text(uint64_t epoch_id); + virtual bool data_ready(uint64_t epoch_id); + virtual int next_epoch_id(uint64_t epoch_id); + virtual bool is_last_epoch(uint64_t epoch_id); + virtual uint64_t epoch_time_interval(); + virtual uint64_t epoch_timestamp(uint64_t epoch_id); + virtual bool need_save_model(uint64_t epoch_id, ModelSaveWay save_way); + virtual std::string model_save_path(uint64_t epoch_id, ModelSaveWay save_way); private: std::string _model_root_path; }; diff --git a/paddle/fluid/train/custom_trainer/feed/common/pipeline.h b/paddle/fluid/train/custom_trainer/feed/common/pipeline.h new file mode 100644 index 0000000000000000000000000000000000000000..2e0f7d42f46d555e1e3903d1a482028c9bcf13d1 --- /dev/null +++ b/paddle/fluid/train/custom_trainer/feed/common/pipeline.h @@ -0,0 +1,131 @@ +#pragma once +#include "paddle/fluid/framework/archive.h" + +namespace paddle { +namespace custom_trainer { +namespace feed { + +class PipelineOptions { +public: + PipelineOptions() = default; + uint32_t buffer_data_num = 400 ; //缓冲区数据个数,需大于batch_size + uint32_t batch_size = 100 ; //从pipe读数据的batch大小 + bool need_hold_input_data = false; //是否保存input流数据,否则消费后释放 +}; + +/* + * 数据流管道,管道内可对流入数据进行格式转换,再流出 + * + * |---------------Pipeline---------------| + * Channel -> Converter -> Channel + * 多个管道可通过connect_to方法进行级联 + * + * 使用initialize 或 connect_to 初始化管道 + */ +template +class Pipeline { +public: + Pipeline() {} + Pipeline(Pipeline&&) = delete; + Pipeline(const Pipeline&) = delete; + typedef std::function PipeDataConverter; + + int initialize(const PipelineOptions& options, + ::paddle::framework::Channel input_channel, + PipeDataConverter data_converter) { + CHECK(_inited == false); + CHECK(options.batch_size > 0); + _inited = true; + _options = options; + _is_read_end = false; + _converter = data_converter; + _input_channel = input_channel; + _output_channel = ::paddle::framework::MakeChannel(); + + auto batch_size = options.batch_size; + auto buffer_data_num = options.buffer_data_num; + _input_channel->SetBlockSize(batch_size); + _output_channel->SetBlockSize(batch_size); + _input_data_buffer.resize(buffer_data_num); + _output_data_buffer.resize(buffer_data_num); + if (buffer_data_num / batch_size < 3) { + buffer_data_num = batch_size * 3; + } + buffer_data_num = (buffer_data_num / batch_size) * batch_size; + _output_channel->SetCapacity(buffer_data_num); + CHECK(_input_channel != nullptr) << " Input Channel is null"; + _convert_thread = std::make_shared([this](){ + async_convert_data(); + }); + return 0; + } + + template + int connect_to(Pipeline& pre_pipeline, + PipeDataConverter data_converter) { + return initialize(pre_pipeline.options(), pre_pipeline.output_chnnel(), data_converter); + } + + virtual ~Pipeline() { + _is_read_end = true; + if (_convert_thread != nullptr) { + _convert_thread->join(); + } + } + + inline size_t read(std::vector& p) { + p.clear(); + size_t num = _output_channel->Read(p); + return num; + } + + inline const PipelineOptions& options() { + return _options; + } + + inline ::paddle::framework::Channel output_chnnel() { + return _output_channel; + } +private: + void async_convert_data() { + size_t convete_batch_size = _input_data_buffer.size() / 4; + if (convete_batch_size < _options.batch_size * 3) { + convete_batch_size = 3 * _options.batch_size; + } + convete_batch_size = (convete_batch_size / _options.batch_size) * _options.batch_size; + while (!_is_read_end) { + while (_output_channel->Size() < _input_data_buffer.size()) { + size_t read_size = _input_channel-> + Read(convete_batch_size, &_input_data_buffer[0]); + if (read_size == 0) { + _is_read_end = true; + break; + } + CHECK(_converter(&_input_data_buffer[0], &_output_data_buffer[0], + read_size) == 0) << "Data Converter Do Failed"; + _output_channel->WriteMove(read_size, &_output_data_buffer[0]); + if (_options.need_hold_input_data) { + _input_channel_backup->WriteMove(read_size, &_input_data_buffer[0]); + } + } + sleep(1); + } + } + + +private: + bool _inited = false; //标识初始化状态 + bool _is_read_end = false; //标识输入流读取完成 + PipelineOptions _options; //pipe参数 + PipeDataConverter _converter; //converter + std::vector _input_data_buffer; //输入数据buffer + std::vector _output_data_buffer; //出数据buffer + std::shared_ptr _convert_thread; //异步convert + ::paddle::framework::Channel _input_channel; //输入流 + ::paddle::framework::Channel _input_channel_backup; //备份原始输入流 + ::paddle::framework::Channel _output_channel; //输出流 +}; + +} // namespace feed +} // namespace custom_trainer +} // namespace paddle diff --git a/paddle/fluid/train/custom_trainer/feed/common/runtime_environment.cc b/paddle/fluid/train/custom_trainer/feed/common/runtime_environment.cc index 39c386f4266834d3820b2a30ac6558ed481fcc95..9dbc44b105de67edda6f3712695642b020ffe0bb 100644 --- a/paddle/fluid/train/custom_trainer/feed/common/runtime_environment.cc +++ b/paddle/fluid/train/custom_trainer/feed/common/runtime_environment.cc @@ -4,29 +4,89 @@ namespace paddle { namespace custom_trainer { namespace feed { - //配置初始化 - int MPIRuntimeEnvironment::initialize(YAML::Node config) { - return 0; +RuntimeEnvironment::RuntimeEnvironment() {} +RuntimeEnvironment::~RuntimeEnvironment() {} +bool RuntimeEnvironment::is_master_node(EnvironmentRole role) { + return rank_id(role) == 0; +} +std::string format_timestamp(time_t time, const char* format) { + std::string result; + struct tm p = *localtime(&time); + char time_str_buffer[64]; + int size = strftime (time_str_buffer, 64, format, &p); + if (size > 0) { + result.assign(time_str_buffer, size); } - //环境初始化,会在所有依赖模块initialize后调用 - int MPIRuntimeEnvironment::wireup() { + return result; +} + +struct MpiNodeInfo { + int rank_id = -1; + int node_num = 0; + MPI_Comm mpi_comm; +}; + +class MPIRuntimeEnvironment : public RuntimeEnvironment { +public: + MPIRuntimeEnvironment() {} + virtual ~MPIRuntimeEnvironment() {} + virtual int initialize(YAML::Node config) { return 0; } - //当前环境rank_idx - uint32_t MPIRuntimeEnvironment::rank_idx() { + virtual int wireup() { + int hr = MPI_Init(NULL, NULL); + if (MPI_SUCCESS != hr) { + LOG(FATAL) << "MPI_init failed with error code" << hr; + return -1; + } + _roles_node_info.resize(static_cast(EnvironmentRole::ALL) + 1); + set_role(EnvironmentRole::ALL); return 0; } - void MPIRuntimeEnvironment::barrier_all() { - return; + + virtual uint32_t rank_id(EnvironmentRole role) { + return mpi_node_info(role).rank_id; + } + virtual uint32_t node_num(EnvironmentRole role) { + return mpi_node_info(role).node_num; } - void MPIRuntimeEnvironment::print_log(EnvironmentLogType type, EnvironmentLogLevel level, const std::string& log_str) { - if (type == EnvironmentLogType::MASTER_LOG && !is_master_node()) { - return; + virtual int set_role(EnvironmentRole role) { + auto& node_info = mpi_node_info(role); + if (node_info.rank_id < 0) { + if (role == EnvironmentRole::ALL) { + node_info.mpi_comm = MPI_COMM_WORLD; + } else { + MPI_Comm_split(MPI_COMM_WORLD, static_cast(role), + mpi_node_info(EnvironmentRole::ALL).rank_id, &(node_info.mpi_comm)); + } + MPI_Comm_rank(node_info.mpi_comm, &(node_info.rank_id)); + MPI_Comm_size(node_info.mpi_comm, &(node_info.node_num)); } - VLOG(2) << log_str; - return; + return 0; } - REGISTER_CLASS(RuntimeEnvironment, MPIRuntimeEnvironment); + + virtual void barrier(EnvironmentRole role) { + MPI_Barrier(mpi_node_info(role).mpi_comm); + } + virtual void bcast(paddle::framework::BinaryArchive& ar, int root_id, EnvironmentRole role) { + auto& node_info = mpi_node_info(role); + int len = (int)ar.length(); + MPI_Bcast(&len, 1, MPI_INT, root_id, node_info.mpi_comm); + ar.resize(len); + ar.set_cursor(ar.buffer()); + MPI_Bcast(ar.buffer(), len, MPI_BYTE, root, node_info.mpi_comm); + } +protected: + virtual void print_log(EnvironmentLogType type, EnvironmentLogLevel level, const std::string& log_str); + inline MpiNodeInfo& mpi_node_info(EnvironmentRole role) { + return _roles_node_info[static_cast(role)]; + } +private: + std::vector _roles_node_info; + +}; + +REGISTER_CLASS(RuntimeEnvironment, MPIRuntimeEnvironment); } // namespace feed diff --git a/paddle/fluid/train/custom_trainer/feed/common/runtime_environment.h b/paddle/fluid/train/custom_trainer/feed/common/runtime_environment.h index 2654b721b77a50561898f9be366ea8a62280e9bd..49fc4629eafd4139b2ff0cc453956e7dbe7b8b12 100644 --- a/paddle/fluid/train/custom_trainer/feed/common/runtime_environment.h +++ b/paddle/fluid/train/custom_trainer/feed/common/runtime_environment.h @@ -6,6 +6,7 @@ */ #pragma once #include +#include "paddle/fluid/framework/archive.h" #include "paddle/fluid/string/string_helper.h" #include "paddle/fluid/train/custom_trainer/feed/common/registerer.h" @@ -21,26 +22,37 @@ enum class EnvironmentLogLevel { }; enum class EnvironmentLogType { - MASTER_LOG = 0, //仅master节点对外输出 - ALL_LOG = 1 //所有节点都会对外输出 + MASTER_LOG = 0, //仅master节点对外输出 + ALL_LOG = 1 //所有节点都会对外输出 +}; + +//保持该枚举值的连续递增,且ALL在尾部 +enum class EnvironmentRole { + WORKER = 0, //训练Worker + PSERVER = 1, //参数服务器 + + ALL = 2 //所有角色,请保持在枚举尾部 }; class RuntimeEnvironment { public: - RuntimeEnvironment() {} - virtual ~RuntimeEnvironment() {} + RuntimeEnvironment(); + virtual ~RuntimeEnvironment(); //配置初始化 virtual int initialize(YAML::Node config) = 0; + //设置role + virtual int set_role(EnvironmentRole role) = 0; //环境初始化,会在所有依赖模块initialize后调用 virtual int wireup() = 0; //多线程可调用接口 Start //当前环境rank_idx - virtual uint32_t rank_idx() = 0; + virtual uint32_t rank_id(EnvironmentRole role) = 0; + //运行环境节点数 + virtual uint32_t node_num(EnvironmentRole role) = 0; //环境内主节点 - virtual bool is_master_node() { - return rank_idx() == 0; - } + virtual bool is_master_node(EnvironmentRole role); + //环境定制化log template void log(EnvironmentLogType type, EnvironmentLogLevel level, @@ -51,29 +63,22 @@ public: //接口只允许在主线程调用 Start - //barrier - virtual void barrier_all() = 0; + //barrier 指定role的节点 + virtual void barrier(EnvironmentRole role) = 0; + //bcast 广播 + virtual void bcast(paddle::framework::BinaryArchive& ar, int root_id, EnvironmentRole role) = 0; //接口只允许在主线程调用 End protected: virtual void print_log(EnvironmentLogType type, EnvironmentLogLevel level, const std::string& log_str) = 0; }; REGISTER_REGISTERER(RuntimeEnvironment); -class MPIRuntimeEnvironment : public RuntimeEnvironment { -public: - MPIRuntimeEnvironment() {} - virtual ~MPIRuntimeEnvironment() {} - //配置初始化 - virtual int initialize(YAML::Node config); - //环境初始化,会在所有依赖模块initialize后调用 - virtual int wireup(); - //当前环境rank_idx - virtual uint32_t rank_idx(); - virtual void barrier_all(); -protected: - virtual void print_log(EnvironmentLogType type, EnvironmentLogLevel level, const std::string& log_str); -}; + +std::string format_timestamp(time_t time, const char* format); +std::string format_timestamp(time_t time, const std::string& format) { + return format_timestamp(time, format.c_str()); +} } // namespace feed } // namespace custom_trainer diff --git a/paddle/fluid/train/custom_trainer/feed/dataset/data_reader.h b/paddle/fluid/train/custom_trainer/feed/dataset/data_reader.h index a10275dd04148f00a1319160e5a773b952ceb4a7..d2c2c9b4ec3333148bea975023d610b3e06c0bf9 100644 --- a/paddle/fluid/train/custom_trainer/feed/dataset/data_reader.h +++ b/paddle/fluid/train/custom_trainer/feed/dataset/data_reader.h @@ -8,6 +8,7 @@ #include #include #include "paddle/fluid/framework/channel.h" +#include "paddle/fluid/train/custom_trainer/feed/common/pipeline.h" #include "paddle/fluid/train/custom_trainer/feed/common/registerer.h" namespace paddle { @@ -36,6 +37,11 @@ public: std::string data;//样本数据, maybe压缩格式 }; +typedef std::shared_ptr> SampleInstancePipe; +inline SampleInstancePipe make_sample_instance_channel() { + return std::make_shared>(); +} + class DataParser { public: DataParser() {} @@ -56,8 +62,12 @@ public: virtual int initialize(const YAML::Node& config, std::shared_ptr context) = 0; //判断样本数据是否已就绪,就绪表明可以开始download virtual bool is_data_ready(const std::string& data_dir) = 0; - //读取数据样本流中 - virtual int read_all(const std::string& data_dir, ::paddle::framework::Channel data_channel) = 0; + //读取dir下文件列表 + virtual std::vector data_file_list(const std::string& data_dir); + //读取目录下数据到样本流中 + virtual int read_all(const std::string& data_dir, ::paddle::framework::Channel& data_channel) = 0; + //读取指定文件列表的数据到样本流中 + virtual int read_all(const std::vector& data_list, ::paddle::framework::Channel& data_channel) = 0; virtual const DataParser* get_parser() { return _parser.get(); } diff --git a/paddle/fluid/train/custom_trainer/feed/dataset/dataset.cc b/paddle/fluid/train/custom_trainer/feed/dataset/dataset.cc new file mode 100644 index 0000000000000000000000000000000000000000..125b412def25e84323f60b1d41daa5279804d78f --- /dev/null +++ b/paddle/fluid/train/custom_trainer/feed/dataset/dataset.cc @@ -0,0 +1,66 @@ +#include "paddle/fluid/train/custom_trainer/feed/dataset/dataset.h" + +namespace paddle { +namespace custom_trainer { +namespace feed { + +int Dataset::initialize( + const YAML::Node& config, std::shared_ptr context) { + if (!config["data_list"]) { + VLOG(0) << "miss data_list config in dataset, please check"; + return -1; + } + int data_num = config["data_list"].size(); + for (int i = 0; i < data_num; ++i) { + std::string name = config["data_list"][i]["name"].as(); + auto data_ptr = std::make_shared(); + if (data_ptr->initialize(config["data_list"][i], context) != 0) { + VLOG(0) << "dataset initialize failed, name:" << name; + return -1; + } + _data_containers[name] = data_ptr; + } + return 0; +} + +inline void Dataset::pre_detect_data( + const std::string& data_name, uint64_t epoch_id) { + _data_containers[data_name]->pre_detect_data(epoch_id); + return; +} + +inline DatasetStatus Dataset::epoch_data_status( + const std::string& data_name, uint64_t epoch_id) { + return _data_containers[data_name]->epoch_data_status(epoch_id); +} + +inline ::paddle::framework::Channel Dataset::fetch_data( + const std::string& data_name, uint64_t epoch_id) { + return _data_containers[data_name]->fetch(epoch_id); +} + +SampleInstancePipe Dataset::fetch_sample( + const std::string& data_name, uint32_t batch_size, uint64_t epoch_id) { + auto* data_container = _data_containers[data_name].get(); + auto data_channel = data_container->fetch(epoch_id); + const auto* data_parser = data_container->data_parser(); + PipelineOptions options; + options.batch_size = batch_size; + options.need_hold_input_data = true; + options.buffer_data_num = batch_size * 10; + SampleInstancePipe pipe = make_sample_instance_channel(); + pipe->initialize(options, data_channel, + [data_parser] (const DataItem* data, SampleInstance* sample, size_t num) -> int { + int ret = 0; + for (int i = 0; i < num; ++i, ++data, ++sample) { + ret |= data_parser->parse_to_sample(*data, *sample); + } + return ret; + }); + return pipe; +} + + +} // namespace feed +} // namespace custom_trainer +} // namespace paddle diff --git a/paddle/fluid/train/custom_trainer/feed/dataset/dataset.h b/paddle/fluid/train/custom_trainer/feed/dataset/dataset.h new file mode 100644 index 0000000000000000000000000000000000000000..248aa9477dd53ce57ad0cf764aabcde1ca63896e --- /dev/null +++ b/paddle/fluid/train/custom_trainer/feed/dataset/dataset.h @@ -0,0 +1,44 @@ +#pragma once +#include +#include +#include +#include +#include +#include "paddle/fluid/framework/channel.h" +#include "paddle/fluid/train/custom_trainer/feed/trainer_context.h" +#include "paddle/fluid/train/custom_trainer/feed/common/registerer.h" +#include "paddle/fluid/train/custom_trainer/feed/dataset/dataset_container.h" + +namespace paddle { +namespace custom_trainer { +namespace feed { + +class Dataset { +public: + Dataset() {} + virtual ~Dataset() {} + + virtual int initialize( + const YAML::Node& config, std::shared_ptr context); + + //触发可预取的数据判断 + virtual void pre_detect_data(const std::string& data_name, uint64_t epoch_id); + + //获取数据状态 + virtual DatasetStatus epoch_data_status(const std::string& data_name, uint64_t epoch_id); + + //返回各DataContainer内的原始数据(maybe 压缩格式) + virtual ::paddle::framework::Channel fetch_data( + const std::string& data_name, uint64_t epoch_id); + + //以管道形式返回标准样本流,管道内会对数据做异步转换 + virtual SampleInstancePipe fetch_sample( + const std::string& data_name, uint32_t batch_size, uint64_t epoch_id); + +private: + std::unordered_map> _data_containers; +}; + +} // namespace feed +} // namespace custom_trainer +} // namespace paddle diff --git a/paddle/fluid/train/custom_trainer/feed/dataset/dataset_container.cc b/paddle/fluid/train/custom_trainer/feed/dataset/dataset_container.cc index ae0e65496753a8fef582893c252521010b403c1b..7b19e1cb0576371c770edfc0fd1435de81bee92e 100644 --- a/paddle/fluid/train/custom_trainer/feed/dataset/dataset_container.cc +++ b/paddle/fluid/train/custom_trainer/feed/dataset/dataset_container.cc @@ -8,31 +8,148 @@ #include #include "paddle/fluid/framework/io/shell.h" #include "paddle/fluid/string/string_helper.h" +#include "paddle/fluid/train/custom_trainer/feed/trainer_context.h" +#include "paddle/fluid/train/custom_trainer/feed/accessor/epoch_accessor.h" #include "paddle/fluid/train/custom_trainer/feed/dataset/dataset_container.h" namespace paddle { namespace custom_trainer { namespace feed { + +int DatasetContainer::initialize( + const YAML::Node& config, std::shared_ptr context) { + _dataset_config = config; + _trainer_context = context.get(); + //预取n轮样本数据 + _prefetch_num = config["prefetch_num"].as(); + _dataset_list.resize(_prefetch_num); + + _data_root_paths = paddle::string::split_string( + config["root_path"].as(), " "); + _data_split_interval = config["data_spit_interval"].as(); + _data_path_formater = config["data_path_formater"].as(); + std::string data_reader_class = config["data_reader"].as(); + DataReader* data_reader = CREATE_CLASS(DataReader, data_reader_class); + _data_reader.reset(data_reader); + return _data_reader->initialize(config, context); +} + +std::shared_ptr DatasetContainer::dataset(uint64_t timestamp) { + auto* epoch_accessor = _trainer_context->epoch_accessor.get(); + auto data_idx = timestamp / epoch_accessor->epoch_time_interval(); + return _dataset_list[data_idx % _prefetch_num]; +} + +void DatasetContainer::pre_detect_data(uint64_t epoch_id) { + int status = 0; + auto* epoch_accessor = _trainer_context->epoch_accessor.get(); + time_t timestamp = epoch_accessor->epoch_timestamp(epoch_id); + if (timestamp % epoch_accessor->epoch_time_interval() != 0) { + LOG(FATAL) << "timestamp:" << timestamp << " don't match interval:" << epoch_accessor->epoch_time_interval(); + return; + } + size_t data_num = data_num_for_train(timestamp, epoch_accessor->epoch_time_interval(), _data_split_interval); + uint64_t data_timestamp = timestamp % _data_split_interval == 0 ? timestamp : (timestamp / _data_split_interval + 1) * _data_split_interval; + std::vector data_path_list; + for (int i = 0; i < _data_root_paths.size() && status == 0; ++i) { + for (int j = 0; j < data_num && status == 0; ++j) { + std::string path_suffix = format_timestamp(data_timestamp + j * _data_split_interval, _data_path_formater); + std::string data_dir = _data_root_paths[i] + "/" + path_suffix; + status = read_data_list(data_dir, data_path_list); + } + } + if (status == 0) { + auto dataset_info = dataset(timestamp); + dataset_info->timestamp = timestamp; + dataset_info->file_path_list = std::move(data_path_list); + dataset_info->status = DatasetStatus::Detected; + } + return; +} + +int DatasetContainer::read_data_list(const std::string& data_dir, std::vector& data_list) { + auto* environment = _trainer_context->environment.get(); + + // 检查数据Ready + int data_status = -1; + if (environment->is_master_node(EnvironmentRole::WORKER)) { + if (_data_reader->is_data_ready(data_dir)) { + data_status = 0; + } + } + paddle::framework::BinaryArchive ar; + ar << data_status; + environment->bcast(ar, 0, EnvironmentRole::WORKER); + ar >> data_status; + if (data_status != 0) { + return -1; + } + + // 读取文件列表 + ar.Clear(); + std::vector data_path_list; + if (environment->is_master_node(EnvironmentRole::WORKER)) { + data_path_list = _data_reader->data_file_list(data_dir); + ar << data_path_list; + } + environment->bcast(ar, 0, EnvironmentRole::WORKER); + ar >> data_path_list; + auto worker_id = environment->rank_id(EnvironmentRole::WORKER); + auto worker_num = environment->node_num(EnvironmentRole::WORKER); + for (int i = worker_id; i < data_path_list.size(); i+=worker_num) { + data_list.push_back(data_path_list[i]); + } + environment->barrier(EnvironmentRole::WORKER); + return 0; +} + +DatasetStatus DatasetContainer::epoch_data_status(uint64_t epoch_id) { + auto* epoch_accessor = _trainer_context->epoch_accessor.get(); + time_t timestamp = epoch_accessor->epoch_timestamp(epoch_id); + return data_status(timestamp); +} + +DatasetStatus DatasetContainer::data_status(uint64_t timestamp) { + auto dataset_info = dataset(timestamp); + if (dataset_info->timestamp != timestamp) { + return DatasetStatus::Empty; + } + return dataset_info->status; +} -paddle::framework::Channel DatasetContainer::fetch(int epoch_id) { +paddle::framework::Channel DatasetContainer::fetch(uint64_t epoch_id) { paddle::framework::Channel result; - if (_ready_epoch_id < epoch_id) { + auto* epoch_accessor = _trainer_context->epoch_accessor.get(); + time_t timestamp = epoch_accessor->epoch_timestamp(epoch_id); + if (data_status(timestamp) != DatasetStatus::Ready) { return result; } - _current_epoch_id = epoch_id; - _current_dataset_idx = epoch_id % _prefetch_num; - //result = _dataset_list[_current_dataset_idx].fetch(); - //_dataset_list[_current_dataset_idx].reset((decltype(result.get())*)NULL); - return result; + auto dataset_info = dataset(timestamp); + return dataset_info->data_channel; } -void DatasetContainer::async_download_data() { +void DatasetContainer::async_download_data(uint64_t start_timestamp) { + auto* epoch_accessor = _trainer_context->epoch_accessor.get(); + if (start_timestamp % epoch_accessor->epoch_time_interval() != 0) { + LOG(FATAL) << "timestamp:" << start_timestamp << " don't match interval:" << epoch_accessor->epoch_time_interval(); + return; + } while (true) { - //do download - sleep(30); + auto dataset_info = dataset(start_timestamp); + while (data_status(start_timestamp) != DatasetStatus::Detected) { + sleep(30); + } + const auto& file_list = dataset_info->file_path_list; + dataset_info->data_channel->Clear(); + while (_data_reader->read_all(file_list, dataset_info->data_channel) != 0) { + dataset_info->data_channel->Clear(); + VLOG(0) << "timestamp:" << start_timestamp << " data read failed, retry"; + sleep(30); + } + start_timestamp += epoch_accessor->epoch_time_interval(); } } -}//namespace feed -}//namespace custom_trainer -}//namespace paddle +} // namespace feed +} // namespace custom_trainer +} // namespace paddle diff --git a/paddle/fluid/train/custom_trainer/feed/dataset/dataset_container.h b/paddle/fluid/train/custom_trainer/feed/dataset/dataset_container.h index a7404a82991f427dd2a53d119a7b8956a8d5aec4..c0a67f8f407b8bedf071a2a102baf97ad3afc8aa 100644 --- a/paddle/fluid/train/custom_trainer/feed/dataset/dataset_container.h +++ b/paddle/fluid/train/custom_trainer/feed/dataset/dataset_container.h @@ -16,38 +16,61 @@ namespace paddle { namespace custom_trainer { namespace feed { +inline int data_num_for_train(uint64_t train_begin_timestamp, uint32_t train_time_interval, uint32_t data_time_interval) { + uint64_t data_begin_time = train_begin_timestamp; + uint64_t data_end_time = data_begin_time + train_time_interval; + uint64_t end_idx = (data_end_time + data_time_interval - 1) / data_time_interval; + uint64_t begin_idx = (data_begin_time + data_time_interval - 1 ) / data_time_interval; + return end_idx - begin_idx; +} + +enum class DatasetStatus { + Empty = 0, + Detected = 1, + Downloding = 2, + Ready = 3 +}; +struct DatasetInfo { + uint64_t timestamp = 0; + std::vector file_path_list; + DatasetStatus status = DatasetStatus::Empty; + ::paddle::framework::Channel data_channel = ::paddle::framework::MakeChannel(); +}; + class DatasetContainer { public: DatasetContainer() {} virtual ~DatasetContainer() {} - virtual int initialize(const YAML::Node& config) { - _dataset_config = config; - //预取n轮样本数据 - _prefetch_num = config["prefetch_num"].as(); - _data_root_path = config["root_path"].as(); - _data_path_generater = config["_data_path_generater"].as(); - return 0; - } + virtual int initialize( + const YAML::Node& config, std::shared_ptr context); virtual void run(); - //获取特定epoch_i样本,如果数据未ready,Channel内为空指针 - virtual ::paddle::framework::Channel fetch(int epoch_id); //触发可预取的数据判断 - virtual void pre_detect_data(RuntimeEnvironment* env); - + virtual void pre_detect_data(uint64_t epoch_id); + //获取数据状态 + virtual DatasetStatus epoch_data_status(uint64_t epoch_id); + //获取特定epoch_i样本,如果数据未ready,Channel内为空指针 + virtual ::paddle::framework::Channel fetch(uint64_t epoch_id); + //获取DataItem解析器 + virtual const DataParser* data_parser() { + return _data_reader->get_parser(); + } protected: + virtual DatasetStatus data_status(uint64_t timestamp); + virtual int read_data_list(const std::string& data_dir, std::vector& data_list); //异步样本download - virtual void async_download_data(); - virtual void download(int epoch_id, const std::vector& paths); + virtual void async_download_data(uint64_t start_timestamp); + virtual std::shared_ptr dataset(uint64_t timestamp); - int _prefetch_num = 0; + int _prefetch_num = 0; + int _data_split_interval = 60; //样本切分周期(秒) YAML::Node _dataset_config; - std::string _data_root_path; - std::string _data_path_generater; + std::string _data_path_formater; + std::vector _data_root_paths; //支持同时读取多个目录 - uint32_t _current_dataset_idx; //当前样本数据idx - int _current_epoch_id = -1; - int _ready_epoch_id = -1; //已下载完成的epoch_id - std::vector> _dataset_list;//预取的数据列表 + TrainerContext* _trainer_context; + std::shared_ptr _data_reader; + std::shared_ptr _downloader_thread; + std::vector> _dataset_list;//预取的数据列表 }; }//namespace feed diff --git a/paddle/fluid/train/custom_trainer/feed/process/learner_process.cc b/paddle/fluid/train/custom_trainer/feed/process/learner_process.cc index d8eba032162414e4da8513c9f00b4e6e491a7511..28b44cf1b3b6b01625ab41e8ee54ba22d9376a86 100644 --- a/paddle/fluid/train/custom_trainer/feed/process/learner_process.cc +++ b/paddle/fluid/train/custom_trainer/feed/process/learner_process.cc @@ -35,7 +35,7 @@ int LearnerProcess::initialize(std::shared_ptr context_ptr) { return 0; } -std::future LearnerProcess::save_model(int epoch_id, int table_id, ModelSaveWay way) { +std::future LearnerProcess::save_model(uint64_t epoch_id, int table_id, ModelSaveWay way) { std::promise p; auto ret = p.get_future(); if (_context_ptr->epoch_accessor->need_save_model(epoch_id, way)) { @@ -47,7 +47,7 @@ std::future LearnerProcess::save_model(int epoch_id, int table_id, ModelSav return ret; } -int LearnerProcess::wait_save_model(int epoch_id, ModelSaveWay way) { +int LearnerProcess::wait_save_model(uint64_t epoch_id, ModelSaveWay way) { auto* environment = _context_ptr->environment.get(); if (!environment->is_master_node()) { return 0; @@ -71,7 +71,7 @@ int LearnerProcess::wait_save_model(int epoch_id, ModelSaveWay way) { int LearnerProcess::run() { auto* environment = _context_ptr->environment.get(); auto* epoch_accessor = _context_ptr->epoch_accessor.get(); - int epoch_id = epoch_accessor->current_epoch_id(); + uint64_t epoch_id = epoch_accessor->current_epoch_id(); environment->log(EnvironmentLogType::MASTER_LOG, EnvironmentLogLevel::NOTICE, "Resume traine with epoch_id:%d label:%s", epoch_id, _context_ptr->epoch_accessor->text(epoch_id).c_str()); diff --git a/paddle/fluid/train/custom_trainer/feed/process/learner_process.h b/paddle/fluid/train/custom_trainer/feed/process/learner_process.h index 203e33790cec622ebca3109094f61bce6c34f6e8..7addb601e9cc2cee07194ae262fe622c00aab4bf 100644 --- a/paddle/fluid/train/custom_trainer/feed/process/learner_process.h +++ b/paddle/fluid/train/custom_trainer/feed/process/learner_process.h @@ -21,9 +21,9 @@ public: protected: //同步保存所有模型 -virtual int wait_save_model(int epoch_id, ModelSaveWay way); +virtual int wait_save_model(uint64_t epoch_id, ModelSaveWay way); //异步保存指定模型 -virtual std::future save_model(int epoch_id, int table_id, ModelSaveWay way); +virtual std::future save_model(uint64_t epoch_id, int table_id, ModelSaveWay way); //执行指定训练网络 virtual int run_executor(Executor* executor);