未验证 提交 ecae7b31 编写于 作者: W Wen Sun 提交者: GitHub

Support both use_calc_stream and sync_op in allgather API (#46295)

上级 255890ff
......@@ -193,7 +193,16 @@ class ProcessGroup {
std::vector<phi::DenseTensor>&, // NOLINT
std::vector<phi::DenseTensor>&) { // NOLINT
PADDLE_THROW(platform::errors::InvalidArgument(
"ProcessGroup%s does not support AllGather", GetBackendName()));
"ProcessGroup%s does not support all_gather", GetBackendName()));
}
virtual std::shared_ptr<ProcessGroup::Task> AllGather(
std::vector<phi::DenseTensor>&, // NOLINT
std::vector<phi::DenseTensor>&, // NOLINT
bool) {
PADDLE_THROW(platform::errors::InvalidArgument(
"ProcessGroup%s does not support all_gather with sync_op flag",
GetBackendName()));
}
virtual std::shared_ptr<ProcessGroup::Task> AllGather_Partial(
......
......@@ -936,6 +936,39 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::AllGather(
CommType::ALLGATHER);
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::AllGather(
std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors,
bool sync_op,
bool use_calc_stream) {
PADDLE_ENFORCE_EQ(
CheckTensorsInCudaPlace(in_tensors),
true,
platform::errors::InvalidArgument("All inputs should be in CudaPlace."));
PADDLE_ENFORCE_EQ(
CheckTensorsInCudaPlace(out_tensors),
true,
platform::errors::InvalidArgument("All outputs should be in CudaPlace."));
return Collective(
in_tensors,
out_tensors,
[&](const phi::DenseTensor& input,
phi::DenseTensor& output,
ncclComm_t comm,
const gpuStream_t& stream) {
return platform::dynload::ncclAllGather(
input.data(),
output.data(),
input.numel(),
platform::ToNCCLDataType(input.dtype()),
comm,
stream);
},
CommType::ALLGATHER,
sync_op,
use_calc_stream);
}
void* GetPointerByOffset(void* raw_pointer,
size_t offset,
experimental::DataType type) {
......@@ -1250,6 +1283,14 @@ ncclComm_t ProcessGroupNCCL::NCCLComm(const Place& place) const {
phi::DeviceContext* ProcessGroupNCCL::GetDeviceContext(
const Place& place) const {
return GetDeviceContext(place, /*use_calc_stream*/ false);
}
phi::DeviceContext* ProcessGroupNCCL::GetDeviceContext(
const Place& place, bool use_calc_stream) const {
if (use_calc_stream) {
return platform::DeviceContextPool::Instance().Get(place);
} else {
std::vector<Place> places = {place};
const auto& iter = places_to_ctx_.find(GetKeyFromPlaces(places));
PADDLE_ENFORCE_NE(iter,
......@@ -1257,6 +1298,7 @@ phi::DeviceContext* ProcessGroupNCCL::GetDeviceContext(
platform::errors::InvalidArgument(
"Cannot find device context in process group."));
return iter->second[0].get();
}
}
} // namespace distributed
......
......@@ -98,6 +98,9 @@ class ProcessGroupNCCL : public ProcessGroupStream {
phi::DeviceContext* GetDeviceContext(const Place& place) const override;
phi::DeviceContext* GetDeviceContext(const Place& place,
bool use_calc_stream) const override;
std::shared_ptr<ProcessGroup::Task> AllReduce(
std::vector<phi::DenseTensor>& in_tensors, // NOLINT
std::vector<phi::DenseTensor>& out_tensors, // NOLINT
......@@ -167,6 +170,12 @@ class ProcessGroupNCCL : public ProcessGroupStream {
std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors) override;
std::shared_ptr<ProcessGroup::Task> AllGather(
std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors,
bool sync_op,
bool use_calc_stream) override;
std::shared_ptr<ProcessGroup::Task> AllGather_Partial(
std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors,
......
......@@ -23,6 +23,31 @@ ProcessGroupStream::ProcessGroupStream(int rank,
int gid)
: ProcessGroup(rank, size, place, gid) {}
phi::DeviceContext* ProcessGroupStream::GetDeviceContext(
const Place& place, bool use_calc_stream) const {
PADDLE_THROW(platform::errors::InvalidArgument(
"ProcessGroup%s does not support get device_context.", GetBackendName()));
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::AllGather(
std::vector<phi::DenseTensor>& input_tensors, // NOLINT
std::vector<phi::DenseTensor>& output_tensors, // NOLINT
bool sync_op) {
return AllGather(input_tensors,
output_tensors,
sync_op,
/*use_calc_stream*/ false);
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::AllGather(
std::vector<phi::DenseTensor>& input_tensors, // NOLINT
std::vector<phi::DenseTensor>& output_tensors, // NOLINT
bool sync_op,
bool use_calc_stream) {
PADDLE_THROW(platform::errors::InvalidArgument(
"ProcessGroup%s does not support do all_gather", GetBackendName()));
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::AllReduce(
std::vector<phi::DenseTensor>& input_tensors, // NOLINT
std::vector<phi::DenseTensor>& output_tensors, // NOLINT
......@@ -42,7 +67,7 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::AllReduce(
bool sync_op,
bool use_calc_stream) {
PADDLE_THROW(platform::errors::InvalidArgument(
"ProcessGroup%s does not support do allreduce", GetBackendName()));
"ProcessGroup%s does not support do all_reduce", GetBackendName()));
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupStream::Send(
......
......@@ -54,6 +54,20 @@ class ProcessGroupStream : public ProcessGroup {
ProcessGroupStream(int rank, int size, const platform::Place& place, int gid);
virtual ~ProcessGroupStream() = default;
virtual phi::DeviceContext* GetDeviceContext(const Place& place,
bool use_calc_stream) const;
std::shared_ptr<ProcessGroup::Task> AllGather(
std::vector<phi::DenseTensor>& in_tensors, // NOLINT
std::vector<phi::DenseTensor>& out_tensors, // NOLINT
bool sync_op) override;
virtual std::shared_ptr<ProcessGroup::Task> AllGather(
std::vector<phi::DenseTensor>& in_tensors, // NOLINT
std::vector<phi::DenseTensor>& out_tensors, // NOLINT
bool sync_op,
bool use_calc_stream);
std::shared_ptr<ProcessGroup::Task> AllReduce(
std::vector<phi::DenseTensor>& input_tensors, // NOLINT
std::vector<phi::DenseTensor>& output_tensors, // NOLINT
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/operators/math/concat_and_split.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/phi/api/include/tensor.h"
#include "paddle/phi/backends/device_manager.h"
namespace paddle {
namespace distributed {
template <typename DeviceContext, typename T>
struct SplitDenseTensor {
void operator()(const DeviceContext *context,
const phi::DenseTensor &in,
std::vector<phi::DenseTensor *> *out,
int axis = 0) {
std::vector<const phi::DenseTensor *> shape_refer;
shape_refer.reserve(out->size());
for (auto *p_tensor : *out) {
shape_refer.emplace_back(p_tensor);
}
operators::math::SplitFunctor<DeviceContext, T> split_functor_;
split_functor_(*context, in, shape_refer, axis, out);
}
};
#ifdef PADDLE_WITH_CUSTOM_DEVICE
template <typename T>
struct SplitDenseTensor<platform::CustomDeviceContext, T> {
void operator()(const platform::CustomDeviceContext *context,
const phi::DenseTensor &in,
std::vector<phi::DenseTensor *> *out) {
auto *in_data = in.data<T>();
auto *device = phi::DeviceManager::GetDeviceWithPlace(context->GetPlace());
size_t offset = 0;
for (auto *p_tensor : *out) {
auto *out_data = p_tensor->data<T>();
auto sz = p_tensor->numel() * sizeof(T);
device->MemoryCopyD2D(out_data, in_data + offset, sz, nullptr);
offset += sz;
}
}
};
#endif
template <typename DeviceContext>
void SplitDenseTensorWithType(const DeviceContext *dev_ctx,
const phi::DenseTensor &p_dense,
std::vector<phi::DenseTensor *> *p_list,
phi::DataType type) {
switch (type) {
case phi::DataType::BOOL:
SplitDenseTensor<DeviceContext, bool>()(dev_ctx, p_dense, p_list);
break;
case phi::DataType::UINT8:
SplitDenseTensor<DeviceContext, uint8_t>()(dev_ctx, p_dense, p_list);
break;
case phi::DataType::INT8:
SplitDenseTensor<DeviceContext, int8_t>()(dev_ctx, p_dense, p_list);
break;
case phi::DataType::INT32:
SplitDenseTensor<DeviceContext, int32_t>()(dev_ctx, p_dense, p_list);
break;
case phi::DataType::INT64:
SplitDenseTensor<DeviceContext, int64_t>()(dev_ctx, p_dense, p_list);
break;
case phi::DataType::FLOAT16:
SplitDenseTensor<DeviceContext, platform::float16>()(
dev_ctx, p_dense, p_list);
break;
case phi::DataType::FLOAT32:
SplitDenseTensor<DeviceContext, float>()(dev_ctx, p_dense, p_list);
break;
case phi::DataType::FLOAT64:
SplitDenseTensor<DeviceContext, double>()(dev_ctx, p_dense, p_list);
break;
default:
PADDLE_THROW(platform::errors::Unimplemented(
"Data type (%s) is not supported when it splits tensors for "
"allgather.",
type));
}
}
void SplitTensor(const phi::DeviceContext *dev_ctx,
const phi::DenseTensor &tensor,
const std::vector<experimental::Tensor> *tensor_list) {
std::vector<phi::DenseTensor *> dense_list;
for (auto &tensor : *tensor_list) {
auto p_tensor =
std::dynamic_pointer_cast<phi::DenseTensor>(tensor.impl()).get();
dense_list.emplace_back(p_tensor);
}
const auto &place = dev_ctx->GetPlace();
if (platform::is_gpu_place(place)) {
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
SplitDenseTensorWithType(static_cast<const phi::GPUContext *>(dev_ctx),
tensor,
&dense_list,
tensor.dtype());
#else
PADDLE_THROW(platform::errors::PermissionDenied(
"Paddle can't split tensor since it's not support NCCL/RCCL, please "
"recompile or reinstall Paddle with NCCL/RCCL support."));
#endif
} else if (platform::is_custom_place(place)) {
#ifdef PADDLE_WITH_CUSTOM_DEVICE
SplitDenseTensorWithType(
static_cast<const platform::CustomDeviceContext *>(dev_ctx),
tensor,
&dense_list,
tensor.dtype());
#else
PADDLE_THROW(platform::errors::PermissionDenied(
"Paddle can't split tensor since it's not compiled with CUSTOM_DEVICE, "
"please recompile or reinstall Paddle with CUSTOM_DEVICE support."));
#endif
} else if (platform::is_cpu_place(place)) {
SplitDenseTensorWithType(static_cast<const phi::CPUContext *>(dev_ctx),
tensor,
&dense_list,
tensor.dtype());
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"Split tensor not supported on place (%s)", place));
}
}
} // namespace distributed
} // namespace paddle
......@@ -24,6 +24,7 @@ limitations under the License. */
#include "paddle/fluid/distributed/collective/ProcessGroup.h"
#include "paddle/fluid/distributed/collective/ProcessGroupStream.h"
#include "paddle/fluid/distributed/collective/Types.h"
#include "paddle/fluid/distributed/collective/Utils.h"
#include "paddle/fluid/distributed/collective/reducer.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/tensor.h"
......@@ -358,6 +359,57 @@ void BindDistributed(py::module *m) {
py::arg("out"),
py::call_guard<py::gil_scoped_release>())
.def(
"allgather",
[](distributed::ProcessGroup &self,
py::handle py_in_tensor,
py::handle py_out_tensor_list,
bool sync_op) {
auto in_tensor = CastPyArg2Tensor(py_in_tensor.ptr(), 0);
auto in_dense = std::dynamic_pointer_cast<phi::DenseTensor>(
in_tensor.impl());
std::vector<phi::DenseTensor> in_wrapper = {*in_dense};
auto out_tensor_list =
CastPyArg2VectorOfTensor(py_out_tensor_list.ptr(), 0);
Tensor concat_out_tensor = paddle::concat(out_tensor_list, 0);
auto out_dense = std::dynamic_pointer_cast<phi::DenseTensor>(
concat_out_tensor.impl());
std::vector<phi::DenseTensor> out_wrapper = {*out_dense};
const auto *dev_ctx = self.GetDeviceContext(in_tensor.place());
auto task = self.AllGather(in_wrapper, out_wrapper, sync_op);
distributed::SplitTensor(dev_ctx, *out_dense, &out_tensor_list);
return task;
},
py::arg("in"),
py::arg("out"),
py::arg("sync_op"),
py::call_guard<py::gil_scoped_release>())
.def(
"allgather_base",
[](distributed::ProcessGroup &self,
py::handle py_in_tensor,
py::handle py_out_tensor,
bool sync_op) {
auto in_tensor = CastPyArg2Tensor(py_in_tensor.ptr(), 0);
auto in_dense = std::dynamic_pointer_cast<phi::DenseTensor>(
in_tensor.impl());
std::vector<phi::DenseTensor> in_wrapper = {*in_dense};
auto out_tensor = CastPyArg2Tensor(py_out_tensor.ptr(), 0);
auto out_dense = std::dynamic_pointer_cast<phi::DenseTensor>(
out_tensor.impl());
std::vector<phi::DenseTensor> out_wrapper = {*out_dense};
return self.AllGather(in_wrapper, out_wrapper, sync_op);
},
py::arg("in"),
py::arg("out"),
py::arg("sync_op"),
py::call_guard<py::gil_scoped_release>())
.def(
"all_gather_partial",
[](distributed::ProcessGroup &self,
......@@ -494,6 +546,60 @@ void BindDistributed(py::module *m) {
py::class_<distributed::ProcessGroupStream,
std::shared_ptr<distributed::ProcessGroupStream>>(
*m, "ProcessGroupStream", ProcessGroup)
.def(
"allgather_on_calc_stream",
[](distributed::ProcessGroupStream &self,
py::handle py_in_tensor,
py::handle py_out_tensor_list) {
auto in_tensor = CastPyArg2Tensor(py_in_tensor.ptr(), 0);
auto in_dense = std::dynamic_pointer_cast<phi::DenseTensor>(
in_tensor.impl());
std::vector<phi::DenseTensor> in_wrapper = {*in_dense};
auto out_tensor_list =
CastPyArg2VectorOfTensor(py_out_tensor_list.ptr(), 0);
Tensor concat_out_tensor = paddle::concat(out_tensor_list, 0);
auto out_dense = std::dynamic_pointer_cast<phi::DenseTensor>(
concat_out_tensor.impl());
std::vector<phi::DenseTensor> out_wrapper = {*out_dense};
const auto *dev_ctx =
self.GetDeviceContext(in_tensor.place(), true);
auto task = self.AllGather(in_wrapper,
out_wrapper,
/*sync_op*/ true,
/*use_calc_stream*/ true);
distributed::SplitTensor(dev_ctx, *out_dense, &out_tensor_list);
return task;
},
py::arg("in"),
py::arg("out"),
py::call_guard<py::gil_scoped_release>())
.def(
"allgather_base_on_calc_stream",
[](distributed::ProcessGroupStream &self,
py::handle py_in_tensor,
py::handle py_out_tensor) {
auto in_tensor = CastPyArg2Tensor(py_in_tensor.ptr(), 0);
auto in_dense = std::dynamic_pointer_cast<phi::DenseTensor>(
in_tensor.impl());
std::vector<phi::DenseTensor> in_wrapper = {*in_dense};
auto out_tensor = CastPyArg2Tensor(py_out_tensor.ptr(), 0);
auto out_dense = std::dynamic_pointer_cast<phi::DenseTensor>(
out_tensor.impl());
std::vector<phi::DenseTensor> out_wrapper = {*out_dense};
return self.AllGather(in_wrapper,
out_wrapper,
/*sync_op*/ true,
/*use_calc_stream*/ true);
},
py::arg("in"),
py::arg("out"),
py::call_guard<py::gil_scoped_release>())
.def(
"allreduce_on_calc_stream",
[](distributed::ProcessGroupStream &self,
......
......@@ -12,8 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from .all_gather import all_gather
from .all_reduce import all_reduce
from .send import send
from .recv import recv
__all__ = ["all_reduce", "send", "recv"]
__all__ = ["all_gather", "all_reduce", "send", "recv"]
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import paddle.fluid.framework as framework
from paddle.distributed import collective
def _check_tensor_shape(tensor, shape, nranks=1):
expect_shape = list(shape)
expect_shape[0] *= nranks
if list(tensor.shape) != expect_shape:
raise RuntimeError('The tensor for all_gather is not correctly-sized.')
def _check_tensor_list_shape(tensor_list, shape, nranks=1):
if len(tensor_list) != nranks:
raise RuntimeError(
'The tensor_list for all_gather is not correctly-sized.')
for tensor in tensor_list:
if tensor.shape != shape:
raise RuntimeError(
'The tensor_list for all_gather is not correctly-sized.')
def _all_gather_base_in_dygraph(out_tensor, in_tensor, group, sync_op,
use_calc_stream):
group = collective._get_default_group() if group is None else group
_check_tensor_shape(out_tensor, in_tensor.shape, group.nranks)
if use_calc_stream:
return group.process_group.allgather_base_on_calc_stream(
in_tensor, out_tensor)
task = group.process_group.allgather_base(in_tensor, out_tensor, sync_op)
if sync_op:
task.wait()
return task
def _all_gather_in_dygraph(tensor_list, tensor, group, sync_op,
use_calc_stream):
group = collective._get_default_group() if group is None else group
if len(tensor_list) == 0:
tensor_list += [paddle.empty_like(tensor) for _ in range(group.nranks)]
else:
_check_tensor_list_shape(tensor_list, tensor.shape, group.nranks)
if use_calc_stream:
return group.process_group.allgather_on_calc_stream(tensor, tensor_list)
task = group.process_group.allgather(tensor, tensor_list, sync_op)
if sync_op:
task.wait()
return task
def all_gather(tensor_or_tensor_list,
tensor,
group=None,
sync_op=True,
use_calc_stream=False):
"""
Gather tensors across devices to a correctly-sized tensor or a tensor list.
Args:
tensor_or_tensor_list (Union[Tensor, List[Tensor]]): The output. If it is a tensor, it should be correctly-sized. If it is a list, it
should be empty or contain correctly-sized tensors.
tensor (Tensor): The input tensor on each rank. The result will overwrite this tenor after communication. Support
float16, float32, float64, int32 or int64 as the input data type.
group (Group, optional): Communicate in which group. If none is given, use the global group as default.
sync_op (bool, optional): Indicate whether the communication is sync or not. If none is given, use true as default.
use_calc_stream (bool, optional): Indicate whether the communication is done on calculation stream. If none is given, use false as default. This
option is designed for high performance demand, be careful to turn it on except you are clearly know its meaning.
Returns:
Return a task object.
Warning:
This API only supports the dygraph mode now.
Examples:
.. code-block:: python
# required: distributed
import paddle
import paddle.distributed as dist
dist.init_parallel_env()
local_rank = dist.get_rank()
tensor_list = []
if local_rank == 0:
data = paddle.to_tensor([[4, 5, 6], [4, 5, 6]])
else:
data = paddle.to_tensor([[1, 2, 3], [1, 2, 3]])
task = dist.stream.all_gather(tensor_list, data, sync_op=False)
task.wait()
print(tensor_list)
# [[[4, 5, 6], [4, 5, 6]], [[1, 2, 3], [1, 2, 3]]] (2 GPUs)
"""
if group is not None and not group.is_member():
raise RuntimeError(
"The group should not be None and all ranks which invoke this operation should be the member of this group."
)
if not sync_op and use_calc_stream:
raise RuntimeError(
"use_calc_stream can only be true in sync op behavior.")
if framework.in_dygraph_mode():
if paddle.is_tensor(tensor_or_tensor_list):
return _all_gather_base_in_dygraph(tensor_or_tensor_list, tensor,
group, sync_op, use_calc_stream)
else:
return _all_gather_in_dygraph(tensor_or_tensor_list, tensor, group,
sync_op, use_calc_stream)
raise RuntimeError(
"paddle.distributed.stream.all_gather is only supported in dygraph mode now."
)
......@@ -266,6 +266,14 @@ if((WITH_GPU OR WITH_ROCM) AND (LINUX))
set_tests_properties(test_collective_wait PROPERTIES TIMEOUT "300" LABELS
"RUN_TYPE=DIST")
endif()
if((WITH_GPU OR WITH_ROCM) AND (LINUX))
py_test_modules(
test_communication_stream_allgather_api MODULES
test_communication_stream_allgather_api ENVS
"PYTHONPATH=..:${PADDLE_BINARY_DIR}/python;http_proxy=;https_proxy=")
set_tests_properties(test_communication_stream_allgather_api
PROPERTIES TIMEOUT "120" LABELS "RUN_TYPE=DIST")
endif()
if((WITH_GPU OR WITH_ROCM) AND (LINUX))
py_test_modules(
test_communication_stream_allreduce_api MODULES
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import numpy as np
import paddle
import paddle.fluid as fluid
import paddle.distributed as dist
import test_communication_api_base as test_base
import test_collective_api_base as test_collective_base
class StreamAllgatherTestCase():
def __init__(self):
self._sync_op = eval(os.getenv("sync_op"))
self._use_calc_stream = eval(os.getenv("use_calc_stream"))
self._backend = os.getenv("backend")
self._shape = eval(os.getenv("shape"))
self._dtype = os.getenv("dtype")
self._seeds = eval(os.getenv("seeds"))
if self._backend not in ["nccl", "gloo"]:
raise NotImplementedError(
"Only support nccl and gloo as the backend for now.")
os.environ["PADDLE_DISTRI_BACKEND"] = self._backend
def run_test_case(self):
dist.init_parallel_env()
test_data_list = []
for seed in self._seeds:
test_data_list.append(
test_collective_base.create_test_data(shape=self._shape,
dtype=self._dtype,
seed=seed))
rank = dist.get_rank()
tensor = paddle.to_tensor(test_data_list[rank])
# case 1: pass an empty tensor list
empty_tensor_list = []
task = dist.stream.all_gather(empty_tensor_list,
tensor,
sync_op=self._sync_op,
use_calc_stream=self._use_calc_stream)
if not self._sync_op:
task.wait()
assert np.allclose(empty_tensor_list,
test_data_list,
rtol=1e-05,
atol=1e-05)
# case 2: pass a pre-sized tensor list
full_tensor_list = [paddle.empty_like(tensor) for _ in test_data_list]
task = dist.stream.all_gather(full_tensor_list,
tensor,
sync_op=self._sync_op,
use_calc_stream=self._use_calc_stream)
if not self._sync_op:
task.wait()
assert np.allclose(full_tensor_list,
test_data_list,
rtol=1e-05,
atol=1e-05)
# case 3: pass a pre-sized tensor
result_tensor = paddle.concat(
[paddle.to_tensor(data) for data in test_data_list])
out_tensor = paddle.empty_like(result_tensor)
task = dist.stream.all_gather(out_tensor,
tensor,
sync_op=self._sync_op,
use_calc_stream=self._use_calc_stream)
if not self._sync_op:
task.wait()
assert np.allclose(out_tensor, result_tensor, rtol=1e-05, atol=1e-05)
if __name__ == "__main__":
StreamAllgatherTestCase().run_test_case()
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import paddle
import itertools
import test_communication_api_base as test_base
class TestCommunicationStreamAllgatherAPI(test_base.CommunicationTestDistBase):
def setUp(self):
super(TestCommunicationStreamAllgatherAPI, self).setUp(num_of_devices=2,
timeout=120)
self._default_envs = {
"backend": "nccl",
"shape": "(100, 200)",
"dtype": "float32",
"seeds": str(self._seeds)
}
self._changeable_envs = {
"sync_op": ["True", "False"],
"use_calc_stream": ["True", "False"]
}
def test_allgather_stream(self):
envs_list = test_base.gen_product_envs_list(self._default_envs,
self._changeable_envs)
for envs in envs_list:
if eval(envs["use_calc_stream"]) and not eval(envs["sync_op"]):
continue
self.run_test_case("communication_stream_allgather_api_dygraph.py",
user_defined_envs=envs)
def tearDown(self):
super(TestCommunicationStreamAllgatherAPI, self).tearDown()
if __name__ == '__main__':
unittest.main()
......@@ -32,6 +32,7 @@ test_collective_split_col_linear,linux,gpu;rocm,300,DIST,test_runner.py,2,,http_
test_collective_split_embedding_none_divisible,linux,gpu;rocm,300,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=..,
test_collective_split_row_linear,linux,gpu;rocm,300,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=..,
test_collective_wait,linux,gpu;rocm,300,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=..,
test_communication_stream_allgather_api,linux,gpu;rocm,120,DIST,,2,,PYTHONPATH=..;http_proxy=;https_proxy=,
test_communication_stream_allreduce_api,linux,gpu;rocm,120,DIST,,2,,PYTHONPATH=..;http_proxy=;https_proxy=,
test_communication_stream_sendrecv_api,linux,gpu;rocm,120,DIST,,2,,PYTHONPATH=..;http_proxy=;https_proxy=,
test_eager_dist_api,linux,gpu;rocm,120,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=..,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册