提交 f6bd8cfd 编写于 作者: X xiexionghang

for mpi trainer

上级 d7ee6ba1
......@@ -36,6 +36,7 @@ CONFIGS('baidu/third-party/pybind11@v2.2.4@git_branch')
CONFIGS('baidu/third-party/python@gcc482output@git_branch')
CONFIGS('baidu/third-party/yaml-cpp@yaml-cpp_0-6-2-0_GEN_PD_BL@git_tag')
CONFIGS('baidu/third-party/openmpi@openmpi_1-4-5-0-feed_mlarch@git_branch')
CONFIGS('baidu/paddlepaddle/pslib@master@git_branch')
CONFIGS('third-64/gtest@base')
HEADERS('paddle/fluid/memory/*.h', '$INC/paddle/fluid/memory/')
......
......@@ -88,7 +88,7 @@ namespace feed {
return "";
}
REGISTER_CLASS(EpochAccessor, HourlyEpochAccessor);
REGIST_CLASS(EpochAccessor, HourlyEpochAccessor);
} // namespace feed
} // namespace custom_trainer
......
......@@ -60,7 +60,7 @@ protected:
std::vector<std::string> _done_status; //当前完成状态,统一存成string
};
REGISTER_REGISTERER(EpochAccessor);
REGIST_REGISTERER(EpochAccessor);
class HourlyEpochAccessor : public EpochAccessor {
public:
......
#pragma once
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/train/custom_trainer/feed/dataset/data_reader.h"
#include "paddle/fluid/train/custom_trainer/feed/accessor/accessor.h"
namespace paddle {
namespace custom_trainer {
namespace feed {
class DataInputAccessor : public Accessor {
public:
DataInputAccessor() {}
virtual ~DataInputAccessor() {}
virtual int initialize(const YAML::Node& config,
std::shared_ptr<TrainerContext> context_ptr);
// 前向, 一般用于填充输入,在训练网络执行前调用
virtual int32_t forward(const SampleInstance* samples,
::paddle::framework::Scope* scope, size_t table_id, size_t num) = 0;
// 后向,一般用于更新梯度,在训练网络执行后调用
virtual int32_t backward(const SampleInstance* samples,
::paddle::framework::Scope* scope, size_t table_id, size_t num) = 0;
protected:
};
REGIST_REGISTERER(DataInputAccessor);
} // namespace feed
} // namespace custom_trainer
} // namespace paddle
#include <vector>
#include <utility>
#include "paddle/fluid/string/string_helper.h"
#include "paddle/fluid/train/custom_trainer/feed/accessor/input_data_accessor.h"
namespace paddle {
namespace custom_trainer {
namespace feed {
class CommonSparseInputAccessor : public DataInputAccessor {
public:
CommonSparseInputAccessor() {}
virtual ~CommonSparseInputAccessor() {}
virtual int initialize(const YAML::Node& config,
std::shared_ptr<TrainerContext> context_ptr) {
CHECK(config["sparse_input"] && config["sparse_input"].Type() == YAML::NodeType::Map);
for (auto& input : config["sparse_input"]) {
std::pair<std::string, std::vector<uint16_t>> sparse_slots;
sparse_slots.first = input.first.as<std::string>();
std::string slots_str = input.second["slots"].as<std::string>();
std::vector<std::string> slots = paddle::string::split_string(slots_str, ",");
for (int i = 0; i < slots.size(); ++i) {
sparse_slots.second.push_back((uint16_t)atoi(slots[i].c_str()));
}
}
return 0;
}
// 取sparse数据
virtual int32_t forward(const SampleInstance* samples,
::paddle::framework::Scope* scope, size_t table_id, size_t num) {
// pull
return 0;
}
// 更新spare数据
virtual int32_t backward(const SampleInstance* samples,
::paddle::framework::Scope* scope, size_t table_id, size_t num) {
return 0;
}
protected:
// 输入层列表
// <data_name, slot_id_list>
std::vector<std::pair<std::string, std::vector<uint16_t> > > _x_variables;
};
} // namespace feed
} // namespace custom_trainer
} // namespace paddle
#include <fcntl.h>
#include <fstream>
#include <sstream>
#include "json2pb/json_to_pb.h"
#include <google/protobuf/text_format.h>
#include <google/protobuf/io/zero_copy_stream_impl.h>
#include "paddle/fluid/train/custom_trainer/feed/common/pslib_warpper.h"
#include "paddle/fluid/train/custom_trainer/feed/common/runtime_environment.h"
namespace paddle {
namespace custom_trainer {
namespace feed {
int PSlib::initialize(const std::string& conf_path,
RuntimeEnvironment* environment, EnvironmentRole role) {
init_gflag();
int file_descriptor = open(conf_path.c_str(), O_RDONLY);
if (file_descriptor == -1){
LOG(ERROR) << "FATAL: cant open " << conf_path;
return -1;
}
google::protobuf::io::FileInputStream fileInput(file_descriptor);
if (!google::protobuf::TextFormat::Parse(&fileInput, &_ps_param)) {
LOG(ERROR) << "FATAL: fail to parse " << conf_path;
return -1;
}
close(file_descriptor);
init_server(role);
init_client(EnvironmentRole::ALL);
return 0;
}
int PSlib::init_server(EnvironmentRole role) {
if (role == EnvironmentRole::PSERVER) {
_server_ptr.reset(paddle::ps::PSServerFactory::create(_ps_param));
_server_ptr->configure(_ps_param, *(_environment->ps_environment()),
_environment->rank_id(role));
_server_ptr->start();
}
_environment->ps_environment()->gather_ps_servers();
return 0;
}
int PSlib::init_client(EnvironmentRole role) {
_client_ptr.reset(paddle::ps::PSClientFactory::create(_ps_param));
_client_ptr->configure(_ps_param, *(_environment->ps_environment()),
_environment->rank_id(role));
return 0;
}
paddle::ps::PSServer* PSlib::ps_server() {
return _server_ptr.get();
}
paddle::ps::PSClient* PSlib::ps_client() {
return _client_ptr.get();
}
paddle::PSParameter* PSlib::get_param() {
return &_ps_param;
}
void PSlib::init_gflag() {
int cnt = 4;
std::shared_ptr<char*> params(new char*[cnt]);
char** params_ptr = params.get();
char p0[] = "exe default";
char p1[] = "-max_body_size=314217728";
char p2[] = "-bthread_concurrency=40";
char p3[] = "-socket_max_unwritten_bytes=2048000000";
params_ptr[0] = p0;
params_ptr[1] = p1;
params_ptr[2] = p2;
params_ptr[3] = p3;
::google::ParseCommandLineFlags(&cnt, &params_ptr, true);
}
} // namespace feed
} // namespace custom_trainer
} // namespace paddle
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "communicate/ps_server.h"
#include "communicate/ps_client.h"
namespace paddle {
namespace custom_trainer {
namespace feed {
class RuntimeEnvironment;
enum class EnvironmentRole;
class PSlib {
public:
PSlib() {}
virtual ~PSlib() {}
int initialize(const std::string& conf_path,
RuntimeEnvironment* environment, EnvironmentRole role);
virtual paddle::ps::PSServer* ps_server();
virtual paddle::ps::PSClient* ps_client();
virtual paddle::PSParameter* get_param();
private:
void init_gflag();
virtual int init_server(EnvironmentRole role);
virtual int init_client(EnvironmentRole role);
paddle::PSParameter _ps_param;
RuntimeEnvironment* _environment;
std::shared_ptr<paddle::ps::PSServer> _server_ptr;
std::shared_ptr<paddle::ps::PSClient> _client_ptr;
};
} // namespace feed
} // namespace custom_trainer
} // namespace paddle
......@@ -3,12 +3,12 @@ namespace paddle {
namespace custom_trainer {
namespace feed {
BaseClassMap& global_factory_map() {
BaseClassMap& global_reg_factory_map() {
static BaseClassMap *base_class = new BaseClassMap();
return *base_class;
}
BaseClassMap& global_factory_map_cpp() {
return global_factory_map();
BaseClassMap& global_reg_factory_map_cpp() {
return global_reg_factory_map();
}
}// feed
......
......@@ -63,23 +63,23 @@ typedef std::map<std::string, FactoryMap> BaseClassMap;
#ifdef __cplusplus
extern "C" {
#endif
BaseClassMap& global_factory_map();
BaseClassMap& global_reg_factory_map();
#ifdef __cplusplus
}
#endif
BaseClassMap& global_factory_map_cpp();
BaseClassMap& global_reg_factory_map_cpp();
#define REGISTER_REGISTERER(base_class) \
#define REGIST_REGISTERER(base_class) \
class base_class ## Registerer { \
public: \
static base_class *CreateInstanceByName(const ::std::string &name) { \
if (global_factory_map_cpp().find(#base_class) \
== global_factory_map_cpp().end()) { \
if (global_reg_factory_map_cpp().find(#base_class) \
== global_reg_factory_map_cpp().end()) { \
LOG(ERROR) << "Can't Find BaseClass For CreateClass with:" << #base_class; \
return NULL; \
} \
FactoryMap &map = global_factory_map_cpp()[#base_class]; \
FactoryMap &map = global_reg_factory_map_cpp()[#base_class]; \
FactoryMap::iterator iter = map.find(name); \
if (iter == map.end()) { \
LOG(ERROR) << "Can't Find Class For Create with:" << name; \
......@@ -90,7 +90,7 @@ BaseClassMap& global_factory_map_cpp();
} \
};
#define REGISTER_CLASS(clazz, name) \
#define REGIST_CLASS(clazz, name) \
class ObjectFactory##name : public ObjectFactory { \
public: \
Any NewInstance() { \
......@@ -98,14 +98,14 @@ BaseClassMap& global_factory_map_cpp();
} \
}; \
void register_factory_##name() { \
FactoryMap &map = global_factory_map_cpp()[#clazz]; \
FactoryMap &map = global_reg_factory_map_cpp()[#clazz]; \
if (map.find(#name) == map.end()) { \
map[#name] = new ObjectFactory##name(); \
} \
} \
void register_factory_##name() __attribute__((constructor));
#define CREATE_CLASS(base_class, name) \
#define CREATE_INSTANCE(base_class, name) \
base_class##Registerer::CreateInstanceByName(name)
}//namespace feed
......
......@@ -44,6 +44,11 @@ public:
set_role(EnvironmentRole::ALL);
return 0;
}
virtual paddle::ps::PSEnvironment* ps_environment() {
static paddle::ps::MpiPSEnvironment ps_environment;
return &ps_environment;
}
virtual uint32_t rank_id(EnvironmentRole role) {
return mpi_node_info(role).rank_id;
......@@ -95,7 +100,7 @@ protected:
private:
std::vector<MpiNodeInfo> _roles_node_info;
};
REGISTER_CLASS(RuntimeEnvironment, MPIRuntimeEnvironment);
REGIST_CLASS(RuntimeEnvironment, MPIRuntimeEnvironment);
//用于本地模式单机训练
class LocalRuntimeEnvironment : public RuntimeEnvironment {
......@@ -108,6 +113,10 @@ public:
virtual int wireup() {
return 0;
}
virtual paddle::ps::PSEnvironment* ps_environment() {
static paddle::ps::LocalPSEnvironment ps_environment;
return &ps_environment;
}
virtual uint32_t rank_id(EnvironmentRole role) {
return 0;
}
......@@ -129,7 +138,7 @@ protected:
VLOG(static_cast<int>(level)) << log_str;
}
};
REGISTER_CLASS(RuntimeEnvironment, LocalRuntimeEnvironment);
REGIST_CLASS(RuntimeEnvironment, LocalRuntimeEnvironment);
} // namespace feed
} // namespace custom_trainer
......
......@@ -6,6 +6,7 @@
*/
#pragma once
#include <yaml-cpp/yaml.h>
#include "communicate/ps_env.h"
#include "paddle/fluid/framework/archive.h"
#include "paddle/fluid/string/string_helper.h"
#include "paddle/fluid/train/custom_trainer/feed/common/registerer.h"
......@@ -14,6 +15,8 @@ namespace paddle {
namespace custom_trainer {
namespace feed {
class paddle::ps::PSEnvironment;
enum class EnvironmentLogLevel {
FATAL = 0,
ERROR = 1,
......@@ -38,41 +41,43 @@ class RuntimeEnvironment {
public:
RuntimeEnvironment();
virtual ~RuntimeEnvironment();
//配置初始化
// 配置初始化
virtual int initialize(YAML::Node config) = 0;
//设置role
// 设置role
virtual int set_role(EnvironmentRole role) = 0;
//环境初始化,会在所有依赖模块initialize后调用
// 环境初始化,会在所有依赖模块initialize后调用
virtual int wireup() = 0;
//多线程可调用接口 Start
//当前环境rank_idx
// 多线程可调用接口 Start
// 当前环境rank_idx
virtual uint32_t rank_id(EnvironmentRole role) = 0;
//运行环境节点数
// 运行环境节点数
virtual uint32_t node_num(EnvironmentRole role) = 0;
//环境内主节点
// 环境内主节点
virtual bool is_master_node(EnvironmentRole role);
//For PS
virtual paddle::ps::PSEnvironment* ps_environment() = 0;
//环境定制化log
// 环境定制化log
template<class... ARGS>
void log(EnvironmentRole role, EnvironmentLogType type,
EnvironmentLogLevel level, const char* fmt, ARGS && ... args) {
print_log(role, type, level, paddle::string::format_string(fmt, args...));
}
//多线程可调用接口 End
// 多线程可调用接口 End
//接口只允许在主线程调用 Start
//barrier 指定role的节点
// 接口只允许在主线程调用 Start
// barrier 指定role的节点
virtual void barrier(EnvironmentRole role) = 0;
//bcast 广播
// bcast 广播
virtual void bcast(paddle::framework::BinaryArchive& ar, int root_id, EnvironmentRole role) = 0;
//接口只允许在主线程调用 End
// 接口只允许在主线程调用 End
protected:
virtual void print_log(EnvironmentRole role, EnvironmentLogType type,
EnvironmentLogLevel level, const std::string& log_str) = 0;
};
REGISTER_REGISTERER(RuntimeEnvironment);
REGIST_REGISTERER(RuntimeEnvironment);
std::string format_timestamp(time_t time, const char* format);
inline std::string format_timestamp(time_t time, const std::string& format) {
......
......@@ -56,10 +56,10 @@ public:
return 0;
}
};
REGISTER_CLASS(DataParser, LineDataParser);
REGIST_CLASS(DataParser, LineDataParser);
int DataReader::initialize(const YAML::Node& config, std::shared_ptr<TrainerContext> context) {
_parser.reset(CREATE_CLASS(DataParser, config["parser"]["class"].as<std::string>()));
_parser.reset(CREATE_INSTANCE(DataParser, config["parser"]["class"].as<std::string>()));
if (_parser == nullptr) {
VLOG(2) << "fail to get parser: " << config["parser"]["class"].as<std::string>();
return -1;
......@@ -85,7 +85,7 @@ public:
if (config["file_system"] && config["file_system"]["class"]) {
_file_system.reset(
CREATE_CLASS(FileSystem, config["file_system"]["class"].as<std::string>()));
CREATE_INSTANCE(FileSystem, config["file_system"]["class"].as<std::string>()));
if (_file_system == nullptr ||
_file_system->initialize(config["file_system"], context) != 0) {
VLOG(2) << "fail to create class: "
......@@ -95,7 +95,7 @@ public:
} else if (context->file_system != nullptr) {
_file_system = context->file_system;
} else {
_file_system.reset(CREATE_CLASS(FileSystem, "LocalFileSystem"));
_file_system.reset(CREATE_INSTANCE(FileSystem, "LocalFileSystem"));
if (_file_system == nullptr || _file_system->initialize(YAML::Load(""), context) != 0) {
VLOG(2) << "fail to init file system";
return -1;
......@@ -203,7 +203,7 @@ private:
std::string _filename_prefix;
std::shared_ptr<FileSystem> _file_system;
};
REGISTER_CLASS(DataReader, LineDataReader);
REGIST_CLASS(DataReader, LineDataReader);
} // namespace feed
} // namespace custom_trainer
......
......@@ -54,7 +54,7 @@ public:
virtual int parse(const char* str, DataItem& data) const = 0;
virtual int parse_to_sample(const DataItem& data, SampleInstance& instance) const = 0;
};
REGISTER_REGISTERER(DataParser);
REGIST_REGISTERER(DataParser);
class DataReader {
public:
......@@ -76,7 +76,7 @@ protected:
std::shared_ptr<DataParser> _parser;//数据格式转换
std::string _pipeline_cmd; //将文件流,重定向到pipeline_cmd,再读入
};
REGISTER_REGISTERER(DataReader);
REGIST_REGISTERER(DataReader);
}//namespace feed
}//namespace custom_trainer
......
......@@ -32,7 +32,7 @@ int DatasetContainer::initialize(
_data_split_interval = config["data_spit_interval"].as<int>();
_data_path_formater = config["data_path_formater"].as<std::string>();
std::string data_reader_class = config["data_reader"].as<std::string>();
DataReader* data_reader = CREATE_CLASS(DataReader, data_reader_class);
DataReader* data_reader = CREATE_INSTANCE(DataReader, data_reader_class);
_data_reader.reset(data_reader);
return _data_reader->initialize(config, context);
}
......
......@@ -121,7 +121,7 @@ protected:
std::unique_ptr<Context> _context;
};
REGISTER_CLASS(Executor, SimpleExecutor);
REGIST_CLASS(Executor, SimpleExecutor);
} // namespace feed
} // namespace custom_trainer
......
......@@ -40,7 +40,7 @@ public:
protected:
::paddle::framework::Scope _scope;
};
REGISTER_REGISTERER(Executor);
REGIST_REGISTERER(Executor);
} // namespace feed
} // namespace custom_trainer
......
......@@ -31,7 +31,7 @@ public:
_file_system.clear();
if (config && config["file_systems"] && config["file_systems"].Type() == YAML::NodeType::Map) {
for (auto& prefix_fs: config["file_systems"]) {
std::unique_ptr<FileSystem> fs(CREATE_CLASS(FileSystem, prefix_fs.second["class"].as<std::string>("")));
std::unique_ptr<FileSystem> fs(CREATE_INSTANCE(FileSystem, prefix_fs.second["class"].as<std::string>("")));
if (fs == nullptr) {
VLOG(2) << "fail to create class: " << prefix_fs.second["class"].as<std::string>("");
return -1;
......@@ -44,7 +44,7 @@ public:
}
}
if (_file_system.find("default") == _file_system.end()) {
std::unique_ptr<FileSystem> fs(CREATE_CLASS(FileSystem, "LocalFileSystem"));
std::unique_ptr<FileSystem> fs(CREATE_INSTANCE(FileSystem, "LocalFileSystem"));
if (fs == nullptr || fs->initialize(YAML::Load(""), context) != 0) {
return -1;
}
......@@ -122,7 +122,7 @@ public:
private:
std::unordered_map<std::string, std::unique_ptr<FileSystem>> _file_system;
};
REGISTER_CLASS(FileSystem, AutoFileSystem);
REGIST_CLASS(FileSystem, AutoFileSystem);
} // namespace feed
} // namespace custom_trainer
......
......@@ -52,7 +52,7 @@ public:
protected:
int _err_no = 0;
};
REGISTER_REGISTERER(FileSystem);
REGIST_REGISTERER(FileSystem);
} // namespace feed
} // namespace custom_trainer
......
......@@ -203,7 +203,7 @@ private:
std::string _hdfs_command;
std::unordered_map<std::string, std::string> _ugi;
};
REGISTER_CLASS(FileSystem, HadoopFileSystem);
REGIST_CLASS(FileSystem, HadoopFileSystem);
} // namespace feed
} // namespace custom_trainer
......
......@@ -129,7 +129,7 @@ public:
private:
size_t _buffer_size = 0;
};
REGISTER_CLASS(FileSystem, LocalFileSystem);
REGIST_CLASS(FileSystem, LocalFileSystem);
} // namespace feed
} // namespace custom_trainer
......
#include <time.h>
#include <fstream>
#include <yaml-cpp/yaml.h>
#include "paddle/fluid/platform/init.h"
#include "paddle/fluid/train/custom_trainer/feed/trainer_context.h"
#include "paddle/fluid/platform/init.h"
#include "paddle/fluid/train/custom_trainer/feed/process/process.h"
#include "paddle/fluid/train/custom_trainer/feed/process/init_env_process.h"
#include "paddle/fluid/framework/op_registry.h"
......@@ -21,27 +21,56 @@ int main(int argc, char* argv[]) {
//load trainer config
auto trainer_context_ptr = std::make_shared<TrainerContext>();
trainer_context_ptr->trainer_config = YAML::LoadFile(FLAGS_feed_trainer_conf_path);
//environment
auto& config = trainer_context_ptr->trainer_config;
std::string env_class = config["environment"]["environment_class"].as<std::string>();
trainer_context_ptr->environment.reset(CREATE_INSTANCE(RuntimeEnvironment, env_class));
if (trainer_context_ptr->environment->initialize(config["environment"]) != 0) {
return -1;
}
EnvironmentRole role;
auto* environment = trainer_context_ptr->environment.get();
environment->wireup();
if (environment->rank_id(EnvironmentRole::ALL) % 2 == 0) {
role = EnvironmentRole::WORKER;
} else {
role = EnvironmentRole::PSERVER;
}
environment->set_role(role);
trainer_context_ptr->pslib.reset(new PSlib());
std::string ps_config = config["environment"]["ps"].as<std::string>();
trainer_context_ptr->pslib->initialize(ps_config, environment, role);
//VLOG(3) << "Node Start With Role:" << role;
std::vector<std::string> process_name_list = {
"InitEnvProcess",
"LearnerProcess"
};
for (const auto& process_name : process_name_list) {
Process* process = CREATE_CLASS(Process, process_name);
if (process == NULL) {
VLOG(1) << "Process:" << process_name << " does not exist";
return -1;
switch (role) {
case EnvironmentRole::WORKER:
for (const auto& process_name : process_name_list) {
Process* process = CREATE_INSTANCE(Process, process_name);
if (process == NULL) {
VLOG(1) << "Process:" << process_name << " does not exist";
return -1;
}
if (process->initialize(trainer_context_ptr) != 0) {
VLOG(1) << "Process:" << process_name << " initialize failed";
return -1;
}
trainer_context_ptr->process_list.push_back(std::shared_ptr<Process>(process));
}
for (auto& process : trainer_context_ptr->process_list) {
process->run();
}
if (process->initialize(trainer_context_ptr) != 0) {
VLOG(1) << "Process:" << process_name << " initialize failed";
return -1;
break;
case EnvironmentRole::PSERVER:
//wait server done
while (true) {
sleep(10000);
}
trainer_context_ptr->process_list.push_back(std::shared_ptr<Process>(process));
}
for (auto& process : trainer_context_ptr->process_list) {
process->run();
break;
}
return 0;
......
......@@ -39,7 +39,7 @@ protected:
std::string _name;
};
REGISTER_REGISTERER(Monitor);
REGIST_REGISTERER(Monitor);
} // namespace feed
} // namespace custom_trainer
......
......@@ -20,22 +20,16 @@ int InitEnvProcess::initialize(std::shared_ptr<TrainerContext> context_ptr) {
context_ptr->cpu_place = paddle::platform::CPUPlace();
YAML::Node config = _context_ptr->trainer_config;
//environment
std::string env_class = config["environment"]["environment_class"].as<std::string>();
context_ptr->environment.reset(CREATE_CLASS(RuntimeEnvironment, env_class));
if (context_ptr->environment->initialize(config["environment"]) != 0) {
return -1;
}
//file_system
context_ptr->file_system.reset(CREATE_CLASS(FileSystem, "AutoFileSystem"));
context_ptr->file_system.reset(CREATE_INSTANCE(FileSystem, "AutoFileSystem"));
if (context_ptr->file_system->initialize(config["io"], context_ptr) != 0) {
return -1;
}
//epoch
std::string epoch_class = config["epoch"]["epoch_class"].as<std::string>();
context_ptr->epoch_accessor.reset(CREATE_CLASS(EpochAccessor, epoch_class));
context_ptr->epoch_accessor.reset(CREATE_INSTANCE(EpochAccessor, epoch_class));
if (context_ptr->epoch_accessor->initialize(config["epoch"], context_ptr) != 0) {
return -1;
}
......@@ -55,10 +49,12 @@ int InitEnvProcess::run() {
VLOG(3) << "Trainer Resume From epoch:" << epoch_accessor->current_epoch_id();
auto next_epoch_id = epoch_accessor->next_epoch_id(epoch_accessor->current_epoch_id());
_context_ptr->dataset->pre_detect_data(next_epoch_id);
//step 1. psserver init
//step2. psserver load
VLOG(3) << "Psserver Start Success";
if (epoch_accessor->checkpoint_path().size() > 0) {
//Load Model
} else {
//Random Init Model
}
//context_ptr->pslib_client()->load_model();
VLOG(3) << "Psserver Load Model Success";
return 0;
......
......@@ -25,7 +25,7 @@ int LearnerProcess::initialize(std::shared_ptr<TrainerContext> context_ptr) {
_threads_executor[i].resize(_executor_num);
for (int e = 0; e < _executor_num; ++e) {
auto e_class = config["executor"][e]["class"].as<std::string>();
auto* e_ptr = CREATE_CLASS(Executor, e_class);
auto* e_ptr = CREATE_INSTANCE(Executor, e_class);
_threads_executor[i][e].reset(e_ptr);
if (e_ptr->initialize(config["executor"][e], context_ptr) != 0) {
ret = -1;
......@@ -84,53 +84,59 @@ int LearnerProcess::run() {
while (true) {
epoch_accessor->next_epoch();
bool already_dump_inference_model = false;
epoch_id = epoch_accessor->current_epoch_id();
std::string epoch_log_title= paddle::string::format_string(
"train epoch_id:%d label:%s", epoch_id, epoch_accessor->text(epoch_id).c_str());
//Step1. 等待样本ready
environment->log(EnvironmentRole::WORKER, EnvironmentLogType::MASTER_LOG, EnvironmentLogLevel::NOTICE,
"Start %s, wait data ready", epoch_log_title.c_str());
while (dataset->epoch_data_status(epoch_id) != DatasetStatus::Ready) {
sleep(30);
dataset->pre_detect_data(epoch_id);
{
environment->log(EnvironmentRole::WORKER, EnvironmentLogType::MASTER_LOG, EnvironmentLogLevel::NOTICE,
"%s, data not ready, wait 30s", epoch_log_title.c_str());
}
environment->log(EnvironmentRole::WORKER, EnvironmentLogType::MASTER_LOG, EnvironmentLogLevel::NOTICE,
"%s, data is ready, start traning", epoch_log_title.c_str());
environment->barrier(EnvironmentRole::WORKER);
"Start %s, wait data ready", epoch_log_title.c_str());
while (dataset->epoch_data_status(epoch_id) != DatasetStatus::Ready) {
sleep(30);
dataset->pre_detect_data(epoch_id);
environment->log(EnvironmentRole::WORKER, EnvironmentLogType::MASTER_LOG, EnvironmentLogLevel::NOTICE,
"%s, data not ready, wait 30s", epoch_log_title.c_str());
}
environment->log(EnvironmentRole::WORKER, EnvironmentLogType::MASTER_LOG, EnvironmentLogLevel::NOTICE,
"%s, data is ready, start traning", epoch_log_title.c_str());
environment->barrier(EnvironmentRole::WORKER);
}
//Step2. 运行训练网络
bool already_dump_inference_model = false;
for (int i = 0; i < _executor_num; ++i) {
std::vector<std::shared_ptr<std::thread>> train_threads(_train_thread_num);
for (int thread_id = 0; thread_id < _train_thread_num; ++thread_id) {
train_threads[i].reset(new std::thread([this](int exe_idx, int thread_idx) {
auto* executor = _threads_executor[thread_idx][exe_idx].get();
run_executor(executor);
}, i, thread_id));
}
for (int i = 0; i < _train_thread_num; ++i) {
train_threads[i]->join();
{
for (int i = 0; i < _executor_num; ++i) {
std::vector<std::shared_ptr<std::thread>> train_threads(_train_thread_num);
for (int thread_id = 0; thread_id < _train_thread_num; ++thread_id) {
train_threads[i].reset(new std::thread([this](int exe_idx, int thread_idx) {
auto* executor = _threads_executor[thread_idx][exe_idx].get();
run_executor(executor);
}, i, thread_id));
}
for (int i = 0; i < _train_thread_num; ++i) {
train_threads[i]->join();
}
environment->barrier(EnvironmentRole::WORKER);
if (_threads_executor[0][i]->is_dump_all_model()) {
already_dump_inference_model = true;
wait_save_model(epoch_id, ModelSaveWay::ModelSaveInferenceDelta);
}
environment->barrier(EnvironmentRole::WORKER);
}
environment->barrier(EnvironmentRole::WORKER);
}
if (_threads_executor[0][i]->is_dump_all_model()) {
//Step3. Dump Model For Delta&&Checkpoint
{
if (!already_dump_inference_model) {
already_dump_inference_model = true;
wait_save_model(epoch_id, ModelSaveWay::ModelSaveInferenceDelta);
}
}
wait_save_model(epoch_id, ModelSaveWay::ModelSaveTrainCheckpoint);
environment->barrier(EnvironmentRole::WORKER);
}
//Step3. Dump Model For Delta&&Checkpoint
if (!already_dump_inference_model) {
already_dump_inference_model = true;
wait_save_model(epoch_id, ModelSaveWay::ModelSaveInferenceDelta);
}
wait_save_model(epoch_id, ModelSaveWay::ModelSaveTrainCheckpoint);
environment->barrier(EnvironmentRole::WORKER);
//Step4. Output Monitor && RunStatus
//TODO
}
......
......@@ -5,8 +5,8 @@
namespace paddle {
namespace custom_trainer {
namespace feed {
REGISTER_CLASS(Process, InitEnvProcess);
REGISTER_CLASS(Process, LearnerProcess);
REGIST_CLASS(Process, InitEnvProcess);
REGIST_CLASS(Process, LearnerProcess);
int Process::run() {
return 0;
}
......
......@@ -18,7 +18,7 @@ public:
protected:
TrainerContext* _context_ptr = NULL;
};
REGISTER_REGISTERER(Process);
REGIST_REGISTERER(Process);
} // namespace feed
} // namespace custom_trainer
......
......@@ -4,6 +4,7 @@
#include <vector>
#include <yaml-cpp/yaml.h>
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/train/custom_trainer/feed/common/pslib_warpper.h"
#include "paddle/fluid/train/custom_trainer/feed/common/runtime_environment.h"
......@@ -38,6 +39,7 @@ public:
YAML::Node trainer_config;
paddle::platform::CPUPlace cpu_place;
std::shared_ptr<PSlib> pslib;
std::shared_ptr<Dataset> dataset; //训练样本
std::shared_ptr<FileSystem> file_system; //文件操作辅助类
std::vector<TableMeta> params_table_list; //参数表
......
......@@ -17,6 +17,7 @@ limitations under the License. */
#include <gtest/gtest.h>
#include <omp.h>
#include "paddle/fluid/train/custom_trainer/feed/trainer_context.h"
#include "paddle/fluid/train/custom_trainer/feed/executor/executor.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/framework/program_desc.h"
......@@ -36,7 +37,7 @@ const char test_data_dir[] = "test_data";
class DataReaderTest : public testing::Test {
public:
static void SetUpTestCase() {
std::unique_ptr<FileSystem> fs(CREATE_CLASS(FileSystem, "LocalFileSystem"));
std::unique_ptr<FileSystem> fs(CREATE_INSTANCE(FileSystem, "LocalFileSystem"));
fs->mkdir(test_data_dir);
shell_set_verbose(true);
......@@ -56,14 +57,14 @@ public:
}
static void TearDownTestCase() {
std::unique_ptr<FileSystem> fs(CREATE_CLASS(FileSystem, "LocalFileSystem"));
std::unique_ptr<FileSystem> fs(CREATE_INSTANCE(FileSystem, "LocalFileSystem"));
fs->remove(test_data_dir);
}
virtual void SetUp() {
thread_num = omp_get_max_threads();
omp_set_num_threads(1);
fs.reset(CREATE_CLASS(FileSystem, "LocalFileSystem"));
fs.reset(CREATE_INSTANCE(FileSystem, "LocalFileSystem"));
context_ptr.reset(new TrainerContext());
}
......@@ -79,7 +80,7 @@ public:
};
TEST_F(DataReaderTest, LineDataParser) {
std::unique_ptr<DataParser> data_parser(CREATE_CLASS(DataParser, "LineDataParser"));
std::unique_ptr<DataParser> data_parser(CREATE_INSTANCE(DataParser, "LineDataParser"));
ASSERT_NE(nullptr, data_parser);
auto config = YAML::Load("");
......@@ -108,7 +109,7 @@ TEST_F(DataReaderTest, LineDataParser) {
}
TEST_F(DataReaderTest, LineDataReader) {
std::unique_ptr<DataReader> data_reader(CREATE_CLASS(DataReader, "LineDataReader"));
std::unique_ptr<DataReader> data_reader(CREATE_INSTANCE(DataReader, "LineDataReader"));
ASSERT_NE(nullptr, data_reader);
auto config = YAML::Load(
......@@ -161,7 +162,7 @@ TEST_F(DataReaderTest, LineDataReader) {
}
TEST_F(DataReaderTest, LineDataReader_filename_prefix) {
std::unique_ptr<DataReader> data_reader(CREATE_CLASS(DataReader, "LineDataReader"));
std::unique_ptr<DataReader> data_reader(CREATE_INSTANCE(DataReader, "LineDataReader"));
ASSERT_NE(nullptr, data_reader);
auto config = YAML::Load(
"parser:\n"
......@@ -196,7 +197,7 @@ TEST_F(DataReaderTest, LineDataReader_filename_prefix) {
}
TEST_F(DataReaderTest, LineDataReader_FileSystem) {
std::unique_ptr<DataReader> data_reader(CREATE_CLASS(DataReader, "LineDataReader"));
std::unique_ptr<DataReader> data_reader(CREATE_INSTANCE(DataReader, "LineDataReader"));
ASSERT_NE(nullptr, data_reader);
auto config = YAML::Load(
"parser:\n"
......
......@@ -18,6 +18,7 @@ limitations under the License. */
#include <gtest/gtest.h>
#include <omp.h>
#include "paddle/fluid/train/custom_trainer/feed/trainer_context.h"
#include "paddle/fluid/train/custom_trainer/feed/executor/executor.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/framework/program_desc.h"
......@@ -37,7 +38,7 @@ const char test_data_dir[] = "test_data";
class DataReaderOmpTest : public testing::Test {
public:
static void SetUpTestCase() {
std::unique_ptr<FileSystem> fs(CREATE_CLASS(FileSystem, "LocalFileSystem"));
std::unique_ptr<FileSystem> fs(CREATE_INSTANCE(FileSystem, "LocalFileSystem"));
fs->mkdir(test_data_dir);
shell_set_verbose(true);
std_items.clear();
......@@ -61,14 +62,14 @@ public:
}
static void TearDownTestCase() {
std::unique_ptr<FileSystem> fs(CREATE_CLASS(FileSystem, "LocalFileSystem"));
std::unique_ptr<FileSystem> fs(CREATE_INSTANCE(FileSystem, "LocalFileSystem"));
fs->remove(test_data_dir);
}
virtual void SetUp() {
thread_num = omp_get_max_threads();
omp_set_num_threads(1);
fs.reset(CREATE_CLASS(FileSystem, "LocalFileSystem"));
fs.reset(CREATE_INSTANCE(FileSystem, "LocalFileSystem"));
context_ptr.reset(new TrainerContext());
}
......@@ -111,7 +112,7 @@ std::vector<DataItem> DataReaderOmpTest::std_items;
std::vector<DataItem> DataReaderOmpTest::sorted_std_items;
TEST_F(DataReaderOmpTest, LineDataReaderSingleThread) {
std::unique_ptr<DataReader> data_reader(CREATE_CLASS(DataReader, "LineDataReader"));
std::unique_ptr<DataReader> data_reader(CREATE_INSTANCE(DataReader, "LineDataReader"));
ASSERT_NE(nullptr, data_reader);
auto config = YAML::Load(
......@@ -148,7 +149,7 @@ TEST_F(DataReaderOmpTest, LineDataReaderSingleThread) {
}
TEST_F(DataReaderOmpTest, LineDataReaderMuiltThread) {
std::unique_ptr<DataReader> data_reader(CREATE_CLASS(DataReader, "LineDataReader"));
std::unique_ptr<DataReader> data_reader(CREATE_INSTANCE(DataReader, "LineDataReader"));
ASSERT_NE(nullptr, data_reader);
auto config = YAML::Load(
......
......@@ -16,6 +16,7 @@ limitations under the License. */
#include <fstream>
#include <gtest/gtest.h>
#include "paddle/fluid/train/custom_trainer/feed/trainer_context.h"
#include "paddle/fluid/train/custom_trainer/feed/executor/executor.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/framework/program_desc.h"
......@@ -38,7 +39,7 @@ class SimpleExecutorTest : public testing::Test
public:
static void SetUpTestCase()
{
std::unique_ptr<FileSystem> fs(CREATE_CLASS(FileSystem, "LocalFileSystem"));
std::unique_ptr<FileSystem> fs(CREATE_INSTANCE(FileSystem, "LocalFileSystem"));
fs->mkdir(test_data_dir);
shell_set_verbose(true);
......@@ -70,7 +71,7 @@ public:
static void TearDownTestCase()
{
std::unique_ptr<FileSystem> fs(CREATE_CLASS(FileSystem, "LocalFileSystem"));
std::unique_ptr<FileSystem> fs(CREATE_INSTANCE(FileSystem, "LocalFileSystem"));
fs->remove(test_data_dir);
}
......@@ -88,7 +89,7 @@ public:
};
TEST_F(SimpleExecutorTest, initialize) {
std::unique_ptr<Executor> executor(CREATE_CLASS(Executor, "SimpleExecutor"));
std::unique_ptr<Executor> executor(CREATE_INSTANCE(Executor, "SimpleExecutor"));
ASSERT_NE(nullptr, executor);
YAML::Node config = YAML::Load("[1, 2, 3]");
ASSERT_NE(0, executor->initialize(config, context_ptr));
......@@ -99,7 +100,7 @@ TEST_F(SimpleExecutorTest, initialize) {
}
TEST_F(SimpleExecutorTest, run) {
std::unique_ptr<Executor> executor(CREATE_CLASS(Executor, "SimpleExecutor"));
std::unique_ptr<Executor> executor(CREATE_INSTANCE(Executor, "SimpleExecutor"));
ASSERT_NE(nullptr, executor);
auto config = YAML::Load(string::format_string("{thread_num: 2, startup_program: %s, main_program: %s}", startup_program_path, main_program_path));
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册