From ec47565c23f872d5f8c1607b7c44c5e3d155c676 Mon Sep 17 00:00:00 2001 From: Dong Zhihong Date: Mon, 23 Oct 2017 14:53:17 -0700 Subject: [PATCH] "add reduce hash function" --- paddle/framework/operator.h | 9 +++++++++ paddle/operators/nccl_op.cc | 11 ++++------- paddle/operators/nccl_op.cu | 29 +++++++++-------------------- 3 files changed, 22 insertions(+), 27 deletions(-) diff --git a/paddle/framework/operator.h b/paddle/framework/operator.h index cf15f9933ab..8cdb07e6770 100644 --- a/paddle/framework/operator.h +++ b/paddle/framework/operator.h @@ -289,6 +289,15 @@ class ExecutionContext { return device_context_; } + //! Get a input which has multiple variables. + const std::vector& Inputs(const std::string& name) const { + return op_.Inputs(name); + } + //! Get an output which has multiple variables. + const std::vector& Outputs(const std::string& name) const { + return op_.Outputs(name); + } + #ifdef PADDLE_WITH_CUDA const platform::CUDADeviceContext& cuda_device_context() const { PADDLE_ENFORCE(platform::is_gpu_place(device_context_.GetPlace())); diff --git a/paddle/operators/nccl_op.cc b/paddle/operators/nccl_op.cc index f0f7b205b68..89dedfc1581 100644 --- a/paddle/operators/nccl_op.cc +++ b/paddle/operators/nccl_op.cc @@ -81,9 +81,6 @@ class NCCLReduceOp : public framework::OperatorWithKernel { " Input(Communicator) of Reduce op input should not be NULL"); PADDLE_ENFORCE(ctx->HasOutput("Out"), " Input(X) of Reduce op input should not be NULL"); - - ctx->SetOutputsDim("Out", x_dims); - ctx->ShareLoD("X", /*->*/ "Out"); } }; @@ -137,8 +134,8 @@ class NCCLAllReduceOpMaker : public framework::OpProtoAndCheckerMaker { // BcastSendOp class NCCLBcastSendOpMaker : public framework::OpProtoAndCheckerMaker { public: - NCCLAllBcastSendOpMaker(framework::OpProto *proto, - framework::OpAttrChecker *op_checker) + NCCLBcastSendOpMaker(framework::OpProto *proto, + framework::OpAttrChecker *op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { AddInput("X", "The input of BcastSend op"); AddInput("Communicator", "Communicator for communicating between gpus"); @@ -152,8 +149,8 @@ class NCCLBcastSendOpMaker : public framework::OpProtoAndCheckerMaker { // BcastOp class NCCLBcastRecvOpMaker : public framework::OpProtoAndCheckerMaker { public: - NCCLAllBcastRecvOpMaker(framework::OpProto *proto, - framework::OpAttrChecker *op_checker) + NCCLBcastRecvOpMaker(framework::OpProto *proto, + framework::OpAttrChecker *op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { AddInput("Communicator", "Communicator for communicating between gpus"); AddAttr("root", "root gpu of BcastRecv"); diff --git a/paddle/operators/nccl_op.cu b/paddle/operators/nccl_op.cu index 4d91a3055fa..5f8e0a886b8 100644 --- a/paddle/operators/nccl_op.cu +++ b/paddle/operators/nccl_op.cu @@ -2,8 +2,8 @@ Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at -http://www.apache.org/licenses/LICENSE-2.0 -Unless required by applicable law or agreed to in writing, software +http://www.apache.org/licenseshashernless required by applicable law or agreed +to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and @@ -27,25 +27,12 @@ class NCCLAllReduceKernel : public framework::OpKernel { auto ins = ctx.MultiInput("X"); auto outs = ctx.MultiOutput("Out"); std::string reduction = ctx.Attr("reduction"); - ncclRedOp_t op_type; - if (reduction == "ncclSum") { - op_type = ncclSum; - } else if (reduction == "ncclProd") { - op_type = ncclProd; - } else if (reduction == "ncclMin") { - op_type = ncclMin; - } else if (reduction == "ncclMax") { - op_type = ncclMax; - } else { - PADDLE_ENFORCE(false, "reduction error."); - } auto* comm = ctx.Input("Communicator"); auto stream = reinterpret_cast( ctx.device_context()) .stream(); - // device id int device_id = boost::get(ctx.GetPlace()).GetDeviceId(); @@ -54,7 +41,7 @@ class NCCLAllReduceKernel : public framework::OpKernel { for (size_t i = 0; i < ins.size(); ++i) { PADDLE_ENFORCE(ncclAllReduce( ins[i]->data(), outs[i]->mutable_data(ctx.GetPlace()), - outs[i]->numel() * sizeof(T), NCCLTypeWrapper::type, op_type, + outs[i]->numel() * sizeof(T), NCCLTypeWrapper::type, ncclSum, comm->comms_[idx], stream)); PADDLE_ENFORCE(cudaStreamSynchronize(stream)); } @@ -68,7 +55,7 @@ class NCCLReduceKernel : public framework::OpKernel { PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()), "This kernel only runs on GPU device."); - auto ins = ctx.MultiInput("X"); + auto ins = ctx.MultiInput("X"); // x0, x1, x2 auto outs = ctx.MultiOutput("Out"); auto* comm = ctx.Input("Communicator"); @@ -81,14 +68,16 @@ class NCCLReduceKernel : public framework::OpKernel { boost::get(ctx.GetPlace()).GetDeviceId(); int idx = comm->GetCommId(device_id); + auto ins_names = ctx.Inputs("X"); + std::hash hasher; for (size_t i = 0; i < ins.size(); ++i) { - int root = std::hash() % comm->comms_.size(); + int root = hasher(ins_names[i]) % comm->comms_.size(); T* recvbuffer = nullptr; if (root == device_id) { recvbuffer = outs[i]->mutable_data(ctx.GetPlace()); } PADDLE_ENFORCE(ncclReduce(ins[i]->data(), recvbuffer, ins[i]->numel(), - NCCLTypeWrapper::type, root, ncclSum, + NCCLTypeWrapper::type, ncclSum, root, comm->comms_[idx], stream)); PADDLE_ENFORCE(cudaStreamSynchronize(stream)); } @@ -124,7 +113,7 @@ class NCCLBcastKernel : public framework::OpKernel { } else { auto outs = ctx.MultiOutput("Out"); for (size_t i = 0; i < outs.size(); ++i) { - PADDLE_ENFORCE(ncclBcast((void*)outs[i]->mutable_data(), + PADDLE_ENFORCE(ncclBcast(outs[i]->mutable_data(ctx.GetPlace()), outs[i]->numel(), NCCLTypeWrapper::type, root, comm->comms_[idx], stream)); PADDLE_ENFORCE(cudaStreamSynchronize(stream)); -- GitLab