diff --git a/paddle/framework/operator.h b/paddle/framework/operator.h index cf15f9933ab3bc881add3d45b7ca17194a70e0f1..8cdb07e6770c14c54e7bcc38ed6e7085a14708d6 100644 --- a/paddle/framework/operator.h +++ b/paddle/framework/operator.h @@ -289,6 +289,15 @@ class ExecutionContext { return device_context_; } + //! Get a input which has multiple variables. + const std::vector& Inputs(const std::string& name) const { + return op_.Inputs(name); + } + //! Get an output which has multiple variables. + const std::vector& Outputs(const std::string& name) const { + return op_.Outputs(name); + } + #ifdef PADDLE_WITH_CUDA const platform::CUDADeviceContext& cuda_device_context() const { PADDLE_ENFORCE(platform::is_gpu_place(device_context_.GetPlace())); diff --git a/paddle/operators/nccl_op.cc b/paddle/operators/nccl_op.cc index f0f7b205b68851f09b662d41271639931f268309..89dedfc1581a9c2437a92fa6ba964b25e4e6cd45 100644 --- a/paddle/operators/nccl_op.cc +++ b/paddle/operators/nccl_op.cc @@ -81,9 +81,6 @@ class NCCLReduceOp : public framework::OperatorWithKernel { " Input(Communicator) of Reduce op input should not be NULL"); PADDLE_ENFORCE(ctx->HasOutput("Out"), " Input(X) of Reduce op input should not be NULL"); - - ctx->SetOutputsDim("Out", x_dims); - ctx->ShareLoD("X", /*->*/ "Out"); } }; @@ -137,8 +134,8 @@ class NCCLAllReduceOpMaker : public framework::OpProtoAndCheckerMaker { // BcastSendOp class NCCLBcastSendOpMaker : public framework::OpProtoAndCheckerMaker { public: - NCCLAllBcastSendOpMaker(framework::OpProto *proto, - framework::OpAttrChecker *op_checker) + NCCLBcastSendOpMaker(framework::OpProto *proto, + framework::OpAttrChecker *op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { AddInput("X", "The input of BcastSend op"); AddInput("Communicator", "Communicator for communicating between gpus"); @@ -152,8 +149,8 @@ class NCCLBcastSendOpMaker : public framework::OpProtoAndCheckerMaker { // BcastOp class NCCLBcastRecvOpMaker : public framework::OpProtoAndCheckerMaker { public: - NCCLAllBcastRecvOpMaker(framework::OpProto *proto, - framework::OpAttrChecker *op_checker) + NCCLBcastRecvOpMaker(framework::OpProto *proto, + framework::OpAttrChecker *op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { AddInput("Communicator", "Communicator for communicating between gpus"); AddAttr("root", "root gpu of BcastRecv"); diff --git a/paddle/operators/nccl_op.cu b/paddle/operators/nccl_op.cu index 4d91a3055faca47cea9b5ea5640675abee8015d6..5f8e0a886b8f0b7b3e0b45cf4a4e9393eb6e98b4 100644 --- a/paddle/operators/nccl_op.cu +++ b/paddle/operators/nccl_op.cu @@ -2,8 +2,8 @@ Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at -http://www.apache.org/licenses/LICENSE-2.0 -Unless required by applicable law or agreed to in writing, software +http://www.apache.org/licenseshashernless required by applicable law or agreed +to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and @@ -27,25 +27,12 @@ class NCCLAllReduceKernel : public framework::OpKernel { auto ins = ctx.MultiInput("X"); auto outs = ctx.MultiOutput("Out"); std::string reduction = ctx.Attr("reduction"); - ncclRedOp_t op_type; - if (reduction == "ncclSum") { - op_type = ncclSum; - } else if (reduction == "ncclProd") { - op_type = ncclProd; - } else if (reduction == "ncclMin") { - op_type = ncclMin; - } else if (reduction == "ncclMax") { - op_type = ncclMax; - } else { - PADDLE_ENFORCE(false, "reduction error."); - } auto* comm = ctx.Input("Communicator"); auto stream = reinterpret_cast( ctx.device_context()) .stream(); - // device id int device_id = boost::get(ctx.GetPlace()).GetDeviceId(); @@ -54,7 +41,7 @@ class NCCLAllReduceKernel : public framework::OpKernel { for (size_t i = 0; i < ins.size(); ++i) { PADDLE_ENFORCE(ncclAllReduce( ins[i]->data(), outs[i]->mutable_data(ctx.GetPlace()), - outs[i]->numel() * sizeof(T), NCCLTypeWrapper::type, op_type, + outs[i]->numel() * sizeof(T), NCCLTypeWrapper::type, ncclSum, comm->comms_[idx], stream)); PADDLE_ENFORCE(cudaStreamSynchronize(stream)); } @@ -68,7 +55,7 @@ class NCCLReduceKernel : public framework::OpKernel { PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()), "This kernel only runs on GPU device."); - auto ins = ctx.MultiInput("X"); + auto ins = ctx.MultiInput("X"); // x0, x1, x2 auto outs = ctx.MultiOutput("Out"); auto* comm = ctx.Input("Communicator"); @@ -81,14 +68,16 @@ class NCCLReduceKernel : public framework::OpKernel { boost::get(ctx.GetPlace()).GetDeviceId(); int idx = comm->GetCommId(device_id); + auto ins_names = ctx.Inputs("X"); + std::hash hasher; for (size_t i = 0; i < ins.size(); ++i) { - int root = std::hash() % comm->comms_.size(); + int root = hasher(ins_names[i]) % comm->comms_.size(); T* recvbuffer = nullptr; if (root == device_id) { recvbuffer = outs[i]->mutable_data(ctx.GetPlace()); } PADDLE_ENFORCE(ncclReduce(ins[i]->data(), recvbuffer, ins[i]->numel(), - NCCLTypeWrapper::type, root, ncclSum, + NCCLTypeWrapper::type, ncclSum, root, comm->comms_[idx], stream)); PADDLE_ENFORCE(cudaStreamSynchronize(stream)); } @@ -124,7 +113,7 @@ class NCCLBcastKernel : public framework::OpKernel { } else { auto outs = ctx.MultiOutput("Out"); for (size_t i = 0; i < outs.size(); ++i) { - PADDLE_ENFORCE(ncclBcast((void*)outs[i]->mutable_data(), + PADDLE_ENFORCE(ncclBcast(outs[i]->mutable_data(ctx.GetPlace()), outs[i]->numel(), NCCLTypeWrapper::type, root, comm->comms_[idx], stream)); PADDLE_ENFORCE(cudaStreamSynchronize(stream));