From 24a80011425a30f29f86dbeffe153e84031aa0fe Mon Sep 17 00:00:00 2001 From: dongdaxiang Date: Mon, 28 Jan 2019 17:46:39 +0800 Subject: [PATCH] make -DWITH_PSLIB=ON compilable --- paddle/fluid/framework/CMakeLists.txt | 55 ++++++++++++--------- paddle/fluid/framework/async_executor.cc | 48 ++---------------- paddle/fluid/framework/async_executor.h | 7 +-- paddle/fluid/framework/device_worker.cc | 27 ++++++++++ paddle/fluid/framework/fleet/CMakeLists.txt | 1 + paddle/fluid/framework/trainer.cc | 25 ++++++++++ 6 files changed, 89 insertions(+), 74 deletions(-) create mode 100644 paddle/fluid/framework/device_worker.cc create mode 100644 paddle/fluid/framework/fleet/CMakeLists.txt create mode 100644 paddle/fluid/framework/trainer.cc diff --git a/paddle/fluid/framework/CMakeLists.txt b/paddle/fluid/framework/CMakeLists.txt index 4d54754cec..11cf91f35a 100644 --- a/paddle/fluid/framework/CMakeLists.txt +++ b/paddle/fluid/framework/CMakeLists.txt @@ -1,3 +1,4 @@ + #windows treat symbolic file as a real file, which is different with unix #We create a hidden file and compile it instead of origin source file. function(windows_symbolic TARGET) @@ -22,9 +23,11 @@ endfunction() add_subdirectory(ir) add_subdirectory(details) +add_subdirectory(fleet) #ddim lib proto_library(framework_proto SRCS framework.proto) proto_library(async_executor_proto SRCS data_feed.proto) +proto_library(trainer_desc_proto SRCS trainer_desc.proto) cc_library(ddim SRCS ddim.cc DEPS eigen3 boost enforce) cc_test(ddim_test SRCS ddim_test.cc DEPS ddim) @@ -129,9 +132,16 @@ cc_test(version_test SRCS version_test.cc DEPS version) cc_library(proto_desc SRCS var_desc.cc op_desc.cc block_desc.cc program_desc.cc DEPS shape_inference op_info operator glog version) cc_library(op_registry SRCS op_registry.cc DEPS op_proto_maker op_info operator glog proto_desc memory_optimize_helper) +if(WITH_NGRAPH) + cc_library(ngraph_bridge SRCS ngraph_bridge.cc DEPS operator framework_proto ngraph) + cc_library(ngraph_operator SRCS ngraph_operator.cc DEPS ngraph_bridge operator op_info device_context tensor scope glog + shape_inference data_transform lod_tensor profiler) +endif(WITH_NGRAPH) + nv_test(op_registry_test SRCS op_registry_test.cc DEPS op_registry) py_proto_compile(framework_py_proto SRCS framework.proto data_feed.proto) +py_proto_compile(trainer_py_proto SRCS trainer_desc.proto data_feed.proto) #Generate an empty \ #__init__.py to make framework_py_proto as a valid python module. add_custom_target(framework_py_proto_init ALL COMMAND ${CMAKE_COMMAND} -E touch __init__.py) @@ -172,7 +182,11 @@ if(WITH_DISTRIBUTE) set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor") set_source_files_properties(executor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) else() - cc_library(executor SRCS executor.cc DEPS op_registry device_context scope framework_proto glog lod_rank_table feed_fetch_method graph_to_program_pass variable_helper ${NGRAPH_EXE_DEPS}) + if(WITH_NGRAPH) + cc_library(executor SRCS executor.cc DEPS op_registry device_context scope framework_proto glog lod_rank_table feed_fetch_method graph_to_program_pass ngraph_operator variable_helper) + else(WITH_NGRAPH) + cc_library(executor SRCS executor.cc DEPS op_registry device_context scope framework_proto glog lod_rank_table feed_fetch_method graph_to_program_pass variable_helper) + endif(WITH_NGRAPH) cc_test(test_naive_executor SRCS naive_executor_test.cc DEPS naive_executor elementwise_add_op) endif() @@ -184,9 +198,23 @@ cc_library(parallel_executor SRCS parallel_executor.cc DEPS fast_threaded_ssa_graph_executor variable_helper) if(WITH_PSLIB) - cc_library(async_executor SRCS async_executor.cc data_feed.cc data_feed_factory.cc executor_thread_worker.cc DEPS op_registry device_context scope framework_proto glog lod_rank_table feed_fetch_method graph_to_program_pass async_executor_proto variable_helper pslib_brpc pslib timer) + cc_library(async_executor SRCS async_executor.cc data_feed.cc data_feed_factory.cc + executor_thread_worker.cc multi_trainer.cc dist_multi_trainer.cc + trainer_factory.cc trainer.cc device_worker.cc hogwild_worker.cc + downpour_worker.cc pull_dense_worker.cc device_worker_factory.cc + DEPS op_registry device_context scope framework_proto + trainer_desc_proto glog lod_rank_table + feed_fetch_method graph_to_program_pass async_executor_proto + variable_helper pslib_brpc pslib timer) else() - cc_library(async_executor SRCS async_executor.cc data_feed.cc data_feed_factory.cc executor_thread_worker.cc DEPS op_registry device_context scope framework_proto glog lod_rank_table feed_fetch_method graph_to_program_pass async_executor_proto variable_helper timer) + cc_library(async_executor SRCS async_executor.cc data_feed.cc data_feed_factory.cc + executor_thread_worker.cc multi_trainer.cc dist_multi_trainer.cc + trainer_factory.cc trainer.cc device_worker.cc hogwild_worker.cc + downpour_worker.cc pull_dense_worker.cc device_worker_factory.cc + DEPS op_registry device_context scope framework_proto + trainer_desc_proto glog lod_rank_table + feed_fetch_method graph_to_program_pass async_executor_proto + variable_helper timer) endif(WITH_PSLIB) @@ -211,24 +239,3 @@ endif (NOT WIN32) cc_library(dlpack_tensor SRCS dlpack_tensor.cc DEPS tensor dlpack) cc_test(dlpack_tensor_test SRCS dlpack_tensor_test.cc DEPS dlpack_tensor glog) - -# Get the current working branch -execute_process( - COMMAND git rev-parse --abbrev-ref HEAD - WORKING_DIRECTORY ${CMAKE_SOURCE_DIR} - OUTPUT_VARIABLE PADDLE_BRANCH - OUTPUT_STRIP_TRAILING_WHITESPACE -) - -# Get the latest abbreviated commit hash of the working branch -execute_process( - COMMAND git log -1 --format=%h - WORKING_DIRECTORY ${CMAKE_SOURCE_DIR} - OUTPUT_VARIABLE PADDLE_COMMIT - OUTPUT_STRIP_TRAILING_WHITESPACE -) - -message(STATUS "commit: ${PADDLE_COMMIT}") -message(STATUS "branch: ${PADDLE_BRANCH}") - -configure_file(commit.h.in commit.h) diff --git a/paddle/fluid/framework/async_executor.cc b/paddle/fluid/framework/async_executor.cc index bfdb584833..b79df98b08 100644 --- a/paddle/fluid/framework/async_executor.cc +++ b/paddle/fluid/framework/async_executor.cc @@ -26,6 +26,7 @@ limitations under the License. */ #include "paddle/fluid/framework/lod_tensor_array.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/reader.h" +#include "paddle/fluid/framework/trainer_desc.pb.h" #include "paddle/fluid/inference/io.h" #include "paddle/fluid/platform/place.h" #include "paddle/fluid/pybind/pybind.h" @@ -56,52 +57,9 @@ void AsyncExecutor::GatherServers(const std::vector& host_sign_list, fleet_ptr_->GatherServers(host_sign_list, node_num); } -void AsyncExecutor::InitModel() { - for (auto table_id : _param_config.dense_table_id) { - std::vector regions; - for (auto& t : _param_config.dense_variable_name[table_id]) { - Variable* var = root_scope_->FindVar(t); - CHECK(var != nullptr) << "var[" << t << "] not found"; - LoDTensor* tensor = var->GetMutable(); +void AsyncExecutor::InitModel() {} - float* g = tensor->data(); - CHECK(g != nullptr) << "var[" << t << "] value not initialized"; - - float init_range = 0.2; - int rown = tensor->dims()[0]; - init_range /= sqrt(rown); - - std::normal_distribution ndistr(0.0, 1.0); - for (auto i = 0u; i < tensor->numel(); ++i) { - g[i] = ndistr(local_random_engine()) * init_range; - } - - paddle::ps::Region reg(g, tensor->numel()); - regions.emplace_back(std::move(reg)); - } - - auto push_status = _pslib_ptr->_worker_ptr->push_dense_param( - regions.data(), regions.size(), table_id); - push_status.wait(); - auto status = push_status.get(); - if (status != 0) { - LOG(FATAL) << "push dense param failed, status[" << status << "]"; - exit(-1); - } - } -} - -void AsyncExecutor::SaveModel(const std::string& path) { - auto ret = _pslib_ptr->_worker_ptr->flush(); - ret.wait(); - ret = _pslib_ptr->_worker_ptr->save(path, 0); - ret.wait(); - int32_t feasign_cnt = ret.get(); - if (feasign_cnt == -1) { // (colourful-tree) TODO should be feasign_cnt < 0 - LOG(FATAL) << "save model failed"; - exit(-1); - } -} +void AsyncExecutor::SaveModel(const std::string& path) {} void AsyncExecutor::RunFromFile(const ProgramDesc& main_program, <<<<<<< HEAD diff --git a/paddle/fluid/framework/async_executor.h b/paddle/fluid/framework/async_executor.h index f05106b61f..4623672279 100644 --- a/paddle/fluid/framework/async_executor.h +++ b/paddle/fluid/framework/async_executor.h @@ -27,6 +27,7 @@ limitations under the License. */ #include "paddle/fluid/framework/data_feed.pb.h" #include "paddle/fluid/framework/executor.h" #include "paddle/fluid/framework/executor_thread_worker.h" +#include "paddle/fluid/framework/fleet/fleet_wrapper.h" #include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/framework/scope.h" @@ -62,11 +63,7 @@ class AsyncExecutor { AsyncExecutor(Scope* scope, const platform::Place& place); virtual ~AsyncExecutor() {} void RunFromFile(const ProgramDesc& main_program, - const std::string& data_feed_desc_str, - const std::vector& filelist, - const int thread_num, - const std::vector& fetch_names, - const std::string& mode, const bool debug = false); + const std::string& trainer_desc_str, const bool debug); void InitServer(const std::string& dist_desc, int index); void InitWorker(const std::string& dist_desc, diff --git a/paddle/fluid/framework/device_worker.cc b/paddle/fluid/framework/device_worker.cc new file mode 100644 index 0000000000..443acf0a16 --- /dev/null +++ b/paddle/fluid/framework/device_worker.cc @@ -0,0 +1,27 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/framework/device_worker.h" + +namespace paddle { +namespace framework { + +void DeviceWorker::SetRootScope(Scope* root_scope) { root_scope_ = root_scope; } + +void DeviceWorker::SetDataFeed(const std::shared_ptr& data_feed) { + device_reader_ = data_feed; +} + +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/fleet/CMakeLists.txt b/paddle/fluid/framework/fleet/CMakeLists.txt new file mode 100644 index 0000000000..1457ac5d7f --- /dev/null +++ b/paddle/fluid/framework/fleet/CMakeLists.txt @@ -0,0 +1 @@ +cc_library(fleet_wrapper SRCS fleet_wrapper.cc) diff --git a/paddle/fluid/framework/trainer.cc b/paddle/fluid/framework/trainer.cc new file mode 100644 index 0000000000..d3bdceffff --- /dev/null +++ b/paddle/fluid/framework/trainer.cc @@ -0,0 +1,25 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/framework/trainer.h" + +namespace paddle { +namespace framework { + +void TrainerBase::SetScope(Scope* root_scope) { root_scope_ = root_scope; } + +void TrainerBase::Initialize(const TrainerDesc& trainer_desc) { return; } + +} // end namespace framework +} // end namespace paddle -- GitLab