未验证 提交 77d24854 编写于 作者: zhenhailiu's avatar zhenhailiu 提交者: GitHub

gather with doc (#52105)

* gather with doc

* resolve comment

* polish

* polish

* code style

* polish doc

* add_test

* polish

* polish

* add test check

* add test check

* polish

* polish

* polish

* polish

* fix_time_out

* polish

* fix timeout

* fix_timeout

* polish

* polish

* polish

* polish

* polish
上级 20ee0d7f
......@@ -332,6 +332,30 @@ class ProcessGroup {
GetBackendName()));
}
virtual std::shared_ptr<ProcessGroup::Task> Gather(
phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
const GatherOptions& opts,
bool sync_op,
bool use_calc_stream) {
PADDLE_THROW(
phi::errors::Unimplemented("ProcessGroup%s does not support gather "
"with sync_op and use_calc_stream flag.",
GetBackendName()));
}
virtual std::shared_ptr<ProcessGroup::Task> Gather(
std::vector<phi::DenseTensor>* gather_tensors_ptr,
const phi::DenseTensor& in_tensor,
const GatherOptions& opts,
bool sync_op,
bool use_calc_stream) {
PADDLE_THROW(
phi::errors::Unimplemented("ProcessGroup%s does not support gather "
"with sync_op and use_calc_stream flag.",
GetBackendName()));
}
virtual std::shared_ptr<ProcessGroup::Task> Recv(phi::DenseTensor* tensor,
int src_rank,
bool sync_op,
......
......@@ -475,6 +475,71 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Scatter(
use_calc_stream);
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Gather(
phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
const GatherOptions& opts,
bool sync_op,
bool use_calc_stream) {
std::vector<phi::DenseTensor> partial_tensors;
if (rank_ == opts.root_rank) {
partial_tensors.reserve(size_);
size_t offset = 0;
size_t numel = out_tensor->numel() / size_;
for (auto i = 0; i < size_; i++) {
partial_tensors.push_back(GetPartialTensor(*out_tensor, offset, numel));
offset += numel;
}
}
return Gather(&partial_tensors, in_tensor, opts, sync_op, use_calc_stream);
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Gather(
std::vector<phi::DenseTensor>* gather_tensors_ptr,
const phi::DenseTensor& in_tensor,
const GatherOptions& opts,
bool sync_op,
bool use_calc_stream) {
auto& gather_tensors = *gather_tensors_ptr;
PADDLE_ENFORCE_GT(size_,
opts.root_rank,
phi::errors::InvalidArgument(
"root world size [%d] is less than root rank [%d]",
size_,
opts.root_rank));
auto gather_func = [&](ncclComm_t comm, gpuStream_t stream) {
// shape check
if (FLAGS_enable_nccl_dynamic_check) {
phi::distributed::NCCLDynamicCheck::CheckGatherShape(
in_tensor, gather_tensors, opts.root_rank, rank_, size_, comm);
}
GroupStart();
// root receive from all devices
if (rank_ == opts.root_rank) {
for (auto i = 0; i < size_; i++) {
auto& gather_tensor = gather_tensors[i];
NCCL_CHECK(
phi::dynload::ncclRecv(gather_tensor.data(),
gather_tensor.numel(),
phi::ToNCCLDataType(gather_tensor.dtype()),
i,
comm,
stream));
}
}
// send to root
NCCL_CHECK(phi::dynload::ncclSend(in_tensor.data(),
in_tensor.numel(),
phi::ToNCCLDataType(in_tensor.dtype()),
opts.root_rank,
comm,
stream));
GroupEnd();
};
return RunFnInNCCLEnv(
gather_func, in_tensor, CommType::GATHER, sync_op, use_calc_stream);
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Recv(
phi::DenseTensor* tensor,
int src_rank,
......
......@@ -136,6 +136,19 @@ class ProcessGroupNCCL final : public ProcessGroupWithStream {
bool sync_op,
bool use_calc_stream) override;
std::shared_ptr<ProcessGroup::Task> Gather(phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
const GatherOptions& opts,
bool sync_op,
bool use_calc_stream) override;
std::shared_ptr<ProcessGroup::Task> Gather(
std::vector<phi::DenseTensor>* gather_tensors_ptr,
const phi::DenseTensor& in_tensor,
const GatherOptions& opts,
bool sync_op,
bool use_calc_stream) override;
std::shared_ptr<ProcessGroup::Task> Recv(phi::DenseTensor* tensor,
int src_rank,
int64_t offset,
......
......@@ -48,6 +48,10 @@ struct ScatterOptions {
int root_rank = 0;
};
struct GatherOptions {
int root_rank = 0;
};
struct ReduceScatterOptions {
ReduceOp reduce_op = ReduceOp::SUM;
};
......
......@@ -113,6 +113,10 @@ void BindDistributed(py::module *m) {
.def_readwrite("reduce_op", &distributed::ReduceOptions::reduce_op)
.def_readwrite("source_root", &distributed::ReduceOptions::root_rank);
py::class_<distributed::GatherOptions>(*m, "GatherOptions")
.def(py::init<>())
.def_readwrite("root_rank", &distributed::GatherOptions::root_rank);
auto ProcessGroup =
py::class_<distributed::ProcessGroup,
std::shared_ptr<distributed::ProcessGroup>>(*m, "ProcessGroup")
......@@ -521,7 +525,44 @@ void BindDistributed(py::module *m) {
py::arg("src"),
py::arg("sync_op"),
py::call_guard<py::gil_scoped_release>())
.def(
"gather",
[](distributed::ProcessGroup &self,
py::handle py_in_tensor,
py::handle py_gather_tensor_list,
int dst,
bool sync_op,
bool use_calc_stream) {
auto out_tensor_list =
CastPyArg2VectorOfTensor(py_gather_tensor_list.ptr(), 0);
Tensor stack_out_tensor = paddle::stack(out_tensor_list, 0);
auto p_out_tensor = std::dynamic_pointer_cast<phi::DenseTensor>(
stack_out_tensor.impl());
auto *out_dense = p_out_tensor.get();
auto in_tensor = CastPyArg2Tensor(py_in_tensor.ptr(), 0);
auto p_in_tensor = std::dynamic_pointer_cast<phi::DenseTensor>(
in_tensor.impl());
auto in_dense = *p_in_tensor;
auto *dev_ctx =
self.GetDeviceContext(in_tensor.place(), use_calc_stream);
distributed::GatherOptions gather_ops{dst};
auto task = self.Gather(
out_dense, in_dense, gather_ops, sync_op, use_calc_stream);
SplitTensor(*dev_ctx, *out_dense, &out_tensor_list);
if (!use_calc_stream) {
// calculate stream will wait comm stream
task->UpdateWaitChain(*dev_ctx);
}
return task;
},
py::arg("in"),
py::arg("out"),
py::arg("dst"),
py::arg("sync_op"),
py::arg("use_calc_stream"),
py::call_guard<py::gil_scoped_release>())
.def(
"barrier",
[](distributed::ProcessGroup &self, int8_t device_id) {
......
......@@ -64,7 +64,7 @@ void NCCLDynamicCheck::CheckDataType(const phi::DenseTensor& tensor,
PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::ncclBroadcast(dtype_device,
dtype_device,
kSize,
1,
ncclInt64,
root_rank,
comm,
......@@ -106,7 +106,7 @@ void NCCLDynamicCheck::CheckShape(const phi::DenseTensor& tensor,
PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::ncclBroadcast(shape_device,
shape_device,
kSize,
1,
ncclInt64,
root_rank,
comm,
......@@ -141,10 +141,9 @@ void NCCLDynamicCheck::CheckShape(const phi::DenseTensor& out_tensor,
PADDLE_ENFORCE_GPU_SUCCESS(gpuMalloc(&in_shape_device, kSize));
PADDLE_ENFORCE_GPU_SUCCESS(gpuMemcpy(
in_shape_device, &in_shape_host, kSize, gpuMemcpyHostToDevice));
PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::ncclReduce(in_shape_device,
in_shape_device,
kSize,
1,
ncclInt64,
ncclSum,
rank,
......@@ -159,5 +158,42 @@ void NCCLDynamicCheck::CheckShape(const phi::DenseTensor& out_tensor,
PADDLE_ENFORCE_GPU_SUCCESS(gpuFree(in_shape_device));
}
}
void NCCLDynamicCheck::CheckGatherShape(
const phi::DenseTensor& in_tensor,
const std::vector<phi::DenseTensor>& out_tensors,
int root_rank,
int cur_rank,
int world_size,
ncclComm_t comm) {
std::vector<int64_t> shapes(world_size, 0);
shapes[cur_rank] = in_tensor.numel();
int64_t* in_shape_device;
PADDLE_ENFORCE_GPU_SUCCESS(
gpuMalloc(&in_shape_device, world_size * sizeof(int64_t)));
PADDLE_ENFORCE_GPU_SUCCESS(gpuMemcpy(in_shape_device,
shapes.data(),
world_size * sizeof(int64_t),
gpuMemcpyHostToDevice));
PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::ncclAllReduce(in_shape_device,
in_shape_device,
world_size,
ncclInt64,
ncclSum,
comm,
kDefaultStream));
PADDLE_ENFORCE_GPU_SUCCESS(gpuMemcpy(shapes.data(),
in_shape_device,
world_size * sizeof(int64_t),
gpuMemcpyDeviceToHost));
PADDLE_ENFORCE_GPU_SUCCESS(gpuFree(in_shape_device));
if (cur_rank == root_rank) {
for (int i = 0; i < world_size; i++) {
CheckShape(out_tensors[i], shapes[i]);
}
}
}
} // namespace distributed
} // namespace phi
......@@ -52,6 +52,14 @@ struct NCCLDynamicCheck {
int world_size,
ncclComm_t comm);
// can be used to check gather and all gather
static void CheckGatherShape(const phi::DenseTensor& in_tensor,
const std::vector<phi::DenseTensor>& out_tensors,
int root_rank,
int cur_rank,
int world_size,
ncclComm_t comm);
private:
// `0` represents default stream for both cuda & hip
static constexpr gpuStream_t kDefaultStream = 0;
......
......@@ -46,6 +46,7 @@ from .communication import (
reduce,
send,
scatter,
gather,
scatter_object_list,
isend,
recv,
......@@ -82,6 +83,7 @@ __all__ = [ # noqa
"spawn",
"launch",
"scatter",
"gather",
"scatter_object_list",
"broadcast",
"broadcast_object_list",
......
......@@ -18,6 +18,7 @@ from .reduce import reduce, ReduceOp
from .send import send, isend
from .recv import recv, irecv
from .scatter import scatter, scatter_object_list
from .gather import gather
from .batch_isend_irecv import batch_isend_irecv, P2POp
from .reduce_scatter import reduce_scatter
from .all_to_all import alltoall, alltoall_single
......
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from paddle import framework
from paddle.distributed.communication import stream
def gather(tensor, gather_list=None, dst=0, group=None, sync_op=True):
"""
Gather tensors from all participators.
Args:
tensor (Tensor): The input Tensor. Its data type
should be float16, float32, float64, int32, int64, int8, uint8, bool or bfloat16.
gather_list (list): A list of Tensors to hold the gathered tensors. Every element in the list must be a Tensor whose data type
should be float16, float32, float64, int32, int64, int8, uint8, bool or bfloat16. Default value is None.
dst (int): The dst rank id. Default value is 0.
group (Group, optional): The group instance return by new_group or None for global default group.
sync_op (bool, optional): Whether this op is a sync op. The default value is True.
Returns:
Async work handle,which can be wait on, if async_op is set to True.
None, if not async_op
Examples:
.. code-block:: python
# required: distributed
import paddle
import paddle.distributed as dist
dist.init_parallel_env()
gather_list = []
if dist.get_rank() == 0:
data = paddle.to_tensor([1, 2, 3])
dist.gather(data, gather_list, dst=0)
else:
data = paddle.to_tensor([4, 5, 6])
dist.gather(data1, gather_list, dst=0)
print(gather_list)
# [[1, 2, 3], [4, 5, 6]] (2 GPUs, out for rank 0)
# [] (2 GPUs, out for rank 1)
"""
assert (
framework.in_dygraph_mode()
), "gather doesn't support static graph mode yet."
return stream.gather(tensor, gather_list, dst, group, sync_op)
......@@ -21,6 +21,7 @@ from .reduce_scatter import reduce_scatter
from .recv import recv
from .scatter import scatter
from .send import send
from .gather import gather
__all__ = [
"all_gather",
......@@ -33,4 +34,5 @@ __all__ = [
"recv",
"scatter",
"send",
"gather",
]
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import warnings
import paddle
import paddle.distributed as dist
from paddle import framework
from paddle.distributed.communication.group import (
_get_global_group,
_get_or_throw_group_rank,
_warn_cur_rank_not_in_group,
)
def _gather_in_dygraph(
tensor, gather_list, dst_rank_in_group, group, sync_op, use_calc_stream
):
nranks = group.nranks
if group.rank == dst_rank_in_group:
if len(gather_list) == 0:
gather_list += [paddle.empty_like(tensor) for _ in range(nranks)]
else:
gather_list = [tensor for _ in range(nranks)]
assert (
len(gather_list) == nranks
), " gather_list length {} and nrankd {} not equal".format(
len(gather_list), nranks
)
task = group.process_group.gather(
tensor, gather_list, dst_rank_in_group, sync_op, use_calc_stream
)
if sync_op:
task.wait()
return task
def gather(
tensor,
gather_list=None,
dst=0,
group=None,
sync_op=True,
use_calc_stream=False,
):
"""
Gather tensors from all participators.
Args:
tensor (Tensor): The input Tensor. Its data type
should be float16, float32, float64, int32, int64, int8, uint8, bool or bfloat16.
gather_list (list): A list of Tensors to hold the gathered tensors. Every element in the list must be a Tensor whose data type
should be float16, float32, float64, int32, int64, int8, uint8, bool or bfloat16. Default value is None.
dst (int): The dst rank id. Default value is 0.
group (Group, optional): The group instance return by new_group or None for global default group.
sync_op (bool, optional): Whether this op is a sync op. The default value is True.
use_calc_stream (bool, optional): Indicate whether the communication is done on calculation stream. If none is given, use false as default. This
option is designed for high performance demand, be careful to turn it on except you are clearly know its meaning.
Returns:
Async work handle,which can be wait on, if async_op is set to True.
None, if not async_op
Examples:
.. code-block:: python
# required: distributed
import paddle
import paddle.distributed as dist
dist.init_parallel_env()
gather_list = []
if dist.get_rank() == 0:
data = paddle.to_tensor([1, 2, 3])
dist.stream.gather(data, gather_list, dst=0)
else:
data = paddle.to_tensor([4, 5, 6])
dist.stream.gather(data1, gather_list, dst=0)
print(gather_list)
# [[1, 2, 3], [4, 5, 6]] (2 GPUs, out for rank 0)
# [] (2 GPUs, out for rank 1)
"""
assert (
framework.in_dygraph_mode()
), "gather doesn't support static graph mode yet."
if _warn_cur_rank_not_in_group(group):
return
if not sync_op and use_calc_stream:
raise RuntimeError(
"use_calc_stream can only be true in sync op behavior."
)
# NOTE(liuzhenhai): Only the dst rank needs to specific the gather_list argument.
# Other ranks which pass this argument in will be ignored with a warning.
# The passed in type for non-dst rank is meaningless, for it will be ignored.
if dst != dist.get_rank():
if gather_list is not None:
warnings.warn(
"Specific `gather_list` is meaningless for rank which is not dst."
)
gather_list = []
else:
assert (
gather_list is not None
), "gather_list must not be none for dst rank"
group = _get_global_group() if group is None else group
dst_rank_in_group = _get_or_throw_group_rank(dst, group)
return _gather_in_dygraph(
tensor, gather_list, dst_rank_in_group, group, sync_op, use_calc_stream
)
......@@ -71,7 +71,7 @@ if((WITH_GPU OR WITH_ROCM) AND (LINUX))
test_collective_allreduce_api MODULES test_collective_allreduce_api ENVS
"http_proxy=;https_proxy=;PYTHONPATH=..:${PADDLE_BINARY_DIR}/python")
set_tests_properties(test_collective_allreduce_api
PROPERTIES TIMEOUT "250" LABELS "RUN_TYPE=DIST")
PROPERTIES TIMEOUT "300" LABELS "RUN_TYPE=DIST")
endif()
if((WITH_GPU OR WITH_ROCM) AND (LINUX))
py_test_modules(
......@@ -195,7 +195,7 @@ if((WITH_GPU OR WITH_ROCM) AND (LINUX))
test_collective_reduce_api MODULES test_collective_reduce_api ENVS
"http_proxy=;https_proxy=;PYTHONPATH=..:${PADDLE_BINARY_DIR}/python")
set_tests_properties(test_collective_reduce_api
PROPERTIES TIMEOUT "230" LABELS "RUN_TYPE=DIST")
PROPERTIES TIMEOUT "500" LABELS "RUN_TYPE=DIST")
endif()
if((WITH_GPU OR WITH_ROCM) AND (LINUX))
bash_test_modules(
......@@ -239,6 +239,13 @@ if((WITH_GPU OR WITH_ROCM) AND (LINUX))
set_tests_properties(test_collective_scatter_object_list_api
PROPERTIES TIMEOUT "120" LABELS "RUN_TYPE=DIST")
endif()
if((WITH_GPU OR WITH_ROCM) AND (LINUX))
py_test_modules(
test_collective_gather_api MODULES test_collective_gather_api ENVS
"http_proxy=;https_proxy=;PYTHONPATH=..:${PADDLE_BINARY_DIR}/python")
set_tests_properties(test_collective_gather_api
PROPERTIES TIMEOUT "180" LABELS "RUN_TYPE=DIST")
endif()
if((WITH_GPU OR WITH_ROCM) AND (LINUX))
py_test_modules(
test_collective_sendrecv MODULES test_collective_sendrecv ENVS
......@@ -251,7 +258,7 @@ if((WITH_GPU OR WITH_ROCM) AND (LINUX))
test_collective_sendrecv_api MODULES test_collective_sendrecv_api ENVS
"http_proxy=;https_proxy=;PYTHONPATH=..:${PADDLE_BINARY_DIR}/python")
set_tests_properties(test_collective_sendrecv_api
PROPERTIES TIMEOUT "120" LABELS "RUN_TYPE=DIST")
PROPERTIES TIMEOUT "300" LABELS "RUN_TYPE=DIST")
endif()
if((WITH_GPU OR WITH_ROCM) AND (LINUX))
py_test_modules(
......
......@@ -21,7 +21,7 @@
```bash
python3 ${PADDLE_ROOT}/tools/gen_ut_cmakelists.py -f ${PADDLE_ROOT}/python/paddle/fluid/tests/unittests/collective/testslist.csv
```
Then the cmd generates a file named CMakeLists.txt in the save directory with the testslist.csv.
Then the cmd generates a file named CMakeLists.txt in the same directory with the testslist.csv.
* usgae:
The command accepts --files/-f or --dirpaths/-d options, both of which accepts multiple values.
Option -f accepts a list of testslist.csv.
......
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import test_collective_api_base as test_base
import paddle
import paddle.distributed as dist
from paddle import fluid
class TestCollectiveGatherAPI(test_base.TestCollectiveAPIRunnerBase):
def __init__(self):
self.global_ring_id = 0
def get_model(self, main_prog, startup_program, rank, indata=None):
with fluid.program_guard(main_prog, startup_program):
gather_list = []
# NOTE: this is a hack relying on an undocumented behavior that `to_tensor` uses uint16 to replace bfloat16
if indata.dtype == "bfloat16":
tindata = paddle.to_tensor(indata, "float32").cast("uint16")
dist.gather(tindata, gather_list, dst=0)
return [e.cast("float32").numpy() for e in gather_list]
else:
tindata = paddle.to_tensor(indata)
dist.gather(tindata, gather_list, dst=0)
return [e.numpy() for e in gather_list]
if __name__ == "__main__":
test_base.runtime_main(TestCollectiveGatherAPI, "gather")
......@@ -2,8 +2,8 @@
# Please don't modify this file manually.
# If you need to change unittests in this file, please modify testslist.csv in the current directory
# and then run the command `python3 ${PADDLE_ROOT}/tools/gen_ut_cmakelists.py -f ${CURRENT_DIRECTORY}/testslist.csv`
set(LOCAL_ALL_PLAT ON)
set(LOCAL_ALL_ARCH ON)
set(LOCAL_ALL_PLAT ON)
if((WITH_GPU
OR WITH_XPU
OR WITH_ASCEND
......@@ -650,6 +650,18 @@ if((WITH_GPU
"PADDLE_DIST_UT_PORT=21270;http_proxy=;https_proxy=")
set_tests_properties(test_c_comm_init_op PROPERTIES TIMEOUT "120")
endif()
if((WITH_GPU) AND (LINUX))
bash_test_modules(
test_fused_attention_pass_with_mp
START_BASH
test_fused_attention_pass_with_mp.sh
LABELS
"RUN_TYPE=DIST"
ENVS
"PADDLE_DIST_UT_PORT=21400;http_proxy=;https_proxy=")
set_tests_properties(test_fused_attention_pass_with_mp PROPERTIES TIMEOUT
"120")
endif()
if(LOCAL_ALL_ARCH AND LOCAL_ALL_PLAT)
bash_test_modules(
test_ir_pass_pipeline
......@@ -882,15 +894,3 @@ if((WITH_GPU) AND (LINUX))
set_tests_properties(test_dygraph_save_for_auto_infer
PROPERTIES TIMEOUT "300" LABELS "RUN_TYPE=DIST")
endif()
if(WITH_GPU)
bash_test_modules(
test_fused_attention_pass_with_mp
START_BASH
test_fused_attention_pass_with_mp.sh
LABELS
"RUN_TYPE=DIST"
ENVS
"PADDLE_DIST_UT_PORT=21400;http_proxy=;https_proxy=")
set_tests_properties(test_fused_attention_pass_with_mp PROPERTIES TIMEOUT
"120")
endif()
......@@ -56,7 +56,7 @@ test_fleet_recompute_meta_optimizer,LINUX;WIN32,GPU;XPU;ASCEND;ASCEND_CL,,,test_
test_fleet_private_function,LINUX;WIN32,,,,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=../..,
test_new_group,,GPU;XPU;ASCEND;ASCEND_CL,,DIST,test_new_group.sh,2,,http_proxy=;https_proxy=,
test_c_comm_init_op,LINUX,GPU;XPU;ASCEND;ASCEND_CL,120,DIST,test_c_comm_init_op.sh,2,,http_proxy=;https_proxy=,
test_fused_attention_pass_with_mp,LINUX,GPU;;;,120,DIST,test_fused_attention_pass_with_mp.sh,2,,http_proxy=;https_proxy=,
test_fused_attention_pass_with_mp,LINUX,GPU,120,DIST,test_fused_attention_pass_with_mp.sh,2,,http_proxy=;https_proxy=,
test_ir_pass_pipeline,,,120,DIST,../../dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=../..,
test_parallel_dygraph_mnist,,GPU;ROCM,200,DIST,../../dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=../..,
test_parallel_dygraph_se_resnext,,GPU;ROCM,200,DIST,../../dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=../..,
......
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from test_collective_api_base import TestDistBase
import paddle
paddle.enable_static()
class TestCollectiveGatherAPI(TestDistBase):
def _setup_config(self):
pass
def test_gather_nccl_dygraph(self):
dtypes_to_test = [
"float16",
"float32",
"float64",
"int32",
"int64",
"int8",
"uint8",
"bool",
]
if self._nccl_version >= 2100:
dtypes_to_test.append("bfloat16")
for dtype in dtypes_to_test:
self.check_with_place(
"collective_gather_api_dygraph.py",
"gather",
"nccl",
static_mode="0",
dtype=dtype,
)
if __name__ == "__main__":
unittest.main()
......@@ -7,13 +7,13 @@ test_c_split,linux,gpu;rocm,120,DIST,test_runner.py,2,,PYTHONPATH=..;http_proxy=
test_collective_split_embedding,linux,rocm;gpu,300,DIST,../dist_test.sh,2,,PYTHONPATH=..;http_proxy=;https_proxy=,
test_collective_allgather_api,linux,gpu;rocm,300,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=..,
test_collective_allgather_object_api,linux,gpu;rocm,120,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=..,
test_collective_allreduce_api,linux,gpu;rocm,180,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=..,
test_collective_allreduce_api,linux,gpu;rocm,300,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=..,
test_collective_alltoall_api,linux,gpu;rocm,120,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=..,
test_collective_alltoall_single,linux,gpu;rocm,350,DIST,../dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=..,
test_collective_alltoall_single_api,linux,gpu;rocm,120,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=..,
test_collective_barrier_api,linux,gpu;rocm,300,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=..,
test_collective_batch_isend_irecv,linux,gpu;rocm,350,DIST,../dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=..,
test_collective_broadcast_api,linux,gpu;rocm,180,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=..,
test_collective_broadcast_api,linux,gpu;rocm,300,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=..,
test_collective_broadcast_object_list_api,linux,gpu;rocm,120,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=..,
test_collective_cpu_barrier_with_gloo,linux,gpu;rocm,300,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=..,
test_collective_global_gather,linux,gpu;rocm,200,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=..,
......@@ -22,14 +22,15 @@ test_collective_isend_irecv_api,linux,gpu;rocm,120,DIST,test_runner.py,2,,http_p
test_collective_optimizer,linux,gpu;rocm,300,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=..,
test_collective_process_group,linux,gpu;rocm,350,DIST,../dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=..,
test_collective_reduce,linux,gpu;rocm,300,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=..,
test_collective_reduce_api,linux,gpu;rocm,180,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=..,
test_collective_reduce_api,linux,gpu;rocm,500,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=..,
test_collective_reduce_scatter,linux,gpu;rocm,350,DIST,../dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=..,
test_collective_reduce_scatter_api,linux,gpu;rocm,120,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=..,
test_collective_scatter,linux,gpu;rocm,300,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=..,
test_collective_scatter_api,linux,gpu;rocm,180,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=..,
test_collective_scatter_object_list_api,linux,gpu;rocm,120,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=..,
test_collective_gather_api,linux,gpu;rocm,180,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=..,
test_collective_sendrecv,linux,gpu;rocm,300,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=..,
test_collective_sendrecv_api,linux,gpu;rocm,120,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=..,
test_collective_sendrecv_api,linux,gpu;rocm,300,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=..,
test_collective_split_col_linear,linux,gpu;rocm,300,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=..,
test_collective_split_embedding_none_divisible,linux,gpu;rocm,300,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=..,
test_collective_split_row_linear,linux,gpu;rocm,300,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=..,
......
......@@ -392,6 +392,14 @@ class TestDistBase(unittest.TestCase):
need_result2 = [need_result[len(need_result) // 2 :]]
self.assertEqual(need_result1, tr0_out)
self.assertEqual(need_result2, tr1_out)
elif col_type == "gather":
# rank 0 gather all tensor
self.assertEqual(len(tr0_out), 2)
# rank 1 get nothing
self.assertEqual(len(tr1_out), 0)
# check values
np.testing.assert_equal(input1, tr0_out[0])
np.testing.assert_equal(input2, tr0_out[1])
elif col_type == "reduce_scatter":
need_result = input1 + input2
need_result1 = need_result[0 : need_result.shape[0] // 2]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册