From e04f8d4a6de786ea010d6cc1d6273c6c50cbf237 Mon Sep 17 00:00:00 2001 From: ronnywang Date: Mon, 15 May 2023 19:23:48 +0800 Subject: [PATCH] [CustomDevice] add inference MP support, PART2 (#53701) --- .../operators/collective/c_comm_init_op.cc | 80 +++--- .../operators/collective/c_gen_xccl_id_op.cc | 132 ++++++++++ .../custom_device_common_op_registry.cc | 227 ++++++++++++++++++ paddle/fluid/pybind/CMakeLists.txt | 11 + paddle/fluid/pybind/imperative.cc | 21 +- paddle/fluid/pybind/pybind.cc | 2 + 6 files changed, 442 insertions(+), 31 deletions(-) create mode 100644 paddle/fluid/operators/collective/c_gen_xccl_id_op.cc diff --git a/paddle/fluid/operators/collective/c_comm_init_op.cc b/paddle/fluid/operators/collective/c_comm_init_op.cc index b32857a27b2..5a22ad716e1 100644 --- a/paddle/fluid/operators/collective/c_comm_init_op.cc +++ b/paddle/fluid/operators/collective/c_comm_init_op.cc @@ -25,7 +25,7 @@ limitations under the License. */ #include "paddle/fluid/framework/op_registry.h" #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) || \ - defined(PADDLE_WITH_XPU_BKCL) + defined(PADDLE_WITH_XPU_BKCL) || defined(PADDLE_WITH_CUSTOM_DEVICE) #include "paddle/fluid/platform/collective_helper.h" #endif @@ -48,43 +48,67 @@ class CCommInitOp : public framework::OperatorBase { void RunImpl(const framework::Scope& scope, const platform::Place& place) const override { + if (platform::is_custom_place(place)) { +#if defined(PADDLE_WITH_CUSTOM_DEVICE) + auto var = scope.FindVar(Input("X")); + PADDLE_ENFORCE_NOT_NULL( + var, platform::errors::InvalidArgument("Input con not be empty.")); + + phi::ccl::CCLRootId* comm_id = var->GetMutable(); + + int nranks = Attr("nranks"); + int rid = Attr("ring_id"); + + int device_id = place.device; + if (Attr("device_id") >= 0) { + device_id = Attr("device_id"); + } + int rank_id = Attr("rank"); + platform::XCCLCommContext::Instance(place.GetDeviceType()) + .CreateComm(comm_id, nranks, rank_id, device_id, rid); +#else + PADDLE_THROW(platform::errors::PreconditionNotMet( + "PaddlePaddle should compile with custom device.")); +#endif + } else { // TODO(wangxi): Put this in the unified header file #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) - using UniqueId = ncclUniqueId; - using CommContext = platform::NCCLCommContext; + using UniqueId = ncclUniqueId; + using CommContext = platform::NCCLCommContext; #elif defined(PADDLE_WITH_XPU_BKCL) - using UniqueId = BKCLUniqueId; - using CommContext = platform::BKCLCommContext; + using UniqueId = BKCLUniqueId; + using CommContext = platform::BKCLCommContext; #else - PADDLE_THROW(platform::errors::PreconditionNotMet( - "PaddlePaddle should be compiled with GPU or XPU.")); + PADDLE_THROW(platform::errors::PreconditionNotMet( + "PaddlePaddle should be compiled with GPU or XPU.")); #endif - PADDLE_ENFORCE_EQ( - platform::is_gpu_place(place) || platform::is_xpu_place(place), - true, - platform::errors::PreconditionNotMet( - "CCommInitOp can run on gpu or xpu place only.")); + PADDLE_ENFORCE_EQ( + platform::is_gpu_place(place) || platform::is_xpu_place(place), + true, + platform::errors::PreconditionNotMet( + "CCommInitOp can run on gpu or xpu place only.")); #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) || \ defined(PADDLE_WITH_XPU_BKCL) - auto var = scope.FindVar(Input("X")); - PADDLE_ENFORCE_NOT_NULL( - var, platform::errors::InvalidArgument("Input con not be empty.")); - - UniqueId* comm_id = var->GetMutable(); - - int nranks = Attr("nranks"); - int rid = Attr("ring_id"); - - int device_id = place.device; - if (Attr("device_id") >= 0) { - device_id = Attr("device_id"); - } - int rank_id = Attr("rank"); - CommContext::Instance().CreateComm( - comm_id, nranks, rank_id, device_id, rid); + auto var = scope.FindVar(Input("X")); + PADDLE_ENFORCE_NOT_NULL( + var, platform::errors::InvalidArgument("Input con not be empty.")); + + UniqueId* comm_id = var->GetMutable(); + + int nranks = Attr("nranks"); + int rid = Attr("ring_id"); + + int device_id = place.device; + if (Attr("device_id") >= 0) { + device_id = Attr("device_id"); + } + int rank_id = Attr("rank"); + CommContext::Instance().CreateComm( + comm_id, nranks, rank_id, device_id, rid); #endif + } } }; diff --git a/paddle/fluid/operators/collective/c_gen_xccl_id_op.cc b/paddle/fluid/operators/collective/c_gen_xccl_id_op.cc new file mode 100644 index 00000000000..effe7021b0d --- /dev/null +++ b/paddle/fluid/operators/collective/c_gen_xccl_id_op.cc @@ -0,0 +1,132 @@ +/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ +#include + +#include "paddle/fluid/framework/op_proto_maker.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/operator.h" +#include "paddle/fluid/framework/scope.h" +#include "paddle/fluid/framework/var_type_traits.h" +#include "paddle/fluid/platform/device_context.h" +#include "paddle/fluid/platform/enforce.h" +#include "paddle/fluid/platform/gen_comm_id_helper.h" +#include "paddle/fluid/platform/place.h" +#include "paddle/phi/backends/device_manager.h" + +namespace paddle { +namespace operators { + +#ifdef PADDLE_WITH_CUSTOM_DEVICE +static void CopyXCCLIDToVar(const std::vector& xccl_ids, + std::function func, + const framework::Scope& scope) { + for (size_t i = 0; i < xccl_ids.size(); ++i) { + std::string var_name = func(i); + auto var = scope.FindVar(var_name); + PADDLE_ENFORCE_NOT_NULL( + var, + platform::errors::NotFound("Variable with name %s is not found", + var_name.c_str())); + auto xccl_id = var->GetMutable(); + *xccl_id = xccl_ids[i]; + } +} + +class CGenXCCLIdOp : public framework::OperatorBase { + public: + CGenXCCLIdOp(const std::string& type, + const framework::VariableNameMap& inputs, + const framework::VariableNameMap& outputs, + const framework::AttributeMap& attrs) + : OperatorBase(type, inputs, outputs, attrs) {} + + void RunImpl(const framework::Scope& scope, + const platform::Place& dev_place) const override { + int rank = Attr("rank"); + int ring_id = Attr("ring_id"); + + std::function func = [&](size_t i) -> std::string { + return Output("Out"); + }; + + std::string endpoint = Attr("endpoint"); + int server_fd = platform::SocketServer::GetInstance(endpoint).socket(); + + std::vector xccl_ids; + xccl_ids.resize(1); + + if (rank == 0) { + for (size_t i = 0; i < xccl_ids.size(); ++i) { + phi::DeviceManager::CCLGetUniqueId(dev_place.GetDeviceType(), + &xccl_ids[i]); + } + std::vector endpoint_list = + Attr>("other_endpoints"); + platform::SendBroadCastCommID(endpoint_list, &xccl_ids, ring_id); + } else { + platform::RecvBroadCastCommID(server_fd, endpoint, &xccl_ids, ring_id); + } + + CopyXCCLIDToVar(xccl_ids, func, scope); + } +}; + +#else +class CGenXCCLIdOp : public framework::OperatorBase { + public: + CGenXCCLIdOp(const std::string& type, + const framework::VariableNameMap& inputs, + const framework::VariableNameMap& outputs, + const framework::AttributeMap& attrs) + : OperatorBase(type, inputs, outputs, attrs) {} + + void RunImpl(const framework::Scope& scope, + const platform::Place& dev_place) const override {} +}; + +#endif + +class CGenXCCLIdOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddOutput("Out", "Raw variable contains a XCCL UniqueId instaces."); + AddComment(R"DOC( +CGenXCCLId operator + +For trainer 0: generate a new UniqueId and send it to all the other trainers. +For trainer 1~n: start a gRPC server to get the UniqueId, once got, stop the server. +)DOC"); + AddAttr("endpoint", + "(string), e.g. 127.0.0.1:6175 " + "current listen endpoint"); + AddAttr>( + "other_endpoints", + "['trainer1_ip:port', 'trainer2_ip:port', ...] " + "list of other trainer endpoints") + .SetDefault({}); + AddAttr("rank", + "(int default 0) " + "The rank of the trainer in distributed training.") + .SetDefault(0); + AddAttr("ring_id", "(int default 0) user specified ring id") + .SetDefault(0); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; + +REGISTER_OPERATOR(c_gen_xccl_id, ops::CGenXCCLIdOp, ops::CGenXCCLIdOpMaker); diff --git a/paddle/fluid/operators/custom_device_common_op_registry.cc b/paddle/fluid/operators/custom_device_common_op_registry.cc index a95fc0cd320..f4d9aff0374 100644 --- a/paddle/fluid/operators/custom_device_common_op_registry.cc +++ b/paddle/fluid/operators/custom_device_common_op_registry.cc @@ -19,6 +19,7 @@ limitations under the License. */ #include "paddle/fluid/operators/load_combine_op.h" #include "paddle/fluid/operators/run_program_op.h" #include "paddle/fluid/operators/save_combine_op.h" +#include "paddle/fluid/platform/collective_helper.h" #include "paddle/phi/api/backward/backward_api.h" #include "paddle/phi/api/include/api.h" #include "paddle/phi/backends/device_manager.h" @@ -497,6 +498,131 @@ class CSyncCalcStreamCustomDeviceKernel : public framework::OpKernel { } }; +template +class CAllReduceOpCustomDeviceKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto in = ctx.Input("X"); + auto out = ctx.Output("Out"); + int rid = ctx.Attr("ring_id"); + + auto place = ctx.GetPlace(); + auto dtype = phi::ccl::ToCCLDataType(in->dtype()); + int64_t numel = in->numel(); + const void* sendbuff = in->data(); + out->Resize(in->dims()); + void* recvbuff = ctx.device_context().Alloc(out); + + auto map = paddle::distributed::ProcessGroupMapFromGid::getInstance(); + if (map->has(rid)) { + // Use ProcessGroup + paddle::distributed::ProcessGroup* pg = map->get(rid); + std::vector in_tensor; + std::vector out_tensor; + in_tensor.push_back(*in); + out_tensor.push_back(*out); + + paddle::distributed::AllreduceOptions opts; + switch (red_type) { + case phi::ccl::CCLReduceOp::SUM: + opts.reduce_op = paddle::distributed::ReduceOp::SUM; + break; + + case phi::ccl::CCLReduceOp::MAX: + opts.reduce_op = paddle::distributed::ReduceOp::MAX; + break; + + case phi::ccl::CCLReduceOp::MIN: + opts.reduce_op = paddle::distributed::ReduceOp::MIN; + break; + + case phi::ccl::CCLReduceOp::PRODUCT: + opts.reduce_op = paddle::distributed::ReduceOp::PRODUCT; + break; + + default: + PADDLE_THROW(paddle::platform::errors::InvalidArgument( + "Invalid reduce type: %d", red_type)); + } + + auto task = pg->AllReduce(in_tensor, out_tensor, opts); + task->Wait(); + return; + } + + auto comm = + paddle::platform::XCCLCommContext::Instance(place.GetDeviceType()) + .Get(rid, place); + + std::shared_ptr stream; + if (ctx.Attr("use_calc_stream")) { + auto dev_ctx = paddle::platform::DeviceContextPool::Instance().Get(place); + stream = static_cast(dev_ctx) + ->GetStream(); + } else { + stream = comm->stream(); + } + phi::DeviceManager::CCLAllReduce(place.GetDeviceType(), + const_cast(sendbuff), + recvbuff, + numel, + dtype, + red_type, + comm->comm(), + *stream); + } +}; + +template +class CBroadcastOpCustomDeviceKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto x = ctx.Input("X"); + auto out = ctx.Output("Out"); + const auto& place = ctx.GetPlace(); + ctx.device_context().Alloc(out); + int root = ctx.Attr("root"); + int rid = ctx.Attr("ring_id"); + + auto stream = static_cast(ctx.device_context()) + .GetStream(); + + int numel = x->numel(); + auto dtype = phi::ccl::ToCCLDataType(x->dtype()); + auto comm = platform::XCCLCommContext::Instance(place.GetDeviceType()) + .Get(rid, place); + if (root == comm->rank()) { + phi::DeviceManager::CCLBroadcast(place.GetDeviceType(), + const_cast(x->data()), + numel, + dtype, + root, + comm->comm(), + *stream); + VLOG(3) << "rank " << comm->rank() << " invoke Bcast. sent " + << x->numel(); + if (out != x) { + framework::TensorCopy( + *static_cast(x), + place, + *platform::DeviceContextPool::Instance().Get(place), + static_cast(out)); + } + } else { + phi::DeviceManager::CCLBroadcast(place.GetDeviceType(), + out->data(), + numel, + dtype, + root, + comm->comm(), + *stream); + VLOG(3) << "rank " << comm->rank() << " invoke Bcast. received " + << phi::product(out->dims()); + } + out->set_lod(x->lod()); + } +}; + template void FeedDenseTensorKernel(const Context& dev_ctx, const phi::ExtendedTensor& x, @@ -636,6 +762,107 @@ void RegisterCustomDeviceCommonKernel(const std::string& dev_type) { paddle::operators::CSyncCalcStreamCustomDeviceKernel< paddle::platform::CustomDeviceContext, paddle::platform::float16>) {} + REGISTER_OP_CUSTOM_DEVICE_KERNEL( + c_allreduce_sum, + device_type, + paddle::operators::CAllReduceOpCustomDeviceKernel< + paddle::platform::CustomDeviceContext, + float, + phi::ccl::CCLReduceOp::SUM>, + paddle::operators::CAllReduceOpCustomDeviceKernel< + paddle::platform::CustomDeviceContext, + double, + phi::ccl::CCLReduceOp::SUM>, + paddle::operators::CAllReduceOpCustomDeviceKernel< + paddle::platform::CustomDeviceContext, + paddle::platform::float16, + phi::ccl::CCLReduceOp::SUM>, + paddle::operators::CAllReduceOpCustomDeviceKernel< + paddle::platform::CustomDeviceContext, + int32_t, + phi::ccl::CCLReduceOp::SUM>, + paddle::operators::CAllReduceOpCustomDeviceKernel< + paddle::platform::CustomDeviceContext, + int64_t, + phi::ccl::CCLReduceOp::SUM>) {} + REGISTER_OP_CUSTOM_DEVICE_KERNEL( + c_allreduce_min, + device_type, + paddle::operators::CAllReduceOpCustomDeviceKernel< + paddle::platform::CustomDeviceContext, + float, + phi::ccl::CCLReduceOp::MIN>, + paddle::operators::CAllReduceOpCustomDeviceKernel< + paddle::platform::CustomDeviceContext, + double, + phi::ccl::CCLReduceOp::MIN>, + paddle::operators::CAllReduceOpCustomDeviceKernel< + paddle::platform::CustomDeviceContext, + paddle::platform::float16, + phi::ccl::CCLReduceOp::MIN>, + paddle::operators::CAllReduceOpCustomDeviceKernel< + paddle::platform::CustomDeviceContext, + int32_t, + phi::ccl::CCLReduceOp::MIN>, + paddle::operators::CAllReduceOpCustomDeviceKernel< + paddle::platform::CustomDeviceContext, + int64_t, + phi::ccl::CCLReduceOp::MIN>) {} + REGISTER_OP_CUSTOM_DEVICE_KERNEL( + c_allreduce_max, + device_type, + paddle::operators::CAllReduceOpCustomDeviceKernel< + paddle::platform::CustomDeviceContext, + float, + phi::ccl::CCLReduceOp::MAX>, + paddle::operators::CAllReduceOpCustomDeviceKernel< + paddle::platform::CustomDeviceContext, + double, + phi::ccl::CCLReduceOp::MAX>, + paddle::operators::CAllReduceOpCustomDeviceKernel< + paddle::platform::CustomDeviceContext, + paddle::platform::float16, + phi::ccl::CCLReduceOp::MAX>, + paddle::operators::CAllReduceOpCustomDeviceKernel< + paddle::platform::CustomDeviceContext, + int32_t, + phi::ccl::CCLReduceOp::MAX>, + paddle::operators::CAllReduceOpCustomDeviceKernel< + paddle::platform::CustomDeviceContext, + int64_t, + phi::ccl::CCLReduceOp::MAX>) {} + REGISTER_OP_CUSTOM_DEVICE_KERNEL( + c_allreduce_prod, + device_type, + paddle::operators::CAllReduceOpCustomDeviceKernel< + paddle::platform::CustomDeviceContext, + float, + phi::ccl::CCLReduceOp::PRODUCT>, + paddle::operators::CAllReduceOpCustomDeviceKernel< + paddle::platform::CustomDeviceContext, + double, + phi::ccl::CCLReduceOp::PRODUCT>, + paddle::operators::CAllReduceOpCustomDeviceKernel< + paddle::platform::CustomDeviceContext, + paddle::platform::float16, + phi::ccl::CCLReduceOp::PRODUCT>, + paddle::operators::CAllReduceOpCustomDeviceKernel< + paddle::platform::CustomDeviceContext, + int32_t, + phi::ccl::CCLReduceOp::PRODUCT>, + paddle::operators::CAllReduceOpCustomDeviceKernel< + paddle::platform::CustomDeviceContext, + int64_t, + phi::ccl::CCLReduceOp::PRODUCT>) {} + REGISTER_OP_CUSTOM_DEVICE_KERNEL( + c_broadcast, + device_type, + paddle::operators::CBroadcastOpCustomDeviceKernel, + paddle::operators::CBroadcastOpCustomDeviceKernel, + paddle::operators::CBroadcastOpCustomDeviceKernel, + paddle::operators::CBroadcastOpCustomDeviceKernel, + paddle::operators::CBroadcastOpCustomDeviceKernel< + paddle::platform::float16>) {} #endif } diff --git a/paddle/fluid/pybind/CMakeLists.txt b/paddle/fluid/pybind/CMakeLists.txt index 1faba336682..2375c01e6bd 100755 --- a/paddle/fluid/pybind/CMakeLists.txt +++ b/paddle/fluid/pybind/CMakeLists.txt @@ -91,6 +91,17 @@ if(WITH_XPU_BKCL) set(PYBIND_DEPS ${PYBIND_DEPS} heter_ccl_context) endif() +if(WITH_CUSTOM_DEVICE) + set(PYBIND_DEPS ${PYBIND_DEPS} xccl_context) + if(NOT + (WITH_NCCL + OR WITH_RCCL + OR WITH_XPU_BKCL)) + set(PYBIND_DEPS ${PYBIND_DEPS} reducer) + set(PYBIND_DEPS ${PYBIND_DEPS} heter_ccl_context) + endif() +endif() + if(NOT WIN32) set(PYBIND_DEPS ${PYBIND_DEPS} data_loader) if(WITH_NCCL OR WITH_RCCL) diff --git a/paddle/fluid/pybind/imperative.cc b/paddle/fluid/pybind/imperative.cc index 65eac1e3dc6..5bbd66fd09c 100644 --- a/paddle/fluid/pybind/imperative.cc +++ b/paddle/fluid/pybind/imperative.cc @@ -52,6 +52,7 @@ limitations under the License. */ #include "paddle/fluid/imperative/reducer.h" #include "paddle/fluid/imperative/tracer.h" #include "paddle/fluid/imperative/type_defs.h" +#include "paddle/fluid/imperative/xccl_context.h" #include "paddle/fluid/memory/allocation/mmap_allocator.h" #include "paddle/fluid/operators/utils.h" #include "paddle/fluid/pybind/cuda_streams_py.h" @@ -2476,8 +2477,9 @@ void BindImperative(py::module *m_ptr) { }, py::call_guard()); -#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) || \ - defined(PADDLE_WITH_XPU_BKCL) || defined(PADDLE_WITH_GLOO) +#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) || \ + defined(PADDLE_WITH_XPU_BKCL) || defined(PADDLE_WITH_GLOO) || \ + defined(PADDLE_WITH_CUSTOM_DEVICE) py::class_>(m, "ParallelContext"); @@ -2517,6 +2519,19 @@ void BindImperative(py::module *m_ptr) { py::arg("ring_id")); #endif +#if defined(PADDLE_WITH_CUSTOM_DEVICE) + py::class_>( + m, "XCCLParallelContext") + .def(py::init()) + .def("init", [](imperative::XCCLParallelContext &self) { self.Init(); }) + .def("init_with_ring_id", + &imperative::XCCLParallelContext::InitWithRingID, + py::arg("ring_id")); +#endif + #if defined(PADDLE_WITH_XPU_BKCL) py::class_>( diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index f4dfb133c1c..a9b1bf398f8 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -158,6 +158,7 @@ limitations under the License. */ #ifdef PADDLE_WITH_CUSTOM_DEVICE #include "paddle/fluid/operators/custom_device_common_op_registry.h" +#include "paddle/fluid/platform/collective_helper.h" #include "paddle/phi/capi/capi.h" #endif @@ -990,6 +991,7 @@ PYBIND11_MODULE(libpaddle, m) { []() { phi::KernelFactory::Instance().kernels().clear(); }); m.def("clear_device_manager", []() { #ifdef PADDLE_WITH_CUSTOM_DEVICE + platform::XCCLCommContext::Release(); phi::DeviceManager::Clear(); #endif }); -- GitLab