From f6bd8cfd3c79fa582decf396e059c128c5e2c88a Mon Sep 17 00:00:00 2001 From: xiexionghang Date: Fri, 16 Aug 2019 21:34:57 +0800 Subject: [PATCH] for mpi trainer --- BCLOUD | 1 + .../feed/accessor/epoch_accessor.cc | 2 +- .../feed/accessor/epoch_accessor.h | 2 +- .../feed/accessor/input_data_accessor.h | 31 +++++++ .../feed/accessor/sparse_input_accessor.cc | 51 ++++++++++++ .../feed/common/pslib_warpper.cc | 80 +++++++++++++++++++ .../feed/common/pslib_warpper.h | 49 ++++++++++++ .../custom_trainer/feed/common/registerer.cc | 6 +- .../custom_trainer/feed/common/registerer.h | 18 ++--- .../feed/common/runtime_environment.cc | 13 ++- .../feed/common/runtime_environment.h | 33 ++++---- .../feed/dataset/data_reader.cc | 10 +-- .../custom_trainer/feed/dataset/data_reader.h | 4 +- .../feed/dataset/dataset_container.cc | 2 +- .../custom_trainer/feed/executor/executor.cc | 2 +- .../custom_trainer/feed/executor/executor.h | 2 +- .../feed/io/auto_file_system.cc | 6 +- .../custom_trainer/feed/io/file_system.h | 2 +- .../feed/io/hadoop_file_system.cc | 2 +- .../feed/io/local_file_system.cc | 2 +- .../fluid/train/custom_trainer/feed/main.cc | 61 ++++++++++---- .../custom_trainer/feed/monitor/monitor.h | 2 +- .../feed/process/init_env_process.cc | 20 ++--- .../feed/process/learner_process.cc | 76 ++++++++++-------- .../custom_trainer/feed/process/process.cc | 4 +- .../custom_trainer/feed/process/process.h | 2 +- .../custom_trainer/feed/trainer_context.h | 2 + .../feed/unit_test/test_datareader.cc | 15 ++-- .../feed/unit_test/test_datareader_omp.cc | 11 +-- .../feed/unit_test/test_executor.cc | 9 ++- 30 files changed, 391 insertions(+), 129 deletions(-) create mode 100644 paddle/fluid/train/custom_trainer/feed/accessor/input_data_accessor.h create mode 100644 paddle/fluid/train/custom_trainer/feed/accessor/sparse_input_accessor.cc create mode 100644 paddle/fluid/train/custom_trainer/feed/common/pslib_warpper.cc create mode 100644 paddle/fluid/train/custom_trainer/feed/common/pslib_warpper.h diff --git a/BCLOUD b/BCLOUD index 151cbed3..e797fdf1 100644 --- a/BCLOUD +++ b/BCLOUD @@ -36,6 +36,7 @@ CONFIGS('baidu/third-party/pybind11@v2.2.4@git_branch') CONFIGS('baidu/third-party/python@gcc482output@git_branch') CONFIGS('baidu/third-party/yaml-cpp@yaml-cpp_0-6-2-0_GEN_PD_BL@git_tag') CONFIGS('baidu/third-party/openmpi@openmpi_1-4-5-0-feed_mlarch@git_branch') +CONFIGS('baidu/paddlepaddle/pslib@master@git_branch') CONFIGS('third-64/gtest@base') HEADERS('paddle/fluid/memory/*.h', '$INC/paddle/fluid/memory/') diff --git a/paddle/fluid/train/custom_trainer/feed/accessor/epoch_accessor.cc b/paddle/fluid/train/custom_trainer/feed/accessor/epoch_accessor.cc index 7e474eef..36e0b9fd 100644 --- a/paddle/fluid/train/custom_trainer/feed/accessor/epoch_accessor.cc +++ b/paddle/fluid/train/custom_trainer/feed/accessor/epoch_accessor.cc @@ -88,7 +88,7 @@ namespace feed { return ""; } - REGISTER_CLASS(EpochAccessor, HourlyEpochAccessor); + REGIST_CLASS(EpochAccessor, HourlyEpochAccessor); } // namespace feed } // namespace custom_trainer diff --git a/paddle/fluid/train/custom_trainer/feed/accessor/epoch_accessor.h b/paddle/fluid/train/custom_trainer/feed/accessor/epoch_accessor.h index 247a1fe3..069abff3 100644 --- a/paddle/fluid/train/custom_trainer/feed/accessor/epoch_accessor.h +++ b/paddle/fluid/train/custom_trainer/feed/accessor/epoch_accessor.h @@ -60,7 +60,7 @@ protected: std::vector _done_status; //当前完成状态,统一存成string }; -REGISTER_REGISTERER(EpochAccessor); +REGIST_REGISTERER(EpochAccessor); class HourlyEpochAccessor : public EpochAccessor { public: diff --git a/paddle/fluid/train/custom_trainer/feed/accessor/input_data_accessor.h b/paddle/fluid/train/custom_trainer/feed/accessor/input_data_accessor.h new file mode 100644 index 00000000..a91a5f46 --- /dev/null +++ b/paddle/fluid/train/custom_trainer/feed/accessor/input_data_accessor.h @@ -0,0 +1,31 @@ +#pragma once +#include "paddle/fluid/framework/scope.h" +#include "paddle/fluid/train/custom_trainer/feed/dataset/data_reader.h" +#include "paddle/fluid/train/custom_trainer/feed/accessor/accessor.h" + +namespace paddle { +namespace custom_trainer { +namespace feed { + +class DataInputAccessor : public Accessor { +public: + DataInputAccessor() {} + virtual ~DataInputAccessor() {} + + virtual int initialize(const YAML::Node& config, + std::shared_ptr context_ptr); + + // 前向, 一般用于填充输入,在训练网络执行前调用 + virtual int32_t forward(const SampleInstance* samples, + ::paddle::framework::Scope* scope, size_t table_id, size_t num) = 0; + + // 后向,一般用于更新梯度,在训练网络执行后调用 + virtual int32_t backward(const SampleInstance* samples, + ::paddle::framework::Scope* scope, size_t table_id, size_t num) = 0; +protected: +}; +REGIST_REGISTERER(DataInputAccessor); + +} // namespace feed +} // namespace custom_trainer +} // namespace paddle diff --git a/paddle/fluid/train/custom_trainer/feed/accessor/sparse_input_accessor.cc b/paddle/fluid/train/custom_trainer/feed/accessor/sparse_input_accessor.cc new file mode 100644 index 00000000..63b896b8 --- /dev/null +++ b/paddle/fluid/train/custom_trainer/feed/accessor/sparse_input_accessor.cc @@ -0,0 +1,51 @@ +#include +#include +#include "paddle/fluid/string/string_helper.h" +#include "paddle/fluid/train/custom_trainer/feed/accessor/input_data_accessor.h" + +namespace paddle { +namespace custom_trainer { +namespace feed { + +class CommonSparseInputAccessor : public DataInputAccessor { +public: + CommonSparseInputAccessor() {} + virtual ~CommonSparseInputAccessor() {} + virtual int initialize(const YAML::Node& config, + std::shared_ptr context_ptr) { + CHECK(config["sparse_input"] && config["sparse_input"].Type() == YAML::NodeType::Map); + for (auto& input : config["sparse_input"]) { + std::pair> sparse_slots; + sparse_slots.first = input.first.as(); + std::string slots_str = input.second["slots"].as(); + std::vector slots = paddle::string::split_string(slots_str, ","); + for (int i = 0; i < slots.size(); ++i) { + sparse_slots.second.push_back((uint16_t)atoi(slots[i].c_str())); + } + } + return 0; + } + + // 取sparse数据 + virtual int32_t forward(const SampleInstance* samples, + ::paddle::framework::Scope* scope, size_t table_id, size_t num) { + // pull + return 0; + } + + // 更新spare数据 + virtual int32_t backward(const SampleInstance* samples, + ::paddle::framework::Scope* scope, size_t table_id, size_t num) { + return 0; + } + +protected: + + // 输入层列表 + // + std::vector > > _x_variables; +}; + +} // namespace feed +} // namespace custom_trainer +} // namespace paddle diff --git a/paddle/fluid/train/custom_trainer/feed/common/pslib_warpper.cc b/paddle/fluid/train/custom_trainer/feed/common/pslib_warpper.cc new file mode 100644 index 00000000..7f8ee5e3 --- /dev/null +++ b/paddle/fluid/train/custom_trainer/feed/common/pslib_warpper.cc @@ -0,0 +1,80 @@ +#include +#include +#include +#include "json2pb/json_to_pb.h" +#include +#include +#include "paddle/fluid/train/custom_trainer/feed/common/pslib_warpper.h" +#include "paddle/fluid/train/custom_trainer/feed/common/runtime_environment.h" + +namespace paddle { +namespace custom_trainer { +namespace feed { + +int PSlib::initialize(const std::string& conf_path, + RuntimeEnvironment* environment, EnvironmentRole role) { + init_gflag(); + int file_descriptor = open(conf_path.c_str(), O_RDONLY); + if (file_descriptor == -1){ + LOG(ERROR) << "FATAL: cant open " << conf_path; + return -1; + } + google::protobuf::io::FileInputStream fileInput(file_descriptor); + if (!google::protobuf::TextFormat::Parse(&fileInput, &_ps_param)) { + LOG(ERROR) << "FATAL: fail to parse " << conf_path; + return -1; + } + close(file_descriptor); + init_server(role); + init_client(EnvironmentRole::ALL); + return 0; +} + +int PSlib::init_server(EnvironmentRole role) { + if (role == EnvironmentRole::PSERVER) { + _server_ptr.reset(paddle::ps::PSServerFactory::create(_ps_param)); + _server_ptr->configure(_ps_param, *(_environment->ps_environment()), + _environment->rank_id(role)); + _server_ptr->start(); + } + _environment->ps_environment()->gather_ps_servers(); + return 0; +} + +int PSlib::init_client(EnvironmentRole role) { + _client_ptr.reset(paddle::ps::PSClientFactory::create(_ps_param)); + _client_ptr->configure(_ps_param, *(_environment->ps_environment()), + _environment->rank_id(role)); + return 0; +} + +paddle::ps::PSServer* PSlib::ps_server() { + return _server_ptr.get(); +} + +paddle::ps::PSClient* PSlib::ps_client() { + return _client_ptr.get(); +} + +paddle::PSParameter* PSlib::get_param() { + return &_ps_param; +} + +void PSlib::init_gflag() { + int cnt = 4; + std::shared_ptr params(new char*[cnt]); + char** params_ptr = params.get(); + char p0[] = "exe default"; + char p1[] = "-max_body_size=314217728"; + char p2[] = "-bthread_concurrency=40"; + char p3[] = "-socket_max_unwritten_bytes=2048000000"; + params_ptr[0] = p0; + params_ptr[1] = p1; + params_ptr[2] = p2; + params_ptr[3] = p3; + ::google::ParseCommandLineFlags(&cnt, ¶ms_ptr, true); +} + +} // namespace feed +} // namespace custom_trainer +} // namespace paddle diff --git a/paddle/fluid/train/custom_trainer/feed/common/pslib_warpper.h b/paddle/fluid/train/custom_trainer/feed/common/pslib_warpper.h new file mode 100644 index 00000000..0180a8e2 --- /dev/null +++ b/paddle/fluid/train/custom_trainer/feed/common/pslib_warpper.h @@ -0,0 +1,49 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "communicate/ps_server.h" +#include "communicate/ps_client.h" + +namespace paddle { +namespace custom_trainer { +namespace feed { + +class RuntimeEnvironment; +enum class EnvironmentRole; +class PSlib { +public: + PSlib() {} + virtual ~PSlib() {} + int initialize(const std::string& conf_path, + RuntimeEnvironment* environment, EnvironmentRole role); + + virtual paddle::ps::PSServer* ps_server(); + virtual paddle::ps::PSClient* ps_client(); + virtual paddle::PSParameter* get_param(); +private: + void init_gflag(); + virtual int init_server(EnvironmentRole role); + virtual int init_client(EnvironmentRole role); + + paddle::PSParameter _ps_param; + RuntimeEnvironment* _environment; + std::shared_ptr _server_ptr; + std::shared_ptr _client_ptr; +}; + +} // namespace feed +} // namespace custom_trainer +} // namespace paddle diff --git a/paddle/fluid/train/custom_trainer/feed/common/registerer.cc b/paddle/fluid/train/custom_trainer/feed/common/registerer.cc index 04382b47..c2dff151 100644 --- a/paddle/fluid/train/custom_trainer/feed/common/registerer.cc +++ b/paddle/fluid/train/custom_trainer/feed/common/registerer.cc @@ -3,12 +3,12 @@ namespace paddle { namespace custom_trainer { namespace feed { -BaseClassMap& global_factory_map() { +BaseClassMap& global_reg_factory_map() { static BaseClassMap *base_class = new BaseClassMap(); return *base_class; } -BaseClassMap& global_factory_map_cpp() { - return global_factory_map(); +BaseClassMap& global_reg_factory_map_cpp() { + return global_reg_factory_map(); } }// feed diff --git a/paddle/fluid/train/custom_trainer/feed/common/registerer.h b/paddle/fluid/train/custom_trainer/feed/common/registerer.h index eb57cabe..b5399fdc 100644 --- a/paddle/fluid/train/custom_trainer/feed/common/registerer.h +++ b/paddle/fluid/train/custom_trainer/feed/common/registerer.h @@ -63,23 +63,23 @@ typedef std::map BaseClassMap; #ifdef __cplusplus extern "C" { #endif -BaseClassMap& global_factory_map(); +BaseClassMap& global_reg_factory_map(); #ifdef __cplusplus } #endif -BaseClassMap& global_factory_map_cpp(); +BaseClassMap& global_reg_factory_map_cpp(); -#define REGISTER_REGISTERER(base_class) \ +#define REGIST_REGISTERER(base_class) \ class base_class ## Registerer { \ public: \ static base_class *CreateInstanceByName(const ::std::string &name) { \ - if (global_factory_map_cpp().find(#base_class) \ - == global_factory_map_cpp().end()) { \ + if (global_reg_factory_map_cpp().find(#base_class) \ + == global_reg_factory_map_cpp().end()) { \ LOG(ERROR) << "Can't Find BaseClass For CreateClass with:" << #base_class; \ return NULL; \ } \ - FactoryMap &map = global_factory_map_cpp()[#base_class]; \ + FactoryMap &map = global_reg_factory_map_cpp()[#base_class]; \ FactoryMap::iterator iter = map.find(name); \ if (iter == map.end()) { \ LOG(ERROR) << "Can't Find Class For Create with:" << name; \ @@ -90,7 +90,7 @@ BaseClassMap& global_factory_map_cpp(); } \ }; -#define REGISTER_CLASS(clazz, name) \ +#define REGIST_CLASS(clazz, name) \ class ObjectFactory##name : public ObjectFactory { \ public: \ Any NewInstance() { \ @@ -98,14 +98,14 @@ BaseClassMap& global_factory_map_cpp(); } \ }; \ void register_factory_##name() { \ - FactoryMap &map = global_factory_map_cpp()[#clazz]; \ + FactoryMap &map = global_reg_factory_map_cpp()[#clazz]; \ if (map.find(#name) == map.end()) { \ map[#name] = new ObjectFactory##name(); \ } \ } \ void register_factory_##name() __attribute__((constructor)); -#define CREATE_CLASS(base_class, name) \ +#define CREATE_INSTANCE(base_class, name) \ base_class##Registerer::CreateInstanceByName(name) }//namespace feed diff --git a/paddle/fluid/train/custom_trainer/feed/common/runtime_environment.cc b/paddle/fluid/train/custom_trainer/feed/common/runtime_environment.cc index 83435853..66f59b9c 100644 --- a/paddle/fluid/train/custom_trainer/feed/common/runtime_environment.cc +++ b/paddle/fluid/train/custom_trainer/feed/common/runtime_environment.cc @@ -44,6 +44,11 @@ public: set_role(EnvironmentRole::ALL); return 0; } + + virtual paddle::ps::PSEnvironment* ps_environment() { + static paddle::ps::MpiPSEnvironment ps_environment; + return &ps_environment; + } virtual uint32_t rank_id(EnvironmentRole role) { return mpi_node_info(role).rank_id; @@ -95,7 +100,7 @@ protected: private: std::vector _roles_node_info; }; -REGISTER_CLASS(RuntimeEnvironment, MPIRuntimeEnvironment); +REGIST_CLASS(RuntimeEnvironment, MPIRuntimeEnvironment); //用于本地模式单机训练 class LocalRuntimeEnvironment : public RuntimeEnvironment { @@ -108,6 +113,10 @@ public: virtual int wireup() { return 0; } + virtual paddle::ps::PSEnvironment* ps_environment() { + static paddle::ps::LocalPSEnvironment ps_environment; + return &ps_environment; + } virtual uint32_t rank_id(EnvironmentRole role) { return 0; } @@ -129,7 +138,7 @@ protected: VLOG(static_cast(level)) << log_str; } }; -REGISTER_CLASS(RuntimeEnvironment, LocalRuntimeEnvironment); +REGIST_CLASS(RuntimeEnvironment, LocalRuntimeEnvironment); } // namespace feed } // namespace custom_trainer diff --git a/paddle/fluid/train/custom_trainer/feed/common/runtime_environment.h b/paddle/fluid/train/custom_trainer/feed/common/runtime_environment.h index d0277483..5e567c78 100644 --- a/paddle/fluid/train/custom_trainer/feed/common/runtime_environment.h +++ b/paddle/fluid/train/custom_trainer/feed/common/runtime_environment.h @@ -6,6 +6,7 @@ */ #pragma once #include +#include "communicate/ps_env.h" #include "paddle/fluid/framework/archive.h" #include "paddle/fluid/string/string_helper.h" #include "paddle/fluid/train/custom_trainer/feed/common/registerer.h" @@ -14,6 +15,8 @@ namespace paddle { namespace custom_trainer { namespace feed { +class paddle::ps::PSEnvironment; + enum class EnvironmentLogLevel { FATAL = 0, ERROR = 1, @@ -38,41 +41,43 @@ class RuntimeEnvironment { public: RuntimeEnvironment(); virtual ~RuntimeEnvironment(); - //配置初始化 + // 配置初始化 virtual int initialize(YAML::Node config) = 0; - //设置role + // 设置role virtual int set_role(EnvironmentRole role) = 0; - //环境初始化,会在所有依赖模块initialize后调用 + // 环境初始化,会在所有依赖模块initialize后调用 virtual int wireup() = 0; - //多线程可调用接口 Start - //当前环境rank_idx + // 多线程可调用接口 Start + // 当前环境rank_idx virtual uint32_t rank_id(EnvironmentRole role) = 0; - //运行环境节点数 + // 运行环境节点数 virtual uint32_t node_num(EnvironmentRole role) = 0; - //环境内主节点 + // 环境内主节点 virtual bool is_master_node(EnvironmentRole role); + //For PS + virtual paddle::ps::PSEnvironment* ps_environment() = 0; - //环境定制化log + // 环境定制化log template void log(EnvironmentRole role, EnvironmentLogType type, EnvironmentLogLevel level, const char* fmt, ARGS && ... args) { print_log(role, type, level, paddle::string::format_string(fmt, args...)); } - //多线程可调用接口 End + // 多线程可调用接口 End - //接口只允许在主线程调用 Start - //barrier 指定role的节点 + // 接口只允许在主线程调用 Start + // barrier 指定role的节点 virtual void barrier(EnvironmentRole role) = 0; - //bcast 广播 + // bcast 广播 virtual void bcast(paddle::framework::BinaryArchive& ar, int root_id, EnvironmentRole role) = 0; - //接口只允许在主线程调用 End + // 接口只允许在主线程调用 End protected: virtual void print_log(EnvironmentRole role, EnvironmentLogType type, EnvironmentLogLevel level, const std::string& log_str) = 0; }; -REGISTER_REGISTERER(RuntimeEnvironment); +REGIST_REGISTERER(RuntimeEnvironment); std::string format_timestamp(time_t time, const char* format); inline std::string format_timestamp(time_t time, const std::string& format) { diff --git a/paddle/fluid/train/custom_trainer/feed/dataset/data_reader.cc b/paddle/fluid/train/custom_trainer/feed/dataset/data_reader.cc index 31ed1870..7e036770 100644 --- a/paddle/fluid/train/custom_trainer/feed/dataset/data_reader.cc +++ b/paddle/fluid/train/custom_trainer/feed/dataset/data_reader.cc @@ -56,10 +56,10 @@ public: return 0; } }; -REGISTER_CLASS(DataParser, LineDataParser); +REGIST_CLASS(DataParser, LineDataParser); int DataReader::initialize(const YAML::Node& config, std::shared_ptr context) { - _parser.reset(CREATE_CLASS(DataParser, config["parser"]["class"].as())); + _parser.reset(CREATE_INSTANCE(DataParser, config["parser"]["class"].as())); if (_parser == nullptr) { VLOG(2) << "fail to get parser: " << config["parser"]["class"].as(); return -1; @@ -85,7 +85,7 @@ public: if (config["file_system"] && config["file_system"]["class"]) { _file_system.reset( - CREATE_CLASS(FileSystem, config["file_system"]["class"].as())); + CREATE_INSTANCE(FileSystem, config["file_system"]["class"].as())); if (_file_system == nullptr || _file_system->initialize(config["file_system"], context) != 0) { VLOG(2) << "fail to create class: " @@ -95,7 +95,7 @@ public: } else if (context->file_system != nullptr) { _file_system = context->file_system; } else { - _file_system.reset(CREATE_CLASS(FileSystem, "LocalFileSystem")); + _file_system.reset(CREATE_INSTANCE(FileSystem, "LocalFileSystem")); if (_file_system == nullptr || _file_system->initialize(YAML::Load(""), context) != 0) { VLOG(2) << "fail to init file system"; return -1; @@ -203,7 +203,7 @@ private: std::string _filename_prefix; std::shared_ptr _file_system; }; -REGISTER_CLASS(DataReader, LineDataReader); +REGIST_CLASS(DataReader, LineDataReader); } // namespace feed } // namespace custom_trainer diff --git a/paddle/fluid/train/custom_trainer/feed/dataset/data_reader.h b/paddle/fluid/train/custom_trainer/feed/dataset/data_reader.h index c24db38a..7b6d1c3f 100644 --- a/paddle/fluid/train/custom_trainer/feed/dataset/data_reader.h +++ b/paddle/fluid/train/custom_trainer/feed/dataset/data_reader.h @@ -54,7 +54,7 @@ public: virtual int parse(const char* str, DataItem& data) const = 0; virtual int parse_to_sample(const DataItem& data, SampleInstance& instance) const = 0; }; -REGISTER_REGISTERER(DataParser); +REGIST_REGISTERER(DataParser); class DataReader { public: @@ -76,7 +76,7 @@ protected: std::shared_ptr _parser;//数据格式转换 std::string _pipeline_cmd; //将文件流,重定向到pipeline_cmd,再读入 }; -REGISTER_REGISTERER(DataReader); +REGIST_REGISTERER(DataReader); }//namespace feed }//namespace custom_trainer diff --git a/paddle/fluid/train/custom_trainer/feed/dataset/dataset_container.cc b/paddle/fluid/train/custom_trainer/feed/dataset/dataset_container.cc index 2f43e735..5fcfb9d6 100644 --- a/paddle/fluid/train/custom_trainer/feed/dataset/dataset_container.cc +++ b/paddle/fluid/train/custom_trainer/feed/dataset/dataset_container.cc @@ -32,7 +32,7 @@ int DatasetContainer::initialize( _data_split_interval = config["data_spit_interval"].as(); _data_path_formater = config["data_path_formater"].as(); std::string data_reader_class = config["data_reader"].as(); - DataReader* data_reader = CREATE_CLASS(DataReader, data_reader_class); + DataReader* data_reader = CREATE_INSTANCE(DataReader, data_reader_class); _data_reader.reset(data_reader); return _data_reader->initialize(config, context); } diff --git a/paddle/fluid/train/custom_trainer/feed/executor/executor.cc b/paddle/fluid/train/custom_trainer/feed/executor/executor.cc index a7d5bdf2..73f8a601 100644 --- a/paddle/fluid/train/custom_trainer/feed/executor/executor.cc +++ b/paddle/fluid/train/custom_trainer/feed/executor/executor.cc @@ -121,7 +121,7 @@ protected: std::unique_ptr _context; }; -REGISTER_CLASS(Executor, SimpleExecutor); +REGIST_CLASS(Executor, SimpleExecutor); } // namespace feed } // namespace custom_trainer diff --git a/paddle/fluid/train/custom_trainer/feed/executor/executor.h b/paddle/fluid/train/custom_trainer/feed/executor/executor.h index f51f3f11..3ae40817 100644 --- a/paddle/fluid/train/custom_trainer/feed/executor/executor.h +++ b/paddle/fluid/train/custom_trainer/feed/executor/executor.h @@ -40,7 +40,7 @@ public: protected: ::paddle::framework::Scope _scope; }; -REGISTER_REGISTERER(Executor); +REGIST_REGISTERER(Executor); } // namespace feed } // namespace custom_trainer diff --git a/paddle/fluid/train/custom_trainer/feed/io/auto_file_system.cc b/paddle/fluid/train/custom_trainer/feed/io/auto_file_system.cc index 16bbfed5..a588b4d1 100644 --- a/paddle/fluid/train/custom_trainer/feed/io/auto_file_system.cc +++ b/paddle/fluid/train/custom_trainer/feed/io/auto_file_system.cc @@ -31,7 +31,7 @@ public: _file_system.clear(); if (config && config["file_systems"] && config["file_systems"].Type() == YAML::NodeType::Map) { for (auto& prefix_fs: config["file_systems"]) { - std::unique_ptr fs(CREATE_CLASS(FileSystem, prefix_fs.second["class"].as(""))); + std::unique_ptr fs(CREATE_INSTANCE(FileSystem, prefix_fs.second["class"].as(""))); if (fs == nullptr) { VLOG(2) << "fail to create class: " << prefix_fs.second["class"].as(""); return -1; @@ -44,7 +44,7 @@ public: } } if (_file_system.find("default") == _file_system.end()) { - std::unique_ptr fs(CREATE_CLASS(FileSystem, "LocalFileSystem")); + std::unique_ptr fs(CREATE_INSTANCE(FileSystem, "LocalFileSystem")); if (fs == nullptr || fs->initialize(YAML::Load(""), context) != 0) { return -1; } @@ -122,7 +122,7 @@ public: private: std::unordered_map> _file_system; }; -REGISTER_CLASS(FileSystem, AutoFileSystem); +REGIST_CLASS(FileSystem, AutoFileSystem); } // namespace feed } // namespace custom_trainer diff --git a/paddle/fluid/train/custom_trainer/feed/io/file_system.h b/paddle/fluid/train/custom_trainer/feed/io/file_system.h index d7aa1cc2..4a157d69 100644 --- a/paddle/fluid/train/custom_trainer/feed/io/file_system.h +++ b/paddle/fluid/train/custom_trainer/feed/io/file_system.h @@ -52,7 +52,7 @@ public: protected: int _err_no = 0; }; -REGISTER_REGISTERER(FileSystem); +REGIST_REGISTERER(FileSystem); } // namespace feed } // namespace custom_trainer diff --git a/paddle/fluid/train/custom_trainer/feed/io/hadoop_file_system.cc b/paddle/fluid/train/custom_trainer/feed/io/hadoop_file_system.cc index d05affc5..d61e1fe1 100644 --- a/paddle/fluid/train/custom_trainer/feed/io/hadoop_file_system.cc +++ b/paddle/fluid/train/custom_trainer/feed/io/hadoop_file_system.cc @@ -203,7 +203,7 @@ private: std::string _hdfs_command; std::unordered_map _ugi; }; -REGISTER_CLASS(FileSystem, HadoopFileSystem); +REGIST_CLASS(FileSystem, HadoopFileSystem); } // namespace feed } // namespace custom_trainer diff --git a/paddle/fluid/train/custom_trainer/feed/io/local_file_system.cc b/paddle/fluid/train/custom_trainer/feed/io/local_file_system.cc index 287d3e0a..78cd8357 100644 --- a/paddle/fluid/train/custom_trainer/feed/io/local_file_system.cc +++ b/paddle/fluid/train/custom_trainer/feed/io/local_file_system.cc @@ -129,7 +129,7 @@ public: private: size_t _buffer_size = 0; }; -REGISTER_CLASS(FileSystem, LocalFileSystem); +REGIST_CLASS(FileSystem, LocalFileSystem); } // namespace feed } // namespace custom_trainer diff --git a/paddle/fluid/train/custom_trainer/feed/main.cc b/paddle/fluid/train/custom_trainer/feed/main.cc index ea3140c6..f1498521 100644 --- a/paddle/fluid/train/custom_trainer/feed/main.cc +++ b/paddle/fluid/train/custom_trainer/feed/main.cc @@ -1,8 +1,8 @@ #include #include #include -#include "paddle/fluid/platform/init.h" #include "paddle/fluid/train/custom_trainer/feed/trainer_context.h" +#include "paddle/fluid/platform/init.h" #include "paddle/fluid/train/custom_trainer/feed/process/process.h" #include "paddle/fluid/train/custom_trainer/feed/process/init_env_process.h" #include "paddle/fluid/framework/op_registry.h" @@ -21,27 +21,56 @@ int main(int argc, char* argv[]) { //load trainer config auto trainer_context_ptr = std::make_shared(); trainer_context_ptr->trainer_config = YAML::LoadFile(FLAGS_feed_trainer_conf_path); - + + //environment + auto& config = trainer_context_ptr->trainer_config; + std::string env_class = config["environment"]["environment_class"].as(); + trainer_context_ptr->environment.reset(CREATE_INSTANCE(RuntimeEnvironment, env_class)); + if (trainer_context_ptr->environment->initialize(config["environment"]) != 0) { + return -1; + } + EnvironmentRole role; + auto* environment = trainer_context_ptr->environment.get(); + environment->wireup(); + if (environment->rank_id(EnvironmentRole::ALL) % 2 == 0) { + role = EnvironmentRole::WORKER; + } else { + role = EnvironmentRole::PSERVER; + } + environment->set_role(role); + trainer_context_ptr->pslib.reset(new PSlib()); + std::string ps_config = config["environment"]["ps"].as(); + trainer_context_ptr->pslib->initialize(ps_config, environment, role); + //VLOG(3) << "Node Start With Role:" << role; + std::vector process_name_list = { "InitEnvProcess", "LearnerProcess" }; - - for (const auto& process_name : process_name_list) { - Process* process = CREATE_CLASS(Process, process_name); - if (process == NULL) { - VLOG(1) << "Process:" << process_name << " does not exist"; - return -1; + switch (role) { + case EnvironmentRole::WORKER: + for (const auto& process_name : process_name_list) { + Process* process = CREATE_INSTANCE(Process, process_name); + if (process == NULL) { + VLOG(1) << "Process:" << process_name << " does not exist"; + return -1; + } + if (process->initialize(trainer_context_ptr) != 0) { + VLOG(1) << "Process:" << process_name << " initialize failed"; + return -1; + } + trainer_context_ptr->process_list.push_back(std::shared_ptr(process)); + } + for (auto& process : trainer_context_ptr->process_list) { + process->run(); } - if (process->initialize(trainer_context_ptr) != 0) { - VLOG(1) << "Process:" << process_name << " initialize failed"; - return -1; + break; + case EnvironmentRole::PSERVER: + //wait server done + while (true) { + sleep(10000); } - trainer_context_ptr->process_list.push_back(std::shared_ptr(process)); - } - - for (auto& process : trainer_context_ptr->process_list) { - process->run(); + break; } return 0; diff --git a/paddle/fluid/train/custom_trainer/feed/monitor/monitor.h b/paddle/fluid/train/custom_trainer/feed/monitor/monitor.h index a7b9186c..0bffde6c 100644 --- a/paddle/fluid/train/custom_trainer/feed/monitor/monitor.h +++ b/paddle/fluid/train/custom_trainer/feed/monitor/monitor.h @@ -39,7 +39,7 @@ protected: std::string _name; }; -REGISTER_REGISTERER(Monitor); +REGIST_REGISTERER(Monitor); } // namespace feed } // namespace custom_trainer diff --git a/paddle/fluid/train/custom_trainer/feed/process/init_env_process.cc b/paddle/fluid/train/custom_trainer/feed/process/init_env_process.cc index 5c6a8cf7..a45320a6 100644 --- a/paddle/fluid/train/custom_trainer/feed/process/init_env_process.cc +++ b/paddle/fluid/train/custom_trainer/feed/process/init_env_process.cc @@ -20,22 +20,16 @@ int InitEnvProcess::initialize(std::shared_ptr context_ptr) { context_ptr->cpu_place = paddle::platform::CPUPlace(); YAML::Node config = _context_ptr->trainer_config; - //environment - std::string env_class = config["environment"]["environment_class"].as(); - context_ptr->environment.reset(CREATE_CLASS(RuntimeEnvironment, env_class)); - if (context_ptr->environment->initialize(config["environment"]) != 0) { - return -1; - } //file_system - context_ptr->file_system.reset(CREATE_CLASS(FileSystem, "AutoFileSystem")); + context_ptr->file_system.reset(CREATE_INSTANCE(FileSystem, "AutoFileSystem")); if (context_ptr->file_system->initialize(config["io"], context_ptr) != 0) { return -1; } //epoch std::string epoch_class = config["epoch"]["epoch_class"].as(); - context_ptr->epoch_accessor.reset(CREATE_CLASS(EpochAccessor, epoch_class)); + context_ptr->epoch_accessor.reset(CREATE_INSTANCE(EpochAccessor, epoch_class)); if (context_ptr->epoch_accessor->initialize(config["epoch"], context_ptr) != 0) { return -1; } @@ -55,10 +49,12 @@ int InitEnvProcess::run() { VLOG(3) << "Trainer Resume From epoch:" << epoch_accessor->current_epoch_id(); auto next_epoch_id = epoch_accessor->next_epoch_id(epoch_accessor->current_epoch_id()); _context_ptr->dataset->pre_detect_data(next_epoch_id); - //step 1. psserver init - //step2. psserver load - VLOG(3) << "Psserver Start Success"; - + + if (epoch_accessor->checkpoint_path().size() > 0) { + //Load Model + } else { + //Random Init Model + } //context_ptr->pslib_client()->load_model(); VLOG(3) << "Psserver Load Model Success"; return 0; diff --git a/paddle/fluid/train/custom_trainer/feed/process/learner_process.cc b/paddle/fluid/train/custom_trainer/feed/process/learner_process.cc index 146e0277..2b11f61b 100644 --- a/paddle/fluid/train/custom_trainer/feed/process/learner_process.cc +++ b/paddle/fluid/train/custom_trainer/feed/process/learner_process.cc @@ -25,7 +25,7 @@ int LearnerProcess::initialize(std::shared_ptr context_ptr) { _threads_executor[i].resize(_executor_num); for (int e = 0; e < _executor_num; ++e) { auto e_class = config["executor"][e]["class"].as(); - auto* e_ptr = CREATE_CLASS(Executor, e_class); + auto* e_ptr = CREATE_INSTANCE(Executor, e_class); _threads_executor[i][e].reset(e_ptr); if (e_ptr->initialize(config["executor"][e], context_ptr) != 0) { ret = -1; @@ -84,53 +84,59 @@ int LearnerProcess::run() { while (true) { epoch_accessor->next_epoch(); + bool already_dump_inference_model = false; epoch_id = epoch_accessor->current_epoch_id(); std::string epoch_log_title= paddle::string::format_string( "train epoch_id:%d label:%s", epoch_id, epoch_accessor->text(epoch_id).c_str()); //Step1. 等待样本ready - environment->log(EnvironmentRole::WORKER, EnvironmentLogType::MASTER_LOG, EnvironmentLogLevel::NOTICE, - "Start %s, wait data ready", epoch_log_title.c_str()); - while (dataset->epoch_data_status(epoch_id) != DatasetStatus::Ready) { - sleep(30); - dataset->pre_detect_data(epoch_id); + { environment->log(EnvironmentRole::WORKER, EnvironmentLogType::MASTER_LOG, EnvironmentLogLevel::NOTICE, - "%s, data not ready, wait 30s", epoch_log_title.c_str()); - } - environment->log(EnvironmentRole::WORKER, EnvironmentLogType::MASTER_LOG, EnvironmentLogLevel::NOTICE, - "%s, data is ready, start traning", epoch_log_title.c_str()); - environment->barrier(EnvironmentRole::WORKER); - + "Start %s, wait data ready", epoch_log_title.c_str()); + while (dataset->epoch_data_status(epoch_id) != DatasetStatus::Ready) { + sleep(30); + dataset->pre_detect_data(epoch_id); + environment->log(EnvironmentRole::WORKER, EnvironmentLogType::MASTER_LOG, EnvironmentLogLevel::NOTICE, + "%s, data not ready, wait 30s", epoch_log_title.c_str()); + } + environment->log(EnvironmentRole::WORKER, EnvironmentLogType::MASTER_LOG, EnvironmentLogLevel::NOTICE, + "%s, data is ready, start traning", epoch_log_title.c_str()); + environment->barrier(EnvironmentRole::WORKER); + } + //Step2. 运行训练网络 - bool already_dump_inference_model = false; - for (int i = 0; i < _executor_num; ++i) { - std::vector> train_threads(_train_thread_num); - for (int thread_id = 0; thread_id < _train_thread_num; ++thread_id) { - train_threads[i].reset(new std::thread([this](int exe_idx, int thread_idx) { - auto* executor = _threads_executor[thread_idx][exe_idx].get(); - run_executor(executor); - }, i, thread_id)); - } - for (int i = 0; i < _train_thread_num; ++i) { - train_threads[i]->join(); + { + for (int i = 0; i < _executor_num; ++i) { + std::vector> train_threads(_train_thread_num); + for (int thread_id = 0; thread_id < _train_thread_num; ++thread_id) { + train_threads[i].reset(new std::thread([this](int exe_idx, int thread_idx) { + auto* executor = _threads_executor[thread_idx][exe_idx].get(); + run_executor(executor); + }, i, thread_id)); + } + for (int i = 0; i < _train_thread_num; ++i) { + train_threads[i]->join(); + } + environment->barrier(EnvironmentRole::WORKER); + + if (_threads_executor[0][i]->is_dump_all_model()) { + already_dump_inference_model = true; + wait_save_model(epoch_id, ModelSaveWay::ModelSaveInferenceDelta); + } + environment->barrier(EnvironmentRole::WORKER); } - environment->barrier(EnvironmentRole::WORKER); + } - if (_threads_executor[0][i]->is_dump_all_model()) { + //Step3. Dump Model For Delta&&Checkpoint + { + if (!already_dump_inference_model) { already_dump_inference_model = true; wait_save_model(epoch_id, ModelSaveWay::ModelSaveInferenceDelta); - } + } + wait_save_model(epoch_id, ModelSaveWay::ModelSaveTrainCheckpoint); environment->barrier(EnvironmentRole::WORKER); } - - //Step3. Dump Model For Delta&&Checkpoint - if (!already_dump_inference_model) { - already_dump_inference_model = true; - wait_save_model(epoch_id, ModelSaveWay::ModelSaveInferenceDelta); - } - wait_save_model(epoch_id, ModelSaveWay::ModelSaveTrainCheckpoint); - environment->barrier(EnvironmentRole::WORKER); - + //Step4. Output Monitor && RunStatus //TODO } diff --git a/paddle/fluid/train/custom_trainer/feed/process/process.cc b/paddle/fluid/train/custom_trainer/feed/process/process.cc index 5226c8c5..0e1cd5fc 100644 --- a/paddle/fluid/train/custom_trainer/feed/process/process.cc +++ b/paddle/fluid/train/custom_trainer/feed/process/process.cc @@ -5,8 +5,8 @@ namespace paddle { namespace custom_trainer { namespace feed { -REGISTER_CLASS(Process, InitEnvProcess); -REGISTER_CLASS(Process, LearnerProcess); +REGIST_CLASS(Process, InitEnvProcess); +REGIST_CLASS(Process, LearnerProcess); int Process::run() { return 0; } diff --git a/paddle/fluid/train/custom_trainer/feed/process/process.h b/paddle/fluid/train/custom_trainer/feed/process/process.h index 2e83e63c..127481e9 100644 --- a/paddle/fluid/train/custom_trainer/feed/process/process.h +++ b/paddle/fluid/train/custom_trainer/feed/process/process.h @@ -18,7 +18,7 @@ public: protected: TrainerContext* _context_ptr = NULL; }; -REGISTER_REGISTERER(Process); +REGIST_REGISTERER(Process); } // namespace feed } // namespace custom_trainer diff --git a/paddle/fluid/train/custom_trainer/feed/trainer_context.h b/paddle/fluid/train/custom_trainer/feed/trainer_context.h index 01b5e0f8..212f32d3 100644 --- a/paddle/fluid/train/custom_trainer/feed/trainer_context.h +++ b/paddle/fluid/train/custom_trainer/feed/trainer_context.h @@ -4,6 +4,7 @@ #include #include #include "paddle/fluid/platform/place.h" +#include "paddle/fluid/train/custom_trainer/feed/common/pslib_warpper.h" #include "paddle/fluid/train/custom_trainer/feed/common/runtime_environment.h" @@ -38,6 +39,7 @@ public: YAML::Node trainer_config; paddle::platform::CPUPlace cpu_place; +std::shared_ptr pslib; std::shared_ptr dataset; //训练样本 std::shared_ptr file_system; //文件操作辅助类 std::vector params_table_list; //参数表 diff --git a/paddle/fluid/train/custom_trainer/feed/unit_test/test_datareader.cc b/paddle/fluid/train/custom_trainer/feed/unit_test/test_datareader.cc index 5fed50cb..0e183d93 100644 --- a/paddle/fluid/train/custom_trainer/feed/unit_test/test_datareader.cc +++ b/paddle/fluid/train/custom_trainer/feed/unit_test/test_datareader.cc @@ -17,6 +17,7 @@ limitations under the License. */ #include #include +#include "paddle/fluid/train/custom_trainer/feed/trainer_context.h" #include "paddle/fluid/train/custom_trainer/feed/executor/executor.h" #include "paddle/fluid/framework/tensor_util.h" #include "paddle/fluid/framework/program_desc.h" @@ -36,7 +37,7 @@ const char test_data_dir[] = "test_data"; class DataReaderTest : public testing::Test { public: static void SetUpTestCase() { - std::unique_ptr fs(CREATE_CLASS(FileSystem, "LocalFileSystem")); + std::unique_ptr fs(CREATE_INSTANCE(FileSystem, "LocalFileSystem")); fs->mkdir(test_data_dir); shell_set_verbose(true); @@ -56,14 +57,14 @@ public: } static void TearDownTestCase() { - std::unique_ptr fs(CREATE_CLASS(FileSystem, "LocalFileSystem")); + std::unique_ptr fs(CREATE_INSTANCE(FileSystem, "LocalFileSystem")); fs->remove(test_data_dir); } virtual void SetUp() { thread_num = omp_get_max_threads(); omp_set_num_threads(1); - fs.reset(CREATE_CLASS(FileSystem, "LocalFileSystem")); + fs.reset(CREATE_INSTANCE(FileSystem, "LocalFileSystem")); context_ptr.reset(new TrainerContext()); } @@ -79,7 +80,7 @@ public: }; TEST_F(DataReaderTest, LineDataParser) { - std::unique_ptr data_parser(CREATE_CLASS(DataParser, "LineDataParser")); + std::unique_ptr data_parser(CREATE_INSTANCE(DataParser, "LineDataParser")); ASSERT_NE(nullptr, data_parser); auto config = YAML::Load(""); @@ -108,7 +109,7 @@ TEST_F(DataReaderTest, LineDataParser) { } TEST_F(DataReaderTest, LineDataReader) { - std::unique_ptr data_reader(CREATE_CLASS(DataReader, "LineDataReader")); + std::unique_ptr data_reader(CREATE_INSTANCE(DataReader, "LineDataReader")); ASSERT_NE(nullptr, data_reader); auto config = YAML::Load( @@ -161,7 +162,7 @@ TEST_F(DataReaderTest, LineDataReader) { } TEST_F(DataReaderTest, LineDataReader_filename_prefix) { - std::unique_ptr data_reader(CREATE_CLASS(DataReader, "LineDataReader")); + std::unique_ptr data_reader(CREATE_INSTANCE(DataReader, "LineDataReader")); ASSERT_NE(nullptr, data_reader); auto config = YAML::Load( "parser:\n" @@ -196,7 +197,7 @@ TEST_F(DataReaderTest, LineDataReader_filename_prefix) { } TEST_F(DataReaderTest, LineDataReader_FileSystem) { - std::unique_ptr data_reader(CREATE_CLASS(DataReader, "LineDataReader")); + std::unique_ptr data_reader(CREATE_INSTANCE(DataReader, "LineDataReader")); ASSERT_NE(nullptr, data_reader); auto config = YAML::Load( "parser:\n" diff --git a/paddle/fluid/train/custom_trainer/feed/unit_test/test_datareader_omp.cc b/paddle/fluid/train/custom_trainer/feed/unit_test/test_datareader_omp.cc index 59e6e2cf..b181f7b7 100644 --- a/paddle/fluid/train/custom_trainer/feed/unit_test/test_datareader_omp.cc +++ b/paddle/fluid/train/custom_trainer/feed/unit_test/test_datareader_omp.cc @@ -18,6 +18,7 @@ limitations under the License. */ #include #include +#include "paddle/fluid/train/custom_trainer/feed/trainer_context.h" #include "paddle/fluid/train/custom_trainer/feed/executor/executor.h" #include "paddle/fluid/framework/tensor_util.h" #include "paddle/fluid/framework/program_desc.h" @@ -37,7 +38,7 @@ const char test_data_dir[] = "test_data"; class DataReaderOmpTest : public testing::Test { public: static void SetUpTestCase() { - std::unique_ptr fs(CREATE_CLASS(FileSystem, "LocalFileSystem")); + std::unique_ptr fs(CREATE_INSTANCE(FileSystem, "LocalFileSystem")); fs->mkdir(test_data_dir); shell_set_verbose(true); std_items.clear(); @@ -61,14 +62,14 @@ public: } static void TearDownTestCase() { - std::unique_ptr fs(CREATE_CLASS(FileSystem, "LocalFileSystem")); + std::unique_ptr fs(CREATE_INSTANCE(FileSystem, "LocalFileSystem")); fs->remove(test_data_dir); } virtual void SetUp() { thread_num = omp_get_max_threads(); omp_set_num_threads(1); - fs.reset(CREATE_CLASS(FileSystem, "LocalFileSystem")); + fs.reset(CREATE_INSTANCE(FileSystem, "LocalFileSystem")); context_ptr.reset(new TrainerContext()); } @@ -111,7 +112,7 @@ std::vector DataReaderOmpTest::std_items; std::vector DataReaderOmpTest::sorted_std_items; TEST_F(DataReaderOmpTest, LineDataReaderSingleThread) { - std::unique_ptr data_reader(CREATE_CLASS(DataReader, "LineDataReader")); + std::unique_ptr data_reader(CREATE_INSTANCE(DataReader, "LineDataReader")); ASSERT_NE(nullptr, data_reader); auto config = YAML::Load( @@ -148,7 +149,7 @@ TEST_F(DataReaderOmpTest, LineDataReaderSingleThread) { } TEST_F(DataReaderOmpTest, LineDataReaderMuiltThread) { - std::unique_ptr data_reader(CREATE_CLASS(DataReader, "LineDataReader")); + std::unique_ptr data_reader(CREATE_INSTANCE(DataReader, "LineDataReader")); ASSERT_NE(nullptr, data_reader); auto config = YAML::Load( diff --git a/paddle/fluid/train/custom_trainer/feed/unit_test/test_executor.cc b/paddle/fluid/train/custom_trainer/feed/unit_test/test_executor.cc index 5866b34e..385d9f95 100644 --- a/paddle/fluid/train/custom_trainer/feed/unit_test/test_executor.cc +++ b/paddle/fluid/train/custom_trainer/feed/unit_test/test_executor.cc @@ -16,6 +16,7 @@ limitations under the License. */ #include #include +#include "paddle/fluid/train/custom_trainer/feed/trainer_context.h" #include "paddle/fluid/train/custom_trainer/feed/executor/executor.h" #include "paddle/fluid/framework/tensor_util.h" #include "paddle/fluid/framework/program_desc.h" @@ -38,7 +39,7 @@ class SimpleExecutorTest : public testing::Test public: static void SetUpTestCase() { - std::unique_ptr fs(CREATE_CLASS(FileSystem, "LocalFileSystem")); + std::unique_ptr fs(CREATE_INSTANCE(FileSystem, "LocalFileSystem")); fs->mkdir(test_data_dir); shell_set_verbose(true); @@ -70,7 +71,7 @@ public: static void TearDownTestCase() { - std::unique_ptr fs(CREATE_CLASS(FileSystem, "LocalFileSystem")); + std::unique_ptr fs(CREATE_INSTANCE(FileSystem, "LocalFileSystem")); fs->remove(test_data_dir); } @@ -88,7 +89,7 @@ public: }; TEST_F(SimpleExecutorTest, initialize) { - std::unique_ptr executor(CREATE_CLASS(Executor, "SimpleExecutor")); + std::unique_ptr executor(CREATE_INSTANCE(Executor, "SimpleExecutor")); ASSERT_NE(nullptr, executor); YAML::Node config = YAML::Load("[1, 2, 3]"); ASSERT_NE(0, executor->initialize(config, context_ptr)); @@ -99,7 +100,7 @@ TEST_F(SimpleExecutorTest, initialize) { } TEST_F(SimpleExecutorTest, run) { - std::unique_ptr executor(CREATE_CLASS(Executor, "SimpleExecutor")); + std::unique_ptr executor(CREATE_INSTANCE(Executor, "SimpleExecutor")); ASSERT_NE(nullptr, executor); auto config = YAML::Load(string::format_string("{thread_num: 2, startup_program: %s, main_program: %s}", startup_program_path, main_program_path)); -- GitLab