未验证 提交 d828ca46 编写于 作者: W Wen Sun 提交者: GitHub

Add static checks for collective communication on NCCL (#48256)

* feat: static check
上级 88cac16b
...@@ -44,5 +44,109 @@ std::string SerializeNCCLUniqueId(const ncclUniqueId& ncclID) { ...@@ -44,5 +44,109 @@ std::string SerializeNCCLUniqueId(const ncclUniqueId& ncclID) {
return oss.str(); return oss.str();
} }
void StaticCheckTensor(const phi::DenseTensor& tensor,
int rank,
int world_size) {
// place check
PADDLE_ENFORCE_EQ(
platform::is_gpu_place(tensor.place()),
true,
platform::errors::InvalidArgument("Tensor should be in GPU place."));
// rank check
PADDLE_ENFORCE_GE(rank,
0,
platform::errors::InvalidArgument(
"Rank should be greater than or equal to 0."));
PADDLE_ENFORCE_LT(
rank,
world_size,
platform::errors::InvalidArgument("Rank is out of the process group."));
}
// static check for collective
void StaticCheckTensors(const phi::DenseTensor& out_tensor,
const phi::DenseTensor& in_tensor,
int rank,
int world_size,
int out_size_factor,
int in_size_factor) {
// place check
PADDLE_ENFORCE_EQ(platform::is_gpu_place(out_tensor.place()),
true,
platform::errors::InvalidArgument(
"Output tensor should be in GPU place."));
PADDLE_ENFORCE_EQ(platform::is_gpu_place(in_tensor.place()),
true,
platform::errors::InvalidArgument(
"Input tensor should be in GPU place."));
// rank check
PADDLE_ENFORCE_GE(rank,
0,
platform::errors::InvalidArgument(
"Rank should be greater than or equal to 0."));
PADDLE_ENFORCE_LT(
rank,
world_size,
platform::errors::InvalidArgument("Rank is out of the process group."));
// shape check
int64_t out_size = out_tensor.numel();
PADDLE_ENFORCE_GT(out_size,
0,
platform::errors::InvalidArgument(
"Size of output tensor should be greater than 0."));
int64_t in_size = in_tensor.numel();
PADDLE_ENFORCE_GT(in_size,
0,
platform::errors::InvalidArgument(
"Size of input tensor should be greater than 0."));
PADDLE_ENFORCE_EQ(
out_size * out_size_factor,
in_size * in_size_factor,
platform::errors::InvalidArgument(
"Input and output tensors should have matching sizes."));
// dtype check
PADDLE_ENFORCE_EQ(
out_tensor.dtype(),
in_tensor.dtype(),
platform::errors::InvalidArgument(
"Input and output tensors should have the same data type."));
}
void StaticCheckTensorsSameShape(const phi::DenseTensor& out_tensor,
const phi::DenseTensor& in_tensor,
int rank,
int world_size) {
StaticCheckTensors(out_tensor,
in_tensor,
rank,
world_size,
/*out_size_factor*/ 1,
/*in_size_factor*/ 1);
}
void StaticCheckTensorsScatterLikeShape(const phi::DenseTensor& out_tensor,
const phi::DenseTensor& in_tensor,
int rank,
int world_size) {
StaticCheckTensors(out_tensor,
in_tensor,
rank,
world_size,
/*out_size_factor*/ world_size,
/*in_size_factor*/ 1);
}
void StaticCheckTensorsGatherLikeShape(const phi::DenseTensor& out_tensor,
const phi::DenseTensor& in_tensor,
int rank,
int world_size) {
StaticCheckTensors(out_tensor,
in_tensor,
rank,
world_size,
/*out_size_factor*/ 1,
/*in_size_factor*/ world_size);
}
} // namespace distributed } // namespace distributed
} // namespace paddle } // namespace paddle
...@@ -63,5 +63,32 @@ ncclRedOp_t ToNCCLRedType(ReduceOp reduction); ...@@ -63,5 +63,32 @@ ncclRedOp_t ToNCCLRedType(ReduceOp reduction);
std::string SerializeNCCLUniqueId(const ncclUniqueId& ncclID); std::string SerializeNCCLUniqueId(const ncclUniqueId& ncclID);
// static check for p2p
void StaticCheckTensor(const phi::DenseTensor& tensor,
int rank,
int world_size);
// static check for collective
void StaticCheckTensors(const phi::DenseTensor& out_tensor,
const phi::DenseTensor& in_tensor,
int rank,
int world_size,
int out_size_factor,
int in_size_factor);
void StaticCheckTensorsSameShape(const phi::DenseTensor& out_tensor,
const phi::DenseTensor& in_tensor,
int rank,
int world_size);
void StaticCheckTensorsScatterLikeShape(const phi::DenseTensor& out_tensor,
const phi::DenseTensor& in_tensor,
int rank,
int world_size);
void StaticCheckTensorsGatherLikeShape(const phi::DenseTensor& out_tensor,
const phi::DenseTensor& in_tensor,
int rank,
int world_size);
} // namespace distributed } // namespace distributed
} // namespace paddle } // namespace paddle
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
#include "paddle/fluid/distributed/collective/ProcessGroupNCCL.h" #include "paddle/fluid/distributed/collective/ProcessGroupNCCL.h"
#include "paddle/fluid/distributed/collective/Common.h" #include "paddle/fluid/distributed/collective/Common.h"
#include "paddle/fluid/distributed/collective/NCCLTools.h"
#include "paddle/fluid/distributed/collective/utils.h" #include "paddle/fluid/distributed/collective/utils.h"
#include "paddle/fluid/platform/device/gpu/nccl_helper.h" #include "paddle/fluid/platform/device/gpu/nccl_helper.h"
#include "paddle/fluid/platform/place.h" #include "paddle/fluid/platform/place.h"
...@@ -137,6 +138,8 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::AllGather( ...@@ -137,6 +138,8 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::AllGather(
// numel > 0 indicates the tensor need to be sliced // numel > 0 indicates the tensor need to be sliced
const phi::DenseTensor& in_tensor_maybe_partial = const phi::DenseTensor& in_tensor_maybe_partial =
numel > 0 ? GetPartialTensor(in_tensor, offset, numel) : in_tensor; numel > 0 ? GetPartialTensor(in_tensor, offset, numel) : in_tensor;
StaticCheckTensorsGatherLikeShape(
*out_tensor, in_tensor_maybe_partial, rank_, size_);
return RunFnInNCCLEnv( return RunFnInNCCLEnv(
[&](ncclComm_t comm, gpuStream_t stream) { [&](ncclComm_t comm, gpuStream_t stream) {
NCCL_CHECK(platform::dynload::ncclAllGather( NCCL_CHECK(platform::dynload::ncclAllGather(
...@@ -159,6 +162,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::AllReduce( ...@@ -159,6 +162,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::AllReduce(
const AllreduceOptions& opts, const AllreduceOptions& opts,
bool sync_op, bool sync_op,
bool use_calc_stream) { bool use_calc_stream) {
StaticCheckTensorsSameShape(*out_tensor, in_tensor, rank_, size_);
return RunFnInNCCLEnv( return RunFnInNCCLEnv(
[&](ncclComm_t comm, gpuStream_t stream) { [&](ncclComm_t comm, gpuStream_t stream) {
NCCL_CHECK(platform::dynload::ncclAllReduce( NCCL_CHECK(platform::dynload::ncclAllReduce(
...@@ -207,6 +211,15 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::AllToAll( ...@@ -207,6 +211,15 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::AllToAll(
CheckSizeOnEachRank(out_dim, out_size_each_rank, size_); CheckSizeOnEachRank(out_dim, out_size_each_rank, size_);
CheckSizeOnEachRank(in_dim, in_size_each_rank, size_); CheckSizeOnEachRank(in_dim, in_size_each_rank, size_);
// NOTE: Since `all_to_all` needs other processes's participation, it cannot
// simply be covered by static checks. Factors are set to 0 here to skip the
// shape check. Its shape check will be done by dynamic checks in debug mode.
StaticCheckTensors(*out_tensor,
in_tensor,
rank_,
size_,
/*out_size_factor*/ 0,
/*in_size_factor*/ 0);
return RunFnInNCCLEnv( return RunFnInNCCLEnv(
[&](ncclComm_t comm, gpuStream_t stream) { [&](ncclComm_t comm, gpuStream_t stream) {
int64_t in_row_size = in_tensor.numel() / in_dim[0], int64_t in_row_size = in_tensor.numel() / in_dim[0],
...@@ -274,6 +287,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Broadcast( ...@@ -274,6 +287,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Broadcast(
const BroadcastOptions& opts, const BroadcastOptions& opts,
bool sync_op, bool sync_op,
bool use_calc_stream) { bool use_calc_stream) {
StaticCheckTensorsSameShape(*out_tensor, in_tensor, rank_, size_);
return RunFnInNCCLEnv( return RunFnInNCCLEnv(
[&](ncclComm_t comm, gpuStream_t stream) { [&](ncclComm_t comm, gpuStream_t stream) {
int root = opts.source_rank + opts.source_root; int root = opts.source_rank + opts.source_root;
...@@ -298,6 +312,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Reduce( ...@@ -298,6 +312,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Reduce(
const ReduceOptions& opts, const ReduceOptions& opts,
bool sync_op, bool sync_op,
bool use_calc_stream) { bool use_calc_stream) {
StaticCheckTensorsSameShape(*out_tensor, in_tensor, rank_, size_);
return RunFnInNCCLEnv( return RunFnInNCCLEnv(
[&](ncclComm_t comm, gpuStream_t stream) { [&](ncclComm_t comm, gpuStream_t stream) {
NCCL_CHECK(platform::dynload::ncclReduce( NCCL_CHECK(platform::dynload::ncclReduce(
...@@ -322,6 +337,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::ReduceScatter( ...@@ -322,6 +337,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::ReduceScatter(
const ReduceScatterOptions& opts, const ReduceScatterOptions& opts,
bool sync_op, bool sync_op,
bool use_calc_stream) { bool use_calc_stream) {
StaticCheckTensorsScatterLikeShape(*out_tensor, in_tensor, rank_, size_);
return RunFnInNCCLEnv( return RunFnInNCCLEnv(
[&](ncclComm_t comm, gpuStream_t stream) { [&](ncclComm_t comm, gpuStream_t stream) {
NCCL_CHECK(platform::dynload::ncclReduceScatter( NCCL_CHECK(platform::dynload::ncclReduceScatter(
...@@ -345,6 +361,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Scatter( ...@@ -345,6 +361,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Scatter(
const ScatterOptions& opts, const ScatterOptions& opts,
bool sync_op, bool sync_op,
bool use_calc_stream) { bool use_calc_stream) {
StaticCheckTensorsScatterLikeShape(*out_tensor, in_tensor, rank_, size_);
return RunFnInNCCLEnv( return RunFnInNCCLEnv(
[&](ncclComm_t comm, gpuStream_t stream) { [&](ncclComm_t comm, gpuStream_t stream) {
int64_t numel = in_tensor.numel() / size_; int64_t numel = in_tensor.numel() / size_;
...@@ -400,6 +417,8 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Recv( ...@@ -400,6 +417,8 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Recv(
partial_tensor = GetPartialTensor(*tensor, offset, numel); partial_tensor = GetPartialTensor(*tensor, offset, numel);
tensor = &partial_tensor; tensor = &partial_tensor;
} }
StaticCheckTensor(*tensor, rank_, size_);
return RunFnInNCCLEnv( return RunFnInNCCLEnv(
[&](ncclComm_t comm, gpuStream_t stream) { [&](ncclComm_t comm, gpuStream_t stream) {
NCCL_CHECK(platform::dynload::ncclRecv( NCCL_CHECK(platform::dynload::ncclRecv(
...@@ -426,6 +445,8 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Send( ...@@ -426,6 +445,8 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Send(
// numel > 0 indicates the tensor need to be sliced // numel > 0 indicates the tensor need to be sliced
const phi::DenseTensor& tensor_maybe_partial = const phi::DenseTensor& tensor_maybe_partial =
numel > 0 ? GetPartialTensor(tensor, offset, numel) : tensor; numel > 0 ? GetPartialTensor(tensor, offset, numel) : tensor;
StaticCheckTensor(tensor_maybe_partial, rank_, size_);
return RunFnInNCCLEnv( return RunFnInNCCLEnv(
[&](ncclComm_t comm, gpuStream_t stream) { [&](ncclComm_t comm, gpuStream_t stream) {
NCCL_CHECK(platform::dynload::ncclSend( NCCL_CHECK(platform::dynload::ncclSend(
......
...@@ -210,6 +210,8 @@ class ProcessGroupNCCL final : public ProcessGroupStream { ...@@ -210,6 +210,8 @@ class ProcessGroupNCCL final : public ProcessGroupStream {
void CreateNCCLEnvCache(const Place& place, const std::string& place_key); void CreateNCCLEnvCache(const Place& place, const std::string& place_key);
void SyncCalcStream(const Place& place);
std::shared_ptr<ProcessGroup::Task> RunFnInNCCLEnv( std::shared_ptr<ProcessGroup::Task> RunFnInNCCLEnv(
std::function<void(ncclComm_t, gpuStream_t)> fn, std::function<void(ncclComm_t, gpuStream_t)> fn,
const phi::DenseTensor& tensor, const phi::DenseTensor& tensor,
...@@ -217,8 +219,6 @@ class ProcessGroupNCCL final : public ProcessGroupStream { ...@@ -217,8 +219,6 @@ class ProcessGroupNCCL final : public ProcessGroupStream {
bool sync_op, bool sync_op,
bool use_calc_stream); bool use_calc_stream);
void SyncCalcStream(const Place& place);
// TODO(sunyilun): methods below will be removed later // TODO(sunyilun): methods below will be removed later
std::shared_ptr<ProcessGroupNCCL::NCCLTask> CreateTask( std::shared_ptr<ProcessGroupNCCL::NCCLTask> CreateTask(
std::vector<Place> places, std::vector<Place> places,
...@@ -245,6 +245,7 @@ class ProcessGroupNCCL final : public ProcessGroupStream { ...@@ -245,6 +245,7 @@ class ProcessGroupNCCL final : public ProcessGroupStream {
private: private:
std::shared_ptr<Store> store_; std::shared_ptr<Store> store_;
std::unordered_map<std::string, platform::DeviceEvent> std::unordered_map<std::string, platform::DeviceEvent>
place_to_calc_event_; // event on calc stream place_to_calc_event_; // event on calc stream
std::unordered_map<std::string, phi::GPUContext*> place_to_calc_ctx_; std::unordered_map<std::string, phi::GPUContext*> place_to_calc_ctx_;
......
...@@ -19,7 +19,7 @@ ...@@ -19,7 +19,7 @@
namespace paddle { namespace paddle {
namespace distributed { namespace distributed {
inline phi::DenseTensor GetPartialTensor(const phi::DenseTensor &tensor, inline phi::DenseTensor GetPartialTensor(const phi::DenseTensor& tensor,
int64_t offset, int64_t offset,
int64_t numel) { int64_t numel) {
phi::DenseTensor tensor_flattened; phi::DenseTensor tensor_flattened;
......
...@@ -17,32 +17,11 @@ import paddle.fluid.framework as framework ...@@ -17,32 +17,11 @@ import paddle.fluid.framework as framework
from paddle.distributed import collective from paddle.distributed import collective
def _check_tensor_shape(tensor, shape, nranks=1):
expect_shape = list(shape)
expect_shape[0] *= nranks
if list(tensor.shape) != expect_shape:
raise RuntimeError("The tensor for all_gather is not correctly-sized.")
def _check_tensor_list_shape(tensor_list, shape, nranks=1):
if len(tensor_list) != nranks:
raise RuntimeError(
"The tensor_list for all_gather is not correctly-sized."
)
for tensor in tensor_list:
if tensor.shape != shape:
raise RuntimeError(
"The tensor_list for all_gather is not correctly-sized."
)
def _all_gather_into_tensor_in_dygraph( def _all_gather_into_tensor_in_dygraph(
out_tensor, in_tensor, group, sync_op, use_calc_stream out_tensor, in_tensor, group, sync_op, use_calc_stream
): ):
group = collective._get_default_group() if group is None else group group = collective._get_default_group() if group is None else group
_check_tensor_shape(out_tensor, in_tensor.shape, group.nranks)
if use_calc_stream: if use_calc_stream:
return group.process_group.all_gather_into_tensor_on_calc_stream( return group.process_group.all_gather_into_tensor_on_calc_stream(
out_tensor, out_tensor,
...@@ -65,8 +44,6 @@ def _all_gather_in_dygraph( ...@@ -65,8 +44,6 @@ def _all_gather_in_dygraph(
if len(tensor_list) == 0: if len(tensor_list) == 0:
tensor_list += [paddle.empty_like(tensor) for _ in range(group.nranks)] tensor_list += [paddle.empty_like(tensor) for _ in range(group.nranks)]
else:
_check_tensor_list_shape(tensor_list, tensor.shape, group.nranks)
if use_calc_stream: if use_calc_stream:
return group.process_group.all_gather_on_calc_stream( return group.process_group.all_gather_on_calc_stream(
......
...@@ -23,29 +23,9 @@ from paddle.distributed.communication.group import ( ...@@ -23,29 +23,9 @@ from paddle.distributed.communication.group import (
) )
def _check_tensor_shape(tensor, shape, nranks=1):
if tensor.shape != shape:
raise RuntimeError('The tensor for alltoall is not correctly-sized.')
def _check_tensor_list_shape(tensor_list, shape, nranks=1):
if len(tensor_list) != nranks:
raise RuntimeError(
'The tensor_list for alltoall is not correctly-sized.'
)
for tensor in tensor_list:
if tensor.shape != shape:
raise RuntimeError(
'The tensor_list for alltoall is not correctly-sized.'
)
def _all_to_all_tensor_in_dygraph( def _all_to_all_tensor_in_dygraph(
out_tensor, in_tensor, group, sync_op, use_calc_stream out_tensor, in_tensor, group, sync_op, use_calc_stream
): ):
_check_tensor_shape(out_tensor, in_tensor.shape, group.nranks)
if use_calc_stream: if use_calc_stream:
return group.process_group.all_to_all_tensor_on_calc_stream( return group.process_group.all_to_all_tensor_on_calc_stream(
in_tensor, out_tensor in_tensor, out_tensor
...@@ -68,10 +48,6 @@ def _all_to_all_in_dygraph( ...@@ -68,10 +48,6 @@ def _all_to_all_in_dygraph(
out_tensor_list += [ out_tensor_list += [
paddle.empty_like(tensor) for tensor in in_tensor_list paddle.empty_like(tensor) for tensor in in_tensor_list
] ]
else:
_check_tensor_list_shape(
out_tensor_list, in_tensor_list[0].shape, group.nranks
)
if use_calc_stream: if use_calc_stream:
return group.process_group.all_to_all_on_calc_stream( return group.process_group.all_to_all_on_calc_stream(
......
...@@ -21,27 +21,6 @@ from paddle.distributed.communication.group import ( ...@@ -21,27 +21,6 @@ from paddle.distributed.communication.group import (
from paddle.distributed.communication.reduce import _get_reduce_op, ReduceOp from paddle.distributed.communication.reduce import _get_reduce_op, ReduceOp
def _check_tensor_shape(tensor, shape, nranks=1):
expect_shape = list(shape)
expect_shape[0] //= nranks
if list(tensor.shape) != expect_shape:
raise RuntimeError(
"The in_tensor for reduce_scatter is not correctly-sized."
)
def _check_tensor_list_shape(tensor_list, shape, nranks=1):
if len(tensor_list) != nranks:
raise RuntimeError(
"The tensor_list for reduce_scatter is not correctly-sized."
)
for tensor in tensor_list:
if tensor.shape != shape:
raise RuntimeError(
"The tensor_list for reduce_scatter is not correctly-sized."
)
def _reduce_scatter_tensor_in_dygraph( def _reduce_scatter_tensor_in_dygraph(
out_tensor, out_tensor,
in_tensor, in_tensor,
...@@ -53,8 +32,6 @@ def _reduce_scatter_tensor_in_dygraph( ...@@ -53,8 +32,6 @@ def _reduce_scatter_tensor_in_dygraph(
): ):
op_type = _get_reduce_op(op, caller) op_type = _get_reduce_op(op, caller)
_check_tensor_shape(out_tensor, in_tensor.shape, group.nranks)
if use_calc_stream: if use_calc_stream:
return group.process_group.reduce_scatter_tensor_on_calc_stream( return group.process_group.reduce_scatter_tensor_on_calc_stream(
out_tensor, in_tensor, op_type out_tensor, in_tensor, op_type
...@@ -74,8 +51,6 @@ def _reduce_scatter_in_dygraph( ...@@ -74,8 +51,6 @@ def _reduce_scatter_in_dygraph(
): ):
op_type = _get_reduce_op(op, "reduce_scatter") op_type = _get_reduce_op(op, "reduce_scatter")
_check_tensor_list_shape(tensor_list, tensor.shape, group.nranks)
if use_calc_stream: if use_calc_stream:
return group.process_group.reduce_scatter_on_calc_stream( return group.process_group.reduce_scatter_on_calc_stream(
tensor, tensor_list, op_type tensor, tensor_list, op_type
......
...@@ -25,31 +25,10 @@ from paddle.distributed.communication.group import ( ...@@ -25,31 +25,10 @@ from paddle.distributed.communication.group import (
) )
def _check_tensor_shape(tensor, shape, nranks=1):
expect_shape = list(shape)
expect_shape[0] //= nranks
if list(tensor.shape) != expect_shape:
raise RuntimeError("The in_tensor for scatter is not correctly-sized.")
def _check_tensor_list_shape(tensor_list, shape, nranks=1):
if len(tensor_list) != nranks:
raise RuntimeError(
"The tensor_list for scatter is not correctly-sized."
)
for tensor in tensor_list:
if tensor.shape != shape:
raise RuntimeError(
"The tensor_list for scatter is not correctly-sized."
)
def _scatter_tensor_in_dygraph( def _scatter_tensor_in_dygraph(
out_tensor, in_tensor, src_rank_in_group, group, sync_op, use_calc_stream out_tensor, in_tensor, src_rank_in_group, group, sync_op, use_calc_stream
): ):
nranks = group.nranks nranks = group.nranks
if group.rank == src_rank_in_group:
_check_tensor_shape(out_tensor, in_tensor.shape, nranks)
if use_calc_stream: if use_calc_stream:
return group.process_group.scatter_tensor_on_calc_stream( return group.process_group.scatter_tensor_on_calc_stream(
...@@ -74,7 +53,6 @@ def _scatter_in_dygraph( ...@@ -74,7 +53,6 @@ def _scatter_in_dygraph(
raise RuntimeError( raise RuntimeError(
"The tensor_list should not be empty on src rank." "The tensor_list should not be empty on src rank."
) )
_check_tensor_list_shape(tensor_list, tensor.shape, nranks)
else: else:
tensor_list = [tensor for _ in range(nranks)] tensor_list = [tensor for _ in range(nranks)]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册