未验证 提交 4552be48 编写于 作者: W Wen Sun 提交者: GitHub

Refactor collective communication static check (#48646)

* refactor: classify static check

* refactor: rename to static_check & use forward decl

* refactor: switch to unary & binary funcs
上级 f9815bfe
......@@ -21,7 +21,7 @@ endif()
if(WITH_NCCL OR WITH_RCCL)
cc_library(
processgroup_nccl
SRCS ProcessGroupNCCL.cc NCCLTools.cc Common.cc
SRCS ProcessGroupNCCL.cc NCCLTools.cc Common.cc static_check.cc
DEPS processgroup
processgroup_stream
place
......
......@@ -44,109 +44,5 @@ std::string SerializeNCCLUniqueId(const ncclUniqueId& ncclID) {
return oss.str();
}
void StaticCheckTensor(const phi::DenseTensor& tensor,
int rank,
int world_size) {
// place check
PADDLE_ENFORCE_EQ(
platform::is_gpu_place(tensor.place()),
true,
platform::errors::InvalidArgument("Tensor should be in GPU place."));
// rank check
PADDLE_ENFORCE_GE(rank,
0,
platform::errors::InvalidArgument(
"Rank should be greater than or equal to 0."));
PADDLE_ENFORCE_LT(
rank,
world_size,
platform::errors::InvalidArgument("Rank is out of the process group."));
}
// static check for collective
void StaticCheckTensors(const phi::DenseTensor& out_tensor,
const phi::DenseTensor& in_tensor,
int rank,
int world_size,
int out_size_factor,
int in_size_factor) {
// place check
PADDLE_ENFORCE_EQ(platform::is_gpu_place(out_tensor.place()),
true,
platform::errors::InvalidArgument(
"Output tensor should be in GPU place."));
PADDLE_ENFORCE_EQ(platform::is_gpu_place(in_tensor.place()),
true,
platform::errors::InvalidArgument(
"Input tensor should be in GPU place."));
// rank check
PADDLE_ENFORCE_GE(rank,
0,
platform::errors::InvalidArgument(
"Rank should be greater than or equal to 0."));
PADDLE_ENFORCE_LT(
rank,
world_size,
platform::errors::InvalidArgument("Rank is out of the process group."));
// shape check
int64_t out_size = out_tensor.numel();
PADDLE_ENFORCE_GT(out_size,
0,
platform::errors::InvalidArgument(
"Size of output tensor should be greater than 0."));
int64_t in_size = in_tensor.numel();
PADDLE_ENFORCE_GT(in_size,
0,
platform::errors::InvalidArgument(
"Size of input tensor should be greater than 0."));
PADDLE_ENFORCE_EQ(
out_size * out_size_factor,
in_size * in_size_factor,
platform::errors::InvalidArgument(
"Input and output tensors should have matching sizes."));
// dtype check
PADDLE_ENFORCE_EQ(
out_tensor.dtype(),
in_tensor.dtype(),
platform::errors::InvalidArgument(
"Input and output tensors should have the same data type."));
}
void StaticCheckTensorsSameShape(const phi::DenseTensor& out_tensor,
const phi::DenseTensor& in_tensor,
int rank,
int world_size) {
StaticCheckTensors(out_tensor,
in_tensor,
rank,
world_size,
/*out_size_factor*/ 1,
/*in_size_factor*/ 1);
}
void StaticCheckTensorsScatterLikeShape(const phi::DenseTensor& out_tensor,
const phi::DenseTensor& in_tensor,
int rank,
int world_size) {
StaticCheckTensors(out_tensor,
in_tensor,
rank,
world_size,
/*out_size_factor*/ world_size,
/*in_size_factor*/ 1);
}
void StaticCheckTensorsGatherLikeShape(const phi::DenseTensor& out_tensor,
const phi::DenseTensor& in_tensor,
int rank,
int world_size) {
StaticCheckTensors(out_tensor,
in_tensor,
rank,
world_size,
/*out_size_factor*/ 1,
/*in_size_factor*/ world_size);
}
} // namespace distributed
} // namespace paddle
......@@ -63,32 +63,5 @@ ncclRedOp_t ToNCCLRedType(ReduceOp reduction);
std::string SerializeNCCLUniqueId(const ncclUniqueId& ncclID);
// static check for p2p
void StaticCheckTensor(const phi::DenseTensor& tensor,
int rank,
int world_size);
// static check for collective
void StaticCheckTensors(const phi::DenseTensor& out_tensor,
const phi::DenseTensor& in_tensor,
int rank,
int world_size,
int out_size_factor,
int in_size_factor);
void StaticCheckTensorsSameShape(const phi::DenseTensor& out_tensor,
const phi::DenseTensor& in_tensor,
int rank,
int world_size);
void StaticCheckTensorsScatterLikeShape(const phi::DenseTensor& out_tensor,
const phi::DenseTensor& in_tensor,
int rank,
int world_size);
void StaticCheckTensorsGatherLikeShape(const phi::DenseTensor& out_tensor,
const phi::DenseTensor& in_tensor,
int rank,
int world_size);
} // namespace distributed
} // namespace paddle
......@@ -16,6 +16,7 @@
#include "paddle/fluid/distributed/collective/Common.h"
#include "paddle/fluid/distributed/collective/NCCLTools.h"
#include "paddle/fluid/distributed/collective/static_check.h"
#include "paddle/fluid/distributed/collective/utils.h"
#include "paddle/fluid/platform/device/gpu/nccl_helper.h"
#include "paddle/fluid/platform/place.h"
......@@ -138,8 +139,11 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::AllGather(
// numel > 0 indicates the tensor need to be sliced
const phi::DenseTensor& in_tensor_maybe_partial =
numel > 0 ? GetPartialTensor(in_tensor, offset, numel) : in_tensor;
StaticCheckTensorsGatherLikeShape(
*out_tensor, in_tensor_maybe_partial, rank_, size_);
CommStaticCheck::GatherLikeShape(*out_tensor,
in_tensor_maybe_partial,
/*dst_rank*/ rank_,
/*cur_rank*/ rank_,
size_);
return RunFnInNCCLEnv(
[&](ncclComm_t comm, gpuStream_t stream) {
NCCL_CHECK(platform::dynload::ncclAllGather(
......@@ -162,7 +166,11 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::AllReduce(
const AllreduceOptions& opts,
bool sync_op,
bool use_calc_stream) {
StaticCheckTensorsSameShape(*out_tensor, in_tensor, rank_, size_);
CommStaticCheck::SameShape(*out_tensor,
in_tensor,
/*dst_rank*/ rank_,
/*cur_rank*/ rank_,
size_);
return RunFnInNCCLEnv(
[&](ncclComm_t comm, gpuStream_t stream) {
NCCL_CHECK(platform::dynload::ncclAllReduce(
......@@ -214,12 +222,13 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::AllToAll(
// NOTE: Since `all_to_all` needs other processes's participation, it cannot
// simply be covered by static checks. Factors are set to 0 here to skip the
// shape check. Its shape check will be done by dynamic checks in debug mode.
StaticCheckTensors(*out_tensor,
in_tensor,
rank_,
size_,
/*out_size_factor*/ 0,
/*in_size_factor*/ 0);
CommStaticCheck::CheckShape(*out_tensor,
in_tensor,
/*dst_rank*/ rank_,
/*cur_rank*/ rank_,
size_,
/*out_size_factor*/ 0,
/*in_size_factor*/ 0);
return RunFnInNCCLEnv(
[&](ncclComm_t comm, gpuStream_t stream) {
int64_t in_row_size = in_tensor.numel() / in_dim[0],
......@@ -287,7 +296,11 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Broadcast(
const BroadcastOptions& opts,
bool sync_op,
bool use_calc_stream) {
StaticCheckTensorsSameShape(*out_tensor, in_tensor, rank_, size_);
CommStaticCheck::SameShape(*out_tensor,
in_tensor,
/*dst_rank*/ rank_,
/*cur_rank*/ rank_,
size_);
return RunFnInNCCLEnv(
[&](ncclComm_t comm, gpuStream_t stream) {
int root = opts.source_rank + opts.source_root;
......@@ -312,7 +325,11 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Reduce(
const ReduceOptions& opts,
bool sync_op,
bool use_calc_stream) {
StaticCheckTensorsSameShape(*out_tensor, in_tensor, rank_, size_);
CommStaticCheck::SameShape(*out_tensor,
in_tensor,
/*dst_rank*/ opts.root_rank,
/*cur_rank*/ rank_,
size_);
return RunFnInNCCLEnv(
[&](ncclComm_t comm, gpuStream_t stream) {
NCCL_CHECK(platform::dynload::ncclReduce(
......@@ -337,7 +354,11 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::ReduceScatter(
const ReduceScatterOptions& opts,
bool sync_op,
bool use_calc_stream) {
StaticCheckTensorsScatterLikeShape(*out_tensor, in_tensor, rank_, size_);
CommStaticCheck::ScatterLikeShape(*out_tensor,
in_tensor,
/*dst_rank*/ rank_,
/*cur_rank*/ rank_,
size_);
return RunFnInNCCLEnv(
[&](ncclComm_t comm, gpuStream_t stream) {
NCCL_CHECK(platform::dynload::ncclReduceScatter(
......@@ -361,7 +382,11 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Scatter(
const ScatterOptions& opts,
bool sync_op,
bool use_calc_stream) {
StaticCheckTensorsScatterLikeShape(*out_tensor, in_tensor, rank_, size_);
CommStaticCheck::ScatterLikeShape(*out_tensor,
in_tensor,
/*dst_rank*/ opts.root_rank,
/*cur_rank*/ rank_,
size_);
return RunFnInNCCLEnv(
[&](ncclComm_t comm, gpuStream_t stream) {
int64_t numel = in_tensor.numel() / size_;
......@@ -418,7 +443,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Recv(
tensor = &partial_tensor;
}
StaticCheckTensor(*tensor, rank_, size_);
CommStaticCheck::SingleTensor(*tensor, rank_, size_);
return RunFnInNCCLEnv(
[&](ncclComm_t comm, gpuStream_t stream) {
NCCL_CHECK(platform::dynload::ncclRecv(
......@@ -446,7 +471,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Send(
const phi::DenseTensor& tensor_maybe_partial =
numel > 0 ? GetPartialTensor(tensor, offset, numel) : tensor;
StaticCheckTensor(tensor_maybe_partial, rank_, size_);
CommStaticCheck::SingleTensor(tensor_maybe_partial, rank_, size_);
return RunFnInNCCLEnv(
[&](ncclComm_t comm, gpuStream_t stream) {
NCCL_CHECK(platform::dynload::ncclSend(
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/distributed/collective/static_check.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/errors.h"
namespace paddle {
namespace distributed {
void CommStaticCheck::CheckRank(int rank, int world_size) {
PADDLE_ENFORCE_GE(rank,
0,
phi::errors::InvalidArgument(
"Rank should be greater than or equal to 0."));
PADDLE_ENFORCE_LT(
rank,
world_size,
phi::errors::InvalidArgument("Rank is out of the process group."));
}
void CommStaticCheck::CheckPlace(const phi::DenseTensor& tensor) {
PADDLE_ENFORCE_EQ(
platform::is_gpu_place(tensor.place()),
true,
platform::errors::InvalidArgument("Tensor should be in GPU place."));
}
void CommStaticCheck::CheckPlace(const phi::DenseTensor& out_tensor,
const phi::DenseTensor& in_tensor) {
CheckPlace(out_tensor);
CheckPlace(in_tensor);
PADDLE_ENFORCE_EQ(
out_tensor.place(),
in_tensor.place(),
phi::errors::InvalidArgument(
"Input and output tensors should be on the same place."));
}
void CommStaticCheck::CheckDataType(const phi::DenseTensor& out_tensor,
const phi::DenseTensor& in_tensor) {
PADDLE_ENFORCE_EQ(
out_tensor.dtype(),
in_tensor.dtype(),
phi::errors::InvalidArgument(
"Input and output tensors should have the same data type."));
}
void CommStaticCheck::CheckShape(const phi::DenseTensor& tensor) {
PADDLE_ENFORCE_GT(
tensor.numel(),
0,
phi::errors::InvalidArgument("Size of tensor should be greater than 0."));
}
void CommStaticCheck::CheckShape(const phi::DenseTensor& out_tensor,
const phi::DenseTensor& in_tensor,
int out_size_factor,
int in_size_factor) {
CheckShape(out_tensor);
CheckShape(in_tensor);
int64_t out_size = out_tensor.numel(), in_size = in_tensor.numel();
PADDLE_ENFORCE_EQ(
out_size * out_size_factor,
in_size * in_size_factor,
phi::errors::InvalidArgument(
"Input and output tensors should have matching sizes."));
}
void CommStaticCheck::CheckShape(const phi::DenseTensor& out_tensor,
const phi::DenseTensor& in_tensor,
int dst_rank,
int cur_rank,
int world_size,
int out_size_factor,
int in_size_factor) {
CheckRank(dst_rank, world_size);
CheckRank(cur_rank, world_size);
CheckPlace(out_tensor, in_tensor);
CheckDataType(out_tensor, in_tensor);
if (dst_rank == cur_rank) {
CheckShape(out_tensor, in_tensor, out_size_factor, in_size_factor);
} else {
CheckShape(out_tensor);
CheckShape(in_tensor);
}
}
void CommStaticCheck::SingleTensor(const phi::DenseTensor& tensor,
int rank,
int world_size) {
CheckPlace(tensor);
CheckRank(rank, world_size);
}
void CommStaticCheck::SameShape(const phi::DenseTensor& out_tensor,
const phi::DenseTensor& in_tensor,
int dst_rank,
int cur_rank,
int world_size) {
CheckShape(out_tensor,
in_tensor,
dst_rank,
cur_rank,
world_size,
/*out_size_factor*/ 1,
/*in_size_factor*/ 1);
}
void CommStaticCheck::ScatterLikeShape(const phi::DenseTensor& out_tensor,
const phi::DenseTensor& in_tensor,
int dst_rank,
int cur_rank,
int world_size) {
CheckShape(out_tensor,
in_tensor,
dst_rank,
cur_rank,
world_size,
/*out_size_factor*/ world_size,
/*in_size_factor*/ 1);
}
void CommStaticCheck::GatherLikeShape(const phi::DenseTensor& out_tensor,
const phi::DenseTensor& in_tensor,
int dst_rank,
int cur_rank,
int world_size) {
CheckShape(out_tensor,
in_tensor,
dst_rank,
cur_rank,
world_size,
/*out_size_factor*/ 1,
/*in_size_factor*/ world_size);
}
} // namespace distributed
} // namespace paddle
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
// forward declaration to reduce deps
namespace phi {
class DenseTensor;
}
namespace paddle {
namespace distributed {
struct CommStaticCheck {
static void CheckRank(int rank, int world_size);
static void CheckPlace(const phi::DenseTensor& tensor);
static void CheckPlace(const phi::DenseTensor& out_tensor,
const phi::DenseTensor& in_tensor);
static void CheckDataType(const phi::DenseTensor& out_tensor,
const phi::DenseTensor& in_tensor);
static void CheckShape(const phi::DenseTensor& tensor);
static void CheckShape(const phi::DenseTensor& out_tensor,
const phi::DenseTensor& in_tensor,
int out_size_factor,
int in_size_factor);
static void CheckShape(const phi::DenseTensor& out_tensor,
const phi::DenseTensor& in_tensor,
int dst_rank,
int cur_rank,
int world_size,
int out_size_factor,
int in_size_factor);
// for p2p
static void SingleTensor(const phi::DenseTensor& tensor,
int rank,
int world_size);
// for collective
static void SameShape(const phi::DenseTensor& out_tensor,
const phi::DenseTensor& in_tensor,
int dst_rank,
int cur_rank,
int world_size);
static void ScatterLikeShape(const phi::DenseTensor& out_tensor,
const phi::DenseTensor& in_tensor,
int dst_rank,
int cur_rank,
int world_size);
static void GatherLikeShape(const phi::DenseTensor& out_tensor,
const phi::DenseTensor& in_tensor,
int dst_rank,
int cur_rank,
int world_size);
};
} // namespace distributed
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册