From c24e7fe15a7b474ba29a9d46ddf224076c2de6c5 Mon Sep 17 00:00:00 2001 From: Wen Sun <35923278+HermitSun@users.noreply.github.com> Date: Thu, 12 Jan 2023 20:49:09 +0800 Subject: [PATCH] Migrate collective communication checks to PHI (#49754) * refactor: migrate comm checks * refactor: add check in comm context * feat: add gloo static check * refactor: add place param in static check --- .../distributed/collective/CMakeLists.txt | 6 +- paddle/fluid/distributed/collective/check.cc | 290 ---------------- paddle/fluid/distributed/collective/check.h | 115 ------- .../collective/process_group_nccl.cc | 322 +++++++++--------- .../collective/process_group_nccl.h | 11 +- paddle/fluid/distributed/collective/types.h | 2 + paddle/phi/core/distributed/CMakeLists.txt | 6 +- .../phi/core/distributed/check/CMakeLists.txt | 11 + .../distributed/check/nccl_dynamic_check.cc | 163 +++++++++ .../distributed/check/nccl_dynamic_check.h | 61 ++++ .../core/distributed/check/static_check.cc | 167 +++++++++ .../phi/core/distributed/check/static_check.h | 85 +++++ .../core/distributed/comm_context_manager.cc | 2 +- .../phi/core/distributed/gloo_comm_context.cc | 12 +- .../phi/core/distributed/gloo_comm_context.h | 8 +- paddle/phi/core/distributed/gloo_utils.cc | 2 +- paddle/phi/core/distributed/gloo_utils.h | 6 +- .../phi/core/distributed/nccl_comm_context.cc | 28 +- 18 files changed, 700 insertions(+), 597 deletions(-) delete mode 100644 paddle/fluid/distributed/collective/check.cc delete mode 100644 paddle/fluid/distributed/collective/check.h create mode 100644 paddle/phi/core/distributed/check/CMakeLists.txt create mode 100644 paddle/phi/core/distributed/check/nccl_dynamic_check.cc create mode 100644 paddle/phi/core/distributed/check/nccl_dynamic_check.h create mode 100644 paddle/phi/core/distributed/check/static_check.cc create mode 100644 paddle/phi/core/distributed/check/static_check.h diff --git a/paddle/fluid/distributed/collective/CMakeLists.txt b/paddle/fluid/distributed/collective/CMakeLists.txt index 39e4c00462..8329eb366a 100644 --- a/paddle/fluid/distributed/collective/CMakeLists.txt +++ b/paddle/fluid/distributed/collective/CMakeLists.txt @@ -18,7 +18,7 @@ endif() if(WITH_NCCL OR WITH_RCCL) cc_library( process_group_nccl - SRCS process_group_nccl.cc nccl_tools.cc common.cc check.cc + SRCS process_group_nccl.cc nccl_tools.cc common.cc DEPS process_group tcp_store place @@ -26,7 +26,9 @@ if(WITH_NCCL OR WITH_RCCL) collective_helper device_context ${DEVICE_EVENT_LIBS} - dense_tensor) + dense_tensor + comm_static_check + nccl_dynamic_check) endif() if(WITH_XPU_BKCL) diff --git a/paddle/fluid/distributed/collective/check.cc b/paddle/fluid/distributed/collective/check.cc deleted file mode 100644 index a5cd37dbc3..0000000000 --- a/paddle/fluid/distributed/collective/check.cc +++ /dev/null @@ -1,290 +0,0 @@ -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/fluid/distributed/collective/check.h" - -#include "paddle/fluid/distributed/collective/nccl_tools.h" -#include "paddle/fluid/platform/place.h" -#include "paddle/phi/core/dense_tensor.h" -#include "paddle/phi/core/enforce.h" -#include "paddle/phi/core/errors.h" - -#ifdef PADDLE_WITH_HIP -#define gpuMalloc hipMalloc -#define gpuMemcpy hipMemcpy -#define gpuMemcpyDeviceToHost hipMemcpyDeviceToHost -#define gpuMemcpyHostToDevice hipMemcpyHostToDevice -#define gpuFree hipFree -#else -#define gpuMalloc cudaMalloc -#define gpuMemcpy cudaMemcpy -#define gpuMemcpyDeviceToHost cudaMemcpyDeviceToHost -#define gpuMemcpyHostToDevice cudaMemcpyHostToDevice -#define gpuFree cudaFree -#endif - -namespace paddle { -namespace distributed { - -// static checks -void CommStaticCheck::CheckRank(int rank, int world_size) { - PADDLE_ENFORCE_GE(rank, - 0, - phi::errors::InvalidArgument( - "Rank should be greater than or equal to 0.")); - PADDLE_ENFORCE_LT( - rank, - world_size, - phi::errors::InvalidArgument("Rank is out of the process group.")); -} - -void CommStaticCheck::CheckPlace(const phi::DenseTensor& tensor) { - PADDLE_ENFORCE_EQ( - platform::is_gpu_place(tensor.place()), - true, - platform::errors::InvalidArgument("Tensor should be in GPU place.")); -} - -void CommStaticCheck::CheckPlace(const phi::DenseTensor& out_tensor, - const phi::DenseTensor& in_tensor) { - CheckPlace(out_tensor); - CheckPlace(in_tensor); - PADDLE_ENFORCE_EQ( - out_tensor.place(), - in_tensor.place(), - phi::errors::InvalidArgument( - "Input and output tensors should be on the same place.")); -} - -void CommStaticCheck::CheckDataType(const phi::DenseTensor& out_tensor, - const phi::DenseTensor& in_tensor) { - PADDLE_ENFORCE_EQ( - out_tensor.dtype(), - in_tensor.dtype(), - phi::errors::InvalidArgument( - "Input and output tensors should have the same data type.")); -} - -void CommStaticCheck::CheckShape(const phi::DenseTensor& tensor) { - PADDLE_ENFORCE_GT( - tensor.numel(), - 0, - phi::errors::InvalidArgument("Size of tensor should be greater than 0.")); -} - -void CommStaticCheck::CheckShape(const phi::DenseTensor& out_tensor, - const phi::DenseTensor& in_tensor, - int out_size_factor, - int in_size_factor) { - CheckShape(out_tensor); - CheckShape(in_tensor); - int64_t out_size = out_tensor.numel(), in_size = in_tensor.numel(); - PADDLE_ENFORCE_EQ( - out_size * out_size_factor, - in_size * in_size_factor, - phi::errors::InvalidArgument( - "Input and output tensors should have matching sizes.")); -} - -void CommStaticCheck::CheckShape(const phi::DenseTensor& out_tensor, - const phi::DenseTensor& in_tensor, - int dst_rank, - int cur_rank, - int world_size, - int out_size_factor, - int in_size_factor) { - CheckRank(dst_rank, world_size); - CheckRank(cur_rank, world_size); - - CheckPlace(out_tensor, in_tensor); - CheckDataType(out_tensor, in_tensor); - - if (dst_rank == cur_rank) { - CheckShape(out_tensor, in_tensor, out_size_factor, in_size_factor); - } else { - CheckShape(out_tensor); - CheckShape(in_tensor); - } -} - -void CommStaticCheck::CheckShape(const phi::DenseTensor& tensor, - int rank, - int world_size) { - CheckPlace(tensor); - CheckRank(rank, world_size); -} - -void CommStaticCheck::SameShape(const phi::DenseTensor& out_tensor, - const phi::DenseTensor& in_tensor, - int dst_rank, - int cur_rank, - int world_size) { - CheckShape(out_tensor, - in_tensor, - dst_rank, - cur_rank, - world_size, - /*out_size_factor*/ 1, - /*in_size_factor*/ 1); -} - -void CommStaticCheck::ScatterLikeShape(const phi::DenseTensor& out_tensor, - const phi::DenseTensor& in_tensor, - int dst_rank, - int cur_rank, - int world_size) { - CheckShape(out_tensor, - in_tensor, - dst_rank, - cur_rank, - world_size, - /*out_size_factor*/ world_size, - /*in_size_factor*/ 1); -} - -void CommStaticCheck::GatherLikeShape(const phi::DenseTensor& out_tensor, - const phi::DenseTensor& in_tensor, - int dst_rank, - int cur_rank, - int world_size) { - CheckShape(out_tensor, - in_tensor, - dst_rank, - cur_rank, - world_size, - /*out_size_factor*/ 1, - /*in_size_factor*/ world_size); -} - -// dynamic checks -void CommDynamicCheck::CheckDataType(const phi::DenseTensor& tensor, - int64_t dtype) { - PADDLE_ENFORCE_EQ( - static_cast(tensor.dtype()), - dtype, - phi::errors::InvalidArgument( - "Tensors in communication are expected to have the same data type.")); -} - -void CommDynamicCheck::CheckDataType(const phi::DenseTensor& tensor, - int root_rank, - int cur_rank, - ncclComm_t comm) { - constexpr int kSize = sizeof(int64_t); - int64_t dtype_host = static_cast(tensor.dtype()); - int64_t* dtype_device; - PADDLE_ENFORCE_GPU_SUCCESS(gpuMalloc(&dtype_device, kSize)); - PADDLE_ENFORCE_GPU_SUCCESS( - gpuMemcpy(dtype_device, &dtype_host, kSize, gpuMemcpyHostToDevice)); - - NCCL_CHECK(phi::dynload::ncclBroadcast(dtype_device, - dtype_device, - kSize, - ncclInt64, - root_rank, - comm, - kDefaultStream)); - - if (root_rank == cur_rank) { - VLOG(3) << "Dynamic check broadcast metadata, dtype: " << dtype_host; - } else { - PADDLE_ENFORCE_GPU_SUCCESS( - gpuMemcpy(&dtype_host, dtype_device, kSize, gpuMemcpyDeviceToHost)); - VLOG(3) << "Dynamic check recv metadata, dtype: " << dtype_host; - CheckDataType(tensor, dtype_host); - } - PADDLE_ENFORCE_GPU_SUCCESS(gpuFree(dtype_device)); -} - -void CommDynamicCheck::CheckShape(const phi::DenseTensor& tensor, - int64_t shape) { - PADDLE_ENFORCE_EQ( - tensor.numel(), - shape, - phi::errors::InvalidArgument( - "Tensors in communication are expected to have matching sizes.")); -} - -void CommDynamicCheck::CheckShape(const phi::DenseTensor& tensor, - int root_rank, - int cur_rank, - ncclComm_t comm) { - CheckDataType(tensor, root_rank, cur_rank, comm); - - constexpr int kSize = sizeof(int64_t); - int64_t shape_host = tensor.numel(); - int64_t* shape_device; - - PADDLE_ENFORCE_GPU_SUCCESS(gpuMalloc(&shape_device, kSize)); - PADDLE_ENFORCE_GPU_SUCCESS( - gpuMemcpy(shape_device, &shape_host, kSize, gpuMemcpyHostToDevice)); - - NCCL_CHECK(phi::dynload::ncclBroadcast(shape_device, - shape_device, - kSize, - ncclInt64, - root_rank, - comm, - kDefaultStream)); - - if (root_rank == cur_rank) { - VLOG(3) << "Dynamic check broadcast metadata, shape: " << shape_host; - } else { - PADDLE_ENFORCE_GPU_SUCCESS( - gpuMemcpy(&shape_host, shape_device, kSize, gpuMemcpyDeviceToHost)); - VLOG(3) << "Dynamic check recv metadata, shape: " << shape_host; - CheckShape(tensor, shape_host); - } - PADDLE_ENFORCE_GPU_SUCCESS(gpuFree(shape_device)); -} - -void CommDynamicCheck::CheckShape(const phi::DenseTensor& out_tensor, - const phi::DenseTensor& in_tensor, - const std::vector& in_size_each_rank, - int cur_rank, - int world_size, - ncclComm_t comm) { - CheckDataType(out_tensor, /*root_rank*/ 0, cur_rank, comm); - CheckDataType(in_tensor, /*root_rank*/ 0, cur_rank, comm); - - constexpr int kSize = sizeof(int64_t); - int64_t in_row_size = in_tensor.numel() / in_tensor.dims()[0]; - - for (int rank = 0; rank < world_size; ++rank) { - int64_t in_shape_host = in_size_each_rank[rank] * in_row_size; - int64_t* in_shape_device; - PADDLE_ENFORCE_GPU_SUCCESS(gpuMalloc(&in_shape_device, kSize)); - PADDLE_ENFORCE_GPU_SUCCESS(gpuMemcpy( - in_shape_device, &in_shape_host, kSize, gpuMemcpyHostToDevice)); - - NCCL_CHECK(phi::dynload::ncclReduce(in_shape_device, - in_shape_device, - kSize, - ncclInt64, - ncclSum, - rank, - comm, - kDefaultStream)); - if (rank == cur_rank) { - PADDLE_ENFORCE_GPU_SUCCESS(gpuMemcpy( - &in_shape_host, in_shape_device, kSize, gpuMemcpyDeviceToHost)); - VLOG(3) << "Dynamic check recv metadata, shape: " << in_shape_host; - CheckShape(out_tensor, in_shape_host); - } - PADDLE_ENFORCE_GPU_SUCCESS(gpuFree(in_shape_device)); - } -} - -} // namespace distributed -} // namespace paddle diff --git a/paddle/fluid/distributed/collective/check.h b/paddle/fluid/distributed/collective/check.h deleted file mode 100644 index be9bfb5f78..0000000000 --- a/paddle/fluid/distributed/collective/check.h +++ /dev/null @@ -1,115 +0,0 @@ -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#include -#include - -#include "paddle/phi/backends/gpu/forwards.h" - -#ifdef PADDLE_WITH_HIP -using gpuStream_t = hipStream_t; -#else -using gpuStream_t = cudaStream_t; -#endif - -// forward declarations -namespace phi { -class DenseTensor; -} - -namespace paddle { -namespace distributed { - -struct CommStaticCheck { - static void CheckRank(int rank, int world_size); - - static void CheckPlace(const phi::DenseTensor& tensor); - - static void CheckPlace(const phi::DenseTensor& out_tensor, - const phi::DenseTensor& in_tensor); - - static void CheckDataType(const phi::DenseTensor& out_tensor, - const phi::DenseTensor& in_tensor); - - static void CheckShape(const phi::DenseTensor& tensor); - - static void CheckShape(const phi::DenseTensor& out_tensor, - const phi::DenseTensor& in_tensor, - int out_size_factor, - int in_size_factor); - - static void CheckShape(const phi::DenseTensor& out_tensor, - const phi::DenseTensor& in_tensor, - int dst_rank, - int cur_rank, - int world_size, - int out_size_factor, - int in_size_factor); - - // for p2p - static void CheckShape(const phi::DenseTensor& tensor, - int rank, - int world_size); - - // for collective - static void SameShape(const phi::DenseTensor& out_tensor, - const phi::DenseTensor& in_tensor, - int dst_rank, - int cur_rank, - int world_size); - - static void ScatterLikeShape(const phi::DenseTensor& out_tensor, - const phi::DenseTensor& in_tensor, - int dst_rank, - int cur_rank, - int world_size); - - static void GatherLikeShape(const phi::DenseTensor& out_tensor, - const phi::DenseTensor& in_tensor, - int dst_rank, - int cur_rank, - int world_size); -}; - -struct CommDynamicCheck { - static void CheckDataType(const phi::DenseTensor& tensor, int64_t dtype); - - static void CheckDataType(const phi::DenseTensor& tensor, - int root_rank, - int cur_rank, - ncclComm_t comm); - - static void CheckShape(const phi::DenseTensor& tensor, int64_t shape); - - static void CheckShape(const phi::DenseTensor& tensor, - int root_rank, - int cur_rank, - ncclComm_t comm); - - static void CheckShape(const phi::DenseTensor& out_tensor, - const phi::DenseTensor& in_tensor, - const std::vector& in_size_each_rank, - int cur_rank, - int world_size, - ncclComm_t comm); - - private: - // `0` represents default stream for both cuda & hip - static constexpr gpuStream_t kDefaultStream = 0; -}; - -} // namespace distributed -} // namespace paddle diff --git a/paddle/fluid/distributed/collective/process_group_nccl.cc b/paddle/fluid/distributed/collective/process_group_nccl.cc index 37fb6312e6..9f9fa42589 100644 --- a/paddle/fluid/distributed/collective/process_group_nccl.cc +++ b/paddle/fluid/distributed/collective/process_group_nccl.cc @@ -14,15 +14,16 @@ #include "paddle/fluid/distributed/collective/process_group_nccl.h" -#include "paddle/fluid/distributed/collective/check.h" #include "paddle/fluid/distributed/collective/common.h" #include "paddle/fluid/distributed/collective/nccl_tools.h" #include "paddle/fluid/distributed/collective/utils.h" #include "paddle/fluid/platform/cuda_device_guard.h" #include "paddle/fluid/platform/device/gpu/nccl_helper.h" -#include "paddle/fluid/platform/place.h" #include "paddle/phi/api/lib/utils/allocator.h" +#include "paddle/phi/core/distributed/check/nccl_dynamic_check.h" +#include "paddle/phi/core/distributed/check/static_check.h" #include "paddle/phi/core/enforce.h" +#include "paddle/phi/core/utils/data_type.h" DECLARE_bool(nccl_blocking_wait); DECLARE_bool(use_stream_safe_cuda_allocator); @@ -142,24 +143,24 @@ std::shared_ptr ProcessGroupNCCL::AllGather( // numel > 0 indicates the tensor need to be sliced const phi::DenseTensor& in_tensor_maybe_partial = numel > 0 ? GetPartialTensor(in_tensor, offset, numel) : in_tensor; - CommStaticCheck::GatherLikeShape(*out_tensor, - in_tensor_maybe_partial, - /*dst_rank*/ rank_, - /*cur_rank*/ rank_, - size_); + phi::distributed::CommStaticCheck::GatherLikeShape(*out_tensor, + in_tensor_maybe_partial, + /*dst_rank*/ rank_, + /*cur_rank*/ rank_, + size_); return RunFnInNCCLEnv( [&](ncclComm_t comm, gpuStream_t stream) { if (FLAGS_enable_nccl_dynamic_check) { - CommDynamicCheck::CheckShape(*out_tensor, - /*root_rank*/ 0, - rank_, - comm); + phi::distributed::NCCLDynamicCheck::CheckShape(*out_tensor, + /*root_rank*/ 0, + rank_, + comm); } NCCL_CHECK(phi::dynload::ncclAllGather( in_tensor_maybe_partial.data(), out_tensor->data(), in_tensor_maybe_partial.numel(), - platform::ToNCCLDataType(in_tensor_maybe_partial.dtype()), + phi::ToNCCLDataType(in_tensor_maybe_partial.dtype()), comm, stream)); }, @@ -175,27 +176,27 @@ std::shared_ptr ProcessGroupNCCL::AllReduce( const AllreduceOptions& opts, bool sync_op, bool use_calc_stream) { - CommStaticCheck::SameShape(*out_tensor, - in_tensor, - /*dst_rank*/ rank_, - /*cur_rank*/ rank_, - size_); + phi::distributed::CommStaticCheck::SameShape(*out_tensor, + in_tensor, + /*dst_rank*/ rank_, + /*cur_rank*/ rank_, + size_); return RunFnInNCCLEnv( [&](ncclComm_t comm, gpuStream_t stream) { if (FLAGS_enable_nccl_dynamic_check) { - CommDynamicCheck::CheckShape(*out_tensor, - /*root_rank*/ 0, - rank_, - comm); + phi::distributed::NCCLDynamicCheck::CheckShape(*out_tensor, + /*root_rank*/ 0, + rank_, + comm); } - NCCL_CHECK(phi::dynload::ncclAllReduce( - in_tensor.data(), - out_tensor->data(), - in_tensor.numel(), - platform::ToNCCLDataType(in_tensor.dtype()), - ToNCCLRedType(opts.reduce_op), - comm, - stream)); + NCCL_CHECK( + phi::dynload::ncclAllReduce(in_tensor.data(), + out_tensor->data(), + in_tensor.numel(), + phi::ToNCCLDataType(in_tensor.dtype()), + ToNCCLRedType(opts.reduce_op), + comm, + stream)); }, in_tensor, CommType::ALLREDUCE, @@ -238,17 +239,17 @@ std::shared_ptr ProcessGroupNCCL::AllToAll( // simply be covered by static checks. Factors are set to 0 here to skip the // shape check. Its shape check will be done by dynamic checks with // FLAGS_enable_nccl_dynamic_check. - CommStaticCheck::CheckShape(*out_tensor, - in_tensor, - /*dst_rank*/ rank_, - /*cur_rank*/ rank_, - size_, - /*out_size_factor*/ 0, - /*in_size_factor*/ 0); + phi::distributed::CommStaticCheck::CheckShape(*out_tensor, + in_tensor, + /*dst_rank*/ rank_, + /*cur_rank*/ rank_, + size_, + /*out_size_factor*/ 0, + /*in_size_factor*/ 0); return RunFnInNCCLEnv( [&](ncclComm_t comm, gpuStream_t stream) { if (FLAGS_enable_nccl_dynamic_check) { - CommDynamicCheck::CheckShape( + phi::distributed::NCCLDynamicCheck::CheckShape( *out_tensor, in_tensor, in_size_each_rank, rank_, size_, comm); } int64_t in_row_size = in_tensor.numel() / in_dim[0], @@ -260,13 +261,13 @@ std::shared_ptr ProcessGroupNCCL::AllToAll( for (auto i = 0; i < size_; i++) { in_numel = in_size_each_rank[i] * in_row_size; input_partial = GetPartialTensor(in_tensor, in_offset, in_numel); - NCCL_CHECK(phi::dynload::ncclSend( - input_partial.data(), - in_numel, - platform::ToNCCLDataType(input_partial.dtype()), - i, - comm, - stream)); + NCCL_CHECK( + phi::dynload::ncclSend(input_partial.data(), + in_numel, + phi::ToNCCLDataType(input_partial.dtype()), + i, + comm, + stream)); in_offset += in_numel; out_numel = out_size_each_rank[i] * out_row_size; @@ -274,7 +275,7 @@ std::shared_ptr ProcessGroupNCCL::AllToAll( NCCL_CHECK(phi::dynload::ncclRecv( output_partial.data(), out_numel, - platform::ToNCCLDataType(output_partial.dtype()), + phi::ToNCCLDataType(output_partial.dtype()), i, comm, stream)); @@ -316,25 +317,26 @@ std::shared_ptr ProcessGroupNCCL::Broadcast( const BroadcastOptions& opts, bool sync_op, bool use_calc_stream) { - CommStaticCheck::SameShape(*out_tensor, - in_tensor, - /*dst_rank*/ rank_, - /*cur_rank*/ rank_, - size_); + phi::distributed::CommStaticCheck::SameShape(*out_tensor, + in_tensor, + /*dst_rank*/ rank_, + /*cur_rank*/ rank_, + size_); return RunFnInNCCLEnv( [&](ncclComm_t comm, gpuStream_t stream) { int root = opts.source_rank + opts.source_root; if (FLAGS_enable_nccl_dynamic_check) { - CommDynamicCheck::CheckShape(*out_tensor, root, rank_, comm); + phi::distributed::NCCLDynamicCheck::CheckShape( + *out_tensor, root, rank_, comm); } - NCCL_CHECK(phi::dynload::ncclBroadcast( - in_tensor.data(), - out_tensor->data(), - in_tensor.numel(), - platform::ToNCCLDataType(in_tensor.dtype()), - root, - comm, - stream)); + NCCL_CHECK( + phi::dynload::ncclBroadcast(in_tensor.data(), + out_tensor->data(), + in_tensor.numel(), + phi::ToNCCLDataType(in_tensor.dtype()), + root, + comm, + stream)); }, in_tensor, CommType::BROADCAST, @@ -348,28 +350,29 @@ std::shared_ptr ProcessGroupNCCL::Reduce( const ReduceOptions& opts, bool sync_op, bool use_calc_stream) { - CommStaticCheck::SameShape(*out_tensor, - in_tensor, - /*dst_rank*/ opts.root_rank, - /*cur_rank*/ rank_, - size_); + phi::distributed::CommStaticCheck::SameShape(*out_tensor, + in_tensor, + /*dst_rank*/ opts.root_rank, + /*cur_rank*/ rank_, + size_); return RunFnInNCCLEnv( [&](ncclComm_t comm, gpuStream_t stream) { if (FLAGS_enable_nccl_dynamic_check) { - CommDynamicCheck::CheckShape(*out_tensor, - /*root_rank*/ opts.root_rank, - rank_, - comm); + phi::distributed::NCCLDynamicCheck::CheckShape( + *out_tensor, + /*root_rank*/ opts.root_rank, + rank_, + comm); } - NCCL_CHECK(phi::dynload::ncclReduce( - in_tensor.data(), - out_tensor->data(), - in_tensor.numel(), - platform::ToNCCLDataType(in_tensor.dtype()), - ToNCCLRedType(opts.reduce_op), - opts.root_rank, - comm, - stream)); + NCCL_CHECK( + phi::dynload::ncclReduce(in_tensor.data(), + out_tensor->data(), + in_tensor.numel(), + phi::ToNCCLDataType(in_tensor.dtype()), + ToNCCLRedType(opts.reduce_op), + opts.root_rank, + comm, + stream)); }, in_tensor, CommType::REDUCE, @@ -383,24 +386,24 @@ std::shared_ptr ProcessGroupNCCL::ReduceScatter( const ReduceScatterOptions& opts, bool sync_op, bool use_calc_stream) { - CommStaticCheck::ScatterLikeShape(*out_tensor, - in_tensor, - /*dst_rank*/ rank_, - /*cur_rank*/ rank_, - size_); + phi::distributed::CommStaticCheck::ScatterLikeShape(*out_tensor, + in_tensor, + /*dst_rank*/ rank_, + /*cur_rank*/ rank_, + size_); return RunFnInNCCLEnv( [&](ncclComm_t comm, gpuStream_t stream) { if (FLAGS_enable_nccl_dynamic_check) { - CommDynamicCheck::CheckShape(*out_tensor, - /*root_rank*/ 0, - rank_, - comm); + phi::distributed::NCCLDynamicCheck::CheckShape(*out_tensor, + /*root_rank*/ 0, + rank_, + comm); } NCCL_CHECK(phi::dynload::ncclReduceScatter( in_tensor.data(), out_tensor->data(), out_tensor->numel(), - platform::ToNCCLDataType(in_tensor.dtype()), + phi::ToNCCLDataType(in_tensor.dtype()), ToNCCLRedType(opts.reduce_op), comm, stream)); @@ -417,18 +420,20 @@ std::shared_ptr ProcessGroupNCCL::Scatter( const ScatterOptions& opts, bool sync_op, bool use_calc_stream) { - CommStaticCheck::ScatterLikeShape(*out_tensor, - in_tensor, - /*dst_rank*/ opts.root_rank, - /*cur_rank*/ rank_, - size_); + phi::distributed::CommStaticCheck::ScatterLikeShape( + *out_tensor, + in_tensor, + /*dst_rank*/ opts.root_rank, + /*cur_rank*/ rank_, + size_); return RunFnInNCCLEnv( [&](ncclComm_t comm, gpuStream_t stream) { if (FLAGS_enable_nccl_dynamic_check) { - CommDynamicCheck::CheckShape(*out_tensor, - /*root_rank*/ opts.root_rank, - rank_, - comm); + phi::distributed::NCCLDynamicCheck::CheckShape( + *out_tensor, + /*root_rank*/ opts.root_rank, + rank_, + comm); } int64_t numel = in_tensor.numel() / size_; if (rank_ == opts.root_rank) { @@ -440,28 +445,28 @@ std::shared_ptr ProcessGroupNCCL::Scatter( NCCL_CHECK(phi::dynload::ncclSend( partial_tensor.data(), numel, - platform::ToNCCLDataType(partial_tensor.dtype()), + phi::ToNCCLDataType(partial_tensor.dtype()), i, comm, stream)); offset += numel; } - NCCL_CHECK(phi::dynload::ncclRecv( - out_tensor->data(), - numel, - platform::ToNCCLDataType(out_tensor->dtype()), - opts.root_rank, - comm, - stream)); + NCCL_CHECK( + phi::dynload::ncclRecv(out_tensor->data(), + numel, + phi::ToNCCLDataType(out_tensor->dtype()), + opts.root_rank, + comm, + stream)); GroupEnd(); } else { - NCCL_CHECK(phi::dynload::ncclRecv( - out_tensor->data(), - numel, - platform::ToNCCLDataType(out_tensor->dtype()), - opts.root_rank, - comm, - stream)); + NCCL_CHECK( + phi::dynload::ncclRecv(out_tensor->data(), + numel, + phi::ToNCCLDataType(out_tensor->dtype()), + opts.root_rank, + comm, + stream)); } }, in_tensor, @@ -484,22 +489,21 @@ std::shared_ptr ProcessGroupNCCL::Recv( tensor = &partial_tensor; } - CommStaticCheck::CheckShape(*tensor, rank_, size_); + phi::distributed::CommStaticCheck::CheckShape(*tensor, rank_, size_); return RunFnInNCCLEnv( [&](ncclComm_t comm, gpuStream_t stream) { if (FLAGS_enable_nccl_dynamic_check) { - CommDynamicCheck::CheckShape(*tensor, - /*root_rank*/ src_rank, - rank_, - comm); + phi::distributed::NCCLDynamicCheck::CheckShape(*tensor, + /*root_rank*/ src_rank, + rank_, + comm); } - NCCL_CHECK( - phi::dynload::ncclRecv(tensor->data(), - tensor->numel(), - platform::ToNCCLDataType(tensor->dtype()), - src_rank, - comm, - stream)); + NCCL_CHECK(phi::dynload::ncclRecv(tensor->data(), + tensor->numel(), + phi::ToNCCLDataType(tensor->dtype()), + src_rank, + comm, + stream)); }, *tensor, CommType::RECV, @@ -518,19 +522,20 @@ std::shared_ptr ProcessGroupNCCL::Send( const phi::DenseTensor& tensor_maybe_partial = numel > 0 ? GetPartialTensor(tensor, offset, numel) : tensor; - CommStaticCheck::CheckShape(tensor_maybe_partial, rank_, size_); + phi::distributed::CommStaticCheck::CheckShape( + tensor_maybe_partial, rank_, size_); return RunFnInNCCLEnv( [&](ncclComm_t comm, gpuStream_t stream) { if (FLAGS_enable_nccl_dynamic_check) { - CommDynamicCheck::CheckShape(tensor_maybe_partial, - /*root_rank*/ rank_, - rank_, - comm); + phi::distributed::NCCLDynamicCheck::CheckShape(tensor_maybe_partial, + /*root_rank*/ rank_, + rank_, + comm); } NCCL_CHECK(phi::dynload::ncclSend( tensor_maybe_partial.data(), tensor_maybe_partial.numel(), - platform::ToNCCLDataType(tensor_maybe_partial.dtype()), + phi::ToNCCLDataType(tensor_maybe_partial.dtype()), dst_rank, comm, stream)); @@ -848,14 +853,13 @@ std::shared_ptr ProcessGroupNCCL::AllReduce( phi::DenseTensor& output, ncclComm_t comm, const gpuStream_t& stream) { - return phi::dynload::ncclAllReduce( - input.data(), - output.data(), - input.numel(), - platform::ToNCCLDataType(input.type()), - ToNCCLRedType(opts.reduce_op), - comm, - stream); + return phi::dynload::ncclAllReduce(input.data(), + output.data(), + input.numel(), + phi::ToNCCLDataType(input.type()), + ToNCCLRedType(opts.reduce_op), + comm, + stream); }, CommType::ALLREDUCE); } @@ -878,14 +882,13 @@ std::shared_ptr ProcessGroupNCCL::Broadcast( const gpuStream_t& stream) { const auto root = opts.source_rank * in_tensors.size() + opts.source_root; - return phi::dynload::ncclBroadcast( - input.data(), - output.data(), - input.numel(), - platform::ToNCCLDataType(input.type()), - root, - comm, - stream); + return phi::dynload::ncclBroadcast(input.data(), + output.data(), + input.numel(), + phi::ToNCCLDataType(input.type()), + root, + comm, + stream); }, CommType::BROADCAST); } @@ -930,7 +933,7 @@ std::shared_ptr ProcessGroupNCCL::Send( int dst_rank) { return phi::dynload::ncclSend(input.data(), input.numel(), - platform::ToNCCLDataType(input.dtype()), + phi::ToNCCLDataType(input.dtype()), dst_rank, comm, stream); @@ -952,7 +955,7 @@ std::shared_ptr ProcessGroupNCCL::Recv( int src_rank) { return phi::dynload::ncclRecv(output.data(), output.numel(), - platform::ToNCCLDataType(output.dtype()), + phi::ToNCCLDataType(output.dtype()), src_rank, comm, stream); @@ -980,13 +983,12 @@ std::shared_ptr ProcessGroupNCCL::AllGather( phi::DenseTensor& output, ncclComm_t comm, const gpuStream_t& stream) { - return phi::dynload::ncclAllGather( - input.data(), - output.data(), - input.numel(), - platform::ToNCCLDataType(input.dtype()), - comm, - stream); + return phi::dynload::ncclAllGather(input.data(), + output.data(), + input.numel(), + phi::ToNCCLDataType(input.dtype()), + comm, + stream); }, CommType::ALLGATHER); } @@ -1052,14 +1054,14 @@ std::shared_ptr ProcessGroupNCCL::AllToAll( PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::ncclSend( GetPointerByOffset(input.data(), offset, input.dtype()), input.numel() / size_, - platform::ToNCCLDataType(input.dtype()), + phi::ToNCCLDataType(input.dtype()), i, comm, stream)); PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::ncclRecv( GetPointerByOffset(output.data(), offset, input.dtype()), input.numel() / size_, - platform::ToNCCLDataType(input.dtype()), + phi::ToNCCLDataType(input.dtype()), i, comm, stream)); @@ -1089,7 +1091,7 @@ std::shared_ptr ProcessGroupNCCL::Reduce( phi::dynload::ncclReduce(input.data(), output.data(), input.numel(), - platform::ToNCCLDataType(input.dtype()), + phi::ToNCCLDataType(input.dtype()), ToNCCLRedType(opts.reduce_op), opts.root_rank, comm, @@ -1124,7 +1126,7 @@ std::shared_ptr ProcessGroupNCCL::Scatter( PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::ncclSend( GetPointerByOffset(input.data(), offset, input.dtype()), input.numel() / size_, - platform::ToNCCLDataType(input.dtype()), + phi::ToNCCLDataType(input.dtype()), i, comm, stream)); @@ -1133,7 +1135,7 @@ std::shared_ptr ProcessGroupNCCL::Scatter( PADDLE_ENFORCE_GPU_SUCCESS( phi::dynload::ncclRecv(output.data(), input.numel() / size_, - platform::ToNCCLDataType(input.dtype()), + phi::ToNCCLDataType(input.dtype()), opts.root_rank, comm, stream)); @@ -1142,7 +1144,7 @@ std::shared_ptr ProcessGroupNCCL::Scatter( PADDLE_ENFORCE_GPU_SUCCESS( phi::dynload::ncclRecv(output.data(), input.numel() / size_, - platform::ToNCCLDataType(input.dtype()), + phi::ToNCCLDataType(input.dtype()), opts.root_rank, comm, stream)); diff --git a/paddle/fluid/distributed/collective/process_group_nccl.h b/paddle/fluid/distributed/collective/process_group_nccl.h index cb83b0ddfe..c5d3842a09 100644 --- a/paddle/fluid/distributed/collective/process_group_nccl.h +++ b/paddle/fluid/distributed/collective/process_group_nccl.h @@ -23,20 +23,11 @@ #include "paddle/fluid/distributed/collective/process_group.h" #include "paddle/fluid/distributed/collective/process_group_with_stream.h" #include "paddle/fluid/platform/device_event.h" +#include "paddle/phi/backends/gpu/forwards.h" #include "paddle/phi/common/place.h" #include "paddle/phi/core/device_context.h" #include "paddle/phi/core/distributed/store/store.h" -#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) -#include "paddle/fluid/distributed/collective/nccl_tools.h" -#endif - -#ifdef PADDLE_WITH_RCCL -#include "paddle/phi/backends/dynload/rccl.h" -#else -#include "paddle/phi/backends/dynload/nccl.h" -#endif - namespace paddle { namespace distributed { diff --git a/paddle/fluid/distributed/collective/types.h b/paddle/fluid/distributed/collective/types.h index 11628ea1f0..5dfb611821 100644 --- a/paddle/fluid/distributed/collective/types.h +++ b/paddle/fluid/distributed/collective/types.h @@ -13,9 +13,11 @@ // limitations under the License. #pragma once + #include #include #include + #include "paddle/phi/common/place.h" namespace paddle { diff --git a/paddle/phi/core/distributed/CMakeLists.txt b/paddle/phi/core/distributed/CMakeLists.txt index 4e0794e042..3c4d9d8500 100644 --- a/paddle/phi/core/distributed/CMakeLists.txt +++ b/paddle/phi/core/distributed/CMakeLists.txt @@ -1,3 +1,4 @@ +add_subdirectory(check) add_subdirectory(store) set(COMM_CONTEXT_MANAGER_DEPS tcp_store) @@ -6,7 +7,8 @@ if(WITH_NCCL OR WITH_RCCL) cc_library( nccl_comm_context SRCS nccl_comm_context.cc - DEPS dense_tensor) + DEPS dense_tensor comm_static_check nccl_dynamic_check) + list(APPEND COMM_CONTEXT_MANAGER_DEPS nccl_comm_context) endif() @@ -19,7 +21,7 @@ if(WITH_GLOO) cc_library( gloo_comm_context SRCS gloo_comm_context.cc - DEPS gloo_utils) + DEPS gloo_utils comm_static_check) list(APPEND COMM_CONTEXT_MANAGER_DEPS gloo_comm_context gloo_store) endif() diff --git a/paddle/phi/core/distributed/check/CMakeLists.txt b/paddle/phi/core/distributed/check/CMakeLists.txt new file mode 100644 index 0000000000..76f4977263 --- /dev/null +++ b/paddle/phi/core/distributed/check/CMakeLists.txt @@ -0,0 +1,11 @@ +cc_library( + comm_static_check + SRCS static_check.cc + DEPS place dense_tensor enforce) + +if(WITH_NCCL OR WITH_RCCL) + cc_library( + nccl_dynamic_check + SRCS nccl_dynamic_check.cc + DEPS dense_tensor) +endif() diff --git a/paddle/phi/core/distributed/check/nccl_dynamic_check.cc b/paddle/phi/core/distributed/check/nccl_dynamic_check.cc new file mode 100644 index 0000000000..6cb4c8cfe1 --- /dev/null +++ b/paddle/phi/core/distributed/check/nccl_dynamic_check.cc @@ -0,0 +1,163 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/core/distributed/check/nccl_dynamic_check.h" + +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/enforce.h" +#include "paddle/phi/core/errors.h" + +#if defined(PADDLE_WITH_RCCL) +#include + +#include "paddle/phi/backends/dynload/rccl.h" + +#define gpuMalloc hipMalloc +#define gpuMemcpy hipMemcpy +#define gpuMemcpyDeviceToHost hipMemcpyDeviceToHost +#define gpuMemcpyHostToDevice hipMemcpyHostToDevice +#define gpuFree hipFree +#else +#include + +#include "paddle/phi/backends/dynload/nccl.h" + +#define gpuMalloc cudaMalloc +#define gpuMemcpy cudaMemcpy +#define gpuMemcpyDeviceToHost cudaMemcpyDeviceToHost +#define gpuMemcpyHostToDevice cudaMemcpyHostToDevice +#define gpuFree cudaFree +#endif + +namespace phi { +namespace distributed { +void NCCLDynamicCheck::CheckDataType(const phi::DenseTensor& tensor, + int64_t dtype) { + PADDLE_ENFORCE_EQ( + static_cast(tensor.dtype()), + dtype, + phi::errors::InvalidArgument( + "Tensors in communication are expected to have the same data type.")); +} + +void NCCLDynamicCheck::CheckDataType(const phi::DenseTensor& tensor, + int root_rank, + int cur_rank, + ncclComm_t comm) { + constexpr int kSize = sizeof(int64_t); + int64_t dtype_host = static_cast(tensor.dtype()); + int64_t* dtype_device; + PADDLE_ENFORCE_GPU_SUCCESS(gpuMalloc(&dtype_device, kSize)); + PADDLE_ENFORCE_GPU_SUCCESS( + gpuMemcpy(dtype_device, &dtype_host, kSize, gpuMemcpyHostToDevice)); + + PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::ncclBroadcast(dtype_device, + dtype_device, + kSize, + ncclInt64, + root_rank, + comm, + kDefaultStream)); + + if (root_rank == cur_rank) { + VLOG(3) << "Dynamic check broadcast metadata, dtype: " << dtype_host; + } else { + PADDLE_ENFORCE_GPU_SUCCESS( + gpuMemcpy(&dtype_host, dtype_device, kSize, gpuMemcpyDeviceToHost)); + VLOG(3) << "Dynamic check recv metadata, dtype: " << dtype_host; + CheckDataType(tensor, dtype_host); + } + PADDLE_ENFORCE_GPU_SUCCESS(gpuFree(dtype_device)); +} + +void NCCLDynamicCheck::CheckShape(const phi::DenseTensor& tensor, + int64_t shape) { + PADDLE_ENFORCE_EQ( + tensor.numel(), + shape, + phi::errors::InvalidArgument( + "Tensors in communication are expected to have matching sizes.")); +} + +void NCCLDynamicCheck::CheckShape(const phi::DenseTensor& tensor, + int root_rank, + int cur_rank, + ncclComm_t comm) { + CheckDataType(tensor, root_rank, cur_rank, comm); + + constexpr int kSize = sizeof(int64_t); + int64_t shape_host = tensor.numel(); + int64_t* shape_device; + + PADDLE_ENFORCE_GPU_SUCCESS(gpuMalloc(&shape_device, kSize)); + PADDLE_ENFORCE_GPU_SUCCESS( + gpuMemcpy(shape_device, &shape_host, kSize, gpuMemcpyHostToDevice)); + + PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::ncclBroadcast(shape_device, + shape_device, + kSize, + ncclInt64, + root_rank, + comm, + kDefaultStream)); + + if (root_rank == cur_rank) { + VLOG(3) << "Dynamic check broadcast metadata, shape: " << shape_host; + } else { + PADDLE_ENFORCE_GPU_SUCCESS( + gpuMemcpy(&shape_host, shape_device, kSize, gpuMemcpyDeviceToHost)); + VLOG(3) << "Dynamic check recv metadata, shape: " << shape_host; + CheckShape(tensor, shape_host); + } + PADDLE_ENFORCE_GPU_SUCCESS(gpuFree(shape_device)); +} + +void NCCLDynamicCheck::CheckShape(const phi::DenseTensor& out_tensor, + const phi::DenseTensor& in_tensor, + const std::vector& in_size_each_rank, + int cur_rank, + int world_size, + ncclComm_t comm) { + CheckDataType(out_tensor, /*root_rank*/ 0, cur_rank, comm); + CheckDataType(in_tensor, /*root_rank*/ 0, cur_rank, comm); + + constexpr int kSize = sizeof(int64_t); + int64_t in_row_size = in_tensor.numel() / in_tensor.dims()[0]; + + for (int rank = 0; rank < world_size; ++rank) { + int64_t in_shape_host = in_size_each_rank[rank] * in_row_size; + int64_t* in_shape_device; + PADDLE_ENFORCE_GPU_SUCCESS(gpuMalloc(&in_shape_device, kSize)); + PADDLE_ENFORCE_GPU_SUCCESS(gpuMemcpy( + in_shape_device, &in_shape_host, kSize, gpuMemcpyHostToDevice)); + + PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::ncclReduce(in_shape_device, + in_shape_device, + kSize, + ncclInt64, + ncclSum, + rank, + comm, + kDefaultStream)); + if (rank == cur_rank) { + PADDLE_ENFORCE_GPU_SUCCESS(gpuMemcpy( + &in_shape_host, in_shape_device, kSize, gpuMemcpyDeviceToHost)); + VLOG(3) << "Dynamic check recv metadata, shape: " << in_shape_host; + CheckShape(out_tensor, in_shape_host); + } + PADDLE_ENFORCE_GPU_SUCCESS(gpuFree(in_shape_device)); + } +} +} // namespace distributed +} // namespace phi diff --git a/paddle/phi/core/distributed/check/nccl_dynamic_check.h b/paddle/phi/core/distributed/check/nccl_dynamic_check.h new file mode 100644 index 0000000000..64c13e2a76 --- /dev/null +++ b/paddle/phi/core/distributed/check/nccl_dynamic_check.h @@ -0,0 +1,61 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include + +#include "paddle/phi/backends/gpu/forwards.h" + +#if defined(PADDLE_WITH_RCCL) +using gpuStream_t = hipStream_t; +#else +using gpuStream_t = cudaStream_t; +#endif + +namespace phi { +// forward declaration +class DenseTensor; + +namespace distributed { +struct NCCLDynamicCheck { + static void CheckDataType(const phi::DenseTensor& tensor, int64_t dtype); + + static void CheckDataType(const phi::DenseTensor& tensor, + int root_rank, + int cur_rank, + ncclComm_t comm); + + static void CheckShape(const phi::DenseTensor& tensor, int64_t shape); + + static void CheckShape(const phi::DenseTensor& tensor, + int root_rank, + int cur_rank, + ncclComm_t comm); + + static void CheckShape(const phi::DenseTensor& out_tensor, + const phi::DenseTensor& in_tensor, + const std::vector& in_size_each_rank, + int cur_rank, + int world_size, + ncclComm_t comm); + + private: + // `0` represents default stream for both cuda & hip + static constexpr gpuStream_t kDefaultStream = 0; +}; + +} // namespace distributed +} // namespace phi diff --git a/paddle/phi/core/distributed/check/static_check.cc b/paddle/phi/core/distributed/check/static_check.cc new file mode 100644 index 0000000000..8ec3e19e60 --- /dev/null +++ b/paddle/phi/core/distributed/check/static_check.cc @@ -0,0 +1,167 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/core/distributed/check/static_check.h" + +#include +#include + +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/enforce.h" +#include "paddle/phi/core/errors.h" + +namespace phi { +namespace distributed { + +void CommStaticCheck::CheckRank(int rank, int world_size) { + PADDLE_ENFORCE_GE(rank, + 0, + phi::errors::InvalidArgument( + "Rank should be greater than or equal to 0.")); + PADDLE_ENFORCE_LT( + rank, + world_size, + phi::errors::InvalidArgument("Rank is out of the process group.")); +} + +void CommStaticCheck::CheckPlace(const phi::DenseTensor& tensor, + phi::AllocationType place) { + PADDLE_ENFORCE_EQ( + tensor.place().GetType(), + place, + phi::errors::InvalidArgument("Tensor should be in backend's place.")); +} + +void CommStaticCheck::CheckPlace(const phi::DenseTensor& out_tensor, + const phi::DenseTensor& in_tensor, + phi::AllocationType place) { + CheckPlace(out_tensor, place); + CheckPlace(in_tensor, place); + PADDLE_ENFORCE_EQ( + out_tensor.place(), + in_tensor.place(), + phi::errors::InvalidArgument( + "Input and output tensors should be on the same place.")); +} + +void CommStaticCheck::CheckDataType(const phi::DenseTensor& out_tensor, + const phi::DenseTensor& in_tensor) { + PADDLE_ENFORCE_EQ( + out_tensor.dtype(), + in_tensor.dtype(), + phi::errors::InvalidArgument( + "Input and output tensors should have the same data type.")); +} + +void CommStaticCheck::CheckShape(const phi::DenseTensor& tensor) { + PADDLE_ENFORCE_GT( + tensor.numel(), + 0, + phi::errors::InvalidArgument("Size of tensor should be greater than 0.")); +} + +void CommStaticCheck::CheckShape(const phi::DenseTensor& out_tensor, + const phi::DenseTensor& in_tensor, + int out_size_factor, + int in_size_factor) { + CheckShape(out_tensor); + CheckShape(in_tensor); + int64_t out_size = out_tensor.numel(), in_size = in_tensor.numel(); + PADDLE_ENFORCE_EQ( + out_size * out_size_factor, + in_size * in_size_factor, + phi::errors::InvalidArgument( + "Input and output tensors should have matching sizes.")); +} + +void CommStaticCheck::CheckShape(const phi::DenseTensor& out_tensor, + const phi::DenseTensor& in_tensor, + int dst_rank, + int cur_rank, + int world_size, + int out_size_factor, + int in_size_factor, + phi::AllocationType place) { + CheckRank(dst_rank, world_size); + CheckRank(cur_rank, world_size); + + CheckPlace(out_tensor, in_tensor, place); + CheckDataType(out_tensor, in_tensor); + + if (dst_rank == cur_rank) { + CheckShape(out_tensor, in_tensor, out_size_factor, in_size_factor); + } else { + CheckShape(out_tensor); + CheckShape(in_tensor); + } +} + +void CommStaticCheck::CheckShape(const phi::DenseTensor& tensor, + int rank, + int world_size, + phi::AllocationType place) { + CheckPlace(tensor, place); + CheckRank(rank, world_size); +} + +void CommStaticCheck::SameShape(const phi::DenseTensor& out_tensor, + const phi::DenseTensor& in_tensor, + int dst_rank, + int cur_rank, + int world_size, + phi::AllocationType place) { + CheckShape(out_tensor, + in_tensor, + dst_rank, + cur_rank, + world_size, + /*out_size_factor*/ 1, + /*in_size_factor*/ 1, + place); +} + +void CommStaticCheck::ScatterLikeShape(const phi::DenseTensor& out_tensor, + const phi::DenseTensor& in_tensor, + int dst_rank, + int cur_rank, + int world_size, + phi::AllocationType place) { + CheckShape(out_tensor, + in_tensor, + dst_rank, + cur_rank, + world_size, + /*out_size_factor*/ world_size, + /*in_size_factor*/ 1, + place); +} + +void CommStaticCheck::GatherLikeShape(const phi::DenseTensor& out_tensor, + const phi::DenseTensor& in_tensor, + int dst_rank, + int cur_rank, + int world_size, + phi::AllocationType place) { + CheckShape(out_tensor, + in_tensor, + dst_rank, + cur_rank, + world_size, + /*out_size_factor*/ 1, + /*in_size_factor*/ world_size, + place); +} + +} // namespace distributed +} // namespace phi diff --git a/paddle/phi/core/distributed/check/static_check.h b/paddle/phi/core/distributed/check/static_check.h new file mode 100644 index 0000000000..2b14f7cb32 --- /dev/null +++ b/paddle/phi/core/distributed/check/static_check.h @@ -0,0 +1,85 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/common/place.h" + +namespace phi { +// forward declaration +class DenseTensor; + +namespace distributed { +struct CommStaticCheck { + static void CheckRank(int rank, int world_size); + + static void CheckPlace(const phi::DenseTensor& tensor, + phi::AllocationType place); + + static void CheckPlace(const phi::DenseTensor& out_tensor, + const phi::DenseTensor& in_tensor, + phi::AllocationType place); + + static void CheckDataType(const phi::DenseTensor& out_tensor, + const phi::DenseTensor& in_tensor); + + static void CheckShape(const phi::DenseTensor& tensor); + + static void CheckShape(const phi::DenseTensor& out_tensor, + const phi::DenseTensor& in_tensor, + int out_size_factor, + int in_size_factor); + + static void CheckShape(const phi::DenseTensor& out_tensor, + const phi::DenseTensor& in_tensor, + int dst_rank, + int cur_rank, + int world_size, + int out_size_factor, + int in_size_factor, + phi::AllocationType place = phi::AllocationType::GPU); + + // for p2p + static void CheckShape(const phi::DenseTensor& tensor, + int rank, + int world_size, + phi::AllocationType place = phi::AllocationType::GPU); + + // for collective + static void SameShape(const phi::DenseTensor& out_tensor, + const phi::DenseTensor& in_tensor, + int dst_rank, + int cur_rank, + int world_size, + phi::AllocationType place = phi::AllocationType::GPU); + + static void ScatterLikeShape( + const phi::DenseTensor& out_tensor, + const phi::DenseTensor& in_tensor, + int dst_rank, + int cur_rank, + int world_size, + phi::AllocationType place = phi::AllocationType::GPU); + + static void GatherLikeShape( + const phi::DenseTensor& out_tensor, + const phi::DenseTensor& in_tensor, + int dst_rank, + int cur_rank, + int world_size, + phi::AllocationType place = phi::AllocationType::GPU); +}; + +} // namespace distributed +} // namespace phi diff --git a/paddle/phi/core/distributed/comm_context_manager.cc b/paddle/phi/core/distributed/comm_context_manager.cc index 7ad44f29f4..7192946555 100644 --- a/paddle/phi/core/distributed/comm_context_manager.cc +++ b/paddle/phi/core/distributed/comm_context_manager.cc @@ -13,7 +13,7 @@ // limitations under the License. #if defined(PADDLE_WITH_GLOO) -#include "gloo/rendezvous/prefix_store.h" +#include #include "paddle/phi/core/distributed/gloo_comm_context.h" #include "paddle/phi/core/distributed/gloo_utils.h" diff --git a/paddle/phi/core/distributed/gloo_comm_context.cc b/paddle/phi/core/distributed/gloo_comm_context.cc index d51db3bee8..9830ebdc10 100644 --- a/paddle/phi/core/distributed/gloo_comm_context.cc +++ b/paddle/phi/core/distributed/gloo_comm_context.cc @@ -14,11 +14,12 @@ #include "paddle/phi/core/distributed/gloo_comm_context.h" -#include "gloo/broadcast.h" -#include "gloo/types.h" +#include +#include #include "paddle/phi/common/data_type.h" #include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/distributed/check/static_check.h" #include "paddle/phi/core/distributed/gloo_utils.h" #include "paddle/phi/core/enforce.h" @@ -38,6 +39,13 @@ GlooCommContext::GlooCommContext( void GlooCommContext::Broadcast(phi::DenseTensor* out_tensor, const phi::DenseTensor& in_tensor, int root) { + // gloo only uses CPU now + CommStaticCheck::SameShape(*out_tensor, + in_tensor, + /*dst_rank*/ rank_, + /*cur_rank*/ rank_, + size_, + phi::AllocationType::CPU); gloo::BroadcastOptions opts(gloo_context_); const auto& dtype = in_tensor.dtype(); GENERATE_FUNC(dtype, SetOutput, &opts, out_tensor); diff --git a/paddle/phi/core/distributed/gloo_comm_context.h b/paddle/phi/core/distributed/gloo_comm_context.h index f3bcd7c1e5..dccd187529 100644 --- a/paddle/phi/core/distributed/gloo_comm_context.h +++ b/paddle/phi/core/distributed/gloo_comm_context.h @@ -13,11 +13,11 @@ // limitations under the License. #pragma once -#include +#include +#include +#include -#include "gloo/rendezvous/context.h" -#include "gloo/rendezvous/store.h" -#include "gloo/transport/tcp/device.h" +#include #include "paddle/phi/core/distributed/comm_context.h" #include "paddle/phi/core/macros.h" diff --git a/paddle/phi/core/distributed/gloo_utils.cc b/paddle/phi/core/distributed/gloo_utils.cc index 76ef17f0f0..d853e6fd47 100644 --- a/paddle/phi/core/distributed/gloo_utils.cc +++ b/paddle/phi/core/distributed/gloo_utils.cc @@ -13,9 +13,9 @@ // limitations under the License. #ifdef _WIN32 +#include #include #include -#include "gloo/common/win.h" #else #include #include diff --git a/paddle/phi/core/distributed/gloo_utils.h b/paddle/phi/core/distributed/gloo_utils.h index 3101c7949b..1efdd40efb 100644 --- a/paddle/phi/core/distributed/gloo_utils.h +++ b/paddle/phi/core/distributed/gloo_utils.h @@ -14,13 +14,13 @@ #pragma once +#include +#include + #include #include #include -#include "gloo/transport/tcp/device.h" -#include "gloo/types.h" - #include "paddle/phi/common/data_type.h" #include "paddle/phi/core/dense_tensor.h" diff --git a/paddle/phi/core/distributed/nccl_comm_context.cc b/paddle/phi/core/distributed/nccl_comm_context.cc index 32c1a2e744..5de134c009 100644 --- a/paddle/phi/core/distributed/nccl_comm_context.cc +++ b/paddle/phi/core/distributed/nccl_comm_context.cc @@ -14,12 +14,17 @@ #include "paddle/phi/core/distributed/nccl_comm_context.h" #include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/distributed/check/nccl_dynamic_check.h" +#include "paddle/phi/core/distributed/check/static_check.h" #include "paddle/phi/core/enforce.h" #include "paddle/phi/core/utils/data_type.h" namespace phi { namespace distributed { +// set this flag to `true` and recompile to enable dynamic checks +constexpr bool FLAGS_enable_nccl_dynamic_check = false; + NCCLCommContext::NCCLCommContext(int rank, int size, ncclUniqueId nccl_id) : CommContext(rank, size) { PADDLE_ENFORCE_GPU_SUCCESS( @@ -30,13 +35,22 @@ void NCCLCommContext::Broadcast(phi::DenseTensor* out_tensor, const phi::DenseTensor& in_tensor, int root, gpuStream_t stream) { - phi::dynload::ncclBroadcast(in_tensor.data(), - out_tensor->data(), - in_tensor.numel(), - ToNCCLDataType(in_tensor.type()), - root, - nccl_comm_, - stream); + CommStaticCheck::SameShape(*out_tensor, + in_tensor, + /*dst_rank*/ rank_, + /*cur_rank*/ rank_, + size_); + if (FLAGS_enable_nccl_dynamic_check) { + NCCLDynamicCheck::CheckShape(*out_tensor, root, rank_, nccl_comm_); + } + PADDLE_ENFORCE_GPU_SUCCESS( + phi::dynload::ncclBroadcast(in_tensor.data(), + out_tensor->data(), + in_tensor.numel(), + ToNCCLDataType(in_tensor.type()), + root, + nccl_comm_, + stream)); } } // namespace distributed -- GitLab