未验证 提交 34d13d6a 编写于 作者: R ronnywang 提交者: GitHub

[CustomDevice] GetCCLComm add custom device support (#47168)

* [CustomDevice] GetCCLComm add custom device support

* update

* update

* update
上级 520adc0e
...@@ -86,11 +86,6 @@ if(WITH_CUSTOM_DEVICE) ...@@ -86,11 +86,6 @@ if(WITH_CUSTOM_DEVICE)
cc_library( cc_library(
processgroup_custom processgroup_custom
SRCS ProcessGroupCustom.cc CustomCCLTools.cc Common.cc SRCS ProcessGroupCustom.cc CustomCCLTools.cc Common.cc
DEPS phi_backends DEPS processgroup phi_backends place enforce collective_helper
place device_context)
enforce
collective_helper
device_context
phi_api
eager_api)
endif() endif()
...@@ -19,7 +19,6 @@ ...@@ -19,7 +19,6 @@
#include "paddle/fluid/memory/malloc.h" #include "paddle/fluid/memory/malloc.h"
#include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/place.h" #include "paddle/fluid/platform/place.h"
#include "paddle/phi/api/include/api.h"
#include "paddle/phi/common/place.h" #include "paddle/phi/common/place.h"
DECLARE_bool(xccl_blocking_wait); DECLARE_bool(xccl_blocking_wait);
...@@ -386,9 +385,10 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::Barrier( ...@@ -386,9 +385,10 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::Barrier(
for (auto& place : places) { for (auto& place : places) {
phi::DeviceGuard guard(place); phi::DeviceGuard guard(place);
auto dt = full({1}, 0, phi::DataType::FLOAT32, place); phi::DenseTensorMeta meta(phi::DataType::FLOAT32, phi::DDim({1}));
barrierTensors.push_back( auto allocator = std::unique_ptr<phi::Allocator>(
*std::dynamic_pointer_cast<phi::DenseTensor>(dt.impl())); new paddle::experimental::DefaultAllocator(place));
barrierTensors.emplace_back(allocator.get(), meta);
} }
auto task = ProcessGroupCustom::AllReduce(barrierTensors, barrierTensors); auto task = ProcessGroupCustom::AllReduce(barrierTensors, barrierTensors);
auto xccl_task = dynamic_cast<ProcessGroupCustom::CustomTask*>(task.get()); auto xccl_task = dynamic_cast<ProcessGroupCustom::CustomTask*>(task.get());
...@@ -396,5 +396,15 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::Barrier( ...@@ -396,5 +396,15 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::Barrier(
return task; return task;
} }
phi::ccl::CCLComm ProcessGroupCustom::CustomCCLComm(const Place& place) const {
std::vector<Place> places = {place};
const auto& iter = places_to_customcomm_.find(GetKeyFromPlaces(places));
PADDLE_ENFORCE_NE(iter,
places_to_customcomm_.end(),
platform::errors::InvalidArgument(
"Cannot find nccl comm in process group."));
return iter->second[0]->GetCustomCCLComm();
}
} // namespace distributed } // namespace distributed
} // namespace paddle } // namespace paddle
...@@ -96,6 +96,8 @@ class ProcessGroupCustom : public ProcessGroup { ...@@ -96,6 +96,8 @@ class ProcessGroupCustom : public ProcessGroup {
std::shared_ptr<ProcessGroup::Task> Barrier( std::shared_ptr<ProcessGroup::Task> Barrier(
const BarrierOptions& = BarrierOptions()) override; const BarrierOptions& = BarrierOptions()) override;
phi::ccl::CCLComm CustomCCLComm(const Place& place) const;
protected: protected:
virtual std::shared_ptr<ProcessGroupCustom::CustomTask> CreateTask( virtual std::shared_ptr<ProcessGroupCustom::CustomTask> CreateTask(
std::vector<Place> places, std::vector<Place> places,
......
...@@ -58,3 +58,15 @@ if(WITH_CUSTOM_DEVICE) ...@@ -58,3 +58,15 @@ if(WITH_CUSTOM_DEVICE)
SRCS custom/capi_test.cc SRCS custom/capi_test.cc
DEPS phi_capi) DEPS phi_capi)
endif() endif()
set(COMM_UTILS_DEPS processgroup)
if(WITH_NCCL OR WITH_RCCL)
set(COMM_UTILS_DEPS ${PROCESS_GROUP_UTILS_DEPS} processgroup_nccl)
endif()
if(WITH_CUSTOM_DEVICE)
set(COMM_UTILS_DEPS ${PROCESS_GROUP_UTILS_DEPS} processgroup_custom)
endif()
cc_library(
processgroup_comm_utils
SRCS processgroup_comm_utils.cc
DEPS ${COMM_UTILS_DEPS})
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/distributed/collective/ProcessGroup.h"
#include "paddle/phi/backends/c_comm_lib.h"
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
#include "paddle/fluid/distributed/collective/ProcessGroupNCCL.h"
#endif
#if defined(PADDLE_WITH_CUSTOM_DEVICE)
#include "paddle/fluid/distributed/collective/ProcessGroupCustom.h"
#endif
namespace phi {
namespace detail {
// FIXME(paddle-dev): Since the singleton of ProcessGroup in fluid is used in
// SyncBN, the fluid symbol will be dependent on external hardware access.
// Here, the part that depends on the fluid symbol is individually encapsulated
// as a temporary function to isolate external symbol dependencies.
// In the future, the dependence on the singleton in fluid in SyncBN needs
// to be removed.
// In principle, the PHI Kernel cannot use the global singleton internally,
// and the required members need to be passed in from the eucalyptus tree.
ccl::CCLComm GetCCLComm(const Place& place, int global_gid) {
paddle::distributed::ProcessGroup* pg = nullptr;
if (paddle::distributed::ProcessGroupMapFromGid::getInstance()->has(
global_gid)) {
pg = paddle::distributed::ProcessGroupMapFromGid::getInstance()->get(
global_gid);
} else {
return nullptr;
}
if (paddle::platform::is_gpu_place(place)) {
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
return static_cast<paddle::distributed::ProcessGroupNCCL*>(pg)->NCCLComm(
place);
#else
return nullptr;
#endif
} else if (paddle::platform::is_custom_place(place)) {
#if defined(PADDLE_WITH_CUSTOM_DEVICE)
return static_cast<paddle::distributed::ProcessGroupCustom*>(pg)
->CustomCCLComm(place);
#else
return nullptr;
#endif
} else {
return nullptr;
}
}
} // namespace detail
} // namespace phi
...@@ -83,6 +83,7 @@ set(COMMON_KERNEL_DEPS ${COMMON_KERNEL_DEPS} processgroup) ...@@ -83,6 +83,7 @@ set(COMMON_KERNEL_DEPS ${COMMON_KERNEL_DEPS} processgroup)
if(WITH_NCCL OR WITH_RCCL) if(WITH_NCCL OR WITH_RCCL)
set(COMMON_KERNEL_DEPS ${COMMON_KERNEL_DEPS} processgroup_nccl) set(COMMON_KERNEL_DEPS ${COMMON_KERNEL_DEPS} processgroup_nccl)
endif() endif()
set(COMMON_KERNEL_DEPS ${COMMON_KERNEL_DEPS} processgroup_comm_utils)
copy_if_different(${kernel_declare_file} ${kernel_declare_file_final}) copy_if_different(${kernel_declare_file} ${kernel_declare_file_final})
......
...@@ -18,26 +18,6 @@ ...@@ -18,26 +18,6 @@
#include "paddle/phi/kernels/gpu/sync_batch_norm_utils.h" #include "paddle/phi/kernels/gpu/sync_batch_norm_utils.h"
namespace phi { namespace phi {
namespace detail {
ccl::CCLComm GetCCLComm(const Place &place, int global_gid) {
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
ncclComm_t comm = nullptr;
if (paddle::distributed::ProcessGroupMapFromGid::getInstance()->has(
global_gid)) {
auto *nccl_pg = static_cast<paddle::distributed::ProcessGroupNCCL *>(
paddle::distributed::ProcessGroupMapFromGid::getInstance()->get(
global_gid));
comm = nccl_pg->NCCLComm(place);
}
return comm;
#else
return nullptr;
#endif
}
} // namespace detail
template <typename T, typename Context> template <typename T, typename Context>
void SyncBatchNormKernel(const Context &ctx, void SyncBatchNormKernel(const Context &ctx,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册