From eeed7af5c3b6d51399412ba3cd0cab2125b33e90 Mon Sep 17 00:00:00 2001 From: typhoonzero Date: Wed, 2 May 2018 20:19:58 +0800 Subject: [PATCH] add gen_nccl_id_op --- paddle/fluid/operators/CMakeLists.txt | 7 +- paddle/fluid/operators/gen_nccl_id_op.cc | 123 ++++++++++++++++++ .../fluid/operators/lookup_sparse_table_op.cc | 2 +- 3 files changed, 130 insertions(+), 2 deletions(-) create mode 100644 paddle/fluid/operators/gen_nccl_id_op.cc diff --git a/paddle/fluid/operators/CMakeLists.txt b/paddle/fluid/operators/CMakeLists.txt index 256aded8c..ad0732131 100644 --- a/paddle/fluid/operators/CMakeLists.txt +++ b/paddle/fluid/operators/CMakeLists.txt @@ -184,6 +184,11 @@ endif() add_subdirectory(detail) if(WITH_DISTRIBUTE) + if(WITH_GPU) + op_library(gen_nccl_id_op DEPS nccl_common) + else() + set(DEPS_OPS ${DEPS_OPS} gen_nccl_id_op) + endif() set(DISTRIBUTE_DEPS sendrecvop_grpc grpc++_unsecure grpc_unsecure gpr cares zlib protobuf) set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor") op_library(send_op DEPS ${DISTRIBUTE_DEPS}) @@ -201,7 +206,7 @@ if(WITH_DISTRIBUTE) set_source_files_properties(send_recv_op_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) cc_test(test_send_recv SRCS send_recv_op_test.cc DEPS prefetch_op send_op listen_and_serv_op sum_op executor) else() - set(DEPS_OPS ${DEPS_OPS} send_op prefetch_op recv_op listen_and_serv_op send_vars_op send_barrier_op) + set(DEPS_OPS ${DEPS_OPS} send_op prefetch_op recv_op listen_and_serv_op send_vars_op send_barrier_op gen_nccl_id_op) endif() op_library(cross_entropy_op DEPS cross_entropy) diff --git a/paddle/fluid/operators/gen_nccl_id_op.cc b/paddle/fluid/operators/gen_nccl_id_op.cc new file mode 100644 index 000000000..e75e045fc --- /dev/null +++ b/paddle/fluid/operators/gen_nccl_id_op.cc @@ -0,0 +1,123 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include +#include +#include + +#include "paddle/fluid/framework/executor.h" +#include "paddle/fluid/framework/lod_tensor.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/threadpool.h" +#include "paddle/fluid/operators/detail/grpc_client.h" +#include "paddle/fluid/operators/detail/grpc_server.h" + +namespace paddle { +namespace operators { + +class GenNCCLIdOp : public framework::OperatorBase { + public: + GenNCCLIdOp(const std::string& type, const framework::VariableNameMap& inputs, + const framework::VariableNameMap& outputs, + const framework::AttributeMap& attrs) + : OperatorBase(type, inputs, outputs, attrs) {} + + void RunImpl(const framework::Scope& scope, + const platform::Place& dev_place) const override { + platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); + auto& dev_ctx = *pool.Get(dev_place); + int trainer_id = Attr("trainer_id"); + framework::Scope& local_scope = scope.NewScope(); + + if (trainer_id == 0) { + GenerateAndSend(&local_scope, dev_ctx); + } else { + GetIdByServer(&local_scope, dev_ctx); + } + } + + private: + void GenerateAndSend(framework::Scope* scope, + const platform::DeviceContext& dev_ctx) const { + auto var = scope->FindVar("NCCLID"); + PADDLE_ENFORCE_NOT_NULL(var); + auto id = var->GetMutable(); + ncclGetUniqueId(id); + + std::vector endpoint_list = + Attr>("endpoint_list"); + detail::RPCClient client; + for (auto& ep : endpoint_list) { + client.AsyncSendVariable(ep, dev_ctx, *scope, "NCCLID"); + } + client.Wait(); + } + + void GetIdByServer(framework::Scope* scope, + const platform::DeviceContext& dev_ctx) const { + std::string endpoint = Attr("endpoint"); + rpc_service_.reset(new detail::AsyncGRPCServer(endpoint, true)); + framework::ProgramDesc empty_program; + framework::Executor executor(dev_ctx.GetPlace()); + rpc_service_->SetScope(scope); + rpc_service_->SetDevCtx(&dev_ctx); + rpc_service_->SetProgram(&empty_program); + rpc_service_->SetExecutor(&executor); + + server_thread_.reset(new std::thread(std::bind( + &detail::AsyncGRPCServer::RunSyncUpdate, rpc_service_.get()))); + + auto recv = rpc_service_->Get(); + rpc_service_->ShutDown(); + // TODO(wuyi): reinit nccl communicators + } + + protected: + mutable std::shared_ptr rpc_service_; + mutable std::shared_ptr server_thread_; +}; + +class GenNCCLIdOpMaker : public framework::OpProtoAndCheckerMaker { + public: + GenNCCLIdOpMaker(OpProto* proto, OpAttrChecker* op_checker) + : GenNCCLIdOpMaker(proto, op_checker) { + AddOutput("NCCLID", "Raw variable contains a NCCL UniqueId instaces."); + AddComment(R"DOC( +GenNCCLId operator + +For trainer 0: generate a new UniqueId and send it to all the other trainers. +For trainer 1~n: start a gRPC server to get the UniqueId, once got, stop the server. +)DOC"); + AddAttr("endpoint", + "(string), e.g. 127.0.0.1:6175 " + "current listen endpoint"); + AddAttr>( + "endpoint_list", + "['trainer1_ip:port', 'trainer2_ip:port', ...] " + "list of trainer endpoints start from trainer 1") + .SetDefault({}); + AddAttr("trainer_id", + "(int default 0) " + "The index of the trainer in distributed training.") + .SetDefault(0); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; + +REGISTER_OPERATOR(gen_nccl_id_op, ops::GenNCCLIdOp, ops::GenNCCLIdOpMaker); diff --git a/paddle/fluid/operators/lookup_sparse_table_op.cc b/paddle/fluid/operators/lookup_sparse_table_op.cc index f1839e456..66b626ed7 100644 --- a/paddle/fluid/operators/lookup_sparse_table_op.cc +++ b/paddle/fluid/operators/lookup_sparse_table_op.cc @@ -62,7 +62,7 @@ class LookupSparseTableOp : public framework::OperatorBase { auto w_t = w_var->GetMutable(); std::vector keys; keys.resize(ids_t.numel()); - for (size_t i = 0; i < ids_t.numel(); ++i) { + for (int64_t i = 0; i < ids_t.numel(); ++i) { keys[i] = ids_t.data()[i]; } -- GitLab