提交 f6bd8cfd 编写于 作者: X xiexionghang

for mpi trainer

上级 d7ee6ba1
...@@ -36,6 +36,7 @@ CONFIGS('baidu/third-party/pybind11@v2.2.4@git_branch') ...@@ -36,6 +36,7 @@ CONFIGS('baidu/third-party/pybind11@v2.2.4@git_branch')
CONFIGS('baidu/third-party/python@gcc482output@git_branch') CONFIGS('baidu/third-party/python@gcc482output@git_branch')
CONFIGS('baidu/third-party/yaml-cpp@yaml-cpp_0-6-2-0_GEN_PD_BL@git_tag') CONFIGS('baidu/third-party/yaml-cpp@yaml-cpp_0-6-2-0_GEN_PD_BL@git_tag')
CONFIGS('baidu/third-party/openmpi@openmpi_1-4-5-0-feed_mlarch@git_branch') CONFIGS('baidu/third-party/openmpi@openmpi_1-4-5-0-feed_mlarch@git_branch')
CONFIGS('baidu/paddlepaddle/pslib@master@git_branch')
CONFIGS('third-64/gtest@base') CONFIGS('third-64/gtest@base')
HEADERS('paddle/fluid/memory/*.h', '$INC/paddle/fluid/memory/') HEADERS('paddle/fluid/memory/*.h', '$INC/paddle/fluid/memory/')
......
...@@ -88,7 +88,7 @@ namespace feed { ...@@ -88,7 +88,7 @@ namespace feed {
return ""; return "";
} }
REGISTER_CLASS(EpochAccessor, HourlyEpochAccessor); REGIST_CLASS(EpochAccessor, HourlyEpochAccessor);
} // namespace feed } // namespace feed
} // namespace custom_trainer } // namespace custom_trainer
......
...@@ -60,7 +60,7 @@ protected: ...@@ -60,7 +60,7 @@ protected:
std::vector<std::string> _done_status; //当前完成状态,统一存成string std::vector<std::string> _done_status; //当前完成状态,统一存成string
}; };
REGISTER_REGISTERER(EpochAccessor); REGIST_REGISTERER(EpochAccessor);
class HourlyEpochAccessor : public EpochAccessor { class HourlyEpochAccessor : public EpochAccessor {
public: public:
......
#pragma once
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/train/custom_trainer/feed/dataset/data_reader.h"
#include "paddle/fluid/train/custom_trainer/feed/accessor/accessor.h"
namespace paddle {
namespace custom_trainer {
namespace feed {
class DataInputAccessor : public Accessor {
public:
DataInputAccessor() {}
virtual ~DataInputAccessor() {}
virtual int initialize(const YAML::Node& config,
std::shared_ptr<TrainerContext> context_ptr);
// 前向, 一般用于填充输入,在训练网络执行前调用
virtual int32_t forward(const SampleInstance* samples,
::paddle::framework::Scope* scope, size_t table_id, size_t num) = 0;
// 后向,一般用于更新梯度,在训练网络执行后调用
virtual int32_t backward(const SampleInstance* samples,
::paddle::framework::Scope* scope, size_t table_id, size_t num) = 0;
protected:
};
REGIST_REGISTERER(DataInputAccessor);
} // namespace feed
} // namespace custom_trainer
} // namespace paddle
#include <vector>
#include <utility>
#include "paddle/fluid/string/string_helper.h"
#include "paddle/fluid/train/custom_trainer/feed/accessor/input_data_accessor.h"
namespace paddle {
namespace custom_trainer {
namespace feed {
class CommonSparseInputAccessor : public DataInputAccessor {
public:
CommonSparseInputAccessor() {}
virtual ~CommonSparseInputAccessor() {}
virtual int initialize(const YAML::Node& config,
std::shared_ptr<TrainerContext> context_ptr) {
CHECK(config["sparse_input"] && config["sparse_input"].Type() == YAML::NodeType::Map);
for (auto& input : config["sparse_input"]) {
std::pair<std::string, std::vector<uint16_t>> sparse_slots;
sparse_slots.first = input.first.as<std::string>();
std::string slots_str = input.second["slots"].as<std::string>();
std::vector<std::string> slots = paddle::string::split_string(slots_str, ",");
for (int i = 0; i < slots.size(); ++i) {
sparse_slots.second.push_back((uint16_t)atoi(slots[i].c_str()));
}
}
return 0;
}
// 取sparse数据
virtual int32_t forward(const SampleInstance* samples,
::paddle::framework::Scope* scope, size_t table_id, size_t num) {
// pull
return 0;
}
// 更新spare数据
virtual int32_t backward(const SampleInstance* samples,
::paddle::framework::Scope* scope, size_t table_id, size_t num) {
return 0;
}
protected:
// 输入层列表
// <data_name, slot_id_list>
std::vector<std::pair<std::string, std::vector<uint16_t> > > _x_variables;
};
} // namespace feed
} // namespace custom_trainer
} // namespace paddle
#include <fcntl.h>
#include <fstream>
#include <sstream>
#include "json2pb/json_to_pb.h"
#include <google/protobuf/text_format.h>
#include <google/protobuf/io/zero_copy_stream_impl.h>
#include "paddle/fluid/train/custom_trainer/feed/common/pslib_warpper.h"
#include "paddle/fluid/train/custom_trainer/feed/common/runtime_environment.h"
namespace paddle {
namespace custom_trainer {
namespace feed {
int PSlib::initialize(const std::string& conf_path,
RuntimeEnvironment* environment, EnvironmentRole role) {
init_gflag();
int file_descriptor = open(conf_path.c_str(), O_RDONLY);
if (file_descriptor == -1){
LOG(ERROR) << "FATAL: cant open " << conf_path;
return -1;
}
google::protobuf::io::FileInputStream fileInput(file_descriptor);
if (!google::protobuf::TextFormat::Parse(&fileInput, &_ps_param)) {
LOG(ERROR) << "FATAL: fail to parse " << conf_path;
return -1;
}
close(file_descriptor);
init_server(role);
init_client(EnvironmentRole::ALL);
return 0;
}
int PSlib::init_server(EnvironmentRole role) {
if (role == EnvironmentRole::PSERVER) {
_server_ptr.reset(paddle::ps::PSServerFactory::create(_ps_param));
_server_ptr->configure(_ps_param, *(_environment->ps_environment()),
_environment->rank_id(role));
_server_ptr->start();
}
_environment->ps_environment()->gather_ps_servers();
return 0;
}
int PSlib::init_client(EnvironmentRole role) {
_client_ptr.reset(paddle::ps::PSClientFactory::create(_ps_param));
_client_ptr->configure(_ps_param, *(_environment->ps_environment()),
_environment->rank_id(role));
return 0;
}
paddle::ps::PSServer* PSlib::ps_server() {
return _server_ptr.get();
}
paddle::ps::PSClient* PSlib::ps_client() {
return _client_ptr.get();
}
paddle::PSParameter* PSlib::get_param() {
return &_ps_param;
}
void PSlib::init_gflag() {
int cnt = 4;
std::shared_ptr<char*> params(new char*[cnt]);
char** params_ptr = params.get();
char p0[] = "exe default";
char p1[] = "-max_body_size=314217728";
char p2[] = "-bthread_concurrency=40";
char p3[] = "-socket_max_unwritten_bytes=2048000000";
params_ptr[0] = p0;
params_ptr[1] = p1;
params_ptr[2] = p2;
params_ptr[3] = p3;
::google::ParseCommandLineFlags(&cnt, &params_ptr, true);
}
} // namespace feed
} // namespace custom_trainer
} // namespace paddle
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "communicate/ps_server.h"
#include "communicate/ps_client.h"
namespace paddle {
namespace custom_trainer {
namespace feed {
class RuntimeEnvironment;
enum class EnvironmentRole;
class PSlib {
public:
PSlib() {}
virtual ~PSlib() {}
int initialize(const std::string& conf_path,
RuntimeEnvironment* environment, EnvironmentRole role);
virtual paddle::ps::PSServer* ps_server();
virtual paddle::ps::PSClient* ps_client();
virtual paddle::PSParameter* get_param();
private:
void init_gflag();
virtual int init_server(EnvironmentRole role);
virtual int init_client(EnvironmentRole role);
paddle::PSParameter _ps_param;
RuntimeEnvironment* _environment;
std::shared_ptr<paddle::ps::PSServer> _server_ptr;
std::shared_ptr<paddle::ps::PSClient> _client_ptr;
};
} // namespace feed
} // namespace custom_trainer
} // namespace paddle
...@@ -3,12 +3,12 @@ namespace paddle { ...@@ -3,12 +3,12 @@ namespace paddle {
namespace custom_trainer { namespace custom_trainer {
namespace feed { namespace feed {
BaseClassMap& global_factory_map() { BaseClassMap& global_reg_factory_map() {
static BaseClassMap *base_class = new BaseClassMap(); static BaseClassMap *base_class = new BaseClassMap();
return *base_class; return *base_class;
} }
BaseClassMap& global_factory_map_cpp() { BaseClassMap& global_reg_factory_map_cpp() {
return global_factory_map(); return global_reg_factory_map();
} }
}// feed }// feed
......
...@@ -63,23 +63,23 @@ typedef std::map<std::string, FactoryMap> BaseClassMap; ...@@ -63,23 +63,23 @@ typedef std::map<std::string, FactoryMap> BaseClassMap;
#ifdef __cplusplus #ifdef __cplusplus
extern "C" { extern "C" {
#endif #endif
BaseClassMap& global_factory_map(); BaseClassMap& global_reg_factory_map();
#ifdef __cplusplus #ifdef __cplusplus
} }
#endif #endif
BaseClassMap& global_factory_map_cpp(); BaseClassMap& global_reg_factory_map_cpp();
#define REGISTER_REGISTERER(base_class) \ #define REGIST_REGISTERER(base_class) \
class base_class ## Registerer { \ class base_class ## Registerer { \
public: \ public: \
static base_class *CreateInstanceByName(const ::std::string &name) { \ static base_class *CreateInstanceByName(const ::std::string &name) { \
if (global_factory_map_cpp().find(#base_class) \ if (global_reg_factory_map_cpp().find(#base_class) \
== global_factory_map_cpp().end()) { \ == global_reg_factory_map_cpp().end()) { \
LOG(ERROR) << "Can't Find BaseClass For CreateClass with:" << #base_class; \ LOG(ERROR) << "Can't Find BaseClass For CreateClass with:" << #base_class; \
return NULL; \ return NULL; \
} \ } \
FactoryMap &map = global_factory_map_cpp()[#base_class]; \ FactoryMap &map = global_reg_factory_map_cpp()[#base_class]; \
FactoryMap::iterator iter = map.find(name); \ FactoryMap::iterator iter = map.find(name); \
if (iter == map.end()) { \ if (iter == map.end()) { \
LOG(ERROR) << "Can't Find Class For Create with:" << name; \ LOG(ERROR) << "Can't Find Class For Create with:" << name; \
...@@ -90,7 +90,7 @@ BaseClassMap& global_factory_map_cpp(); ...@@ -90,7 +90,7 @@ BaseClassMap& global_factory_map_cpp();
} \ } \
}; };
#define REGISTER_CLASS(clazz, name) \ #define REGIST_CLASS(clazz, name) \
class ObjectFactory##name : public ObjectFactory { \ class ObjectFactory##name : public ObjectFactory { \
public: \ public: \
Any NewInstance() { \ Any NewInstance() { \
...@@ -98,14 +98,14 @@ BaseClassMap& global_factory_map_cpp(); ...@@ -98,14 +98,14 @@ BaseClassMap& global_factory_map_cpp();
} \ } \
}; \ }; \
void register_factory_##name() { \ void register_factory_##name() { \
FactoryMap &map = global_factory_map_cpp()[#clazz]; \ FactoryMap &map = global_reg_factory_map_cpp()[#clazz]; \
if (map.find(#name) == map.end()) { \ if (map.find(#name) == map.end()) { \
map[#name] = new ObjectFactory##name(); \ map[#name] = new ObjectFactory##name(); \
} \ } \
} \ } \
void register_factory_##name() __attribute__((constructor)); void register_factory_##name() __attribute__((constructor));
#define CREATE_CLASS(base_class, name) \ #define CREATE_INSTANCE(base_class, name) \
base_class##Registerer::CreateInstanceByName(name) base_class##Registerer::CreateInstanceByName(name)
}//namespace feed }//namespace feed
......
...@@ -44,6 +44,11 @@ public: ...@@ -44,6 +44,11 @@ public:
set_role(EnvironmentRole::ALL); set_role(EnvironmentRole::ALL);
return 0; return 0;
} }
virtual paddle::ps::PSEnvironment* ps_environment() {
static paddle::ps::MpiPSEnvironment ps_environment;
return &ps_environment;
}
virtual uint32_t rank_id(EnvironmentRole role) { virtual uint32_t rank_id(EnvironmentRole role) {
return mpi_node_info(role).rank_id; return mpi_node_info(role).rank_id;
...@@ -95,7 +100,7 @@ protected: ...@@ -95,7 +100,7 @@ protected:
private: private:
std::vector<MpiNodeInfo> _roles_node_info; std::vector<MpiNodeInfo> _roles_node_info;
}; };
REGISTER_CLASS(RuntimeEnvironment, MPIRuntimeEnvironment); REGIST_CLASS(RuntimeEnvironment, MPIRuntimeEnvironment);
//用于本地模式单机训练 //用于本地模式单机训练
class LocalRuntimeEnvironment : public RuntimeEnvironment { class LocalRuntimeEnvironment : public RuntimeEnvironment {
...@@ -108,6 +113,10 @@ public: ...@@ -108,6 +113,10 @@ public:
virtual int wireup() { virtual int wireup() {
return 0; return 0;
} }
virtual paddle::ps::PSEnvironment* ps_environment() {
static paddle::ps::LocalPSEnvironment ps_environment;
return &ps_environment;
}
virtual uint32_t rank_id(EnvironmentRole role) { virtual uint32_t rank_id(EnvironmentRole role) {
return 0; return 0;
} }
...@@ -129,7 +138,7 @@ protected: ...@@ -129,7 +138,7 @@ protected:
VLOG(static_cast<int>(level)) << log_str; VLOG(static_cast<int>(level)) << log_str;
} }
}; };
REGISTER_CLASS(RuntimeEnvironment, LocalRuntimeEnvironment); REGIST_CLASS(RuntimeEnvironment, LocalRuntimeEnvironment);
} // namespace feed } // namespace feed
} // namespace custom_trainer } // namespace custom_trainer
......
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
*/ */
#pragma once #pragma once
#include <yaml-cpp/yaml.h> #include <yaml-cpp/yaml.h>
#include "communicate/ps_env.h"
#include "paddle/fluid/framework/archive.h" #include "paddle/fluid/framework/archive.h"
#include "paddle/fluid/string/string_helper.h" #include "paddle/fluid/string/string_helper.h"
#include "paddle/fluid/train/custom_trainer/feed/common/registerer.h" #include "paddle/fluid/train/custom_trainer/feed/common/registerer.h"
...@@ -14,6 +15,8 @@ namespace paddle { ...@@ -14,6 +15,8 @@ namespace paddle {
namespace custom_trainer { namespace custom_trainer {
namespace feed { namespace feed {
class paddle::ps::PSEnvironment;
enum class EnvironmentLogLevel { enum class EnvironmentLogLevel {
FATAL = 0, FATAL = 0,
ERROR = 1, ERROR = 1,
...@@ -38,41 +41,43 @@ class RuntimeEnvironment { ...@@ -38,41 +41,43 @@ class RuntimeEnvironment {
public: public:
RuntimeEnvironment(); RuntimeEnvironment();
virtual ~RuntimeEnvironment(); virtual ~RuntimeEnvironment();
//配置初始化 // 配置初始化
virtual int initialize(YAML::Node config) = 0; virtual int initialize(YAML::Node config) = 0;
//设置role // 设置role
virtual int set_role(EnvironmentRole role) = 0; virtual int set_role(EnvironmentRole role) = 0;
//环境初始化,会在所有依赖模块initialize后调用 // 环境初始化,会在所有依赖模块initialize后调用
virtual int wireup() = 0; virtual int wireup() = 0;
//多线程可调用接口 Start // 多线程可调用接口 Start
//当前环境rank_idx // 当前环境rank_idx
virtual uint32_t rank_id(EnvironmentRole role) = 0; virtual uint32_t rank_id(EnvironmentRole role) = 0;
//运行环境节点数 // 运行环境节点数
virtual uint32_t node_num(EnvironmentRole role) = 0; virtual uint32_t node_num(EnvironmentRole role) = 0;
//环境内主节点 // 环境内主节点
virtual bool is_master_node(EnvironmentRole role); virtual bool is_master_node(EnvironmentRole role);
//For PS
virtual paddle::ps::PSEnvironment* ps_environment() = 0;
//环境定制化log // 环境定制化log
template<class... ARGS> template<class... ARGS>
void log(EnvironmentRole role, EnvironmentLogType type, void log(EnvironmentRole role, EnvironmentLogType type,
EnvironmentLogLevel level, const char* fmt, ARGS && ... args) { EnvironmentLogLevel level, const char* fmt, ARGS && ... args) {
print_log(role, type, level, paddle::string::format_string(fmt, args...)); print_log(role, type, level, paddle::string::format_string(fmt, args...));
} }
//多线程可调用接口 End // 多线程可调用接口 End
//接口只允许在主线程调用 Start // 接口只允许在主线程调用 Start
//barrier 指定role的节点 // barrier 指定role的节点
virtual void barrier(EnvironmentRole role) = 0; virtual void barrier(EnvironmentRole role) = 0;
//bcast 广播 // bcast 广播
virtual void bcast(paddle::framework::BinaryArchive& ar, int root_id, EnvironmentRole role) = 0; virtual void bcast(paddle::framework::BinaryArchive& ar, int root_id, EnvironmentRole role) = 0;
//接口只允许在主线程调用 End // 接口只允许在主线程调用 End
protected: protected:
virtual void print_log(EnvironmentRole role, EnvironmentLogType type, virtual void print_log(EnvironmentRole role, EnvironmentLogType type,
EnvironmentLogLevel level, const std::string& log_str) = 0; EnvironmentLogLevel level, const std::string& log_str) = 0;
}; };
REGISTER_REGISTERER(RuntimeEnvironment); REGIST_REGISTERER(RuntimeEnvironment);
std::string format_timestamp(time_t time, const char* format); std::string format_timestamp(time_t time, const char* format);
inline std::string format_timestamp(time_t time, const std::string& format) { inline std::string format_timestamp(time_t time, const std::string& format) {
......
...@@ -56,10 +56,10 @@ public: ...@@ -56,10 +56,10 @@ public:
return 0; return 0;
} }
}; };
REGISTER_CLASS(DataParser, LineDataParser); REGIST_CLASS(DataParser, LineDataParser);
int DataReader::initialize(const YAML::Node& config, std::shared_ptr<TrainerContext> context) { int DataReader::initialize(const YAML::Node& config, std::shared_ptr<TrainerContext> context) {
_parser.reset(CREATE_CLASS(DataParser, config["parser"]["class"].as<std::string>())); _parser.reset(CREATE_INSTANCE(DataParser, config["parser"]["class"].as<std::string>()));
if (_parser == nullptr) { if (_parser == nullptr) {
VLOG(2) << "fail to get parser: " << config["parser"]["class"].as<std::string>(); VLOG(2) << "fail to get parser: " << config["parser"]["class"].as<std::string>();
return -1; return -1;
...@@ -85,7 +85,7 @@ public: ...@@ -85,7 +85,7 @@ public:
if (config["file_system"] && config["file_system"]["class"]) { if (config["file_system"] && config["file_system"]["class"]) {
_file_system.reset( _file_system.reset(
CREATE_CLASS(FileSystem, config["file_system"]["class"].as<std::string>())); CREATE_INSTANCE(FileSystem, config["file_system"]["class"].as<std::string>()));
if (_file_system == nullptr || if (_file_system == nullptr ||
_file_system->initialize(config["file_system"], context) != 0) { _file_system->initialize(config["file_system"], context) != 0) {
VLOG(2) << "fail to create class: " VLOG(2) << "fail to create class: "
...@@ -95,7 +95,7 @@ public: ...@@ -95,7 +95,7 @@ public:
} else if (context->file_system != nullptr) { } else if (context->file_system != nullptr) {
_file_system = context->file_system; _file_system = context->file_system;
} else { } else {
_file_system.reset(CREATE_CLASS(FileSystem, "LocalFileSystem")); _file_system.reset(CREATE_INSTANCE(FileSystem, "LocalFileSystem"));
if (_file_system == nullptr || _file_system->initialize(YAML::Load(""), context) != 0) { if (_file_system == nullptr || _file_system->initialize(YAML::Load(""), context) != 0) {
VLOG(2) << "fail to init file system"; VLOG(2) << "fail to init file system";
return -1; return -1;
...@@ -203,7 +203,7 @@ private: ...@@ -203,7 +203,7 @@ private:
std::string _filename_prefix; std::string _filename_prefix;
std::shared_ptr<FileSystem> _file_system; std::shared_ptr<FileSystem> _file_system;
}; };
REGISTER_CLASS(DataReader, LineDataReader); REGIST_CLASS(DataReader, LineDataReader);
} // namespace feed } // namespace feed
} // namespace custom_trainer } // namespace custom_trainer
......
...@@ -54,7 +54,7 @@ public: ...@@ -54,7 +54,7 @@ public:
virtual int parse(const char* str, DataItem& data) const = 0; virtual int parse(const char* str, DataItem& data) const = 0;
virtual int parse_to_sample(const DataItem& data, SampleInstance& instance) const = 0; virtual int parse_to_sample(const DataItem& data, SampleInstance& instance) const = 0;
}; };
REGISTER_REGISTERER(DataParser); REGIST_REGISTERER(DataParser);
class DataReader { class DataReader {
public: public:
...@@ -76,7 +76,7 @@ protected: ...@@ -76,7 +76,7 @@ protected:
std::shared_ptr<DataParser> _parser;//数据格式转换 std::shared_ptr<DataParser> _parser;//数据格式转换
std::string _pipeline_cmd; //将文件流,重定向到pipeline_cmd,再读入 std::string _pipeline_cmd; //将文件流,重定向到pipeline_cmd,再读入
}; };
REGISTER_REGISTERER(DataReader); REGIST_REGISTERER(DataReader);
}//namespace feed }//namespace feed
}//namespace custom_trainer }//namespace custom_trainer
......
...@@ -32,7 +32,7 @@ int DatasetContainer::initialize( ...@@ -32,7 +32,7 @@ int DatasetContainer::initialize(
_data_split_interval = config["data_spit_interval"].as<int>(); _data_split_interval = config["data_spit_interval"].as<int>();
_data_path_formater = config["data_path_formater"].as<std::string>(); _data_path_formater = config["data_path_formater"].as<std::string>();
std::string data_reader_class = config["data_reader"].as<std::string>(); std::string data_reader_class = config["data_reader"].as<std::string>();
DataReader* data_reader = CREATE_CLASS(DataReader, data_reader_class); DataReader* data_reader = CREATE_INSTANCE(DataReader, data_reader_class);
_data_reader.reset(data_reader); _data_reader.reset(data_reader);
return _data_reader->initialize(config, context); return _data_reader->initialize(config, context);
} }
......
...@@ -121,7 +121,7 @@ protected: ...@@ -121,7 +121,7 @@ protected:
std::unique_ptr<Context> _context; std::unique_ptr<Context> _context;
}; };
REGISTER_CLASS(Executor, SimpleExecutor); REGIST_CLASS(Executor, SimpleExecutor);
} // namespace feed } // namespace feed
} // namespace custom_trainer } // namespace custom_trainer
......
...@@ -40,7 +40,7 @@ public: ...@@ -40,7 +40,7 @@ public:
protected: protected:
::paddle::framework::Scope _scope; ::paddle::framework::Scope _scope;
}; };
REGISTER_REGISTERER(Executor); REGIST_REGISTERER(Executor);
} // namespace feed } // namespace feed
} // namespace custom_trainer } // namespace custom_trainer
......
...@@ -31,7 +31,7 @@ public: ...@@ -31,7 +31,7 @@ public:
_file_system.clear(); _file_system.clear();
if (config && config["file_systems"] && config["file_systems"].Type() == YAML::NodeType::Map) { if (config && config["file_systems"] && config["file_systems"].Type() == YAML::NodeType::Map) {
for (auto& prefix_fs: config["file_systems"]) { for (auto& prefix_fs: config["file_systems"]) {
std::unique_ptr<FileSystem> fs(CREATE_CLASS(FileSystem, prefix_fs.second["class"].as<std::string>(""))); std::unique_ptr<FileSystem> fs(CREATE_INSTANCE(FileSystem, prefix_fs.second["class"].as<std::string>("")));
if (fs == nullptr) { if (fs == nullptr) {
VLOG(2) << "fail to create class: " << prefix_fs.second["class"].as<std::string>(""); VLOG(2) << "fail to create class: " << prefix_fs.second["class"].as<std::string>("");
return -1; return -1;
...@@ -44,7 +44,7 @@ public: ...@@ -44,7 +44,7 @@ public:
} }
} }
if (_file_system.find("default") == _file_system.end()) { if (_file_system.find("default") == _file_system.end()) {
std::unique_ptr<FileSystem> fs(CREATE_CLASS(FileSystem, "LocalFileSystem")); std::unique_ptr<FileSystem> fs(CREATE_INSTANCE(FileSystem, "LocalFileSystem"));
if (fs == nullptr || fs->initialize(YAML::Load(""), context) != 0) { if (fs == nullptr || fs->initialize(YAML::Load(""), context) != 0) {
return -1; return -1;
} }
...@@ -122,7 +122,7 @@ public: ...@@ -122,7 +122,7 @@ public:
private: private:
std::unordered_map<std::string, std::unique_ptr<FileSystem>> _file_system; std::unordered_map<std::string, std::unique_ptr<FileSystem>> _file_system;
}; };
REGISTER_CLASS(FileSystem, AutoFileSystem); REGIST_CLASS(FileSystem, AutoFileSystem);
} // namespace feed } // namespace feed
} // namespace custom_trainer } // namespace custom_trainer
......
...@@ -52,7 +52,7 @@ public: ...@@ -52,7 +52,7 @@ public:
protected: protected:
int _err_no = 0; int _err_no = 0;
}; };
REGISTER_REGISTERER(FileSystem); REGIST_REGISTERER(FileSystem);
} // namespace feed } // namespace feed
} // namespace custom_trainer } // namespace custom_trainer
......
...@@ -203,7 +203,7 @@ private: ...@@ -203,7 +203,7 @@ private:
std::string _hdfs_command; std::string _hdfs_command;
std::unordered_map<std::string, std::string> _ugi; std::unordered_map<std::string, std::string> _ugi;
}; };
REGISTER_CLASS(FileSystem, HadoopFileSystem); REGIST_CLASS(FileSystem, HadoopFileSystem);
} // namespace feed } // namespace feed
} // namespace custom_trainer } // namespace custom_trainer
......
...@@ -129,7 +129,7 @@ public: ...@@ -129,7 +129,7 @@ public:
private: private:
size_t _buffer_size = 0; size_t _buffer_size = 0;
}; };
REGISTER_CLASS(FileSystem, LocalFileSystem); REGIST_CLASS(FileSystem, LocalFileSystem);
} // namespace feed } // namespace feed
} // namespace custom_trainer } // namespace custom_trainer
......
#include <time.h> #include <time.h>
#include <fstream> #include <fstream>
#include <yaml-cpp/yaml.h> #include <yaml-cpp/yaml.h>
#include "paddle/fluid/platform/init.h"
#include "paddle/fluid/train/custom_trainer/feed/trainer_context.h" #include "paddle/fluid/train/custom_trainer/feed/trainer_context.h"
#include "paddle/fluid/platform/init.h"
#include "paddle/fluid/train/custom_trainer/feed/process/process.h" #include "paddle/fluid/train/custom_trainer/feed/process/process.h"
#include "paddle/fluid/train/custom_trainer/feed/process/init_env_process.h" #include "paddle/fluid/train/custom_trainer/feed/process/init_env_process.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
...@@ -21,27 +21,56 @@ int main(int argc, char* argv[]) { ...@@ -21,27 +21,56 @@ int main(int argc, char* argv[]) {
//load trainer config //load trainer config
auto trainer_context_ptr = std::make_shared<TrainerContext>(); auto trainer_context_ptr = std::make_shared<TrainerContext>();
trainer_context_ptr->trainer_config = YAML::LoadFile(FLAGS_feed_trainer_conf_path); trainer_context_ptr->trainer_config = YAML::LoadFile(FLAGS_feed_trainer_conf_path);
//environment
auto& config = trainer_context_ptr->trainer_config;
std::string env_class = config["environment"]["environment_class"].as<std::string>();
trainer_context_ptr->environment.reset(CREATE_INSTANCE(RuntimeEnvironment, env_class));
if (trainer_context_ptr->environment->initialize(config["environment"]) != 0) {
return -1;
}
EnvironmentRole role;
auto* environment = trainer_context_ptr->environment.get();
environment->wireup();
if (environment->rank_id(EnvironmentRole::ALL) % 2 == 0) {
role = EnvironmentRole::WORKER;
} else {
role = EnvironmentRole::PSERVER;
}
environment->set_role(role);
trainer_context_ptr->pslib.reset(new PSlib());
std::string ps_config = config["environment"]["ps"].as<std::string>();
trainer_context_ptr->pslib->initialize(ps_config, environment, role);
//VLOG(3) << "Node Start With Role:" << role;
std::vector<std::string> process_name_list = { std::vector<std::string> process_name_list = {
"InitEnvProcess", "InitEnvProcess",
"LearnerProcess" "LearnerProcess"
}; };
switch (role) {
for (const auto& process_name : process_name_list) { case EnvironmentRole::WORKER:
Process* process = CREATE_CLASS(Process, process_name); for (const auto& process_name : process_name_list) {
if (process == NULL) { Process* process = CREATE_INSTANCE(Process, process_name);
VLOG(1) << "Process:" << process_name << " does not exist"; if (process == NULL) {
return -1; VLOG(1) << "Process:" << process_name << " does not exist";
return -1;
}
if (process->initialize(trainer_context_ptr) != 0) {
VLOG(1) << "Process:" << process_name << " initialize failed";
return -1;
}
trainer_context_ptr->process_list.push_back(std::shared_ptr<Process>(process));
}
for (auto& process : trainer_context_ptr->process_list) {
process->run();
} }
if (process->initialize(trainer_context_ptr) != 0) { break;
VLOG(1) << "Process:" << process_name << " initialize failed"; case EnvironmentRole::PSERVER:
return -1; //wait server done
while (true) {
sleep(10000);
} }
trainer_context_ptr->process_list.push_back(std::shared_ptr<Process>(process)); break;
}
for (auto& process : trainer_context_ptr->process_list) {
process->run();
} }
return 0; return 0;
......
...@@ -39,7 +39,7 @@ protected: ...@@ -39,7 +39,7 @@ protected:
std::string _name; std::string _name;
}; };
REGISTER_REGISTERER(Monitor); REGIST_REGISTERER(Monitor);
} // namespace feed } // namespace feed
} // namespace custom_trainer } // namespace custom_trainer
......
...@@ -20,22 +20,16 @@ int InitEnvProcess::initialize(std::shared_ptr<TrainerContext> context_ptr) { ...@@ -20,22 +20,16 @@ int InitEnvProcess::initialize(std::shared_ptr<TrainerContext> context_ptr) {
context_ptr->cpu_place = paddle::platform::CPUPlace(); context_ptr->cpu_place = paddle::platform::CPUPlace();
YAML::Node config = _context_ptr->trainer_config; YAML::Node config = _context_ptr->trainer_config;
//environment
std::string env_class = config["environment"]["environment_class"].as<std::string>();
context_ptr->environment.reset(CREATE_CLASS(RuntimeEnvironment, env_class));
if (context_ptr->environment->initialize(config["environment"]) != 0) {
return -1;
}
//file_system //file_system
context_ptr->file_system.reset(CREATE_CLASS(FileSystem, "AutoFileSystem")); context_ptr->file_system.reset(CREATE_INSTANCE(FileSystem, "AutoFileSystem"));
if (context_ptr->file_system->initialize(config["io"], context_ptr) != 0) { if (context_ptr->file_system->initialize(config["io"], context_ptr) != 0) {
return -1; return -1;
} }
//epoch //epoch
std::string epoch_class = config["epoch"]["epoch_class"].as<std::string>(); std::string epoch_class = config["epoch"]["epoch_class"].as<std::string>();
context_ptr->epoch_accessor.reset(CREATE_CLASS(EpochAccessor, epoch_class)); context_ptr->epoch_accessor.reset(CREATE_INSTANCE(EpochAccessor, epoch_class));
if (context_ptr->epoch_accessor->initialize(config["epoch"], context_ptr) != 0) { if (context_ptr->epoch_accessor->initialize(config["epoch"], context_ptr) != 0) {
return -1; return -1;
} }
...@@ -55,10 +49,12 @@ int InitEnvProcess::run() { ...@@ -55,10 +49,12 @@ int InitEnvProcess::run() {
VLOG(3) << "Trainer Resume From epoch:" << epoch_accessor->current_epoch_id(); VLOG(3) << "Trainer Resume From epoch:" << epoch_accessor->current_epoch_id();
auto next_epoch_id = epoch_accessor->next_epoch_id(epoch_accessor->current_epoch_id()); auto next_epoch_id = epoch_accessor->next_epoch_id(epoch_accessor->current_epoch_id());
_context_ptr->dataset->pre_detect_data(next_epoch_id); _context_ptr->dataset->pre_detect_data(next_epoch_id);
//step 1. psserver init
//step2. psserver load if (epoch_accessor->checkpoint_path().size() > 0) {
VLOG(3) << "Psserver Start Success"; //Load Model
} else {
//Random Init Model
}
//context_ptr->pslib_client()->load_model(); //context_ptr->pslib_client()->load_model();
VLOG(3) << "Psserver Load Model Success"; VLOG(3) << "Psserver Load Model Success";
return 0; return 0;
......
...@@ -25,7 +25,7 @@ int LearnerProcess::initialize(std::shared_ptr<TrainerContext> context_ptr) { ...@@ -25,7 +25,7 @@ int LearnerProcess::initialize(std::shared_ptr<TrainerContext> context_ptr) {
_threads_executor[i].resize(_executor_num); _threads_executor[i].resize(_executor_num);
for (int e = 0; e < _executor_num; ++e) { for (int e = 0; e < _executor_num; ++e) {
auto e_class = config["executor"][e]["class"].as<std::string>(); auto e_class = config["executor"][e]["class"].as<std::string>();
auto* e_ptr = CREATE_CLASS(Executor, e_class); auto* e_ptr = CREATE_INSTANCE(Executor, e_class);
_threads_executor[i][e].reset(e_ptr); _threads_executor[i][e].reset(e_ptr);
if (e_ptr->initialize(config["executor"][e], context_ptr) != 0) { if (e_ptr->initialize(config["executor"][e], context_ptr) != 0) {
ret = -1; ret = -1;
...@@ -84,53 +84,59 @@ int LearnerProcess::run() { ...@@ -84,53 +84,59 @@ int LearnerProcess::run() {
while (true) { while (true) {
epoch_accessor->next_epoch(); epoch_accessor->next_epoch();
bool already_dump_inference_model = false;
epoch_id = epoch_accessor->current_epoch_id(); epoch_id = epoch_accessor->current_epoch_id();
std::string epoch_log_title= paddle::string::format_string( std::string epoch_log_title= paddle::string::format_string(
"train epoch_id:%d label:%s", epoch_id, epoch_accessor->text(epoch_id).c_str()); "train epoch_id:%d label:%s", epoch_id, epoch_accessor->text(epoch_id).c_str());
//Step1. 等待样本ready //Step1. 等待样本ready
environment->log(EnvironmentRole::WORKER, EnvironmentLogType::MASTER_LOG, EnvironmentLogLevel::NOTICE, {
"Start %s, wait data ready", epoch_log_title.c_str());
while (dataset->epoch_data_status(epoch_id) != DatasetStatus::Ready) {
sleep(30);
dataset->pre_detect_data(epoch_id);
environment->log(EnvironmentRole::WORKER, EnvironmentLogType::MASTER_LOG, EnvironmentLogLevel::NOTICE, environment->log(EnvironmentRole::WORKER, EnvironmentLogType::MASTER_LOG, EnvironmentLogLevel::NOTICE,
"%s, data not ready, wait 30s", epoch_log_title.c_str()); "Start %s, wait data ready", epoch_log_title.c_str());
} while (dataset->epoch_data_status(epoch_id) != DatasetStatus::Ready) {
environment->log(EnvironmentRole::WORKER, EnvironmentLogType::MASTER_LOG, EnvironmentLogLevel::NOTICE, sleep(30);
"%s, data is ready, start traning", epoch_log_title.c_str()); dataset->pre_detect_data(epoch_id);
environment->barrier(EnvironmentRole::WORKER); environment->log(EnvironmentRole::WORKER, EnvironmentLogType::MASTER_LOG, EnvironmentLogLevel::NOTICE,
"%s, data not ready, wait 30s", epoch_log_title.c_str());
}
environment->log(EnvironmentRole::WORKER, EnvironmentLogType::MASTER_LOG, EnvironmentLogLevel::NOTICE,
"%s, data is ready, start traning", epoch_log_title.c_str());
environment->barrier(EnvironmentRole::WORKER);
}
//Step2. 运行训练网络 //Step2. 运行训练网络
bool already_dump_inference_model = false; {
for (int i = 0; i < _executor_num; ++i) { for (int i = 0; i < _executor_num; ++i) {
std::vector<std::shared_ptr<std::thread>> train_threads(_train_thread_num); std::vector<std::shared_ptr<std::thread>> train_threads(_train_thread_num);
for (int thread_id = 0; thread_id < _train_thread_num; ++thread_id) { for (int thread_id = 0; thread_id < _train_thread_num; ++thread_id) {
train_threads[i].reset(new std::thread([this](int exe_idx, int thread_idx) { train_threads[i].reset(new std::thread([this](int exe_idx, int thread_idx) {
auto* executor = _threads_executor[thread_idx][exe_idx].get(); auto* executor = _threads_executor[thread_idx][exe_idx].get();
run_executor(executor); run_executor(executor);
}, i, thread_id)); }, i, thread_id));
} }
for (int i = 0; i < _train_thread_num; ++i) { for (int i = 0; i < _train_thread_num; ++i) {
train_threads[i]->join(); train_threads[i]->join();
}
environment->barrier(EnvironmentRole::WORKER);
if (_threads_executor[0][i]->is_dump_all_model()) {
already_dump_inference_model = true;
wait_save_model(epoch_id, ModelSaveWay::ModelSaveInferenceDelta);
}
environment->barrier(EnvironmentRole::WORKER);
} }
environment->barrier(EnvironmentRole::WORKER); }
if (_threads_executor[0][i]->is_dump_all_model()) { //Step3. Dump Model For Delta&&Checkpoint
{
if (!already_dump_inference_model) {
already_dump_inference_model = true; already_dump_inference_model = true;
wait_save_model(epoch_id, ModelSaveWay::ModelSaveInferenceDelta); wait_save_model(epoch_id, ModelSaveWay::ModelSaveInferenceDelta);
} }
wait_save_model(epoch_id, ModelSaveWay::ModelSaveTrainCheckpoint);
environment->barrier(EnvironmentRole::WORKER); environment->barrier(EnvironmentRole::WORKER);
} }
//Step3. Dump Model For Delta&&Checkpoint
if (!already_dump_inference_model) {
already_dump_inference_model = true;
wait_save_model(epoch_id, ModelSaveWay::ModelSaveInferenceDelta);
}
wait_save_model(epoch_id, ModelSaveWay::ModelSaveTrainCheckpoint);
environment->barrier(EnvironmentRole::WORKER);
//Step4. Output Monitor && RunStatus //Step4. Output Monitor && RunStatus
//TODO //TODO
} }
......
...@@ -5,8 +5,8 @@ ...@@ -5,8 +5,8 @@
namespace paddle { namespace paddle {
namespace custom_trainer { namespace custom_trainer {
namespace feed { namespace feed {
REGISTER_CLASS(Process, InitEnvProcess); REGIST_CLASS(Process, InitEnvProcess);
REGISTER_CLASS(Process, LearnerProcess); REGIST_CLASS(Process, LearnerProcess);
int Process::run() { int Process::run() {
return 0; return 0;
} }
......
...@@ -18,7 +18,7 @@ public: ...@@ -18,7 +18,7 @@ public:
protected: protected:
TrainerContext* _context_ptr = NULL; TrainerContext* _context_ptr = NULL;
}; };
REGISTER_REGISTERER(Process); REGIST_REGISTERER(Process);
} // namespace feed } // namespace feed
} // namespace custom_trainer } // namespace custom_trainer
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
#include <vector> #include <vector>
#include <yaml-cpp/yaml.h> #include <yaml-cpp/yaml.h>
#include "paddle/fluid/platform/place.h" #include "paddle/fluid/platform/place.h"
#include "paddle/fluid/train/custom_trainer/feed/common/pslib_warpper.h"
#include "paddle/fluid/train/custom_trainer/feed/common/runtime_environment.h" #include "paddle/fluid/train/custom_trainer/feed/common/runtime_environment.h"
...@@ -38,6 +39,7 @@ public: ...@@ -38,6 +39,7 @@ public:
YAML::Node trainer_config; YAML::Node trainer_config;
paddle::platform::CPUPlace cpu_place; paddle::platform::CPUPlace cpu_place;
std::shared_ptr<PSlib> pslib;
std::shared_ptr<Dataset> dataset; //训练样本 std::shared_ptr<Dataset> dataset; //训练样本
std::shared_ptr<FileSystem> file_system; //文件操作辅助类 std::shared_ptr<FileSystem> file_system; //文件操作辅助类
std::vector<TableMeta> params_table_list; //参数表 std::vector<TableMeta> params_table_list; //参数表
......
...@@ -17,6 +17,7 @@ limitations under the License. */ ...@@ -17,6 +17,7 @@ limitations under the License. */
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <omp.h> #include <omp.h>
#include "paddle/fluid/train/custom_trainer/feed/trainer_context.h"
#include "paddle/fluid/train/custom_trainer/feed/executor/executor.h" #include "paddle/fluid/train/custom_trainer/feed/executor/executor.h"
#include "paddle/fluid/framework/tensor_util.h" #include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/framework/program_desc.h"
...@@ -36,7 +37,7 @@ const char test_data_dir[] = "test_data"; ...@@ -36,7 +37,7 @@ const char test_data_dir[] = "test_data";
class DataReaderTest : public testing::Test { class DataReaderTest : public testing::Test {
public: public:
static void SetUpTestCase() { static void SetUpTestCase() {
std::unique_ptr<FileSystem> fs(CREATE_CLASS(FileSystem, "LocalFileSystem")); std::unique_ptr<FileSystem> fs(CREATE_INSTANCE(FileSystem, "LocalFileSystem"));
fs->mkdir(test_data_dir); fs->mkdir(test_data_dir);
shell_set_verbose(true); shell_set_verbose(true);
...@@ -56,14 +57,14 @@ public: ...@@ -56,14 +57,14 @@ public:
} }
static void TearDownTestCase() { static void TearDownTestCase() {
std::unique_ptr<FileSystem> fs(CREATE_CLASS(FileSystem, "LocalFileSystem")); std::unique_ptr<FileSystem> fs(CREATE_INSTANCE(FileSystem, "LocalFileSystem"));
fs->remove(test_data_dir); fs->remove(test_data_dir);
} }
virtual void SetUp() { virtual void SetUp() {
thread_num = omp_get_max_threads(); thread_num = omp_get_max_threads();
omp_set_num_threads(1); omp_set_num_threads(1);
fs.reset(CREATE_CLASS(FileSystem, "LocalFileSystem")); fs.reset(CREATE_INSTANCE(FileSystem, "LocalFileSystem"));
context_ptr.reset(new TrainerContext()); context_ptr.reset(new TrainerContext());
} }
...@@ -79,7 +80,7 @@ public: ...@@ -79,7 +80,7 @@ public:
}; };
TEST_F(DataReaderTest, LineDataParser) { TEST_F(DataReaderTest, LineDataParser) {
std::unique_ptr<DataParser> data_parser(CREATE_CLASS(DataParser, "LineDataParser")); std::unique_ptr<DataParser> data_parser(CREATE_INSTANCE(DataParser, "LineDataParser"));
ASSERT_NE(nullptr, data_parser); ASSERT_NE(nullptr, data_parser);
auto config = YAML::Load(""); auto config = YAML::Load("");
...@@ -108,7 +109,7 @@ TEST_F(DataReaderTest, LineDataParser) { ...@@ -108,7 +109,7 @@ TEST_F(DataReaderTest, LineDataParser) {
} }
TEST_F(DataReaderTest, LineDataReader) { TEST_F(DataReaderTest, LineDataReader) {
std::unique_ptr<DataReader> data_reader(CREATE_CLASS(DataReader, "LineDataReader")); std::unique_ptr<DataReader> data_reader(CREATE_INSTANCE(DataReader, "LineDataReader"));
ASSERT_NE(nullptr, data_reader); ASSERT_NE(nullptr, data_reader);
auto config = YAML::Load( auto config = YAML::Load(
...@@ -161,7 +162,7 @@ TEST_F(DataReaderTest, LineDataReader) { ...@@ -161,7 +162,7 @@ TEST_F(DataReaderTest, LineDataReader) {
} }
TEST_F(DataReaderTest, LineDataReader_filename_prefix) { TEST_F(DataReaderTest, LineDataReader_filename_prefix) {
std::unique_ptr<DataReader> data_reader(CREATE_CLASS(DataReader, "LineDataReader")); std::unique_ptr<DataReader> data_reader(CREATE_INSTANCE(DataReader, "LineDataReader"));
ASSERT_NE(nullptr, data_reader); ASSERT_NE(nullptr, data_reader);
auto config = YAML::Load( auto config = YAML::Load(
"parser:\n" "parser:\n"
...@@ -196,7 +197,7 @@ TEST_F(DataReaderTest, LineDataReader_filename_prefix) { ...@@ -196,7 +197,7 @@ TEST_F(DataReaderTest, LineDataReader_filename_prefix) {
} }
TEST_F(DataReaderTest, LineDataReader_FileSystem) { TEST_F(DataReaderTest, LineDataReader_FileSystem) {
std::unique_ptr<DataReader> data_reader(CREATE_CLASS(DataReader, "LineDataReader")); std::unique_ptr<DataReader> data_reader(CREATE_INSTANCE(DataReader, "LineDataReader"));
ASSERT_NE(nullptr, data_reader); ASSERT_NE(nullptr, data_reader);
auto config = YAML::Load( auto config = YAML::Load(
"parser:\n" "parser:\n"
......
...@@ -18,6 +18,7 @@ limitations under the License. */ ...@@ -18,6 +18,7 @@ limitations under the License. */
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <omp.h> #include <omp.h>
#include "paddle/fluid/train/custom_trainer/feed/trainer_context.h"
#include "paddle/fluid/train/custom_trainer/feed/executor/executor.h" #include "paddle/fluid/train/custom_trainer/feed/executor/executor.h"
#include "paddle/fluid/framework/tensor_util.h" #include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/framework/program_desc.h"
...@@ -37,7 +38,7 @@ const char test_data_dir[] = "test_data"; ...@@ -37,7 +38,7 @@ const char test_data_dir[] = "test_data";
class DataReaderOmpTest : public testing::Test { class DataReaderOmpTest : public testing::Test {
public: public:
static void SetUpTestCase() { static void SetUpTestCase() {
std::unique_ptr<FileSystem> fs(CREATE_CLASS(FileSystem, "LocalFileSystem")); std::unique_ptr<FileSystem> fs(CREATE_INSTANCE(FileSystem, "LocalFileSystem"));
fs->mkdir(test_data_dir); fs->mkdir(test_data_dir);
shell_set_verbose(true); shell_set_verbose(true);
std_items.clear(); std_items.clear();
...@@ -61,14 +62,14 @@ public: ...@@ -61,14 +62,14 @@ public:
} }
static void TearDownTestCase() { static void TearDownTestCase() {
std::unique_ptr<FileSystem> fs(CREATE_CLASS(FileSystem, "LocalFileSystem")); std::unique_ptr<FileSystem> fs(CREATE_INSTANCE(FileSystem, "LocalFileSystem"));
fs->remove(test_data_dir); fs->remove(test_data_dir);
} }
virtual void SetUp() { virtual void SetUp() {
thread_num = omp_get_max_threads(); thread_num = omp_get_max_threads();
omp_set_num_threads(1); omp_set_num_threads(1);
fs.reset(CREATE_CLASS(FileSystem, "LocalFileSystem")); fs.reset(CREATE_INSTANCE(FileSystem, "LocalFileSystem"));
context_ptr.reset(new TrainerContext()); context_ptr.reset(new TrainerContext());
} }
...@@ -111,7 +112,7 @@ std::vector<DataItem> DataReaderOmpTest::std_items; ...@@ -111,7 +112,7 @@ std::vector<DataItem> DataReaderOmpTest::std_items;
std::vector<DataItem> DataReaderOmpTest::sorted_std_items; std::vector<DataItem> DataReaderOmpTest::sorted_std_items;
TEST_F(DataReaderOmpTest, LineDataReaderSingleThread) { TEST_F(DataReaderOmpTest, LineDataReaderSingleThread) {
std::unique_ptr<DataReader> data_reader(CREATE_CLASS(DataReader, "LineDataReader")); std::unique_ptr<DataReader> data_reader(CREATE_INSTANCE(DataReader, "LineDataReader"));
ASSERT_NE(nullptr, data_reader); ASSERT_NE(nullptr, data_reader);
auto config = YAML::Load( auto config = YAML::Load(
...@@ -148,7 +149,7 @@ TEST_F(DataReaderOmpTest, LineDataReaderSingleThread) { ...@@ -148,7 +149,7 @@ TEST_F(DataReaderOmpTest, LineDataReaderSingleThread) {
} }
TEST_F(DataReaderOmpTest, LineDataReaderMuiltThread) { TEST_F(DataReaderOmpTest, LineDataReaderMuiltThread) {
std::unique_ptr<DataReader> data_reader(CREATE_CLASS(DataReader, "LineDataReader")); std::unique_ptr<DataReader> data_reader(CREATE_INSTANCE(DataReader, "LineDataReader"));
ASSERT_NE(nullptr, data_reader); ASSERT_NE(nullptr, data_reader);
auto config = YAML::Load( auto config = YAML::Load(
......
...@@ -16,6 +16,7 @@ limitations under the License. */ ...@@ -16,6 +16,7 @@ limitations under the License. */
#include <fstream> #include <fstream>
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include "paddle/fluid/train/custom_trainer/feed/trainer_context.h"
#include "paddle/fluid/train/custom_trainer/feed/executor/executor.h" #include "paddle/fluid/train/custom_trainer/feed/executor/executor.h"
#include "paddle/fluid/framework/tensor_util.h" #include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/framework/program_desc.h"
...@@ -38,7 +39,7 @@ class SimpleExecutorTest : public testing::Test ...@@ -38,7 +39,7 @@ class SimpleExecutorTest : public testing::Test
public: public:
static void SetUpTestCase() static void SetUpTestCase()
{ {
std::unique_ptr<FileSystem> fs(CREATE_CLASS(FileSystem, "LocalFileSystem")); std::unique_ptr<FileSystem> fs(CREATE_INSTANCE(FileSystem, "LocalFileSystem"));
fs->mkdir(test_data_dir); fs->mkdir(test_data_dir);
shell_set_verbose(true); shell_set_verbose(true);
...@@ -70,7 +71,7 @@ public: ...@@ -70,7 +71,7 @@ public:
static void TearDownTestCase() static void TearDownTestCase()
{ {
std::unique_ptr<FileSystem> fs(CREATE_CLASS(FileSystem, "LocalFileSystem")); std::unique_ptr<FileSystem> fs(CREATE_INSTANCE(FileSystem, "LocalFileSystem"));
fs->remove(test_data_dir); fs->remove(test_data_dir);
} }
...@@ -88,7 +89,7 @@ public: ...@@ -88,7 +89,7 @@ public:
}; };
TEST_F(SimpleExecutorTest, initialize) { TEST_F(SimpleExecutorTest, initialize) {
std::unique_ptr<Executor> executor(CREATE_CLASS(Executor, "SimpleExecutor")); std::unique_ptr<Executor> executor(CREATE_INSTANCE(Executor, "SimpleExecutor"));
ASSERT_NE(nullptr, executor); ASSERT_NE(nullptr, executor);
YAML::Node config = YAML::Load("[1, 2, 3]"); YAML::Node config = YAML::Load("[1, 2, 3]");
ASSERT_NE(0, executor->initialize(config, context_ptr)); ASSERT_NE(0, executor->initialize(config, context_ptr));
...@@ -99,7 +100,7 @@ TEST_F(SimpleExecutorTest, initialize) { ...@@ -99,7 +100,7 @@ TEST_F(SimpleExecutorTest, initialize) {
} }
TEST_F(SimpleExecutorTest, run) { TEST_F(SimpleExecutorTest, run) {
std::unique_ptr<Executor> executor(CREATE_CLASS(Executor, "SimpleExecutor")); std::unique_ptr<Executor> executor(CREATE_INSTANCE(Executor, "SimpleExecutor"));
ASSERT_NE(nullptr, executor); ASSERT_NE(nullptr, executor);
auto config = YAML::Load(string::format_string("{thread_num: 2, startup_program: %s, main_program: %s}", startup_program_path, main_program_path)); auto config = YAML::Load(string::format_string("{thread_num: 2, startup_program: %s, main_program: %s}", startup_program_path, main_program_path));
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册