提交 c1d64d7e 编写于 作者: X xiexionghang

for dataset pipeline

上级 8699c196
......@@ -11,25 +11,31 @@ namespace feed {
void HourlyEpochAccessor::next_epoch() {
_current_epoch_id = next_epoch_id(_current_epoch_id);
}
std::string HourlyEpochAccessor::text(int epoch_id) {
std::string HourlyEpochAccessor::text(uint64_t epoch_id) {
return std::to_string(epoch_id);
}
bool HourlyEpochAccessor::data_ready(int epoch_id) {
bool HourlyEpochAccessor::data_ready(uint64_t epoch_id) {
return true;
}
int HourlyEpochAccessor::next_epoch_id(int epoch_id) {
if (epoch_id <= 0) {
int HourlyEpochAccessor::next_epoch_id(uint64_t epoch_id) {
if (epoch_id == 0) {
struct timeval now;
gettimeofday(&now, NULL);
return now.tv_sec / (24 * 3600) * (24 * 3600);
}
return epoch_id + 3600;
}
bool HourlyEpochAccessor::is_last_epoch(int epoch_id) {
bool HourlyEpochAccessor::is_last_epoch(uint64_t epoch_id) {
return ((epoch_id / 3600) % 24) == 23;
}
bool HourlyEpochAccessor::need_save_model(int epoch_id, ModelSaveWay save_way) {
if (epoch_id <= 0) {
uint64_t HourlyEpochAccessor::epoch_time_interval() {
return 3600;
}
uint64_t HourlyEpochAccessor::epoch_timestamp(uint64_t epoch_id) {
return epoch_id;
}
bool HourlyEpochAccessor::need_save_model(uint64_t epoch_id, ModelSaveWay save_way) {
if (epoch_id == 0) {
return false;
}
if (save_way == ModelSaveWay::ModelSaveInferenceDelta) {
......@@ -41,7 +47,7 @@ namespace feed {
}
return false;
}
std::string HourlyEpochAccessor::model_save_path(int epoch_id, ModelSaveWay save_way) {
std::string HourlyEpochAccessor::model_save_path(uint64_t epoch_id, ModelSaveWay save_way) {
if (save_way == ModelSaveWay::ModelSaveInferenceDelta) {
return _model_root_path + "/xbox/delta-" + std::to_string(epoch_id);
} else if (save_way == ModelSaveWay::ModelSaveInferenceBase) {
......
......@@ -13,18 +13,22 @@ public:
virtual int initialize(YAML::Node config,
std::shared_ptr<TrainerContext> context_ptr) = 0;
virtual int current_epoch_id() {
virtual uint64_t current_epoch_id() {
return _current_epoch_id;
}
virtual void next_epoch() = 0;
virtual std::string text(int epoch_id) = 0;
virtual bool data_ready(int epoch_id) = 0;
virtual int next_epoch_id(int epoch_id) = 0;
virtual bool is_last_epoch(int epoch_id) = 0;
virtual bool need_save_model(int epoch_id, ModelSaveWay save_way) = 0;
virtual std::string model_save_path(int epoch_id, ModelSaveWay save_way) = 0;
virtual std::string text(uint64_t epoch_id) = 0;
virtual bool data_ready(uint64_t epoch_id) = 0;
virtual int next_epoch_id(uint64_t epoch_id) = 0;
virtual bool is_last_epoch(uint64_t epoch_id) = 0;
//epoch间的数据时间间隔(秒)
virtual uint64_t epoch_time_interval() = 0;
//获取epoch的样本数据时间
virtual uint64_t epoch_timestamp(uint64_t epoch_id) = 0;
virtual bool need_save_model(uint64_t epoch_id, ModelSaveWay save_way) = 0;
virtual std::string model_save_path(uint64_t epoch_id, ModelSaveWay save_way) = 0;
protected:
int _current_epoch_id;
uint64_t _current_epoch_id;
};
REGISTER_REGISTERER(EpochAccessor);
......@@ -35,12 +39,14 @@ public:
virtual int initialize(YAML::Node config,
std::shared_ptr<TrainerContext> context_ptr);
virtual void next_epoch();
virtual std::string text(int epoch_id);
virtual bool data_ready(int epoch_id);
virtual int next_epoch_id(int epoch_id);
virtual bool is_last_epoch(int epoch_id);
virtual bool need_save_model(int epoch_id, ModelSaveWay save_way);
virtual std::string model_save_path(int epoch_id, ModelSaveWay save_way);
virtual std::string text(uint64_t epoch_id);
virtual bool data_ready(uint64_t epoch_id);
virtual int next_epoch_id(uint64_t epoch_id);
virtual bool is_last_epoch(uint64_t epoch_id);
virtual uint64_t epoch_time_interval();
virtual uint64_t epoch_timestamp(uint64_t epoch_id);
virtual bool need_save_model(uint64_t epoch_id, ModelSaveWay save_way);
virtual std::string model_save_path(uint64_t epoch_id, ModelSaveWay save_way);
private:
std::string _model_root_path;
};
......
#pragma once
#include "paddle/fluid/framework/archive.h"
namespace paddle {
namespace custom_trainer {
namespace feed {
class PipelineOptions {
public:
PipelineOptions() = default;
uint32_t buffer_data_num = 400 ; //缓冲区数据个数,需大于batch_size
uint32_t batch_size = 100 ; //从pipe读数据的batch大小
bool need_hold_input_data = false; //是否保存input流数据,否则消费后释放
};
/*
* 数据流管道,管道内可对流入数据进行格式转换,再流出
*
* |---------------Pipeline---------------|
* Channel<IN> -> Converter -> Channel<OUT>
* 多个管道可通过connect_to方法进行级联
*
* 使用initialize 或 connect_to 初始化管道
*/
template <class TypeIn, class TypeOut>
class Pipeline {
public:
Pipeline() {}
Pipeline(Pipeline&&) = delete;
Pipeline(const Pipeline&) = delete;
typedef std::function<int(const TypeIn*, TypeOut*, size_t num)> PipeDataConverter;
int initialize(const PipelineOptions& options,
::paddle::framework::Channel<TypeIn> input_channel,
PipeDataConverter data_converter) {
CHECK(_inited == false);
CHECK(options.batch_size > 0);
_inited = true;
_options = options;
_is_read_end = false;
_converter = data_converter;
_input_channel = input_channel;
_output_channel = ::paddle::framework::MakeChannel<TypeOut>();
auto batch_size = options.batch_size;
auto buffer_data_num = options.buffer_data_num;
_input_channel->SetBlockSize(batch_size);
_output_channel->SetBlockSize(batch_size);
_input_data_buffer.resize(buffer_data_num);
_output_data_buffer.resize(buffer_data_num);
if (buffer_data_num / batch_size < 3) {
buffer_data_num = batch_size * 3;
}
buffer_data_num = (buffer_data_num / batch_size) * batch_size;
_output_channel->SetCapacity(buffer_data_num);
CHECK(_input_channel != nullptr) << " Input Channel is null";
_convert_thread = std::make_shared<std::thread>([this](){
async_convert_data();
});
return 0;
}
template <class PreTypeIn>
int connect_to(Pipeline<PreTypeIn, TypeIn>& pre_pipeline,
PipeDataConverter data_converter) {
return initialize(pre_pipeline.options(), pre_pipeline.output_chnnel(), data_converter);
}
virtual ~Pipeline() {
_is_read_end = true;
if (_convert_thread != nullptr) {
_convert_thread->join();
}
}
inline size_t read(std::vector<TypeOut>& p) {
p.clear();
size_t num = _output_channel->Read(p);
return num;
}
inline const PipelineOptions& options() {
return _options;
}
inline ::paddle::framework::Channel<TypeOut> output_chnnel() {
return _output_channel;
}
private:
void async_convert_data() {
size_t convete_batch_size = _input_data_buffer.size() / 4;
if (convete_batch_size < _options.batch_size * 3) {
convete_batch_size = 3 * _options.batch_size;
}
convete_batch_size = (convete_batch_size / _options.batch_size) * _options.batch_size;
while (!_is_read_end) {
while (_output_channel->Size() < _input_data_buffer.size()) {
size_t read_size = _input_channel->
Read(convete_batch_size, &_input_data_buffer[0]);
if (read_size == 0) {
_is_read_end = true;
break;
}
CHECK(_converter(&_input_data_buffer[0], &_output_data_buffer[0],
read_size) == 0) << "Data Converter Do Failed";
_output_channel->WriteMove(read_size, &_output_data_buffer[0]);
if (_options.need_hold_input_data) {
_input_channel_backup->WriteMove(read_size, &_input_data_buffer[0]);
}
}
sleep(1);
}
}
private:
bool _inited = false; //标识初始化状态
bool _is_read_end = false; //标识输入流读取完成
PipelineOptions _options; //pipe参数
PipeDataConverter _converter; //converter
std::vector<TypeIn> _input_data_buffer; //输入数据buffer
std::vector<TypeOut> _output_data_buffer; //出数据buffer
std::shared_ptr<std::thread> _convert_thread; //异步convert
::paddle::framework::Channel<TypeIn> _input_channel; //输入流
::paddle::framework::Channel<TypeIn> _input_channel_backup; //备份原始输入流
::paddle::framework::Channel<TypeOut> _output_channel; //输出流
};
} // namespace feed
} // namespace custom_trainer
} // namespace paddle
......@@ -4,29 +4,89 @@ namespace paddle {
namespace custom_trainer {
namespace feed {
//配置初始化
int MPIRuntimeEnvironment::initialize(YAML::Node config) {
return 0;
RuntimeEnvironment::RuntimeEnvironment() {}
RuntimeEnvironment::~RuntimeEnvironment() {}
bool RuntimeEnvironment::is_master_node(EnvironmentRole role) {
return rank_id(role) == 0;
}
std::string format_timestamp(time_t time, const char* format) {
std::string result;
struct tm p = *localtime(&time);
char time_str_buffer[64];
int size = strftime (time_str_buffer, 64, format, &p);
if (size > 0) {
result.assign(time_str_buffer, size);
}
//环境初始化,会在所有依赖模块initialize后调用
int MPIRuntimeEnvironment::wireup() {
return result;
}
struct MpiNodeInfo {
int rank_id = -1;
int node_num = 0;
MPI_Comm mpi_comm;
};
class MPIRuntimeEnvironment : public RuntimeEnvironment {
public:
MPIRuntimeEnvironment() {}
virtual ~MPIRuntimeEnvironment() {}
virtual int initialize(YAML::Node config) {
return 0;
}
//当前环境rank_idx
uint32_t MPIRuntimeEnvironment::rank_idx() {
virtual int wireup() {
int hr = MPI_Init(NULL, NULL);
if (MPI_SUCCESS != hr) {
LOG(FATAL) << "MPI_init failed with error code" << hr;
return -1;
}
_roles_node_info.resize(static_cast<int>(EnvironmentRole::ALL) + 1);
set_role(EnvironmentRole::ALL);
return 0;
}
void MPIRuntimeEnvironment::barrier_all() {
return;
virtual uint32_t rank_id(EnvironmentRole role) {
return mpi_node_info(role).rank_id;
}
virtual uint32_t node_num(EnvironmentRole role) {
return mpi_node_info(role).node_num;
}
void MPIRuntimeEnvironment::print_log(EnvironmentLogType type, EnvironmentLogLevel level, const std::string& log_str) {
if (type == EnvironmentLogType::MASTER_LOG && !is_master_node()) {
return;
virtual int set_role(EnvironmentRole role) {
auto& node_info = mpi_node_info(role);
if (node_info.rank_id < 0) {
if (role == EnvironmentRole::ALL) {
node_info.mpi_comm = MPI_COMM_WORLD;
} else {
MPI_Comm_split(MPI_COMM_WORLD, static_cast<int>(role),
mpi_node_info(EnvironmentRole::ALL).rank_id, &(node_info.mpi_comm));
}
MPI_Comm_rank(node_info.mpi_comm, &(node_info.rank_id));
MPI_Comm_size(node_info.mpi_comm, &(node_info.node_num));
}
VLOG(2) << log_str;
return;
return 0;
}
REGISTER_CLASS(RuntimeEnvironment, MPIRuntimeEnvironment);
virtual void barrier(EnvironmentRole role) {
MPI_Barrier(mpi_node_info(role).mpi_comm);
}
virtual void bcast(paddle::framework::BinaryArchive& ar, int root_id, EnvironmentRole role) {
auto& node_info = mpi_node_info(role);
int len = (int)ar.length();
MPI_Bcast(&len, 1, MPI_INT, root_id, node_info.mpi_comm);
ar.resize(len);
ar.set_cursor(ar.buffer());
MPI_Bcast(ar.buffer(), len, MPI_BYTE, root, node_info.mpi_comm);
}
protected:
virtual void print_log(EnvironmentLogType type, EnvironmentLogLevel level, const std::string& log_str);
inline MpiNodeInfo& mpi_node_info(EnvironmentRole role) {
return _roles_node_info[static_cast<int>(role)];
}
private:
std::vector<MpiNodeInfo> _roles_node_info;
};
REGISTER_CLASS(RuntimeEnvironment, MPIRuntimeEnvironment);
} // namespace feed
......
......@@ -6,6 +6,7 @@
*/
#pragma once
#include <yaml-cpp/yaml.h>
#include "paddle/fluid/framework/archive.h"
#include "paddle/fluid/string/string_helper.h"
#include "paddle/fluid/train/custom_trainer/feed/common/registerer.h"
......@@ -21,26 +22,37 @@ enum class EnvironmentLogLevel {
};
enum class EnvironmentLogType {
MASTER_LOG = 0, //仅master节点对外输出
ALL_LOG = 1 //所有节点都会对外输出
MASTER_LOG = 0, //仅master节点对外输出
ALL_LOG = 1 //所有节点都会对外输出
};
//保持该枚举值的连续递增,且ALL在尾部
enum class EnvironmentRole {
WORKER = 0, //训练Worker
PSERVER = 1, //参数服务器
ALL = 2 //所有角色,请保持在枚举尾部
};
class RuntimeEnvironment {
public:
RuntimeEnvironment() {}
virtual ~RuntimeEnvironment() {}
RuntimeEnvironment();
virtual ~RuntimeEnvironment();
//配置初始化
virtual int initialize(YAML::Node config) = 0;
//设置role
virtual int set_role(EnvironmentRole role) = 0;
//环境初始化,会在所有依赖模块initialize后调用
virtual int wireup() = 0;
//多线程可调用接口 Start
//当前环境rank_idx
virtual uint32_t rank_idx() = 0;
virtual uint32_t rank_id(EnvironmentRole role) = 0;
//运行环境节点数
virtual uint32_t node_num(EnvironmentRole role) = 0;
//环境内主节点
virtual bool is_master_node() {
return rank_idx() == 0;
}
virtual bool is_master_node(EnvironmentRole role);
//环境定制化log
template<class... ARGS>
void log(EnvironmentLogType type, EnvironmentLogLevel level,
......@@ -51,29 +63,22 @@ public:
//接口只允许在主线程调用 Start
//barrier
virtual void barrier_all() = 0;
//barrier 指定role的节点
virtual void barrier(EnvironmentRole role) = 0;
//bcast 广播
virtual void bcast(paddle::framework::BinaryArchive& ar, int root_id, EnvironmentRole role) = 0;
//接口只允许在主线程调用 End
protected:
virtual void print_log(EnvironmentLogType type, EnvironmentLogLevel level, const std::string& log_str) = 0;
};
REGISTER_REGISTERER(RuntimeEnvironment);
class MPIRuntimeEnvironment : public RuntimeEnvironment {
public:
MPIRuntimeEnvironment() {}
virtual ~MPIRuntimeEnvironment() {}
//配置初始化
virtual int initialize(YAML::Node config);
//环境初始化,会在所有依赖模块initialize后调用
virtual int wireup();
//当前环境rank_idx
virtual uint32_t rank_idx();
virtual void barrier_all();
protected:
virtual void print_log(EnvironmentLogType type, EnvironmentLogLevel level, const std::string& log_str);
};
std::string format_timestamp(time_t time, const char* format);
std::string format_timestamp(time_t time, const std::string& format) {
return format_timestamp(time, format.c_str());
}
} // namespace feed
} // namespace custom_trainer
......
......@@ -8,6 +8,7 @@
#include <memory>
#include <yaml-cpp/yaml.h>
#include "paddle/fluid/framework/channel.h"
#include "paddle/fluid/train/custom_trainer/feed/common/pipeline.h"
#include "paddle/fluid/train/custom_trainer/feed/common/registerer.h"
namespace paddle {
......@@ -36,6 +37,11 @@ public:
std::string data;//样本数据, maybe压缩格式
};
typedef std::shared_ptr<Pipeline<DataItem, SampleInstance>> SampleInstancePipe;
inline SampleInstancePipe make_sample_instance_channel() {
return std::make_shared<Pipeline<DataItem, SampleInstance>>();
}
class DataParser {
public:
DataParser() {}
......@@ -56,8 +62,12 @@ public:
virtual int initialize(const YAML::Node& config, std::shared_ptr<TrainerContext> context) = 0;
//判断样本数据是否已就绪,就绪表明可以开始download
virtual bool is_data_ready(const std::string& data_dir) = 0;
//读取数据样本流中
virtual int read_all(const std::string& data_dir, ::paddle::framework::Channel<DataItem> data_channel) = 0;
//读取dir下文件列表
virtual std::vector<std::string> data_file_list(const std::string& data_dir);
//读取目录下数据到样本流中
virtual int read_all(const std::string& data_dir, ::paddle::framework::Channel<DataItem>& data_channel) = 0;
//读取指定文件列表的数据到样本流中
virtual int read_all(const std::vector<std::string>& data_list, ::paddle::framework::Channel<DataItem>& data_channel) = 0;
virtual const DataParser* get_parser() {
return _parser.get();
}
......
#include "paddle/fluid/train/custom_trainer/feed/dataset/dataset.h"
namespace paddle {
namespace custom_trainer {
namespace feed {
int Dataset::initialize(
const YAML::Node& config, std::shared_ptr<TrainerContext> context) {
if (!config["data_list"]) {
VLOG(0) << "miss data_list config in dataset, please check";
return -1;
}
int data_num = config["data_list"].size();
for (int i = 0; i < data_num; ++i) {
std::string name = config["data_list"][i]["name"].as<std::string>();
auto data_ptr = std::make_shared<DatasetContainer>();
if (data_ptr->initialize(config["data_list"][i], context) != 0) {
VLOG(0) << "dataset initialize failed, name:" << name;
return -1;
}
_data_containers[name] = data_ptr;
}
return 0;
}
inline void Dataset::pre_detect_data(
const std::string& data_name, uint64_t epoch_id) {
_data_containers[data_name]->pre_detect_data(epoch_id);
return;
}
inline DatasetStatus Dataset::epoch_data_status(
const std::string& data_name, uint64_t epoch_id) {
return _data_containers[data_name]->epoch_data_status(epoch_id);
}
inline ::paddle::framework::Channel<DataItem> Dataset::fetch_data(
const std::string& data_name, uint64_t epoch_id) {
return _data_containers[data_name]->fetch(epoch_id);
}
SampleInstancePipe Dataset::fetch_sample(
const std::string& data_name, uint32_t batch_size, uint64_t epoch_id) {
auto* data_container = _data_containers[data_name].get();
auto data_channel = data_container->fetch(epoch_id);
const auto* data_parser = data_container->data_parser();
PipelineOptions options;
options.batch_size = batch_size;
options.need_hold_input_data = true;
options.buffer_data_num = batch_size * 10;
SampleInstancePipe pipe = make_sample_instance_channel();
pipe->initialize(options, data_channel,
[data_parser] (const DataItem* data, SampleInstance* sample, size_t num) -> int {
int ret = 0;
for (int i = 0; i < num; ++i, ++data, ++sample) {
ret |= data_parser->parse_to_sample(*data, *sample);
}
return ret;
});
return pipe;
}
} // namespace feed
} // namespace custom_trainer
} // namespace paddle
#pragma once
#include <map>
#include <string>
#include <vector>
#include <memory>
#include <yaml-cpp/yaml.h>
#include "paddle/fluid/framework/channel.h"
#include "paddle/fluid/train/custom_trainer/feed/trainer_context.h"
#include "paddle/fluid/train/custom_trainer/feed/common/registerer.h"
#include "paddle/fluid/train/custom_trainer/feed/dataset/dataset_container.h"
namespace paddle {
namespace custom_trainer {
namespace feed {
class Dataset {
public:
Dataset() {}
virtual ~Dataset() {}
virtual int initialize(
const YAML::Node& config, std::shared_ptr<TrainerContext> context);
//触发可预取的数据判断
virtual void pre_detect_data(const std::string& data_name, uint64_t epoch_id);
//获取数据状态
virtual DatasetStatus epoch_data_status(const std::string& data_name, uint64_t epoch_id);
//返回各DataContainer内的原始数据(maybe 压缩格式)
virtual ::paddle::framework::Channel<DataItem> fetch_data(
const std::string& data_name, uint64_t epoch_id);
//以管道形式返回标准样本流,管道内会对数据做异步转换
virtual SampleInstancePipe fetch_sample(
const std::string& data_name, uint32_t batch_size, uint64_t epoch_id);
private:
std::unordered_map<std::string, std::shared_ptr<DatasetContainer>> _data_containers;
};
} // namespace feed
} // namespace custom_trainer
} // namespace paddle
......@@ -8,31 +8,148 @@
#include <yaml-cpp/yaml.h>
#include "paddle/fluid/framework/io/shell.h"
#include "paddle/fluid/string/string_helper.h"
#include "paddle/fluid/train/custom_trainer/feed/trainer_context.h"
#include "paddle/fluid/train/custom_trainer/feed/accessor/epoch_accessor.h"
#include "paddle/fluid/train/custom_trainer/feed/dataset/dataset_container.h"
namespace paddle {
namespace custom_trainer {
namespace feed {
int DatasetContainer::initialize(
const YAML::Node& config, std::shared_ptr<TrainerContext> context) {
_dataset_config = config;
_trainer_context = context.get();
//预取n轮样本数据
_prefetch_num = config["prefetch_num"].as<int>();
_dataset_list.resize(_prefetch_num);
_data_root_paths = paddle::string::split_string(
config["root_path"].as<std::string>(), " ");
_data_split_interval = config["data_spit_interval"].as<int>();
_data_path_formater = config["data_path_formater"].as<std::string>();
std::string data_reader_class = config["data_reader"].as<std::string>();
DataReader* data_reader = CREATE_CLASS(DataReader, data_reader_class);
_data_reader.reset(data_reader);
return _data_reader->initialize(config, context);
}
std::shared_ptr<DatasetInfo> DatasetContainer::dataset(uint64_t timestamp) {
auto* epoch_accessor = _trainer_context->epoch_accessor.get();
auto data_idx = timestamp / epoch_accessor->epoch_time_interval();
return _dataset_list[data_idx % _prefetch_num];
}
void DatasetContainer::pre_detect_data(uint64_t epoch_id) {
int status = 0;
auto* epoch_accessor = _trainer_context->epoch_accessor.get();
time_t timestamp = epoch_accessor->epoch_timestamp(epoch_id);
if (timestamp % epoch_accessor->epoch_time_interval() != 0) {
LOG(FATAL) << "timestamp:" << timestamp << " don't match interval:" << epoch_accessor->epoch_time_interval();
return;
}
size_t data_num = data_num_for_train(timestamp, epoch_accessor->epoch_time_interval(), _data_split_interval);
uint64_t data_timestamp = timestamp % _data_split_interval == 0 ? timestamp : (timestamp / _data_split_interval + 1) * _data_split_interval;
std::vector<std::string> data_path_list;
for (int i = 0; i < _data_root_paths.size() && status == 0; ++i) {
for (int j = 0; j < data_num && status == 0; ++j) {
std::string path_suffix = format_timestamp(data_timestamp + j * _data_split_interval, _data_path_formater);
std::string data_dir = _data_root_paths[i] + "/" + path_suffix;
status = read_data_list(data_dir, data_path_list);
}
}
if (status == 0) {
auto dataset_info = dataset(timestamp);
dataset_info->timestamp = timestamp;
dataset_info->file_path_list = std::move(data_path_list);
dataset_info->status = DatasetStatus::Detected;
}
return;
}
int DatasetContainer::read_data_list(const std::string& data_dir, std::vector<std::string>& data_list) {
auto* environment = _trainer_context->environment.get();
// 检查数据Ready
int data_status = -1;
if (environment->is_master_node(EnvironmentRole::WORKER)) {
if (_data_reader->is_data_ready(data_dir)) {
data_status = 0;
}
}
paddle::framework::BinaryArchive ar;
ar << data_status;
environment->bcast(ar, 0, EnvironmentRole::WORKER);
ar >> data_status;
if (data_status != 0) {
return -1;
}
// 读取文件列表
ar.Clear();
std::vector<std::string> data_path_list;
if (environment->is_master_node(EnvironmentRole::WORKER)) {
data_path_list = _data_reader->data_file_list(data_dir);
ar << data_path_list;
}
environment->bcast(ar, 0, EnvironmentRole::WORKER);
ar >> data_path_list;
auto worker_id = environment->rank_id(EnvironmentRole::WORKER);
auto worker_num = environment->node_num(EnvironmentRole::WORKER);
for (int i = worker_id; i < data_path_list.size(); i+=worker_num) {
data_list.push_back(data_path_list[i]);
}
environment->barrier(EnvironmentRole::WORKER);
return 0;
}
DatasetStatus DatasetContainer::epoch_data_status(uint64_t epoch_id) {
auto* epoch_accessor = _trainer_context->epoch_accessor.get();
time_t timestamp = epoch_accessor->epoch_timestamp(epoch_id);
return data_status(timestamp);
}
DatasetStatus DatasetContainer::data_status(uint64_t timestamp) {
auto dataset_info = dataset(timestamp);
if (dataset_info->timestamp != timestamp) {
return DatasetStatus::Empty;
}
return dataset_info->status;
}
paddle::framework::Channel<DataItem> DatasetContainer::fetch(int epoch_id) {
paddle::framework::Channel<DataItem> DatasetContainer::fetch(uint64_t epoch_id) {
paddle::framework::Channel<DataItem> result;
if (_ready_epoch_id < epoch_id) {
auto* epoch_accessor = _trainer_context->epoch_accessor.get();
time_t timestamp = epoch_accessor->epoch_timestamp(epoch_id);
if (data_status(timestamp) != DatasetStatus::Ready) {
return result;
}
_current_epoch_id = epoch_id;
_current_dataset_idx = epoch_id % _prefetch_num;
//result = _dataset_list[_current_dataset_idx].fetch();
//_dataset_list[_current_dataset_idx].reset((decltype(result.get())*)NULL);
return result;
auto dataset_info = dataset(timestamp);
return dataset_info->data_channel;
}
void DatasetContainer::async_download_data() {
void DatasetContainer::async_download_data(uint64_t start_timestamp) {
auto* epoch_accessor = _trainer_context->epoch_accessor.get();
if (start_timestamp % epoch_accessor->epoch_time_interval() != 0) {
LOG(FATAL) << "timestamp:" << start_timestamp << " don't match interval:" << epoch_accessor->epoch_time_interval();
return;
}
while (true) {
//do download
sleep(30);
auto dataset_info = dataset(start_timestamp);
while (data_status(start_timestamp) != DatasetStatus::Detected) {
sleep(30);
}
const auto& file_list = dataset_info->file_path_list;
dataset_info->data_channel->Clear();
while (_data_reader->read_all(file_list, dataset_info->data_channel) != 0) {
dataset_info->data_channel->Clear();
VLOG(0) << "timestamp:" << start_timestamp << " data read failed, retry";
sleep(30);
}
start_timestamp += epoch_accessor->epoch_time_interval();
}
}
}//namespace feed
}//namespace custom_trainer
}//namespace paddle
} // namespace feed
} // namespace custom_trainer
} // namespace paddle
......@@ -16,38 +16,61 @@ namespace paddle {
namespace custom_trainer {
namespace feed {
inline int data_num_for_train(uint64_t train_begin_timestamp, uint32_t train_time_interval, uint32_t data_time_interval) {
uint64_t data_begin_time = train_begin_timestamp;
uint64_t data_end_time = data_begin_time + train_time_interval;
uint64_t end_idx = (data_end_time + data_time_interval - 1) / data_time_interval;
uint64_t begin_idx = (data_begin_time + data_time_interval - 1 ) / data_time_interval;
return end_idx - begin_idx;
}
enum class DatasetStatus {
Empty = 0,
Detected = 1,
Downloding = 2,
Ready = 3
};
struct DatasetInfo {
uint64_t timestamp = 0;
std::vector<std::string> file_path_list;
DatasetStatus status = DatasetStatus::Empty;
::paddle::framework::Channel<DataItem> data_channel = ::paddle::framework::MakeChannel<DataItem>();
};
class DatasetContainer {
public:
DatasetContainer() {}
virtual ~DatasetContainer() {}
virtual int initialize(const YAML::Node& config) {
_dataset_config = config;
//预取n轮样本数据
_prefetch_num = config["prefetch_num"].as<int>();
_data_root_path = config["root_path"].as<std::string>();
_data_path_generater = config["_data_path_generater"].as<std::string>();
return 0;
}
virtual int initialize(
const YAML::Node& config, std::shared_ptr<TrainerContext> context);
virtual void run();
//获取特定epoch_i样本,如果数据未ready,Channel内为空指针
virtual ::paddle::framework::Channel<DataItem> fetch(int epoch_id);
//触发可预取的数据判断
virtual void pre_detect_data(RuntimeEnvironment* env);
virtual void pre_detect_data(uint64_t epoch_id);
//获取数据状态
virtual DatasetStatus epoch_data_status(uint64_t epoch_id);
//获取特定epoch_i样本,如果数据未ready,Channel内为空指针
virtual ::paddle::framework::Channel<DataItem> fetch(uint64_t epoch_id);
//获取DataItem解析器
virtual const DataParser* data_parser() {
return _data_reader->get_parser();
}
protected:
virtual DatasetStatus data_status(uint64_t timestamp);
virtual int read_data_list(const std::string& data_dir, std::vector<std::string>& data_list);
//异步样本download
virtual void async_download_data();
virtual void download(int epoch_id, const std::vector<std::string>& paths);
virtual void async_download_data(uint64_t start_timestamp);
virtual std::shared_ptr<DatasetInfo> dataset(uint64_t timestamp);
int _prefetch_num = 0;
int _prefetch_num = 0;
int _data_split_interval = 60; //样本切分周期(秒)
YAML::Node _dataset_config;
std::string _data_root_path;
std::string _data_path_generater;
std::string _data_path_formater;
std::vector<std::string> _data_root_paths; //支持同时读取多个目录
uint32_t _current_dataset_idx; //当前样本数据idx
int _current_epoch_id = -1;
int _ready_epoch_id = -1; //已下载完成的epoch_id
std::vector<std::shared_ptr<::paddle::framework::Dataset>> _dataset_list;//预取的数据列表
TrainerContext* _trainer_context;
std::shared_ptr<DataReader> _data_reader;
std::shared_ptr<std::thread> _downloader_thread;
std::vector<std::shared_ptr<DatasetInfo>> _dataset_list;//预取的数据列表
};
}//namespace feed
......
......@@ -35,7 +35,7 @@ int LearnerProcess::initialize(std::shared_ptr<TrainerContext> context_ptr) {
return 0;
}
std::future<int> LearnerProcess::save_model(int epoch_id, int table_id, ModelSaveWay way) {
std::future<int> LearnerProcess::save_model(uint64_t epoch_id, int table_id, ModelSaveWay way) {
std::promise<int> p;
auto ret = p.get_future();
if (_context_ptr->epoch_accessor->need_save_model(epoch_id, way)) {
......@@ -47,7 +47,7 @@ std::future<int> LearnerProcess::save_model(int epoch_id, int table_id, ModelSav
return ret;
}
int LearnerProcess::wait_save_model(int epoch_id, ModelSaveWay way) {
int LearnerProcess::wait_save_model(uint64_t epoch_id, ModelSaveWay way) {
auto* environment = _context_ptr->environment.get();
if (!environment->is_master_node()) {
return 0;
......@@ -71,7 +71,7 @@ int LearnerProcess::wait_save_model(int epoch_id, ModelSaveWay way) {
int LearnerProcess::run() {
auto* environment = _context_ptr->environment.get();
auto* epoch_accessor = _context_ptr->epoch_accessor.get();
int epoch_id = epoch_accessor->current_epoch_id();
uint64_t epoch_id = epoch_accessor->current_epoch_id();
environment->log(EnvironmentLogType::MASTER_LOG, EnvironmentLogLevel::NOTICE,
"Resume traine with epoch_id:%d label:%s", epoch_id, _context_ptr->epoch_accessor->text(epoch_id).c_str());
......
......@@ -21,9 +21,9 @@ public:
protected:
//同步保存所有模型
virtual int wait_save_model(int epoch_id, ModelSaveWay way);
virtual int wait_save_model(uint64_t epoch_id, ModelSaveWay way);
//异步保存指定模型
virtual std::future<int> save_model(int epoch_id, int table_id, ModelSaveWay way);
virtual std::future<int> save_model(uint64_t epoch_id, int table_id, ModelSaveWay way);
//执行指定训练网络
virtual int run_executor(Executor* executor);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册