diff --git a/paddle/fluid/framework/fleet/CMakeLists.txt b/paddle/fluid/framework/fleet/CMakeLists.txt index 7d363d1afdc8ac72741e6e4fea02fb96fe9347fa..12fc454fd262cdcf30f64757a6199c6a9331e1a2 100644 --- a/paddle/fluid/framework/fleet/CMakeLists.txt +++ b/paddle/fluid/framework/fleet/CMakeLists.txt @@ -3,3 +3,5 @@ if(WITH_PSLIB) else() cc_library(fleet_wrapper SRCS fleet_wrapper.cc DEPS framework_proto variable_helper scope) endif(WITH_PSLIB) + +cc_library(nccl_wrapper SRCS nccl_wrapper.cc DEPS framework_proto variable_helper scope) diff --git a/paddle/fluid/framework/fleet/nccl_wrapper.cc b/paddle/fluid/framework/fleet/nccl_wrapper.cc new file mode 100644 index 0000000000000000000000000000000000000000..9a39da6dcae1e788d696c92f863e36f41e59ba4b --- /dev/null +++ b/paddle/fluid/framework/fleet/nccl_wrapper.cc @@ -0,0 +1,66 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/fleet/nccl_wrapper.h" +#include +#include "paddle/fluid/framework/data_feed.h" +#include "paddle/fluid/framework/scope.h" + +namespace paddle { +namespace framework { + +std::shared_ptr NCCLWrapper::s_instance_ = NULL; +bool NCCLWrapper::is_initialized_ = false; + +void NCCLWrapper::InitNCCL() { + platform::dynload::ncclCommInitRank( + &(nccl_info_.comm_), nccl_info_.global_ranks_, nccl_info_.nccl_id_, + nccl_info_.my_global_rank_); + return; +} + +void NCCLWrapper::SetNCCLId(const NCCLInfo& nccl_info) { + nccl_info_.nccl_id_ = nccl_info.nccl_id_; +} + +NCCLInfo NCCLWrapper::GetNCCLId() { + PADDLE_ENFORCE(platform::dynload::ncclGetUniqueId(&(nccl_info_.nccl_id_))); + return nccl_info_; +} + +void NCCLWrapper::SetRankInfo(const int local_rank, const int global_rank, + const int ranks) { + nccl_info_.local_rank_ = local_rank; + nccl_info_.my_global_rank_ = global_rank; + nccl_info_.global_ranks_ = ranks; + PADDLE_ENFORCE(cudaSetDevice(local_rank)); + PADDLE_ENFORCE(cudaStreamCreate(&(nccl_info_.stream_))); + return; +} + +void NCCLWrapper::SyncVar(const int root_rank, const Scope& scope, + const std::vector& var_names) { + for (auto& name : var_names) { + auto var = scope.FindVar(name); + LoDTensor* tensor = var->GetMutable(); + int32_t total_size = tensor->numel(); + platform::dynload::ncclBcast(reinterpret_cast(tensor->data()), + total_size, ncclFloat, root_rank, + nccl_info_.comm_, nccl_info_.stream_); + cudaStreamSynchronize(nccl_info_.stream_); + } +} + +} // end namespace framework +} // end namespace paddle diff --git a/paddle/fluid/framework/fleet/nccl_wrapper.h b/paddle/fluid/framework/fleet/nccl_wrapper.h index 9eeb4ce6251719716b8f6f00e435a0818f440460..5d73cba892d4f0358141c2b882a461f45d313f2c 100644 --- a/paddle/fluid/framework/fleet/nccl_wrapper.h +++ b/paddle/fluid/framework/fleet/nccl_wrapper.h @@ -29,114 +29,49 @@ limitations under the License. */ namespace paddle { namespace framework { +class NCCLInfo { + public: + NCCLInfo() {} + virtual ~NCCLInfo() {} + + public: + int local_rank_; + int global_ranks_; + int my_global_rank_; + ncclUniqueId nccl_id_; + ncclComm_t comm_; + cudaStream_t stream_; +}; + class NCCLWrapper { public: virtual ~NCCLWrapper() {} NCCLWrapper() {} - // Pull sparse variables from server in Sync mode - // Param: scope, table_id, var_names, fea_keys - // Param: fea_values - void PullSparseVarsSync(const Scope& scope, const uint64_t table_id, - const std::vector& var_names, - std::vector* fea_keys, - std::vector>* fea_values, - int fea_dim); - - void PullDenseVarsSync(const Scope& scope, const uint64_t table_id, - const std::vector& var_names); - - void PullDenseVarsAsync( - const Scope& scope, const uint64_t table_id, - const std::vector& var_names, - std::vector<::std::future>* pull_dense_status); - - void PushDenseParamSync(const Scope& scope, const uint64_t table_id, - const std::vector& var_names); + void InitNCCL(); + void SetNCCLId(const NCCLInfo& nccl_info); + NCCLInfo GetNCCLId(); + void SetRankInfo(const int local_rank, const int global_rank, + const int ranks); + void SyncVar(const int root_rank, const Scope& scope, + const std::vector& var_names); - // Push dense variables to server in async mode - // Param: scope, table_id, var_names, - // Param: push_sparse_status - void PushDenseVarsAsync( - const Scope& scope, const uint64_t table_id, - const std::vector& var_names, - std::vector<::std::future>* push_sparse_status); - - void PushDenseVarsSync(Scope* scope, const uint64_t table_id, - const std::vector& var_names); - - // Push sparse variables with labels to server in Async mode - // This is specially designed for click/show stats in server - // Param: scope, table_id, var_grad_names, - // fea_keys, fea_labels, sparse_grad_names - // Param: push_values, push_sparse_status - void PushSparseVarsWithLabelAsync( - const Scope& scope, const uint64_t table_id, - const std::vector& fea_keys, - const std::vector& fea_labels, - const std::vector& sparse_key_names, - const std::vector& sparse_grad_names, const int emb_dim, - std::vector>* push_values, - std::vector<::std::future>* push_sparse_status); - - // Push sparse variables to server in Async mode - // Param: scope, table_id, fea_keys, sparse_grad_names - // Param: push_values, push_sparse_status - /* - void PushSparseVarsAsync( - const Scope& scope, - const uint64_t table_id, - const std::vector& fea_keys, - const std::vector& sparse_grad_names, - std::vector>* push_values, - std::vector<::std::future>* push_sparse_status); - */ - - void InitServer(const std::string& dist_desc, int index); - void InitWorker(const std::string& dist_desc, - const std::vector& host_sign_list, int node_num, - int index); - void StopServer(); - uint64_t RunServer(); - void GatherServers(const std::vector& host_sign_list, int node_num); - // gather client ip - void GatherClients(const std::vector& host_sign_list); - // get client info - std::vector GetClientsInfo(); - // create client to client connection - void CreateClient2ClientConnection(); - - // register client to client communication - typedef std::function MsgHandlerFunc; - int RegisterClientToClientMsgHandler(int msg_type, MsgHandlerFunc handler); - // send client to client message - std::future SendClientToClientMsg(int msg_type, int to_client_id, - const std::string& msg); - - template - void Serialize(const std::vector& t, std::string* str); - template - void Deserialize(std::vector* t, const std::string& str); - static std::shared_ptr GetInstance() { + static std::shared_ptr GetInstance() { if (NULL == s_instance_) { - s_instance_.reset(new paddle::framework::FleetWrapper()); + s_instance_.reset(new paddle::framework::NCCLWrapper()); } return s_instance_; } -#ifdef PADDLE_WITH_PSLIB - static std::shared_ptr pslib_ptr_; -#endif + public: + NCCLInfo nccl_info_; private: - static std::shared_ptr s_instance_; -#ifdef PADDLE_WITH_PSLIB - std::map> _regions; -#endif + static std::shared_ptr s_instance_; protected: static bool is_initialized_; - DISABLE_COPY_AND_ASSIGN(FleetWrapper); + DISABLE_COPY_AND_ASSIGN(NCCLWrapper); }; } // end namespace framework diff --git a/paddle/fluid/pybind/nccl_wrapper_py.cc b/paddle/fluid/pybind/nccl_wrapper_py.cc new file mode 100644 index 0000000000000000000000000000000000000000..aec4f4463540df294694aa7a6b1d702867815eac --- /dev/null +++ b/paddle/fluid/pybind/nccl_wrapper_py.cc @@ -0,0 +1,52 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ +#include + +#ifdef _POSIX_C_SOURCE +#undef _POSIX_C_SOURCE +#endif + +#ifdef _XOPEN_SOURCE +#undef _XOPEN_SOURCE +#endif + +#include +#include + +#include "google/protobuf/io/zero_copy_stream_impl.h" +#include "google/protobuf/text_format.h" +#include "paddle/fluid/framework/async_executor.h" +#include "paddle/fluid/framework/data_feed.h" +#include "paddle/fluid/framework/data_feed.pb.h" +#include "paddle/fluid/framework/scope.h" +#include "paddle/fluid/inference/io.h" +#include "paddle/fluid/platform/place.h" +#include "paddle/fluid/platform/variant.h" +#include "paddle/fluid/pybind/nccl_wrapper_py.h" + +namespace py = pybind11; +namespace pd = paddle::framework; + +namespace paddle { +namespace pybind { +void BindNCCLWrapper(py::module* m) { + py::class_(*m, "Nccl") + .def(py::init()) + .def("init_nccl", &framework::NCCLWrapper::InitNCCL) + .def("set_nccl_id", &framework::NCCLWrapper::SetNCCLId) + .def("set_rank_info", &framework::NCCLWrapper::SetRankInfo) + .def("sync_var", &framework::NCCLWrapper::SyncVar); +} // end NCCLWrapper +} // end namespace pybind +} // end namespace paddle diff --git a/paddle/fluid/pybind/nccl_wrapper_py.h b/paddle/fluid/pybind/nccl_wrapper_py.h new file mode 100644 index 0000000000000000000000000000000000000000..683eb4d61e00abf4e7192efb1d102ff73cb9e02e --- /dev/null +++ b/paddle/fluid/pybind/nccl_wrapper_py.h @@ -0,0 +1,28 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "pybind11/pybind11.h" +#include "pybind11/stl.h" + +namespace py = pybind11; + +namespace paddle { +namespace pybind { + +void BindNCCLWrapper(py::module* m); + +} // namespace pybind +} // namespace paddle diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index a8a2a94d473b18fdcd78771063ef4565c7fe0e42..6a5f5f60bca1974635730ce746869b95cf4e80ed 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -58,6 +58,7 @@ limitations under the License. */ #include "paddle/fluid/pybind/imperative.h" #include "paddle/fluid/pybind/inference_api.h" #include "paddle/fluid/pybind/ir.h" +#include "paddle/fluid/pybind/nccl_wrapper_py.h" #include "paddle/fluid/pybind/protobuf.h" #include "paddle/fluid/pybind/pybind.h" // NOLINT #include "paddle/fluid/pybind/reader_py.h" @@ -1405,6 +1406,7 @@ All parameter, weight, gradient are variables in Paddle. BindRecordIOWriter(&m); BindAsyncExecutor(&m); BindFleetWrapper(&m); + BindNCCLWrapper(&m); BindGraph(&m); BindNode(&m); BindInferenceApi(&m);