diff --git a/paddle/framework/op_registry.h b/paddle/framework/op_registry.h index 2f461e7b2a1b971264009a086cfcb27e18fd1fed..19a9fc3802a2f2348ad7d50a267615ed70bbc4fe 100644 --- a/paddle/framework/op_registry.h +++ b/paddle/framework/op_registry.h @@ -228,6 +228,10 @@ class OpKernelRegistrar : public Registrar { USE_OP_ITSELF(op_type); \ USE_OP_DEVICE_KERNEL(op_type, CPU); +#define USE_GPU_ONLY_OP(op_type) \ + USE_OP_ITSELF(op_type); \ + USE_OP_DEVICE_KERNEL(op_type, GPU) + #define USE_OP(op_type) \ USE_OP_ITSELF(op_type); \ USE_OP_KERNEL(op_type) diff --git a/paddle/framework/operator.h b/paddle/framework/operator.h index 3a9c7a732885b489c23c7a7961066a6e87849203..1294e06fb1c5a569d7b4813449e9731d95472755 100644 --- a/paddle/framework/operator.h +++ b/paddle/framework/operator.h @@ -122,7 +122,7 @@ class OperatorBase { protected: std::string type_; // NOTE: in case of OpGrad, inputs_ contains: - // I (Inputs)opear + // I (Inputs) // O (Outputs) // OG (Output Gradients) VariableNameMap inputs_; @@ -287,6 +287,16 @@ class ExecutionContext { return device_context_; } + //! Get actual name vector for this input. + const std::vector& Inputs(const std::string& name) const { + return op_.Inputs(name); + } + + //! Get actual name vector for this output. + const std::vector& Outputs(const std::string& name) const { + return op_.Outputs(name); + } + #ifdef PADDLE_WITH_CUDA const platform::CUDADeviceContext& cuda_device_context() const { PADDLE_ENFORCE(platform::is_gpu_place(device_context_.GetPlace())); diff --git a/paddle/operators/CMakeLists.txt b/paddle/operators/CMakeLists.txt index c72261710173a0f3af199646d6800bf9d7c27b67..60dc55a32f5f05875e4f3ce77431556e14adc74a 100644 --- a/paddle/operators/CMakeLists.txt +++ b/paddle/operators/CMakeLists.txt @@ -90,6 +90,13 @@ function(op_library TARGET) file(APPEND ${pybind_file} "USE_OP(sigmoid);\n") endif() + # nccl_op contains several operators + if ("${TARGET}" STREQUAL "nccl_op") + set(pybind_flag 1) + # It's enough to just adding one operator to pybind + file(APPEND ${pybind_file} "USE_GPU_ONLY_OP(ncclAllReduce);\n") + endif() + # reduce_op contains several operators if ("${TARGET}" STREQUAL "reduce_op") set(pybind_flag 1) @@ -121,6 +128,7 @@ function(op_library TARGET) endfunction() add_subdirectory(math) +add_subdirectory(nccl) set(DEPS_OPS recurrent_op @@ -130,6 +138,7 @@ set(DEPS_OPS sum_op pool_op pool_with_index_op + nccl_op sequence_conv_op lstm_op) @@ -142,6 +151,9 @@ op_library(softmax_with_cross_entropy_op DEPS cross_entropy softmax) op_library(sum_op DEPS net_op selected_rows_functor) op_library(pool_op DEPS pooling) op_library(pool_with_index_op DEPS pooling) +if(WITH_GPU) +op_library(nccl_op DEPS nccl_common) +endif() op_library(sequence_conv_op DEPS context_project) op_library(lstm_op DEPS sequence2batch lstm_compute) @@ -157,4 +169,8 @@ cc_test(net_op_test SRCS net_op_test.cc DEPS net_op) cc_test(scatter_test SRCS scatter_test.cc DEPS tensor) cc_test(strided_memcpy_test SRCS strided_memcpy_test.cc DEPS tensor paddle_memory) cc_test(dynamic_recurrent_op_test SRCS dynamic_recurrent_op_test.cc DEPS dynamic_recurrent_op recurrent_op tensor_array) + +if(WITH_GPU) + nv_test(nccl_op_test SRCS nccl_op_test.cu DEPS nccl_op gpu_info device_context) +endif() cc_test(save_load_op_test SRCS save_load_op_test.cc DEPS save_op load_op) diff --git a/paddle/operators/nccl/CMakeLists.txt b/paddle/operators/nccl/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..ce0ddd89bfb0d73e237a6f9a777376624d8ef2d4 --- /dev/null +++ b/paddle/operators/nccl/CMakeLists.txt @@ -0,0 +1,3 @@ +if(WITH_GPU) + nv_library(nccl_common SRCS nccl_gpu_common.cc DEPS device_context operator ) +endif() diff --git a/paddle/operators/nccl/nccl_gpu_common.cc b/paddle/operators/nccl/nccl_gpu_common.cc new file mode 100644 index 0000000000000000000000000000000000000000..6be735e4c731f79684e0bdac3d69a30b328fed84 --- /dev/null +++ b/paddle/operators/nccl/nccl_gpu_common.cc @@ -0,0 +1,17 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include "paddle/operators/nccl/nccl_gpu_common.h" +#include "paddle/platform/gpu_info.h" + +namespace paddle { +namespace platform {} // namespace platform +} // namespace paddle diff --git a/paddle/operators/nccl/nccl_gpu_common.h b/paddle/operators/nccl/nccl_gpu_common.h new file mode 100644 index 0000000000000000000000000000000000000000..5858cd4839d367bb888b2b98cde2225751391162 --- /dev/null +++ b/paddle/operators/nccl/nccl_gpu_common.h @@ -0,0 +1,63 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +#include "paddle/platform/device_context.h" +#include "paddle/platform/dynload/nccl.h" +#include "paddle/platform/enforce.h" +#include "paddle/platform/macros.h" + +namespace paddle { +namespace platform { + +constexpr int kInvalidGPUId = -1; + +struct Communicator { + std::vector comms_; + std::unordered_map comm_id_map_; + + Communicator() {} + + int GetCommId(int device_id) const { return comm_id_map_.at(device_id); } + + void InitAll(const std::vector& gpus) { + comms_.resize(gpus.size()); + for (size_t i = 0; i < gpus.size(); ++i) { + comm_id_map_[gpus[i]] = i; + } + PADDLE_ENFORCE( + dynload::ncclCommInitAll(comms_.data(), gpus.size(), gpus.data())); + } + + ~Communicator() { + for (size_t i = 0; i < comms_.size(); ++i) { + // FIXME(dzh) : PADDLE_ENFORCE return void + dynload::ncclCommDestroy(comms_[i]); + } + } + + DISABLE_COPY_AND_ASSIGN(Communicator); +}; + +} // namespace platform +} // namespace paddle diff --git a/paddle/operators/nccl_op.cc b/paddle/operators/nccl_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..d39cb2fcf9cc205edf86f8ab1d5e04b5672e00f6 --- /dev/null +++ b/paddle/operators/nccl_op.cc @@ -0,0 +1,206 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include "paddle/framework/op_registry.h" +#include "paddle/operators/nccl/nccl_gpu_common.h" + +namespace paddle { +namespace operators { + +// NCCLinitOp +class NCCLInitOp : public framework::OperatorBase { + public: + NCCLInitOp(const std::string &type, const framework::VariableNameMap &inputs, + const framework::VariableNameMap &outputs, + const framework::AttributeMap &attrs) + : OperatorBase(type, inputs, outputs, attrs) {} + + void Run(const framework::Scope &scope, + const platform::DeviceContext &dev_ctx) const override { + const auto &name = Output("Communicator"); + PADDLE_ENFORCE_NOT_NULL(scope.FindVar(name), + "Can not find variable '%s' in the scope.", name); + std::vector gpus = Attr>("gpus"); + PADDLE_ENFORCE(!gpus.empty(), "Attr(gpus) should not be empty."); + + if (scope.FindVar(name) == nullptr) { + PADDLE_THROW("Output(Communicator) is needed for ncclInit operator."); + } + + platform::Communicator *comm = + scope.FindVar(name)->GetMutable(); + comm->InitAll(gpus); + } +}; + +class NCCLInitOpMaker : public framework::OpProtoAndCheckerMaker { + public: + NCCLInitOpMaker(framework::OpProto *proto, + framework::OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddOutput("Communicator", + "Create Communicator for communicating between gpus"); + AddAttr>("gpus", "gpu id lists"); + AddAttr("data_type", "output data type") + .SetDefault(framework::DataType::FP32); + AddComment(R"DOC( + create communicator. + )DOC"); + } +}; + +// AllReduceOp +class NCCLAllReduceOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(framework::InferShapeContext *ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), + " Input(X) of AllReduce op input should not be NULL"); + PADDLE_ENFORCE( + ctx->HasInput("Communicator"), + " Input(Communicator) of AllReduce op input should not be NULL"); + PADDLE_ENFORCE(ctx->HasOutput("Out"), + " Input(X) of AllReduce op input should not be NULL"); + + auto x_dims = ctx->GetInputsDim("X"); + + std::string reduction = ctx->Attrs().Get("reduction"); + PADDLE_ENFORCE((reduction == "ncclSum" || reduction == "ncclProd" || + reduction == "ncclMin" || reduction == "ncclMax"), + "invalid reduction."); + + ctx->SetOutputsDim("Out", x_dims); + ctx->ShareLoD("X", /*->*/ "Out"); + } +}; + +// ReduceOp +class NCCLReduceOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(framework::InferShapeContext *ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), + " Input(X) of Reduce op input should not be NULL"); + PADDLE_ENFORCE( + ctx->HasInput("Communicator"), + " Input(Communicator) of Reduce op input should not be NULL"); + PADDLE_ENFORCE(ctx->HasOutput("Out"), + " Input(X) of Reduce op input should not be NULL"); + + std::string reduction = ctx->Attrs().Get("reduction"); + PADDLE_ENFORCE((reduction == "ncclSum" || reduction == "ncclProd" || + reduction == "ncclMin" || reduction == "ncclMax"), + "invalid reduction."); + + auto x_dims = ctx->GetInputsDim("X"); + ctx->SetOutputsDim("Out", x_dims); + ctx->ShareLoD("X", /*->*/ "Out"); + } +}; + +// BcastOp +class NCCLBcastOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(framework::InferShapeContext *ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), + " Input(X) of Bcast op input should not be NULL"); + PADDLE_ENFORCE(ctx->HasInput("Communicator"), + " Input(Communicator) of Bcast op input should not be NULL"); + PADDLE_ENFORCE(ctx->HasOutput("Out"), + " Output(Out) of Bcast op output should not be NULL"); + + int root = ctx->Attrs().Get("root"); + PADDLE_ENFORCE(root != platform::kInvalidGPUId, "Bcast root must be set."); + + auto x_dims = ctx->GetInputsDim("X"); + ctx->SetOutputsDim("Out", x_dims); + ctx->ShareLoD("X", /*->*/ "Out"); + } +}; + +// AllreduceOp +class NCCLAllReduceOpMaker : public framework::OpProtoAndCheckerMaker { + public: + NCCLAllReduceOpMaker(framework::OpProto *proto, + framework::OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("X", "The input of AllReduce op"); + AddInput("Communicator", "Communicator for communicating between gpus"); + AddOutput("Out", "The output of AllReduce op"); + AddAttr("reduction", + "{'ncclMin', 'ncclMax', 'ncclProd', 'ncclSum'}.") + .SetDefault("ncclSum"); + AddComment(R"DOC( + AllReduce the input tensors. + )DOC"); + } +}; + +// ReduceOp +class NCCLReduceOpMaker : public framework::OpProtoAndCheckerMaker { + public: + NCCLReduceOpMaker(framework::OpProto *proto, + framework::OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("X", "The input of Reduce op"); + AddInput("Communicator", "Communicator for communicating between gpus"); + AddOutput("Out", "The output of Reduce op"); + AddAttr("reduction", + "{'ncclMin', 'ncclMax', 'ncclProd', 'ncclSum'}.") + .SetDefault("ncclSum"); + AddAttr("root", + "root gpu of the parameter. if not " + "set(platform::kInvalidGPUId). hashed by name.") + .SetDefault(platform::kInvalidGPUId); + AddComment(R"DOC( + Reduce the tensors)DOC"); + } +}; + +// BcastOp +class NCCLBcastOpMaker : public framework::OpProtoAndCheckerMaker { + public: + NCCLBcastOpMaker(framework::OpProto *proto, + framework::OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("X", "The input of BcastSend op"); + AddInput("Communicator", "Communicator for communicating between gpus"); + AddOutput("Out", "The output of Bcast"); + AddAttr("root", + "root gpu of the parameter. if not " + "set(platform::kInvalidGPUId). hashed by name.") + .SetDefault(platform::kInvalidGPUId); + AddComment(R"DOC( + Bcast the tensors. + )DOC"); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OPERATOR(ncclInit, ops::NCCLInitOp, + paddle::framework::EmptyGradOpMaker, ops::NCCLInitOpMaker); + +REGISTER_OP_WITHOUT_GRADIENT(ncclAllReduce, ops::NCCLAllReduceOp, + ops::NCCLAllReduceOpMaker); +REGISTER_OP_WITHOUT_GRADIENT(ncclBcast, ops::NCCLBcastOp, + ops::NCCLBcastOpMaker); +REGISTER_OP_WITHOUT_GRADIENT(ncclReduce, ops::NCCLReduceOp, + ops::NCCLReduceOpMaker); diff --git a/paddle/operators/nccl_op.cu b/paddle/operators/nccl_op.cu new file mode 100644 index 0000000000000000000000000000000000000000..86dee8ee8e1c1a1041d6bc9fa515d669a9c4e466 --- /dev/null +++ b/paddle/operators/nccl_op.cu @@ -0,0 +1,211 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at +http://www.apache.org/licenseshashernless required by applicable law or agreed +to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include + +#include "paddle/framework/lod_tensor.h" +#include "paddle/framework/op_registry.h" +#include "paddle/operators/nccl/nccl_gpu_common.h" + +namespace paddle { +namespace operators { + +using framework::Tensor; +using platform::Communicator; +using framework::LoDTensor; + +template +class NCCLTypeWrapper; + +template <> +class NCCLTypeWrapper { + public: + static const ncclDataType_t type = ncclFloat; +}; + +template <> +class NCCLTypeWrapper { + public: + static const ncclDataType_t type = ncclDouble; +}; + +template +class NCCLAllReduceKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()), + "This kernel only runs on GPU device."); + + auto ins = ctx.MultiInput("X"); + auto outs = ctx.MultiOutput("Out"); + + std::string reduction = ctx.Attr("reduction"); + ncclRedOp_t reduction_op_ = ncclSum; + + if (reduction == "ncclMin") { + reduction_op_ = ncclMin; + } else if (reduction == "ncclMax") { + reduction_op_ = ncclMax; + } else if (reduction == "ncclSum") { + reduction_op_ = ncclSum; + } else if (reduction == "ncclProd") { + reduction_op_ = ncclProd; + } else { + PADDLE_THROW("Invalid reduction. default ncclSum."); + } + + auto* comm = ctx.Input("Communicator"); + + auto stream = reinterpret_cast( + ctx.device_context()) + .stream(); + + // device id + int gpu_id = boost::get(ctx.GetPlace()).GetDeviceId(); + int idx = comm->GetCommId(gpu_id); + + for (size_t i = 0; i < ins.size(); ++i) { + VLOG(1) << "gpu : " + << " invoke allreduce. send " << ins[i]->numel() << " recv " + << outs[i]->numel(); + + PADDLE_ENFORCE(platform::dynload::ncclAllReduce( + ins[i]->data(), outs[i]->mutable_data(ctx.GetPlace()), + outs[i]->numel(), NCCLTypeWrapper::type, reduction_op_, + comm->comms_[idx], stream)); + PADDLE_ENFORCE(cudaStreamSynchronize(stream)); + + VLOG(1) << "gpu : " + << " finished allreduce. send " << ins[i]->numel() << " recv " + << outs[i]->numel(); + } + } +}; + +template +class NCCLReduceKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()), + "This kernel only runs on GPU device."); + + auto ins = ctx.MultiInput("X"); // x0, x1, x2 + auto outs = ctx.MultiOutput("Out"); + + std::string reduction = ctx.Attr("reduction"); + ncclRedOp_t reduction_op_ = ncclSum; + + if (reduction == "ncclMin") { + reduction_op_ = ncclMin; + } else if (reduction == "ncclMax") { + reduction_op_ = ncclMax; + } else if (reduction == "ncclSum") { + reduction_op_ = ncclSum; + } else if (reduction == "ncclProd") { + reduction_op_ = ncclProd; + } else { + PADDLE_THROW("Invalid reduction. default ncclSum."); + } + + int root = ctx.Attr("root"); + auto* comm = ctx.Input("Communicator"); + + auto stream = reinterpret_cast( + ctx.device_context()) + .stream(); + // device id + int gpu_id = boost::get(ctx.GetPlace()).GetDeviceId(); + int idx = comm->GetCommId(gpu_id); + + auto ins_names = ctx.Inputs("X"); + std::hash hasher; + for (size_t i = 0; i < ins.size(); ++i) { + if (root == platform::kInvalidGPUId) { + root = hasher(ins_names[i]) % comm->comms_.size(); + } + T* recvbuffer = nullptr; + if (root == gpu_id) { + recvbuffer = outs[i]->mutable_data(ctx.GetPlace()); + } + + VLOG(1) << "gpu : " << gpu_id << " invoke reduce. send " + << ins[i]->numel() << " recv " << outs[i]->numel(); + + PADDLE_ENFORCE(platform::dynload::ncclReduce( + ins[i]->data(), recvbuffer, ins[i]->numel(), + NCCLTypeWrapper::type, reduction_op_, root, comm->comms_[idx], + stream)); + PADDLE_ENFORCE(cudaStreamSynchronize(stream)); + + VLOG(1) << "gpu : " << gpu_id << " finished reduce. send " + << ins[i]->numel() << " recv " << outs[i]->numel(); + } + } +}; + +template +class NCCLBcastKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()), + "This kernel only runs on GPU device."); + + int root = ctx.Attr("root"); + + auto* comm = ctx.Input("Communicator"); + + auto stream = reinterpret_cast( + ctx.device_context()) + .stream(); + // device id + int gpu_id = boost::get(ctx.GetPlace()).GetDeviceId(); + int idx = comm->GetCommId(gpu_id); + + if (idx == root) { + auto ins = ctx.MultiInput("X"); + for (size_t i = 0; i < ins.size(); ++i) { + VLOG(1) << "gpu : " << gpu_id << " invoke Bcast. send " + << ins[i]->numel(); + + VLOG(1) << " before ncclBcast"; + PADDLE_ENFORCE(platform::dynload::ncclBcast( + (void*)ins[i]->data(), ins[i]->numel(), NCCLTypeWrapper::type, + root, comm->comms_[idx], stream)); + VLOG(1) << " after ncclBcast"; + PADDLE_ENFORCE(cudaStreamSynchronize(stream)); + + VLOG(1) << "gpu : " << gpu_id << " finished Bcast."; + } + } else { + auto outs = ctx.MultiOutput("Out"); + for (size_t i = 0; i < outs.size(); ++i) { + VLOG(1) << "gpu : " << gpu_id << " invoke Bcast. recv buffer " + << framework::product(outs[i]->dims()); + + PADDLE_ENFORCE(platform::dynload::ncclBcast( + outs[i]->mutable_data(ctx.GetPlace()), outs[i]->numel(), + NCCLTypeWrapper::type, root, comm->comms_[idx], stream)); + PADDLE_ENFORCE(cudaStreamSynchronize(stream)); + + VLOG(1) << "gpu : " << gpu_id << " finished Bcast. recv " + << outs[i]->numel(); + } + } + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_GPU_KERNEL(ncclAllReduce, ops::NCCLAllReduceKernel); +REGISTER_OP_GPU_KERNEL(ncclBcast, ops::NCCLBcastKernel); +REGISTER_OP_GPU_KERNEL(ncclReduce, ops::NCCLReduceKernel); diff --git a/paddle/operators/nccl_op_test.cu b/paddle/operators/nccl_op_test.cu new file mode 100644 index 0000000000000000000000000000000000000000..80c50a28a9e5d560fc693c518b9e62091ddc5724 --- /dev/null +++ b/paddle/operators/nccl_op_test.cu @@ -0,0 +1,307 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "paddle/framework/block_desc.h" +#include "paddle/framework/op_desc.h" +#include "paddle/framework/op_registry.h" +#include "paddle/framework/program_desc.h" +#include "paddle/framework/var_desc.h" +#include "paddle/operators/math/math_function.h" +#include "paddle/operators/nccl/nccl_gpu_common.h" +#include "paddle/platform/device_context.h" +#include "paddle/platform/enforce.h" +#include "paddle/platform/gpu_info.h" +#include "paddle/platform/place.h" + +USE_NO_KERNEL_OP(ncclInit); +USE_GPU_ONLY_OP(ncclAllReduce); +USE_GPU_ONLY_OP(ncclReduce); +USE_GPU_ONLY_OP(ncclBcast); + +namespace f = paddle::framework; +namespace p = paddle::platform; + +static std::vector gpu_list; + +// test data amount +const f::DDim kDims = {100, 100}; + +// nccl op common tester, init communicator. +class NCCLTester : public ::testing::Test { + public: + virtual void SetUp() override { + cpu_ctx = new p::CPUDeviceContext(p::CPUPlace()); + for (size_t i = 0; i < gpu_list.size(); ++i) { + p::GPUPlace place(i); + dev_ctxs.emplace_back(new p::CUDADeviceContext(place)); + } + + NCCLInitOp(); + } + + virtual void TearDown() override { + for (auto &device_context : dev_ctxs) { + delete device_context; + } + } + + void NCCLInitOp() { + std::unique_ptr op1(new f::OpDescBind); + + op1->SetType("ncclInit"); + op1->SetOutput("Communicator", {"comm"}); + op1->SetAttr("gpus", {gpu_list}); + + auto *var = g_scope.Var("comm"); + var->GetMutable(); + + auto op = f::OpRegistry::CreateOp(*op1); + VLOG(1) << "invoke NCCLInitOp."; + op->Run(g_scope, *cpu_ctx); + VLOG(1) << "NCCLInitOp finished."; + } + + template + void PerThreadProgram(int gpu_id, const f::OpDescBind &op_desc, + f::Scope *scope) { + std::unique_lock lk(mu); + const f::OpDescBind *op1 = &op_desc; + + p::GPUPlace place(gpu_id); + auto &ctx = dev_ctxs.at(gpu_id); + + auto *send_tensor = scope->Var("st")->GetMutable(); + auto *recv_tensor = scope->Var("rt")->GetMutable(); + + if (!send_tensor->numel()) { + send_tensor->Resize(kDims); + send_tensor->mutable_data(kDims, place); + + std::vector send_vector(f::product(kDims), gpu_id); + send_tensor->CopyFromVector(send_vector, *ctx); + ctx->Wait(); + VLOG(1) << "Send Tensor filled with elements " << send_tensor->numel(); + } + + lk.unlock(); + + PADDLE_ENFORCE(send_tensor->numel() == f::product(kDims), + "Tensor numel not match!"); + + auto op = f::OpRegistry::CreateOp(*op1); + + VLOG(1) << "Device : " << gpu_id << " invoke " << op_desc.Type(); + VLOG(1) << " send_tensor : " << send_tensor->numel() + << " recv_tensor : " << recv_tensor->numel(); + op->Run(*scope, *ctx); + VLOG(1) << "Device : " << gpu_id << " finished " << op_desc.Type(); + } + + public: + std::vector dev_ctxs; + p::DeviceContext *cpu_ctx; + f::Scope g_scope; + std::mutex mu; +}; + +// ncclInitOp with desc +TEST(NCCL, ncclInitOp) { + std::unique_ptr op_desc(new f::OpDescBind); + + op_desc->SetType("ncclInit"); + op_desc->SetOutput("Communicator", {"x1"}); + op_desc->SetAttr("gpus", {gpu_list}); + + f::Scope g_scope; + std::unique_ptr ctx(new p::CPUDeviceContext(p::CPUPlace())); + + auto *var = g_scope.Var("x1"); + var->GetMutable(); + + auto op = f::OpRegistry::CreateOp(*op_desc); + VLOG(1) << "invoke NCCLInitOp."; + op->Run(g_scope, *ctx.get()); + VLOG(1) << "NCCLInitOp finished."; +} + +// ncclAllReduceOp with desc +TEST_F(NCCLTester, ncclAllReduceOp) { + std::unique_ptr op2(new f::OpDescBind); + op2->SetType("ncclAllReduce"); + op2->SetInput("X", {"st"}); + op2->SetInput("Communicator", {"comm"}); + op2->SetOutput("Out", {"rt"}); + + std::vector dev_scopes; + + std::vector ths; + + for (size_t i = 0; i < gpu_list.size(); ++i) { + dev_scopes.emplace_back(&g_scope.NewScope()); + std::thread th(&NCCLTester::PerThreadProgram, this, gpu_list[i], + *op2.get(), dev_scopes[i]); + ths.emplace_back(std::move(th)); + } + + for (size_t i = 0; i < gpu_list.size(); ++i) { + ths[i].join(); + } + + // check results + float result = std::accumulate(gpu_list.begin(), gpu_list.end(), 0); + + for (size_t i = 0; i < dev_scopes.size(); ++i) { + p::CPUPlace cpu_place; + p::GPUPlace gpu_place(gpu_list[i]); + + auto &recv_tensor = dev_scopes[i]->FindVar("rt")->Get(); + auto *rt = recv_tensor.data(); + auto *result_tensor = dev_scopes[i]->Var("ct")->GetMutable(); + result_tensor->Resize(kDims); + auto *ct = result_tensor->mutable_data(cpu_place); + + paddle::memory::Copy( + cpu_place, ct, p::GPUPlace(gpu_list[i]), rt, + recv_tensor.numel() * sizeof(float), + static_cast(dev_ctxs[i])->stream()); + + for (size_t j = 0; j < f::product(kDims); ++j) { + ASSERT_NEAR(ct[j], result, 1e-5); + } + } +} + +// ncclReduceOp with desc +TEST_F(NCCLTester, ncclReduceOp) { + std::unique_ptr op2(new f::OpDescBind); + const int kRoot = 0; + op2->SetType("ncclReduce"); + op2->SetInput("X", {"st"}); + op2->SetInput("Communicator", {"comm"}); + op2->SetOutput("Out", {"rt"}); + op2->SetAttr("root", kRoot); + + std::vector dev_scopes; + + std::vector ths; + + for (size_t i = 0; i < gpu_list.size(); ++i) { + dev_scopes.emplace_back(&g_scope.NewScope()); + std::thread th(&NCCLTester::PerThreadProgram, this, gpu_list[i], + *op2.get(), dev_scopes[i]); + ths.emplace_back(std::move(th)); + } + + for (size_t i = 0; i < gpu_list.size(); ++i) { + ths[i].join(); + } + + // check results on + float result = std::accumulate(gpu_list.begin(), gpu_list.end(), 0); + + p::CPUPlace cpu_place; + p::GPUPlace gpu_place(gpu_list[kRoot]); + + auto &recv_tensor = dev_scopes[kRoot]->FindVar("rt")->Get(); + auto *rt = recv_tensor.data(); + auto *result_tensor = + dev_scopes[kRoot]->Var("ct")->GetMutable(); + result_tensor->Resize(kDims); + auto *ct = result_tensor->mutable_data(cpu_place); + + paddle::memory::Copy( + cpu_place, ct, p::GPUPlace(gpu_list[kRoot]), rt, + recv_tensor.numel() * sizeof(float), + static_cast(dev_ctxs[kRoot])->stream()); + + for (int j = 0; j < f::product(kDims); ++j) { + ASSERT_NEAR(ct[j], result, 1e-5); + } +} + +// ncclBcastOp with desc +TEST_F(NCCLTester, ncclBcastOp) { + std::unique_ptr op2(new f::OpDescBind); + const int kRoot = 5; + op2->SetType("ncclBcast"); + op2->SetInput("X", {"st"}); + op2->SetInput("Communicator", {"comm"}); + op2->SetOutput("Out", {"rt"}); + op2->SetAttr("root", kRoot); + + std::vector dev_scopes; + + std::vector ths; + + for (size_t i = 0; i < gpu_list.size(); ++i) { + dev_scopes.emplace_back(&g_scope.NewScope()); + std::thread th(&NCCLTester::PerThreadProgram, this, gpu_list[i], + *op2.get(), dev_scopes[i]); + ths.emplace_back(std::move(th)); + } + + for (size_t i = 0; i < gpu_list.size(); ++i) { + ths[i].join(); + } + + const int idx = 1; + // check results on + float result = kRoot; + + p::CPUPlace cpu_place; + p::GPUPlace gpu_place(gpu_list[idx]); + + auto &recv_tensor = dev_scopes[idx]->FindVar("rt")->Get(); + auto *rt = recv_tensor.data(); + auto *result_tensor = dev_scopes[idx]->Var("ct")->GetMutable(); + result_tensor->Resize(kDims); + auto *ct = result_tensor->mutable_data(cpu_place); + + paddle::memory::Copy( + cpu_place, ct, p::GPUPlace(gpu_list[idx]), rt, + recv_tensor.numel() * sizeof(float), + static_cast(dev_ctxs[idx])->stream()); + + for (size_t j = 0; j < f::product(kDims); ++j) { + ASSERT_NEAR(ct[j], result, 1e-5); + } +} + +int main(int argc, char **argv) { + const int dev_count = p::GetCUDADeviceCount(); + if (dev_count <= 1) { + LOG(WARNING) + << "Cannot test multi-gpu nccl, because the CUDA device count is " + << dev_count; + return 0; + } + + for (int i = 0; i < dev_count; ++i) { + gpu_list.emplace_back(i); + } + testing::InitGoogleTest(&argc, argv); + + // device context should be release before scope. + // otherwise driver will down. + return RUN_ALL_TESTS(); +} diff --git a/paddle/platform/nccl_test.cu b/paddle/platform/nccl_test.cu index ab8b96f7263aed83407866fedf9e529ce0affe3f..c99dae68bef67c58d3efea42fef45e84bb3d9255 100644 --- a/paddle/platform/nccl_test.cu +++ b/paddle/platform/nccl_test.cu @@ -31,9 +31,7 @@ namespace platform { TEST(NCCL, init) { std::vector comms; comms.resize(dev_count); - - auto status = dynload::ncclCommInitAll(comms.data(), dev_count, nullptr); - PADDLE_ENFORCE(status); + PADDLE_ENFORCE(dynload::ncclCommInitAll(comms.data(), dev_count, nullptr)); for (int i = 0; i < dev_count; ++i) { dynload::ncclCommDestroy(comms[i]); } @@ -64,8 +62,7 @@ TEST(NCCL, all_reduce) { std::vector comms; comms.resize(dev_count); VLOG(1) << "Initializing ncclComm"; - auto status = dynload::ncclCommInitAll(comms.data(), dev_count, nullptr); - PADDLE_ENFORCE(status); + PADDLE_ENFORCE(dynload::ncclCommInitAll(comms.data(), dev_count, nullptr)); VLOG(1) << "ncclComm initialized"; VLOG(1) << "Creating thread data"; std::vector>> data; diff --git a/paddle/pybind/pybind.cc b/paddle/pybind/pybind.cc index e9c1d40de185768048159c10ee278ffde12335f7..bf6e12264269c7603484e0acf502adab25645856 100644 --- a/paddle/pybind/pybind.cc +++ b/paddle/pybind/pybind.cc @@ -33,6 +33,11 @@ limitations under the License. */ #include "paddle/pybind/tensor_py.h" #include "paddle/string/to_string.h" +#ifdef PADDLE_WITH_CUDA +#include "paddle/operators/nccl/nccl_gpu_common.h" +#include "paddle/platform/gpu_info.h" +#endif + namespace paddle { namespace pybind { static size_t UniqueIntegerGenerator() { @@ -204,6 +209,13 @@ All parameter, weight, gradient are variables in Paddle. return self.GetMutable(); }, py::return_value_policy::reference) +#ifdef PADDLE_WITH_CUDA + .def("get_communicator", + [](Variable &self) -> platform::Communicator * { + return self.GetMutable(); + }, + py::return_value_policy::reference) +#endif .def("get_net", [](Variable &self) -> operators::NetOp * { return self.GetMutable(); @@ -269,8 +281,11 @@ All parameter, weight, gradient are variables in Paddle. return new paddle::platform::CUDADeviceContext(place); #endif }); - // clang-format on +// clang-format on +#ifdef PADDLE_WITH_CUDA + py::class_(m, "Communicator").def(py::init<>()); +#endif py::class_(m, "GPUPlace") .def(py::init()) .def("__str__", string::to_string); @@ -479,6 +494,9 @@ All parameter, weight, gradient are variables in Paddle. BindOpDesc(m); m.def("op_support_gpu", OpSupportGPU); +#ifdef PADDLE_WITH_CUDA + m.def("get_cuda_device_count", platform::GetCUDADeviceCount); +#endif return m.ptr(); } diff --git a/python/paddle/v2/framework/tests/test_nccl_init_op.py b/python/paddle/v2/framework/tests/test_nccl_init_op.py new file mode 100644 index 0000000000000000000000000000000000000000..054909fdf5517a68c6a07971c65a1d5bdc20d4fa --- /dev/null +++ b/python/paddle/v2/framework/tests/test_nccl_init_op.py @@ -0,0 +1,39 @@ +import unittest, os +import numpy as np +import paddle.v2 as paddle +from paddle.v2.framework.op import Operator +import paddle.v2.framework.core as core +from op_test import OpTest, create_op, set_input + +if not core.is_compile_gpu(): + exit(0) + +gpu_count = core.get_cuda_device_count() + +if gpu_count <= 1: + exit(0) + +g_scope = core.Scope() +g_ctx = core.DeviceContext.create(core.CPUPlace()) + + +class TestNCCLInit(unittest.TestCase): + def test_init(self): + self.op_type = "ncclInit" + self.gpus = range(gpu_count) + + self.inputs = {} + self.attrs = {"gpus": self.gpus} + g_scope.var("Communicator").get_communicator() + self.outputs = {"Communicator": g_scope.find_var("Communicator")} + nccl_init = create_op( + g_scope, + op_type=self.op_type, + inputs=self.inputs, + outputs=self.outputs, + attrs=self.attrs) + nccl_init.run(g_scope, g_ctx) + + +if __name__ == "__main__": + unittest.main()