diff --git a/BCLOUD b/BCLOUD index 151cbed3b41dcb3bc78bbf8ef41a6641f9215ad6..e797fdf1d7c69313543cd2d06c0c5766e81e346d 100644 --- a/BCLOUD +++ b/BCLOUD @@ -36,6 +36,7 @@ CONFIGS('baidu/third-party/pybind11@v2.2.4@git_branch') CONFIGS('baidu/third-party/python@gcc482output@git_branch') CONFIGS('baidu/third-party/yaml-cpp@yaml-cpp_0-6-2-0_GEN_PD_BL@git_tag') CONFIGS('baidu/third-party/openmpi@openmpi_1-4-5-0-feed_mlarch@git_branch') +CONFIGS('baidu/paddlepaddle/pslib@master@git_branch') CONFIGS('third-64/gtest@base') HEADERS('paddle/fluid/memory/*.h', '$INC/paddle/fluid/memory/') diff --git a/paddle/fluid/train/custom_trainer/feed/accessor/epoch_accessor.cc b/paddle/fluid/train/custom_trainer/feed/accessor/epoch_accessor.cc index 7e474eefd6bd9ec48d2f0393e64d088d60d5cfc4..36e0b9fd7920ad44d3cf4b3f0e284276339528ce 100644 --- a/paddle/fluid/train/custom_trainer/feed/accessor/epoch_accessor.cc +++ b/paddle/fluid/train/custom_trainer/feed/accessor/epoch_accessor.cc @@ -88,7 +88,7 @@ namespace feed { return ""; } - REGISTER_CLASS(EpochAccessor, HourlyEpochAccessor); + REGIST_CLASS(EpochAccessor, HourlyEpochAccessor); } // namespace feed } // namespace custom_trainer diff --git a/paddle/fluid/train/custom_trainer/feed/accessor/epoch_accessor.h b/paddle/fluid/train/custom_trainer/feed/accessor/epoch_accessor.h index 247a1fe3d1fab24fb5faae7f00def1adb4eda4e5..069abff31d162fbf09f0793530bb2322a6ac3915 100644 --- a/paddle/fluid/train/custom_trainer/feed/accessor/epoch_accessor.h +++ b/paddle/fluid/train/custom_trainer/feed/accessor/epoch_accessor.h @@ -60,7 +60,7 @@ protected: std::vector _done_status; //当前完成状态,统一存成string }; -REGISTER_REGISTERER(EpochAccessor); +REGIST_REGISTERER(EpochAccessor); class HourlyEpochAccessor : public EpochAccessor { public: diff --git a/paddle/fluid/train/custom_trainer/feed/accessor/input_data_accessor.h b/paddle/fluid/train/custom_trainer/feed/accessor/input_data_accessor.h new file mode 100644 index 0000000000000000000000000000000000000000..a91a5f4620e20755089a2fdb53a19bfd7184ef80 --- /dev/null +++ b/paddle/fluid/train/custom_trainer/feed/accessor/input_data_accessor.h @@ -0,0 +1,31 @@ +#pragma once +#include "paddle/fluid/framework/scope.h" +#include "paddle/fluid/train/custom_trainer/feed/dataset/data_reader.h" +#include "paddle/fluid/train/custom_trainer/feed/accessor/accessor.h" + +namespace paddle { +namespace custom_trainer { +namespace feed { + +class DataInputAccessor : public Accessor { +public: + DataInputAccessor() {} + virtual ~DataInputAccessor() {} + + virtual int initialize(const YAML::Node& config, + std::shared_ptr context_ptr); + + // 前向, 一般用于填充输入,在训练网络执行前调用 + virtual int32_t forward(const SampleInstance* samples, + ::paddle::framework::Scope* scope, size_t table_id, size_t num) = 0; + + // 后向,一般用于更新梯度,在训练网络执行后调用 + virtual int32_t backward(const SampleInstance* samples, + ::paddle::framework::Scope* scope, size_t table_id, size_t num) = 0; +protected: +}; +REGIST_REGISTERER(DataInputAccessor); + +} // namespace feed +} // namespace custom_trainer +} // namespace paddle diff --git a/paddle/fluid/train/custom_trainer/feed/accessor/sparse_input_accessor.cc b/paddle/fluid/train/custom_trainer/feed/accessor/sparse_input_accessor.cc new file mode 100644 index 0000000000000000000000000000000000000000..63b896b8e3977ef15695fcddb8b5ee697db24f89 --- /dev/null +++ b/paddle/fluid/train/custom_trainer/feed/accessor/sparse_input_accessor.cc @@ -0,0 +1,51 @@ +#include +#include +#include "paddle/fluid/string/string_helper.h" +#include "paddle/fluid/train/custom_trainer/feed/accessor/input_data_accessor.h" + +namespace paddle { +namespace custom_trainer { +namespace feed { + +class CommonSparseInputAccessor : public DataInputAccessor { +public: + CommonSparseInputAccessor() {} + virtual ~CommonSparseInputAccessor() {} + virtual int initialize(const YAML::Node& config, + std::shared_ptr context_ptr) { + CHECK(config["sparse_input"] && config["sparse_input"].Type() == YAML::NodeType::Map); + for (auto& input : config["sparse_input"]) { + std::pair> sparse_slots; + sparse_slots.first = input.first.as(); + std::string slots_str = input.second["slots"].as(); + std::vector slots = paddle::string::split_string(slots_str, ","); + for (int i = 0; i < slots.size(); ++i) { + sparse_slots.second.push_back((uint16_t)atoi(slots[i].c_str())); + } + } + return 0; + } + + // 取sparse数据 + virtual int32_t forward(const SampleInstance* samples, + ::paddle::framework::Scope* scope, size_t table_id, size_t num) { + // pull + return 0; + } + + // 更新spare数据 + virtual int32_t backward(const SampleInstance* samples, + ::paddle::framework::Scope* scope, size_t table_id, size_t num) { + return 0; + } + +protected: + + // 输入层列表 + // + std::vector > > _x_variables; +}; + +} // namespace feed +} // namespace custom_trainer +} // namespace paddle diff --git a/paddle/fluid/train/custom_trainer/feed/common/pslib_warpper.cc b/paddle/fluid/train/custom_trainer/feed/common/pslib_warpper.cc new file mode 100644 index 0000000000000000000000000000000000000000..7f8ee5e3aeabb6acf78a0d62b4702ff2d5d6d710 --- /dev/null +++ b/paddle/fluid/train/custom_trainer/feed/common/pslib_warpper.cc @@ -0,0 +1,80 @@ +#include +#include +#include +#include "json2pb/json_to_pb.h" +#include +#include +#include "paddle/fluid/train/custom_trainer/feed/common/pslib_warpper.h" +#include "paddle/fluid/train/custom_trainer/feed/common/runtime_environment.h" + +namespace paddle { +namespace custom_trainer { +namespace feed { + +int PSlib::initialize(const std::string& conf_path, + RuntimeEnvironment* environment, EnvironmentRole role) { + init_gflag(); + int file_descriptor = open(conf_path.c_str(), O_RDONLY); + if (file_descriptor == -1){ + LOG(ERROR) << "FATAL: cant open " << conf_path; + return -1; + } + google::protobuf::io::FileInputStream fileInput(file_descriptor); + if (!google::protobuf::TextFormat::Parse(&fileInput, &_ps_param)) { + LOG(ERROR) << "FATAL: fail to parse " << conf_path; + return -1; + } + close(file_descriptor); + init_server(role); + init_client(EnvironmentRole::ALL); + return 0; +} + +int PSlib::init_server(EnvironmentRole role) { + if (role == EnvironmentRole::PSERVER) { + _server_ptr.reset(paddle::ps::PSServerFactory::create(_ps_param)); + _server_ptr->configure(_ps_param, *(_environment->ps_environment()), + _environment->rank_id(role)); + _server_ptr->start(); + } + _environment->ps_environment()->gather_ps_servers(); + return 0; +} + +int PSlib::init_client(EnvironmentRole role) { + _client_ptr.reset(paddle::ps::PSClientFactory::create(_ps_param)); + _client_ptr->configure(_ps_param, *(_environment->ps_environment()), + _environment->rank_id(role)); + return 0; +} + +paddle::ps::PSServer* PSlib::ps_server() { + return _server_ptr.get(); +} + +paddle::ps::PSClient* PSlib::ps_client() { + return _client_ptr.get(); +} + +paddle::PSParameter* PSlib::get_param() { + return &_ps_param; +} + +void PSlib::init_gflag() { + int cnt = 4; + std::shared_ptr params(new char*[cnt]); + char** params_ptr = params.get(); + char p0[] = "exe default"; + char p1[] = "-max_body_size=314217728"; + char p2[] = "-bthread_concurrency=40"; + char p3[] = "-socket_max_unwritten_bytes=2048000000"; + params_ptr[0] = p0; + params_ptr[1] = p1; + params_ptr[2] = p2; + params_ptr[3] = p3; + ::google::ParseCommandLineFlags(&cnt, ¶ms_ptr, true); +} + +} // namespace feed +} // namespace custom_trainer +} // namespace paddle diff --git a/paddle/fluid/train/custom_trainer/feed/common/pslib_warpper.h b/paddle/fluid/train/custom_trainer/feed/common/pslib_warpper.h new file mode 100644 index 0000000000000000000000000000000000000000..0180a8e2ec479491948ed8a7b3e6cd304f74e168 --- /dev/null +++ b/paddle/fluid/train/custom_trainer/feed/common/pslib_warpper.h @@ -0,0 +1,49 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "communicate/ps_server.h" +#include "communicate/ps_client.h" + +namespace paddle { +namespace custom_trainer { +namespace feed { + +class RuntimeEnvironment; +enum class EnvironmentRole; +class PSlib { +public: + PSlib() {} + virtual ~PSlib() {} + int initialize(const std::string& conf_path, + RuntimeEnvironment* environment, EnvironmentRole role); + + virtual paddle::ps::PSServer* ps_server(); + virtual paddle::ps::PSClient* ps_client(); + virtual paddle::PSParameter* get_param(); +private: + void init_gflag(); + virtual int init_server(EnvironmentRole role); + virtual int init_client(EnvironmentRole role); + + paddle::PSParameter _ps_param; + RuntimeEnvironment* _environment; + std::shared_ptr _server_ptr; + std::shared_ptr _client_ptr; +}; + +} // namespace feed +} // namespace custom_trainer +} // namespace paddle diff --git a/paddle/fluid/train/custom_trainer/feed/common/registerer.cc b/paddle/fluid/train/custom_trainer/feed/common/registerer.cc index 04382b47eecf8437828e845137a0bf23d485c638..c2dff1517dfbe88634e16f2bd1068b0688d6113d 100644 --- a/paddle/fluid/train/custom_trainer/feed/common/registerer.cc +++ b/paddle/fluid/train/custom_trainer/feed/common/registerer.cc @@ -3,12 +3,12 @@ namespace paddle { namespace custom_trainer { namespace feed { -BaseClassMap& global_factory_map() { +BaseClassMap& global_reg_factory_map() { static BaseClassMap *base_class = new BaseClassMap(); return *base_class; } -BaseClassMap& global_factory_map_cpp() { - return global_factory_map(); +BaseClassMap& global_reg_factory_map_cpp() { + return global_reg_factory_map(); } }// feed diff --git a/paddle/fluid/train/custom_trainer/feed/common/registerer.h b/paddle/fluid/train/custom_trainer/feed/common/registerer.h index eb57cabea97398b620e94c03fc38975947fd60dd..b5399fdc9df1dafa87fea896a91f3855ff5605af 100644 --- a/paddle/fluid/train/custom_trainer/feed/common/registerer.h +++ b/paddle/fluid/train/custom_trainer/feed/common/registerer.h @@ -63,23 +63,23 @@ typedef std::map BaseClassMap; #ifdef __cplusplus extern "C" { #endif -BaseClassMap& global_factory_map(); +BaseClassMap& global_reg_factory_map(); #ifdef __cplusplus } #endif -BaseClassMap& global_factory_map_cpp(); +BaseClassMap& global_reg_factory_map_cpp(); -#define REGISTER_REGISTERER(base_class) \ +#define REGIST_REGISTERER(base_class) \ class base_class ## Registerer { \ public: \ static base_class *CreateInstanceByName(const ::std::string &name) { \ - if (global_factory_map_cpp().find(#base_class) \ - == global_factory_map_cpp().end()) { \ + if (global_reg_factory_map_cpp().find(#base_class) \ + == global_reg_factory_map_cpp().end()) { \ LOG(ERROR) << "Can't Find BaseClass For CreateClass with:" << #base_class; \ return NULL; \ } \ - FactoryMap &map = global_factory_map_cpp()[#base_class]; \ + FactoryMap &map = global_reg_factory_map_cpp()[#base_class]; \ FactoryMap::iterator iter = map.find(name); \ if (iter == map.end()) { \ LOG(ERROR) << "Can't Find Class For Create with:" << name; \ @@ -90,7 +90,7 @@ BaseClassMap& global_factory_map_cpp(); } \ }; -#define REGISTER_CLASS(clazz, name) \ +#define REGIST_CLASS(clazz, name) \ class ObjectFactory##name : public ObjectFactory { \ public: \ Any NewInstance() { \ @@ -98,14 +98,14 @@ BaseClassMap& global_factory_map_cpp(); } \ }; \ void register_factory_##name() { \ - FactoryMap &map = global_factory_map_cpp()[#clazz]; \ + FactoryMap &map = global_reg_factory_map_cpp()[#clazz]; \ if (map.find(#name) == map.end()) { \ map[#name] = new ObjectFactory##name(); \ } \ } \ void register_factory_##name() __attribute__((constructor)); -#define CREATE_CLASS(base_class, name) \ +#define CREATE_INSTANCE(base_class, name) \ base_class##Registerer::CreateInstanceByName(name) }//namespace feed diff --git a/paddle/fluid/train/custom_trainer/feed/common/runtime_environment.cc b/paddle/fluid/train/custom_trainer/feed/common/runtime_environment.cc index 83435853b580f24e01cc26b5d57c9b8258714a34..66f59b9cf5e0a97cdf95d4aeae1b701373290ce6 100644 --- a/paddle/fluid/train/custom_trainer/feed/common/runtime_environment.cc +++ b/paddle/fluid/train/custom_trainer/feed/common/runtime_environment.cc @@ -44,6 +44,11 @@ public: set_role(EnvironmentRole::ALL); return 0; } + + virtual paddle::ps::PSEnvironment* ps_environment() { + static paddle::ps::MpiPSEnvironment ps_environment; + return &ps_environment; + } virtual uint32_t rank_id(EnvironmentRole role) { return mpi_node_info(role).rank_id; @@ -95,7 +100,7 @@ protected: private: std::vector _roles_node_info; }; -REGISTER_CLASS(RuntimeEnvironment, MPIRuntimeEnvironment); +REGIST_CLASS(RuntimeEnvironment, MPIRuntimeEnvironment); //用于本地模式单机训练 class LocalRuntimeEnvironment : public RuntimeEnvironment { @@ -108,6 +113,10 @@ public: virtual int wireup() { return 0; } + virtual paddle::ps::PSEnvironment* ps_environment() { + static paddle::ps::LocalPSEnvironment ps_environment; + return &ps_environment; + } virtual uint32_t rank_id(EnvironmentRole role) { return 0; } @@ -129,7 +138,7 @@ protected: VLOG(static_cast(level)) << log_str; } }; -REGISTER_CLASS(RuntimeEnvironment, LocalRuntimeEnvironment); +REGIST_CLASS(RuntimeEnvironment, LocalRuntimeEnvironment); } // namespace feed } // namespace custom_trainer diff --git a/paddle/fluid/train/custom_trainer/feed/common/runtime_environment.h b/paddle/fluid/train/custom_trainer/feed/common/runtime_environment.h index d0277483a95c68dc790ad85527f87767218b02fb..5e567c787d7db0fde69fe9ecdfaef21de9061f74 100644 --- a/paddle/fluid/train/custom_trainer/feed/common/runtime_environment.h +++ b/paddle/fluid/train/custom_trainer/feed/common/runtime_environment.h @@ -6,6 +6,7 @@ */ #pragma once #include +#include "communicate/ps_env.h" #include "paddle/fluid/framework/archive.h" #include "paddle/fluid/string/string_helper.h" #include "paddle/fluid/train/custom_trainer/feed/common/registerer.h" @@ -14,6 +15,8 @@ namespace paddle { namespace custom_trainer { namespace feed { +class paddle::ps::PSEnvironment; + enum class EnvironmentLogLevel { FATAL = 0, ERROR = 1, @@ -38,41 +41,43 @@ class RuntimeEnvironment { public: RuntimeEnvironment(); virtual ~RuntimeEnvironment(); - //配置初始化 + // 配置初始化 virtual int initialize(YAML::Node config) = 0; - //设置role + // 设置role virtual int set_role(EnvironmentRole role) = 0; - //环境初始化,会在所有依赖模块initialize后调用 + // 环境初始化,会在所有依赖模块initialize后调用 virtual int wireup() = 0; - //多线程可调用接口 Start - //当前环境rank_idx + // 多线程可调用接口 Start + // 当前环境rank_idx virtual uint32_t rank_id(EnvironmentRole role) = 0; - //运行环境节点数 + // 运行环境节点数 virtual uint32_t node_num(EnvironmentRole role) = 0; - //环境内主节点 + // 环境内主节点 virtual bool is_master_node(EnvironmentRole role); + //For PS + virtual paddle::ps::PSEnvironment* ps_environment() = 0; - //环境定制化log + // 环境定制化log template void log(EnvironmentRole role, EnvironmentLogType type, EnvironmentLogLevel level, const char* fmt, ARGS && ... args) { print_log(role, type, level, paddle::string::format_string(fmt, args...)); } - //多线程可调用接口 End + // 多线程可调用接口 End - //接口只允许在主线程调用 Start - //barrier 指定role的节点 + // 接口只允许在主线程调用 Start + // barrier 指定role的节点 virtual void barrier(EnvironmentRole role) = 0; - //bcast 广播 + // bcast 广播 virtual void bcast(paddle::framework::BinaryArchive& ar, int root_id, EnvironmentRole role) = 0; - //接口只允许在主线程调用 End + // 接口只允许在主线程调用 End protected: virtual void print_log(EnvironmentRole role, EnvironmentLogType type, EnvironmentLogLevel level, const std::string& log_str) = 0; }; -REGISTER_REGISTERER(RuntimeEnvironment); +REGIST_REGISTERER(RuntimeEnvironment); std::string format_timestamp(time_t time, const char* format); inline std::string format_timestamp(time_t time, const std::string& format) { diff --git a/paddle/fluid/train/custom_trainer/feed/dataset/data_reader.cc b/paddle/fluid/train/custom_trainer/feed/dataset/data_reader.cc index 31ed1870e2ec8b8c11fc36eb55baa679a2286483..7e03677054395e63dc15090b7a64fd21cf93d48c 100644 --- a/paddle/fluid/train/custom_trainer/feed/dataset/data_reader.cc +++ b/paddle/fluid/train/custom_trainer/feed/dataset/data_reader.cc @@ -56,10 +56,10 @@ public: return 0; } }; -REGISTER_CLASS(DataParser, LineDataParser); +REGIST_CLASS(DataParser, LineDataParser); int DataReader::initialize(const YAML::Node& config, std::shared_ptr context) { - _parser.reset(CREATE_CLASS(DataParser, config["parser"]["class"].as())); + _parser.reset(CREATE_INSTANCE(DataParser, config["parser"]["class"].as())); if (_parser == nullptr) { VLOG(2) << "fail to get parser: " << config["parser"]["class"].as(); return -1; @@ -85,7 +85,7 @@ public: if (config["file_system"] && config["file_system"]["class"]) { _file_system.reset( - CREATE_CLASS(FileSystem, config["file_system"]["class"].as())); + CREATE_INSTANCE(FileSystem, config["file_system"]["class"].as())); if (_file_system == nullptr || _file_system->initialize(config["file_system"], context) != 0) { VLOG(2) << "fail to create class: " @@ -95,7 +95,7 @@ public: } else if (context->file_system != nullptr) { _file_system = context->file_system; } else { - _file_system.reset(CREATE_CLASS(FileSystem, "LocalFileSystem")); + _file_system.reset(CREATE_INSTANCE(FileSystem, "LocalFileSystem")); if (_file_system == nullptr || _file_system->initialize(YAML::Load(""), context) != 0) { VLOG(2) << "fail to init file system"; return -1; @@ -203,7 +203,7 @@ private: std::string _filename_prefix; std::shared_ptr _file_system; }; -REGISTER_CLASS(DataReader, LineDataReader); +REGIST_CLASS(DataReader, LineDataReader); } // namespace feed } // namespace custom_trainer diff --git a/paddle/fluid/train/custom_trainer/feed/dataset/data_reader.h b/paddle/fluid/train/custom_trainer/feed/dataset/data_reader.h index c24db38ae59312a2c9406c7c4ece98d56b91eea1..7b6d1c3f679d6e4295cbb8acd0d81a192f9293a4 100644 --- a/paddle/fluid/train/custom_trainer/feed/dataset/data_reader.h +++ b/paddle/fluid/train/custom_trainer/feed/dataset/data_reader.h @@ -54,7 +54,7 @@ public: virtual int parse(const char* str, DataItem& data) const = 0; virtual int parse_to_sample(const DataItem& data, SampleInstance& instance) const = 0; }; -REGISTER_REGISTERER(DataParser); +REGIST_REGISTERER(DataParser); class DataReader { public: @@ -76,7 +76,7 @@ protected: std::shared_ptr _parser;//数据格式转换 std::string _pipeline_cmd; //将文件流,重定向到pipeline_cmd,再读入 }; -REGISTER_REGISTERER(DataReader); +REGIST_REGISTERER(DataReader); }//namespace feed }//namespace custom_trainer diff --git a/paddle/fluid/train/custom_trainer/feed/dataset/dataset_container.cc b/paddle/fluid/train/custom_trainer/feed/dataset/dataset_container.cc index 2f43e735ac525defa58139a6e0e0666c16eaa864..5fcfb9d616627511a16f315cae16be19a3b67ee6 100644 --- a/paddle/fluid/train/custom_trainer/feed/dataset/dataset_container.cc +++ b/paddle/fluid/train/custom_trainer/feed/dataset/dataset_container.cc @@ -32,7 +32,7 @@ int DatasetContainer::initialize( _data_split_interval = config["data_spit_interval"].as(); _data_path_formater = config["data_path_formater"].as(); std::string data_reader_class = config["data_reader"].as(); - DataReader* data_reader = CREATE_CLASS(DataReader, data_reader_class); + DataReader* data_reader = CREATE_INSTANCE(DataReader, data_reader_class); _data_reader.reset(data_reader); return _data_reader->initialize(config, context); } diff --git a/paddle/fluid/train/custom_trainer/feed/executor/executor.cc b/paddle/fluid/train/custom_trainer/feed/executor/executor.cc index a7d5bdf2463894255f9d1e59c7ad1fef1349e1f2..73f8a601d10f160c764f76db94c9196624c05f30 100644 --- a/paddle/fluid/train/custom_trainer/feed/executor/executor.cc +++ b/paddle/fluid/train/custom_trainer/feed/executor/executor.cc @@ -121,7 +121,7 @@ protected: std::unique_ptr _context; }; -REGISTER_CLASS(Executor, SimpleExecutor); +REGIST_CLASS(Executor, SimpleExecutor); } // namespace feed } // namespace custom_trainer diff --git a/paddle/fluid/train/custom_trainer/feed/executor/executor.h b/paddle/fluid/train/custom_trainer/feed/executor/executor.h index f51f3f114390a6c9c9179b224706f71217761f76..3ae40817eb56ca8738fbec13b92d8be86f5b6094 100644 --- a/paddle/fluid/train/custom_trainer/feed/executor/executor.h +++ b/paddle/fluid/train/custom_trainer/feed/executor/executor.h @@ -40,7 +40,7 @@ public: protected: ::paddle::framework::Scope _scope; }; -REGISTER_REGISTERER(Executor); +REGIST_REGISTERER(Executor); } // namespace feed } // namespace custom_trainer diff --git a/paddle/fluid/train/custom_trainer/feed/io/auto_file_system.cc b/paddle/fluid/train/custom_trainer/feed/io/auto_file_system.cc index 16bbfed5f4a801ff58dcf049f3161ff1a6afa107..a588b4d16cb3598799872bb3b8441b395295d307 100644 --- a/paddle/fluid/train/custom_trainer/feed/io/auto_file_system.cc +++ b/paddle/fluid/train/custom_trainer/feed/io/auto_file_system.cc @@ -31,7 +31,7 @@ public: _file_system.clear(); if (config && config["file_systems"] && config["file_systems"].Type() == YAML::NodeType::Map) { for (auto& prefix_fs: config["file_systems"]) { - std::unique_ptr fs(CREATE_CLASS(FileSystem, prefix_fs.second["class"].as(""))); + std::unique_ptr fs(CREATE_INSTANCE(FileSystem, prefix_fs.second["class"].as(""))); if (fs == nullptr) { VLOG(2) << "fail to create class: " << prefix_fs.second["class"].as(""); return -1; @@ -44,7 +44,7 @@ public: } } if (_file_system.find("default") == _file_system.end()) { - std::unique_ptr fs(CREATE_CLASS(FileSystem, "LocalFileSystem")); + std::unique_ptr fs(CREATE_INSTANCE(FileSystem, "LocalFileSystem")); if (fs == nullptr || fs->initialize(YAML::Load(""), context) != 0) { return -1; } @@ -122,7 +122,7 @@ public: private: std::unordered_map> _file_system; }; -REGISTER_CLASS(FileSystem, AutoFileSystem); +REGIST_CLASS(FileSystem, AutoFileSystem); } // namespace feed } // namespace custom_trainer diff --git a/paddle/fluid/train/custom_trainer/feed/io/file_system.h b/paddle/fluid/train/custom_trainer/feed/io/file_system.h index d7aa1cc2d67df2c4f76c7873335f1818eeacd736..4a157d697e0f0dcfa3459172eae17b8c042af55b 100644 --- a/paddle/fluid/train/custom_trainer/feed/io/file_system.h +++ b/paddle/fluid/train/custom_trainer/feed/io/file_system.h @@ -52,7 +52,7 @@ public: protected: int _err_no = 0; }; -REGISTER_REGISTERER(FileSystem); +REGIST_REGISTERER(FileSystem); } // namespace feed } // namespace custom_trainer diff --git a/paddle/fluid/train/custom_trainer/feed/io/hadoop_file_system.cc b/paddle/fluid/train/custom_trainer/feed/io/hadoop_file_system.cc index d05affc5606b3a5eaf3a48adbd213d22efe2ac78..d61e1fe1be08bef601ef695256d610013d8dd457 100644 --- a/paddle/fluid/train/custom_trainer/feed/io/hadoop_file_system.cc +++ b/paddle/fluid/train/custom_trainer/feed/io/hadoop_file_system.cc @@ -203,7 +203,7 @@ private: std::string _hdfs_command; std::unordered_map _ugi; }; -REGISTER_CLASS(FileSystem, HadoopFileSystem); +REGIST_CLASS(FileSystem, HadoopFileSystem); } // namespace feed } // namespace custom_trainer diff --git a/paddle/fluid/train/custom_trainer/feed/io/local_file_system.cc b/paddle/fluid/train/custom_trainer/feed/io/local_file_system.cc index 287d3e0ae5545a9a6ec288b392a37ab4de586388..78cd8357c979ee58587e0261802236ab84fddce2 100644 --- a/paddle/fluid/train/custom_trainer/feed/io/local_file_system.cc +++ b/paddle/fluid/train/custom_trainer/feed/io/local_file_system.cc @@ -129,7 +129,7 @@ public: private: size_t _buffer_size = 0; }; -REGISTER_CLASS(FileSystem, LocalFileSystem); +REGIST_CLASS(FileSystem, LocalFileSystem); } // namespace feed } // namespace custom_trainer diff --git a/paddle/fluid/train/custom_trainer/feed/main.cc b/paddle/fluid/train/custom_trainer/feed/main.cc index ea3140c62d348fcbaaf34b650b6323c1a790a0d1..f14985215369f5a48ba433b07a83e0204ed7cd28 100644 --- a/paddle/fluid/train/custom_trainer/feed/main.cc +++ b/paddle/fluid/train/custom_trainer/feed/main.cc @@ -1,8 +1,8 @@ #include #include #include -#include "paddle/fluid/platform/init.h" #include "paddle/fluid/train/custom_trainer/feed/trainer_context.h" +#include "paddle/fluid/platform/init.h" #include "paddle/fluid/train/custom_trainer/feed/process/process.h" #include "paddle/fluid/train/custom_trainer/feed/process/init_env_process.h" #include "paddle/fluid/framework/op_registry.h" @@ -21,27 +21,56 @@ int main(int argc, char* argv[]) { //load trainer config auto trainer_context_ptr = std::make_shared(); trainer_context_ptr->trainer_config = YAML::LoadFile(FLAGS_feed_trainer_conf_path); - + + //environment + auto& config = trainer_context_ptr->trainer_config; + std::string env_class = config["environment"]["environment_class"].as(); + trainer_context_ptr->environment.reset(CREATE_INSTANCE(RuntimeEnvironment, env_class)); + if (trainer_context_ptr->environment->initialize(config["environment"]) != 0) { + return -1; + } + EnvironmentRole role; + auto* environment = trainer_context_ptr->environment.get(); + environment->wireup(); + if (environment->rank_id(EnvironmentRole::ALL) % 2 == 0) { + role = EnvironmentRole::WORKER; + } else { + role = EnvironmentRole::PSERVER; + } + environment->set_role(role); + trainer_context_ptr->pslib.reset(new PSlib()); + std::string ps_config = config["environment"]["ps"].as(); + trainer_context_ptr->pslib->initialize(ps_config, environment, role); + //VLOG(3) << "Node Start With Role:" << role; + std::vector process_name_list = { "InitEnvProcess", "LearnerProcess" }; - - for (const auto& process_name : process_name_list) { - Process* process = CREATE_CLASS(Process, process_name); - if (process == NULL) { - VLOG(1) << "Process:" << process_name << " does not exist"; - return -1; + switch (role) { + case EnvironmentRole::WORKER: + for (const auto& process_name : process_name_list) { + Process* process = CREATE_INSTANCE(Process, process_name); + if (process == NULL) { + VLOG(1) << "Process:" << process_name << " does not exist"; + return -1; + } + if (process->initialize(trainer_context_ptr) != 0) { + VLOG(1) << "Process:" << process_name << " initialize failed"; + return -1; + } + trainer_context_ptr->process_list.push_back(std::shared_ptr(process)); + } + for (auto& process : trainer_context_ptr->process_list) { + process->run(); } - if (process->initialize(trainer_context_ptr) != 0) { - VLOG(1) << "Process:" << process_name << " initialize failed"; - return -1; + break; + case EnvironmentRole::PSERVER: + //wait server done + while (true) { + sleep(10000); } - trainer_context_ptr->process_list.push_back(std::shared_ptr(process)); - } - - for (auto& process : trainer_context_ptr->process_list) { - process->run(); + break; } return 0; diff --git a/paddle/fluid/train/custom_trainer/feed/monitor/monitor.h b/paddle/fluid/train/custom_trainer/feed/monitor/monitor.h index a7b9186c92fed694d67549d22d7cb197f285dfcb..0bffde6ced90562141309f0c794cf9a14abc4383 100644 --- a/paddle/fluid/train/custom_trainer/feed/monitor/monitor.h +++ b/paddle/fluid/train/custom_trainer/feed/monitor/monitor.h @@ -39,7 +39,7 @@ protected: std::string _name; }; -REGISTER_REGISTERER(Monitor); +REGIST_REGISTERER(Monitor); } // namespace feed } // namespace custom_trainer diff --git a/paddle/fluid/train/custom_trainer/feed/process/init_env_process.cc b/paddle/fluid/train/custom_trainer/feed/process/init_env_process.cc index 5c6a8cf7bdb0410b6261802505d8d317fb912115..a45320a685fbeffc9c68b4a7e99ddf8a3a3bf123 100644 --- a/paddle/fluid/train/custom_trainer/feed/process/init_env_process.cc +++ b/paddle/fluid/train/custom_trainer/feed/process/init_env_process.cc @@ -20,22 +20,16 @@ int InitEnvProcess::initialize(std::shared_ptr context_ptr) { context_ptr->cpu_place = paddle::platform::CPUPlace(); YAML::Node config = _context_ptr->trainer_config; - //environment - std::string env_class = config["environment"]["environment_class"].as(); - context_ptr->environment.reset(CREATE_CLASS(RuntimeEnvironment, env_class)); - if (context_ptr->environment->initialize(config["environment"]) != 0) { - return -1; - } //file_system - context_ptr->file_system.reset(CREATE_CLASS(FileSystem, "AutoFileSystem")); + context_ptr->file_system.reset(CREATE_INSTANCE(FileSystem, "AutoFileSystem")); if (context_ptr->file_system->initialize(config["io"], context_ptr) != 0) { return -1; } //epoch std::string epoch_class = config["epoch"]["epoch_class"].as(); - context_ptr->epoch_accessor.reset(CREATE_CLASS(EpochAccessor, epoch_class)); + context_ptr->epoch_accessor.reset(CREATE_INSTANCE(EpochAccessor, epoch_class)); if (context_ptr->epoch_accessor->initialize(config["epoch"], context_ptr) != 0) { return -1; } @@ -55,10 +49,12 @@ int InitEnvProcess::run() { VLOG(3) << "Trainer Resume From epoch:" << epoch_accessor->current_epoch_id(); auto next_epoch_id = epoch_accessor->next_epoch_id(epoch_accessor->current_epoch_id()); _context_ptr->dataset->pre_detect_data(next_epoch_id); - //step 1. psserver init - //step2. psserver load - VLOG(3) << "Psserver Start Success"; - + + if (epoch_accessor->checkpoint_path().size() > 0) { + //Load Model + } else { + //Random Init Model + } //context_ptr->pslib_client()->load_model(); VLOG(3) << "Psserver Load Model Success"; return 0; diff --git a/paddle/fluid/train/custom_trainer/feed/process/learner_process.cc b/paddle/fluid/train/custom_trainer/feed/process/learner_process.cc index 146e0277ae7656a7bb7de947406eedb64b2a8888..2b11f61b3fa15ceb1c1cba9727635cfa1004f0c4 100644 --- a/paddle/fluid/train/custom_trainer/feed/process/learner_process.cc +++ b/paddle/fluid/train/custom_trainer/feed/process/learner_process.cc @@ -25,7 +25,7 @@ int LearnerProcess::initialize(std::shared_ptr context_ptr) { _threads_executor[i].resize(_executor_num); for (int e = 0; e < _executor_num; ++e) { auto e_class = config["executor"][e]["class"].as(); - auto* e_ptr = CREATE_CLASS(Executor, e_class); + auto* e_ptr = CREATE_INSTANCE(Executor, e_class); _threads_executor[i][e].reset(e_ptr); if (e_ptr->initialize(config["executor"][e], context_ptr) != 0) { ret = -1; @@ -84,53 +84,59 @@ int LearnerProcess::run() { while (true) { epoch_accessor->next_epoch(); + bool already_dump_inference_model = false; epoch_id = epoch_accessor->current_epoch_id(); std::string epoch_log_title= paddle::string::format_string( "train epoch_id:%d label:%s", epoch_id, epoch_accessor->text(epoch_id).c_str()); //Step1. 等待样本ready - environment->log(EnvironmentRole::WORKER, EnvironmentLogType::MASTER_LOG, EnvironmentLogLevel::NOTICE, - "Start %s, wait data ready", epoch_log_title.c_str()); - while (dataset->epoch_data_status(epoch_id) != DatasetStatus::Ready) { - sleep(30); - dataset->pre_detect_data(epoch_id); + { environment->log(EnvironmentRole::WORKER, EnvironmentLogType::MASTER_LOG, EnvironmentLogLevel::NOTICE, - "%s, data not ready, wait 30s", epoch_log_title.c_str()); - } - environment->log(EnvironmentRole::WORKER, EnvironmentLogType::MASTER_LOG, EnvironmentLogLevel::NOTICE, - "%s, data is ready, start traning", epoch_log_title.c_str()); - environment->barrier(EnvironmentRole::WORKER); - + "Start %s, wait data ready", epoch_log_title.c_str()); + while (dataset->epoch_data_status(epoch_id) != DatasetStatus::Ready) { + sleep(30); + dataset->pre_detect_data(epoch_id); + environment->log(EnvironmentRole::WORKER, EnvironmentLogType::MASTER_LOG, EnvironmentLogLevel::NOTICE, + "%s, data not ready, wait 30s", epoch_log_title.c_str()); + } + environment->log(EnvironmentRole::WORKER, EnvironmentLogType::MASTER_LOG, EnvironmentLogLevel::NOTICE, + "%s, data is ready, start traning", epoch_log_title.c_str()); + environment->barrier(EnvironmentRole::WORKER); + } + //Step2. 运行训练网络 - bool already_dump_inference_model = false; - for (int i = 0; i < _executor_num; ++i) { - std::vector> train_threads(_train_thread_num); - for (int thread_id = 0; thread_id < _train_thread_num; ++thread_id) { - train_threads[i].reset(new std::thread([this](int exe_idx, int thread_idx) { - auto* executor = _threads_executor[thread_idx][exe_idx].get(); - run_executor(executor); - }, i, thread_id)); - } - for (int i = 0; i < _train_thread_num; ++i) { - train_threads[i]->join(); + { + for (int i = 0; i < _executor_num; ++i) { + std::vector> train_threads(_train_thread_num); + for (int thread_id = 0; thread_id < _train_thread_num; ++thread_id) { + train_threads[i].reset(new std::thread([this](int exe_idx, int thread_idx) { + auto* executor = _threads_executor[thread_idx][exe_idx].get(); + run_executor(executor); + }, i, thread_id)); + } + for (int i = 0; i < _train_thread_num; ++i) { + train_threads[i]->join(); + } + environment->barrier(EnvironmentRole::WORKER); + + if (_threads_executor[0][i]->is_dump_all_model()) { + already_dump_inference_model = true; + wait_save_model(epoch_id, ModelSaveWay::ModelSaveInferenceDelta); + } + environment->barrier(EnvironmentRole::WORKER); } - environment->barrier(EnvironmentRole::WORKER); + } - if (_threads_executor[0][i]->is_dump_all_model()) { + //Step3. Dump Model For Delta&&Checkpoint + { + if (!already_dump_inference_model) { already_dump_inference_model = true; wait_save_model(epoch_id, ModelSaveWay::ModelSaveInferenceDelta); - } + } + wait_save_model(epoch_id, ModelSaveWay::ModelSaveTrainCheckpoint); environment->barrier(EnvironmentRole::WORKER); } - - //Step3. Dump Model For Delta&&Checkpoint - if (!already_dump_inference_model) { - already_dump_inference_model = true; - wait_save_model(epoch_id, ModelSaveWay::ModelSaveInferenceDelta); - } - wait_save_model(epoch_id, ModelSaveWay::ModelSaveTrainCheckpoint); - environment->barrier(EnvironmentRole::WORKER); - + //Step4. Output Monitor && RunStatus //TODO } diff --git a/paddle/fluid/train/custom_trainer/feed/process/process.cc b/paddle/fluid/train/custom_trainer/feed/process/process.cc index 5226c8c59228d89aaa5347e64402f3731f0ae102..0e1cd5fcbeb9bbca14a5822431347cd05a7f2dfd 100644 --- a/paddle/fluid/train/custom_trainer/feed/process/process.cc +++ b/paddle/fluid/train/custom_trainer/feed/process/process.cc @@ -5,8 +5,8 @@ namespace paddle { namespace custom_trainer { namespace feed { -REGISTER_CLASS(Process, InitEnvProcess); -REGISTER_CLASS(Process, LearnerProcess); +REGIST_CLASS(Process, InitEnvProcess); +REGIST_CLASS(Process, LearnerProcess); int Process::run() { return 0; } diff --git a/paddle/fluid/train/custom_trainer/feed/process/process.h b/paddle/fluid/train/custom_trainer/feed/process/process.h index 2e83e63cbb1e0907e6e62d227d27047dbd350444..127481e9371ca704c2b3cef991241c6d542b4be3 100644 --- a/paddle/fluid/train/custom_trainer/feed/process/process.h +++ b/paddle/fluid/train/custom_trainer/feed/process/process.h @@ -18,7 +18,7 @@ public: protected: TrainerContext* _context_ptr = NULL; }; -REGISTER_REGISTERER(Process); +REGIST_REGISTERER(Process); } // namespace feed } // namespace custom_trainer diff --git a/paddle/fluid/train/custom_trainer/feed/trainer_context.h b/paddle/fluid/train/custom_trainer/feed/trainer_context.h index 01b5e0f845a94ae5c9b8eeeb0e4b9c22b38a414a..212f32d370a5399eb3566b9ef1e741854c8de03d 100644 --- a/paddle/fluid/train/custom_trainer/feed/trainer_context.h +++ b/paddle/fluid/train/custom_trainer/feed/trainer_context.h @@ -4,6 +4,7 @@ #include #include #include "paddle/fluid/platform/place.h" +#include "paddle/fluid/train/custom_trainer/feed/common/pslib_warpper.h" #include "paddle/fluid/train/custom_trainer/feed/common/runtime_environment.h" @@ -38,6 +39,7 @@ public: YAML::Node trainer_config; paddle::platform::CPUPlace cpu_place; +std::shared_ptr pslib; std::shared_ptr dataset; //训练样本 std::shared_ptr file_system; //文件操作辅助类 std::vector params_table_list; //参数表 diff --git a/paddle/fluid/train/custom_trainer/feed/unit_test/test_datareader.cc b/paddle/fluid/train/custom_trainer/feed/unit_test/test_datareader.cc index 5fed50cb61001c2ff2cb4b1ed16e6897aed897f8..0e183d93f7afb7daf59b4c46dbc3fa6659bf2952 100644 --- a/paddle/fluid/train/custom_trainer/feed/unit_test/test_datareader.cc +++ b/paddle/fluid/train/custom_trainer/feed/unit_test/test_datareader.cc @@ -17,6 +17,7 @@ limitations under the License. */ #include #include +#include "paddle/fluid/train/custom_trainer/feed/trainer_context.h" #include "paddle/fluid/train/custom_trainer/feed/executor/executor.h" #include "paddle/fluid/framework/tensor_util.h" #include "paddle/fluid/framework/program_desc.h" @@ -36,7 +37,7 @@ const char test_data_dir[] = "test_data"; class DataReaderTest : public testing::Test { public: static void SetUpTestCase() { - std::unique_ptr fs(CREATE_CLASS(FileSystem, "LocalFileSystem")); + std::unique_ptr fs(CREATE_INSTANCE(FileSystem, "LocalFileSystem")); fs->mkdir(test_data_dir); shell_set_verbose(true); @@ -56,14 +57,14 @@ public: } static void TearDownTestCase() { - std::unique_ptr fs(CREATE_CLASS(FileSystem, "LocalFileSystem")); + std::unique_ptr fs(CREATE_INSTANCE(FileSystem, "LocalFileSystem")); fs->remove(test_data_dir); } virtual void SetUp() { thread_num = omp_get_max_threads(); omp_set_num_threads(1); - fs.reset(CREATE_CLASS(FileSystem, "LocalFileSystem")); + fs.reset(CREATE_INSTANCE(FileSystem, "LocalFileSystem")); context_ptr.reset(new TrainerContext()); } @@ -79,7 +80,7 @@ public: }; TEST_F(DataReaderTest, LineDataParser) { - std::unique_ptr data_parser(CREATE_CLASS(DataParser, "LineDataParser")); + std::unique_ptr data_parser(CREATE_INSTANCE(DataParser, "LineDataParser")); ASSERT_NE(nullptr, data_parser); auto config = YAML::Load(""); @@ -108,7 +109,7 @@ TEST_F(DataReaderTest, LineDataParser) { } TEST_F(DataReaderTest, LineDataReader) { - std::unique_ptr data_reader(CREATE_CLASS(DataReader, "LineDataReader")); + std::unique_ptr data_reader(CREATE_INSTANCE(DataReader, "LineDataReader")); ASSERT_NE(nullptr, data_reader); auto config = YAML::Load( @@ -161,7 +162,7 @@ TEST_F(DataReaderTest, LineDataReader) { } TEST_F(DataReaderTest, LineDataReader_filename_prefix) { - std::unique_ptr data_reader(CREATE_CLASS(DataReader, "LineDataReader")); + std::unique_ptr data_reader(CREATE_INSTANCE(DataReader, "LineDataReader")); ASSERT_NE(nullptr, data_reader); auto config = YAML::Load( "parser:\n" @@ -196,7 +197,7 @@ TEST_F(DataReaderTest, LineDataReader_filename_prefix) { } TEST_F(DataReaderTest, LineDataReader_FileSystem) { - std::unique_ptr data_reader(CREATE_CLASS(DataReader, "LineDataReader")); + std::unique_ptr data_reader(CREATE_INSTANCE(DataReader, "LineDataReader")); ASSERT_NE(nullptr, data_reader); auto config = YAML::Load( "parser:\n" diff --git a/paddle/fluid/train/custom_trainer/feed/unit_test/test_datareader_omp.cc b/paddle/fluid/train/custom_trainer/feed/unit_test/test_datareader_omp.cc index 59e6e2cfa21b2d8cc1f8d009703e2e7a1bb111b3..b181f7b73faa1a4ab0e03bd63b1ed9cb9584c438 100644 --- a/paddle/fluid/train/custom_trainer/feed/unit_test/test_datareader_omp.cc +++ b/paddle/fluid/train/custom_trainer/feed/unit_test/test_datareader_omp.cc @@ -18,6 +18,7 @@ limitations under the License. */ #include #include +#include "paddle/fluid/train/custom_trainer/feed/trainer_context.h" #include "paddle/fluid/train/custom_trainer/feed/executor/executor.h" #include "paddle/fluid/framework/tensor_util.h" #include "paddle/fluid/framework/program_desc.h" @@ -37,7 +38,7 @@ const char test_data_dir[] = "test_data"; class DataReaderOmpTest : public testing::Test { public: static void SetUpTestCase() { - std::unique_ptr fs(CREATE_CLASS(FileSystem, "LocalFileSystem")); + std::unique_ptr fs(CREATE_INSTANCE(FileSystem, "LocalFileSystem")); fs->mkdir(test_data_dir); shell_set_verbose(true); std_items.clear(); @@ -61,14 +62,14 @@ public: } static void TearDownTestCase() { - std::unique_ptr fs(CREATE_CLASS(FileSystem, "LocalFileSystem")); + std::unique_ptr fs(CREATE_INSTANCE(FileSystem, "LocalFileSystem")); fs->remove(test_data_dir); } virtual void SetUp() { thread_num = omp_get_max_threads(); omp_set_num_threads(1); - fs.reset(CREATE_CLASS(FileSystem, "LocalFileSystem")); + fs.reset(CREATE_INSTANCE(FileSystem, "LocalFileSystem")); context_ptr.reset(new TrainerContext()); } @@ -111,7 +112,7 @@ std::vector DataReaderOmpTest::std_items; std::vector DataReaderOmpTest::sorted_std_items; TEST_F(DataReaderOmpTest, LineDataReaderSingleThread) { - std::unique_ptr data_reader(CREATE_CLASS(DataReader, "LineDataReader")); + std::unique_ptr data_reader(CREATE_INSTANCE(DataReader, "LineDataReader")); ASSERT_NE(nullptr, data_reader); auto config = YAML::Load( @@ -148,7 +149,7 @@ TEST_F(DataReaderOmpTest, LineDataReaderSingleThread) { } TEST_F(DataReaderOmpTest, LineDataReaderMuiltThread) { - std::unique_ptr data_reader(CREATE_CLASS(DataReader, "LineDataReader")); + std::unique_ptr data_reader(CREATE_INSTANCE(DataReader, "LineDataReader")); ASSERT_NE(nullptr, data_reader); auto config = YAML::Load( diff --git a/paddle/fluid/train/custom_trainer/feed/unit_test/test_executor.cc b/paddle/fluid/train/custom_trainer/feed/unit_test/test_executor.cc index 5866b34ee05cfdd0ef0b55e5cb96c8263ddb8c94..385d9f95cb9c34b52e7ff568d9f29ac49fa60f56 100644 --- a/paddle/fluid/train/custom_trainer/feed/unit_test/test_executor.cc +++ b/paddle/fluid/train/custom_trainer/feed/unit_test/test_executor.cc @@ -16,6 +16,7 @@ limitations under the License. */ #include #include +#include "paddle/fluid/train/custom_trainer/feed/trainer_context.h" #include "paddle/fluid/train/custom_trainer/feed/executor/executor.h" #include "paddle/fluid/framework/tensor_util.h" #include "paddle/fluid/framework/program_desc.h" @@ -38,7 +39,7 @@ class SimpleExecutorTest : public testing::Test public: static void SetUpTestCase() { - std::unique_ptr fs(CREATE_CLASS(FileSystem, "LocalFileSystem")); + std::unique_ptr fs(CREATE_INSTANCE(FileSystem, "LocalFileSystem")); fs->mkdir(test_data_dir); shell_set_verbose(true); @@ -70,7 +71,7 @@ public: static void TearDownTestCase() { - std::unique_ptr fs(CREATE_CLASS(FileSystem, "LocalFileSystem")); + std::unique_ptr fs(CREATE_INSTANCE(FileSystem, "LocalFileSystem")); fs->remove(test_data_dir); } @@ -88,7 +89,7 @@ public: }; TEST_F(SimpleExecutorTest, initialize) { - std::unique_ptr executor(CREATE_CLASS(Executor, "SimpleExecutor")); + std::unique_ptr executor(CREATE_INSTANCE(Executor, "SimpleExecutor")); ASSERT_NE(nullptr, executor); YAML::Node config = YAML::Load("[1, 2, 3]"); ASSERT_NE(0, executor->initialize(config, context_ptr)); @@ -99,7 +100,7 @@ TEST_F(SimpleExecutorTest, initialize) { } TEST_F(SimpleExecutorTest, run) { - std::unique_ptr executor(CREATE_CLASS(Executor, "SimpleExecutor")); + std::unique_ptr executor(CREATE_INSTANCE(Executor, "SimpleExecutor")); ASSERT_NE(nullptr, executor); auto config = YAML::Load(string::format_string("{thread_num: 2, startup_program: %s, main_program: %s}", startup_program_path, main_program_path));