diff --git a/paddle/operators/nccl/nccl_gpu_common.cc b/paddle/operators/nccl/nccl_gpu_common.cc index 80cb66300e98bbe4c30ac6cacb8ea7bb8c2ec44b..934f79f2457107957b1605f5344038df7f4f261e 100644 --- a/paddle/operators/nccl/nccl_gpu_common.cc +++ b/paddle/operators/nccl/nccl_gpu_common.cc @@ -18,7 +18,7 @@ NCCLManager::~NCCLManager() { int idx = gid % gpus_.size(); // wait finish PADDLE_ENFORCE( - cudaStreamWaitEvent(*comm->streams_[idx], comm->events_[idx], 0)); + cudaStreamWaitEvent(comm->streams_[idx], comm->events_[idx], 0)); PADDLE_ENFORCE(cudaEventDestroy(comm->events_[idx])); diff --git a/paddle/operators/nccl/nccl_gpu_common.h b/paddle/operators/nccl/nccl_gpu_common.h index 4a375fcc36a1f07287332aed985de132f94d5929..5ca6a9e05efd833044618a06267ce9ff906e92dd 100644 --- a/paddle/operators/nccl/nccl_gpu_common.h +++ b/paddle/operators/nccl/nccl_gpu_common.h @@ -65,20 +65,10 @@ class WaitGroup { std::condition_variable cv_; }; -// class NCCLContext : public DeviceContext { -// public: -// explicit NCCLContext(GPUPlace place); -// virtual ~NCCLContext(); - -// private: -// std::vector gpu_ids_; -// std::vector streams_; -// }; - // TODO(dzh) : make resources managed unified with framework struct Communicator { std::vector comms_; - std::vector streams_; + std::vector streams_; std::vector events_; std::vector gpus_; WaitGroup wg_; diff --git a/paddle/operators/nccl/nccl_ops.cc b/paddle/operators/nccl/nccl_ops.cc index ccb22f3052f1994af250e61ccbca191387a5c6a7..f1a83c1e1e344841e2d95984400c2c7d4c5c41f7 100644 --- a/paddle/operators/nccl/nccl_ops.cc +++ b/paddle/operators/nccl/nccl_ops.cc @@ -1,3 +1,14 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + #include "paddle/operators/nccl/nccl_ops.h" namespace paddle { @@ -9,54 +20,27 @@ class NCCLAllReduceOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - // allreduce do nothing in infershape - void InferShape(const framework::InferShapeContext &ctx) const override { - PADDLE_ENFORCE_NOT_NULL( - ctx.InputVar("X"), - " Input(X) of AllReduce op input should not be NULL"); - auto ins = ctx.MultiInput("X"); - auto outs = ctx.MultiOutput("Out"); - PADDLE_ENFORCE(ins.size() == outs.size(), - "Input(X) and Output(Out) must have same size"); - for (size_t i = 0; i < ins.size(); ++i) { - outs[i]->Resize(ins[i]->dims()); - } - std::string reduction = ctx.Attr("reduction"); - PADDLE_ENFORCE((reduction == "ncclSum" || reduction == "ncclProd" || - reduction == "ncclMin" || reduction == "ncclMax"), - "invalid reduction!"); - } -}; + void InferShape(framework::InferShapeContext *ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), + " Input(X) of AllReduce op input should not be NULL"); + PADDLE_ENFORCE(ctx->HasOutput("Out"), + " Input(X) of AllReduce op input should not be NULL"); -// BcastSendOp -template -class NCCLBcastSendOp final : public framework::OperatorWithKernel { - public: - using framework::OperatorWithKernel::OperatorWithKernel; - - protected: - void InferShape(const framework::InferShapeContext &ctx) const override { - PADDLE_ENFORCE_NOT_NULL( - ctx.InputVar("X"), - " Input(X) of BcastSend op input should not be NULL"); - } -}; + auto x_dims = ctx->GetInputsDim("X"); -// BcastRecvOp -template -class NCCLBcastRecvOp final : public framework::OperatorWithKernel { - public: - using framework::OperatorWithKernel::OperatorWithKernel; + std::string reduction = ctx->Attrs().Get("reduction"); + PADDLE_ENFORCE((reduction == "ncclSum" || reduction == "ncclProd" || + reduction == "ncclMin" || reduction == "ncclMax"), + "invalid reduction."); - protected: - void InferShape(const framework::InferShapeContext &ctx) const override { - PADDLE_ENFORCE_NOT_NULL( - ctx.OutputVar("Out"), - " Input(X) of BcastRecv op input should not be NULL"); + ctx->SetOutputsDim("Out", x_dims); + ctx->ShareLoD("X", /*->*/ "Out"); } }; +// AllreduceOp class NCCLAllReduceOpMaker : public framework::OpProtoAndCheckerMaker { + public: NCCLAllReduceOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { @@ -71,7 +55,9 @@ class NCCLAllReduceOpMaker : public framework::OpProtoAndCheckerMaker { } }; +// BcastSendOp class NCCLBcastSendOpMaker : public framework::OpProtoAndCheckerMaker { + public: NCCLAllReduceOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { @@ -82,7 +68,9 @@ class NCCLBcastSendOpMaker : public framework::OpProtoAndCheckerMaker { } }; +// BcastRecvOp class NCCLBcastRecvOpMaker : public framework::OpProtoAndCheckerMaker { + public: NCCLAllReduceOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { @@ -93,5 +81,9 @@ class NCCLBcastRecvOpMaker : public framework::OpProtoAndCheckerMaker { } }; -} // operators -} // paddle +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_WITHOUT_GRADIENT(ncclAllReduce, ops::NCCLAllReduceOp, + ops::NCCLAllReduceOpMaker); diff --git a/paddle/operators/nccl/nccl_ops.cu b/paddle/operators/nccl/nccl_ops.cu new file mode 100644 index 0000000000000000000000000000000000000000..eabe5f172926b68b1e3ebf9ec01b8b1722a79582 --- /dev/null +++ b/paddle/operators/nccl/nccl_ops.cu @@ -0,0 +1,16 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at +http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#define EIGEN_USE_GPU +#include "paddle/operators/nccl/nccl_ops.h" + +namespace ops = paddle::operators; +REGISTER_OP_GPU_KERNEL(ncclAllReduce, ops::NCCLAllReduceKernel); \ No newline at end of file diff --git a/paddle/operators/nccl/nccl_ops.h b/paddle/operators/nccl/nccl_ops.h index f56b89d2ad87e88c2ef3e37e22dbd4ebab3afe0d..c46fdd7d44f789ca178ea5671b7126fe8eaa8d67 100644 --- a/paddle/operators/nccl/nccl_ops.h +++ b/paddle/operators/nccl/nccl_ops.h @@ -1,3 +1,14 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + #pragma once #include "paddle/framework/op_registry.h" #include "paddle/operators/nccl/nccl_gpu_common.h" @@ -14,11 +25,13 @@ class NCCLTypeWrapper; template <> class NCCLTypeWrapper { + public: static const ncclDataType_t type = ncclFloat; }; template <> class NCCLTypeWrapper { + public: static const ncclDataType_t type = ncclDouble; }; @@ -49,10 +62,10 @@ class NCCLAllReduceKernel : public framework::OpKernel { auto* comm = m->GetCommunicator(gpus); comm->wg_.Add(1); - auto* stream = &dev_ctx.stream(); + auto stream = dev_ctx.stream(); // device id - int gid = ctx.GetPlace().GetDeviceId(); + int gid = static_cast(ctx.GetPlace()).GetDeviceId(); int idx = gid % gpus.size(); comm->streams_[idx] = stream; @@ -60,9 +73,8 @@ class NCCLAllReduceKernel : public framework::OpKernel { PADDLE_ENFORCE( ncclAllReduce(ins[i]->data(), outs[i]->mutable_data(), outs[i]->numel() * sizeof(T), NCCLTypeWrapper::type, - op_type, &comm->comms_[idx], comm->streams_[idx])); - PADDLE_ENFORCE( - cudaEventRecord(comm->events_[idx], *comms_->streams_[idx])); + op_type, comm->comms_[idx], comm->streams_[idx])); + PADDLE_ENFORCE(cudaEventRecord(comm->events_[idx], comm->streams_[idx])); // wait finish PADDLE_ENFORCE( @@ -71,8 +83,9 @@ class NCCLAllReduceKernel : public framework::OpKernel { comm->wg_.Done(); - wg.Wait(); + comm->wg_.Wait(); } }; -} -} + +} // namespace operators +} // namespace paddle