From 34d13d6abb4ddf9ef407a61f3efc1894989d8bfa Mon Sep 17 00:00:00 2001 From: ronnywang Date: Mon, 31 Oct 2022 14:06:28 +0800 Subject: [PATCH] [CustomDevice] GetCCLComm add custom device support (#47168) * [CustomDevice] GetCCLComm add custom device support * update * update * update --- .../distributed/collective/CMakeLists.txt | 9 +-- .../collective/ProcessGroupCustom.cc | 18 +++-- .../collective/ProcessGroupCustom.h | 2 + paddle/phi/backends/CMakeLists.txt | 12 ++++ .../phi/backends/processgroup_comm_utils.cc | 65 +++++++++++++++++++ paddle/phi/kernels/CMakeLists.txt | 1 + .../phi/kernels/gpu/sync_batch_norm_kernel.cu | 20 ------ 7 files changed, 96 insertions(+), 31 deletions(-) create mode 100644 paddle/phi/backends/processgroup_comm_utils.cc diff --git a/paddle/fluid/distributed/collective/CMakeLists.txt b/paddle/fluid/distributed/collective/CMakeLists.txt index 7f6a5e262b..aa816f26f9 100644 --- a/paddle/fluid/distributed/collective/CMakeLists.txt +++ b/paddle/fluid/distributed/collective/CMakeLists.txt @@ -86,11 +86,6 @@ if(WITH_CUSTOM_DEVICE) cc_library( processgroup_custom SRCS ProcessGroupCustom.cc CustomCCLTools.cc Common.cc - DEPS phi_backends - place - enforce - collective_helper - device_context - phi_api - eager_api) + DEPS processgroup phi_backends place enforce collective_helper + device_context) endif() diff --git a/paddle/fluid/distributed/collective/ProcessGroupCustom.cc b/paddle/fluid/distributed/collective/ProcessGroupCustom.cc index f18765a05f..87bd474477 100644 --- a/paddle/fluid/distributed/collective/ProcessGroupCustom.cc +++ b/paddle/fluid/distributed/collective/ProcessGroupCustom.cc @@ -19,7 +19,6 @@ #include "paddle/fluid/memory/malloc.h" #include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/place.h" -#include "paddle/phi/api/include/api.h" #include "paddle/phi/common/place.h" DECLARE_bool(xccl_blocking_wait); @@ -386,9 +385,10 @@ std::shared_ptr ProcessGroupCustom::Barrier( for (auto& place : places) { phi::DeviceGuard guard(place); - auto dt = full({1}, 0, phi::DataType::FLOAT32, place); - barrierTensors.push_back( - *std::dynamic_pointer_cast(dt.impl())); + phi::DenseTensorMeta meta(phi::DataType::FLOAT32, phi::DDim({1})); + auto allocator = std::unique_ptr( + new paddle::experimental::DefaultAllocator(place)); + barrierTensors.emplace_back(allocator.get(), meta); } auto task = ProcessGroupCustom::AllReduce(barrierTensors, barrierTensors); auto xccl_task = dynamic_cast(task.get()); @@ -396,5 +396,15 @@ std::shared_ptr ProcessGroupCustom::Barrier( return task; } +phi::ccl::CCLComm ProcessGroupCustom::CustomCCLComm(const Place& place) const { + std::vector places = {place}; + const auto& iter = places_to_customcomm_.find(GetKeyFromPlaces(places)); + PADDLE_ENFORCE_NE(iter, + places_to_customcomm_.end(), + platform::errors::InvalidArgument( + "Cannot find nccl comm in process group.")); + return iter->second[0]->GetCustomCCLComm(); +} + } // namespace distributed } // namespace paddle diff --git a/paddle/fluid/distributed/collective/ProcessGroupCustom.h b/paddle/fluid/distributed/collective/ProcessGroupCustom.h index ce3532bbb6..38a794a0e7 100644 --- a/paddle/fluid/distributed/collective/ProcessGroupCustom.h +++ b/paddle/fluid/distributed/collective/ProcessGroupCustom.h @@ -96,6 +96,8 @@ class ProcessGroupCustom : public ProcessGroup { std::shared_ptr Barrier( const BarrierOptions& = BarrierOptions()) override; + phi::ccl::CCLComm CustomCCLComm(const Place& place) const; + protected: virtual std::shared_ptr CreateTask( std::vector places, diff --git a/paddle/phi/backends/CMakeLists.txt b/paddle/phi/backends/CMakeLists.txt index 9bc9573529..b2095f7983 100644 --- a/paddle/phi/backends/CMakeLists.txt +++ b/paddle/phi/backends/CMakeLists.txt @@ -58,3 +58,15 @@ if(WITH_CUSTOM_DEVICE) SRCS custom/capi_test.cc DEPS phi_capi) endif() + +set(COMM_UTILS_DEPS processgroup) +if(WITH_NCCL OR WITH_RCCL) + set(COMM_UTILS_DEPS ${PROCESS_GROUP_UTILS_DEPS} processgroup_nccl) +endif() +if(WITH_CUSTOM_DEVICE) + set(COMM_UTILS_DEPS ${PROCESS_GROUP_UTILS_DEPS} processgroup_custom) +endif() +cc_library( + processgroup_comm_utils + SRCS processgroup_comm_utils.cc + DEPS ${COMM_UTILS_DEPS}) diff --git a/paddle/phi/backends/processgroup_comm_utils.cc b/paddle/phi/backends/processgroup_comm_utils.cc new file mode 100644 index 0000000000..580aebd17e --- /dev/null +++ b/paddle/phi/backends/processgroup_comm_utils.cc @@ -0,0 +1,65 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/distributed/collective/ProcessGroup.h" +#include "paddle/phi/backends/c_comm_lib.h" +#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) +#include "paddle/fluid/distributed/collective/ProcessGroupNCCL.h" +#endif +#if defined(PADDLE_WITH_CUSTOM_DEVICE) +#include "paddle/fluid/distributed/collective/ProcessGroupCustom.h" +#endif + +namespace phi { +namespace detail { + +// FIXME(paddle-dev): Since the singleton of ProcessGroup in fluid is used in +// SyncBN, the fluid symbol will be dependent on external hardware access. +// Here, the part that depends on the fluid symbol is individually encapsulated +// as a temporary function to isolate external symbol dependencies. +// In the future, the dependence on the singleton in fluid in SyncBN needs +// to be removed. +// In principle, the PHI Kernel cannot use the global singleton internally, +// and the required members need to be passed in from the eucalyptus tree. +ccl::CCLComm GetCCLComm(const Place& place, int global_gid) { + paddle::distributed::ProcessGroup* pg = nullptr; + if (paddle::distributed::ProcessGroupMapFromGid::getInstance()->has( + global_gid)) { + pg = paddle::distributed::ProcessGroupMapFromGid::getInstance()->get( + global_gid); + } else { + return nullptr; + } + + if (paddle::platform::is_gpu_place(place)) { +#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) + return static_cast(pg)->NCCLComm( + place); +#else + return nullptr; +#endif + } else if (paddle::platform::is_custom_place(place)) { +#if defined(PADDLE_WITH_CUSTOM_DEVICE) + return static_cast(pg) + ->CustomCCLComm(place); +#else + return nullptr; +#endif + } else { + return nullptr; + } +} + +} // namespace detail +} // namespace phi diff --git a/paddle/phi/kernels/CMakeLists.txt b/paddle/phi/kernels/CMakeLists.txt index 7cbd218543..8e45da27a8 100644 --- a/paddle/phi/kernels/CMakeLists.txt +++ b/paddle/phi/kernels/CMakeLists.txt @@ -83,6 +83,7 @@ set(COMMON_KERNEL_DEPS ${COMMON_KERNEL_DEPS} processgroup) if(WITH_NCCL OR WITH_RCCL) set(COMMON_KERNEL_DEPS ${COMMON_KERNEL_DEPS} processgroup_nccl) endif() +set(COMMON_KERNEL_DEPS ${COMMON_KERNEL_DEPS} processgroup_comm_utils) copy_if_different(${kernel_declare_file} ${kernel_declare_file_final}) diff --git a/paddle/phi/kernels/gpu/sync_batch_norm_kernel.cu b/paddle/phi/kernels/gpu/sync_batch_norm_kernel.cu index 106b3d6642..d41f50677f 100644 --- a/paddle/phi/kernels/gpu/sync_batch_norm_kernel.cu +++ b/paddle/phi/kernels/gpu/sync_batch_norm_kernel.cu @@ -18,26 +18,6 @@ #include "paddle/phi/kernels/gpu/sync_batch_norm_utils.h" namespace phi { -namespace detail { - -ccl::CCLComm GetCCLComm(const Place &place, int global_gid) { -#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) - ncclComm_t comm = nullptr; - - if (paddle::distributed::ProcessGroupMapFromGid::getInstance()->has( - global_gid)) { - auto *nccl_pg = static_cast( - paddle::distributed::ProcessGroupMapFromGid::getInstance()->get( - global_gid)); - comm = nccl_pg->NCCLComm(place); - } - return comm; -#else - return nullptr; -#endif -} - -} // namespace detail template void SyncBatchNormKernel(const Context &ctx, -- GitLab