未验证 提交 37216a8f 编写于 作者: H Haohongxiang 提交者: GitHub

[Dygraph] Support new apis in ProcessGroupNCCL (#43918)

* fix conflict

* new pg apis

* add docs of new apis

* update

* fix coverage

* update

* fix bug

* fix reduce scatter

* fix api

* update
Co-authored-by: NForFishes <2282912238@qq.com>
上级 02e4f1f8
......@@ -46,6 +46,7 @@ enum class CommType : std::uint8_t {
SEND = 9,
RECV = 10,
BARRIER = 11,
ALLTOALL_SINGLE = 12,
UNKNOWN = 100,
};
......@@ -143,6 +144,15 @@ class ProcessGroup {
"ProcessGroup%s does not support AllToAll", GetBackendName()));
}
virtual std::shared_ptr<ProcessGroup::Task> AllToAll_Single(
std::vector<phi::DenseTensor>&, // NOLINT
std::vector<phi::DenseTensor>&, // NOLINT
std::vector<int64_t>&,
std::vector<int64_t>&) {
PADDLE_THROW(platform::errors::InvalidArgument(
"ProcessGroup%s does not support AllToAll_Single", GetBackendName()));
}
virtual std::shared_ptr<ProcessGroup::Task> Reduce(
std::vector<phi::DenseTensor>&, // NOLINT
std::vector<phi::DenseTensor>&, // NOLINT
......@@ -159,6 +169,14 @@ class ProcessGroup {
"ProcessGroup%s does not support Scatter", GetBackendName()));
}
virtual std::shared_ptr<ProcessGroup::Task> _ReduceScatterBase(
phi::DenseTensor&, // NOLINT
phi::DenseTensor&, // NOLINT
const ReduceScatterOptions&) { // NOLINT
PADDLE_THROW(platform::errors::InvalidArgument(
"ProcessGroup%s does not support ReduceScatter", GetBackendName()));
}
protected:
const int rank_;
const int size_;
......
......@@ -85,6 +85,34 @@ bool ProcessGroupNCCL::NCCLTask::IsCompleted() {
return true;
}
void ProcessGroupNCCL::CheckSplitSizes(std::vector<int64_t>& split_sizes,
std::vector<int64_t> tensor_shape) {
int64_t len_size = split_sizes.size();
if (len_size == 0) {
PADDLE_ENFORCE_EQ(tensor_shape[0] % size_ == 0,
true,
platform::errors::InvalidArgument(
"Tensor's dim[0] must be divisible by group size "
"when split_sizes not given."));
split_sizes.insert(split_sizes.end(),
size_,
static_cast<int64_t>(tensor_shape[0] / size_));
} else {
PADDLE_ENFORCE_EQ(
len_size == size_,
true,
platform::errors::InvalidArgument(
"The length of split_sizes must be equal to group size."));
auto sum_size = std::accumulate(
split_sizes.begin(), split_sizes.end(), static_cast<int64_t>(0));
PADDLE_ENFORCE_EQ(
sum_size == tensor_shape[0],
true,
platform::errors::InvalidArgument(
"The sum of split_sizes must be equal to tensor's dim[0]."));
}
}
// TODO(sheniang03): Add timeout for wait, now timeout unused
bool ProcessGroupNCCL::NCCLTask::Wait(std::chrono::milliseconds timeout) {
SynchronizeStreams();
......@@ -637,7 +665,69 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::AllToAll(
}
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupEnd());
},
CommType::ALLREDUCE);
CommType::ALLTOALL);
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::AllToAll_Single(
std::vector<phi::DenseTensor>& in_tensors,
std::vector<phi::DenseTensor>& out_tensors,
std::vector<int64_t>& in_sizes,
std::vector<int64_t>& out_sizes) {
PADDLE_ENFORCE_EQ(
CheckTensorsInCudaPlace(in_tensors),
true,
platform::errors::InvalidArgument("All inputs should be in CudaPlace."));
PADDLE_ENFORCE_EQ(
CheckTensorsInCudaPlace(out_tensors),
true,
platform::errors::InvalidArgument("All inputs should be in CudaPlace."));
return Collective(
in_tensors,
out_tensors,
[&](phi::DenseTensor& input,
phi::DenseTensor& output,
ncclComm_t comm,
const gpuStream_t& stream) {
PADDLE_ENFORCE_EQ(input.dtype() == output.dtype(),
true,
platform::errors::InvalidArgument(
"The dtypes of input and output must be equal."));
std::vector<int64_t> in_dims = phi::vectorize(input.dims());
std::vector<int64_t> out_dims = phi::vectorize(output.dims());
CheckSplitSizes(in_sizes, in_dims);
CheckSplitSizes(out_sizes, out_dims);
size_t in_offset = 0, out_offset = 0;
size_t in_length = 0, out_length = 0;
size_t in_row_size = input.numel() / in_dims[0];
size_t out_row_size = output.numel() / out_dims[0];
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupStart());
for (auto i = 0; i < size_; i++) {
in_length = in_sizes[i] * in_row_size;
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclSend(
GetPointerByOffset(input.data(), in_offset, input.dtype()),
in_length,
platform::ToNCCLDataType(input.dtype()),
i,
comm,
stream));
in_offset += in_length;
out_length = out_sizes[i] * out_row_size;
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclRecv(
GetPointerByOffset(output.data(), out_offset, input.dtype()),
out_length,
platform::ToNCCLDataType(input.dtype()),
i,
comm,
stream));
out_offset += out_length;
}
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupEnd());
},
CommType::ALLTOALL_SINGLE);
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Reduce(
......@@ -721,5 +811,57 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::Scatter(
CommType::SCATTER);
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupNCCL::_ReduceScatterBase(
phi::DenseTensor& out_tensor,
phi::DenseTensor& in_tensor,
const ReduceScatterOptions& opts) {
// auto tensor = out_tensors.back();
PADDLE_ENFORCE_EQ(
out_tensor.dtype(),
in_tensor.dtype(),
platform::errors::InvalidArgument(
"Input tensor and output tensor should be same dtype."));
PADDLE_ENFORCE_EQ(
out_tensor.numel() * size_,
in_tensor.numel(),
platform::errors::InvalidArgument("input tensor must be the same size as "
"output tensor size times world_size"));
auto inputs = std::vector<phi::DenseTensor>{in_tensor};
auto outputs = std::vector<phi::DenseTensor>{out_tensor};
return Collective(
inputs,
outputs,
[&](phi::DenseTensor& input,
phi::DenseTensor& output,
ncclComm_t comm,
const gpuStream_t& stream) {
if (FLAGS_use_stream_safe_cuda_allocator) {
platform::CUDADeviceGuard cuda_guard;
cuda_guard.SetDevice(output.place());
memory::RecordStream(output.Holder(), stream);
}
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclReduceScatter(
input.data(),
output.data(),
output.numel(),
platform::ToNCCLDataType(input.dtype()),
ToNCCLRedType(opts.reduce_op),
comm,
stream));
},
CommType::REDUCE_SCATTER);
}
void ProcessGroupNCCL::GroupStart() {
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupStart());
}
void ProcessGroupNCCL::GroupEnd() {
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclGroupEnd());
}
} // namespace distributed
} // namespace paddle
......@@ -129,6 +129,12 @@ class ProcessGroupNCCL : public ProcessGroup {
std::vector<phi::DenseTensor>& in,
std::vector<phi::DenseTensor>& out) override;
std::shared_ptr<ProcessGroup::Task> AllToAll_Single(
std::vector<phi::DenseTensor>& in,
std::vector<phi::DenseTensor>& out,
std::vector<int64_t>& in_sizes,
std::vector<int64_t>& out_sizes) override;
std::shared_ptr<ProcessGroup::Task> Reduce(
std::vector<phi::DenseTensor>& tensors,
std::vector<phi::DenseTensor>& out_tensors,
......@@ -139,6 +145,15 @@ class ProcessGroupNCCL : public ProcessGroup {
std::vector<phi::DenseTensor>& out_tensors,
const ScatterOptions&) override;
std::shared_ptr<ProcessGroup::Task> _ReduceScatterBase(
phi::DenseTensor&, // NOLINT
phi::DenseTensor&, // NOLINT
const ReduceScatterOptions&) override;
static void GroupStart();
static void GroupEnd();
protected:
virtual std::shared_ptr<ProcessGroupNCCL::NCCLTask> CreateTask(
std::vector<Place> places,
......@@ -162,7 +177,7 @@ class ProcessGroupNCCL : public ProcessGroup {
std::set<int> used_place_ids_;
private:
void BcastNCCLId(std::vector<ncclUniqueId>& nccl_ids,
void BcastNCCLId(std::vector<ncclUniqueId>& nccl_ids, // NOLINT
int root, // NOLINT
int server_fd);
......@@ -190,6 +205,9 @@ class ProcessGroupNCCL : public ProcessGroup {
void CreateNCCLManagerCache(const std::string& places_key,
const std::vector<Place>& places);
void CheckSplitSizes(std::vector<int64_t>& split_sizes,
std::vector<int64_t> tensor_shape);
};
} // namespace distributed
......
......@@ -45,5 +45,9 @@ struct ScatterOptions {
int root_rank = 0;
};
struct ReduceScatterOptions {
ReduceOp reduce_op = ReduceOp::SUM;
};
} // namespace distributed
} // namespace paddle
......@@ -225,6 +225,30 @@ void BindDistributed(py::module *m) {
py::arg("out"),
py::call_guard<py::gil_scoped_release>())
.def(
"alltoall_single",
[](distributed::ProcessGroup &self,
py::handle py_in_tensor,
py::handle py_out_tensor,
std::vector<int64_t> in_sizes,
std::vector<int64_t> out_sizes) {
auto in_tensor = CastPyArg2Tensor(py_in_tensor.ptr(), 0);
auto out_tensor = CastPyArg2Tensor(py_out_tensor.ptr(), 0);
auto in_dense = std::dynamic_pointer_cast<phi::DenseTensor>(
in_tensor.impl());
auto out_dense = std::dynamic_pointer_cast<phi::DenseTensor>(
out_tensor.impl());
std::vector<phi::DenseTensor> in_tensors = {*in_dense};
std::vector<phi::DenseTensor> out_tensors = {*out_dense};
return self.AllToAll_Single(
in_tensors, out_tensors, in_sizes, out_sizes);
},
py::arg("in"),
py::arg("out"),
py::arg("in_sizes"),
py::arg("out_sizes"),
py::call_guard<py::gil_scoped_release>())
.def(
"reduce",
[](distributed::ProcessGroup &self,
......@@ -244,7 +268,6 @@ void BindDistributed(py::module *m) {
py::arg("dst"),
py::arg("op") = distributed::ReduceOp::SUM,
py::call_guard<py::gil_scoped_release>())
.def(
"scatter",
[](distributed::ProcessGroup &self,
......@@ -266,9 +289,30 @@ void BindDistributed(py::module *m) {
py::arg("in"),
py::arg("out"),
py::arg("src"),
py::call_guard<py::gil_scoped_release>())
.def(
"_reduce_scatter_base",
[](distributed::ProcessGroup &self,
py::handle py_out_tensor,
py::handle py_in_tensor,
distributed::ReduceOp op) {
auto in_tensor = CastPyArg2Tensor(py_in_tensor.ptr(), 0);
auto out_tensor = CastPyArg2Tensor(py_out_tensor.ptr(), 0);
distributed::ReduceScatterOptions opts;
opts.reduce_op = op;
auto dense_out = std::dynamic_pointer_cast<phi::DenseTensor>(
out_tensor.impl());
auto dense_in = std::dynamic_pointer_cast<phi::DenseTensor>(
in_tensor.impl());
return self._ReduceScatterBase(*dense_out, *dense_in, opts);
},
py::arg("out_tensor"),
py::arg("in_tensor"),
py::arg("op") = distributed::ReduceOp::SUM,
py::call_guard<py::gil_scoped_release>());
#if defined(PADDLE_WITH_RCCL) || defined(PADDLE_WITH_NCCL)
auto processGroupNCCL =
py::class_<distributed::ProcessGroupNCCL,
std::shared_ptr<distributed::ProcessGroupNCCL>>(
*m, "ProcessGroupNCCL", ProcessGroup)
......@@ -283,6 +327,12 @@ void BindDistributed(py::module *m) {
py::arg("place"),
py::arg("group_id") = 0,
py::call_guard<py::gil_scoped_release>());
processGroupNCCL.def_static(
"group_start", []() { distributed::ProcessGroupNCCL::GroupStart(); });
processGroupNCCL.def_static(
"group_end", []() { distributed::ProcessGroupNCCL::GroupEnd(); });
#endif
#if defined(PADDLE_WITH_GLOO) && defined(PADDLE_WITH_PSCORE) && \
......
......@@ -41,6 +41,14 @@ from .collective import recv # noqa: F401
from .collective import get_group # noqa: F401
from .collective import send # noqa: F401
from .collective import wait # noqa: F401
from .collective import is_initialized # noqa: F401
from .collective import destroy_process_group # noqa: F401
from .collective import alltoall_single # noqa: F401
from .collective import isend # noqa: F401
from .collective import irecv # noqa: F401
from .collective import batch_isend_irecv # noqa: F401
from .collective import P2POp # noqa: F401
from .collective import reduce_scatter # noqa: F401
from .auto_parallel import shard_op # noqa: F401
from .auto_parallel import shard_tensor # noqa: F401
......@@ -59,33 +67,11 @@ from . import utils # noqa: F401
from .sharding import * # noqa: F401
__all__ = [ # noqa
"spawn",
"launch",
"scatter",
"broadcast",
"ParallelEnv",
"new_group",
"init_parallel_env",
"gloo_init_parallel_env",
"gloo_barrier",
"gloo_release",
"QueueDataset",
"split",
"CountFilterEntry",
"ShowClickEntry",
"get_world_size",
"get_group",
"all_gather",
"InMemoryDataset",
"barrier",
"all_reduce",
"alltoall",
"send",
"reduce",
"recv",
"ReduceOp",
"wait",
"get_rank",
"ProbabilityEntry",
"ParallelMode",
"spawn", "launch", "scatter", "broadcast", "ParallelEnv", "new_group",
"init_parallel_env", "gloo_init_parallel_env", "gloo_barrier",
"gloo_release", "QueueDataset", "split", "CountFilterEntry",
"ShowClickEntry", "get_world_size", "get_group", "all_gather",
"InMemoryDataset", "barrier", "all_reduce", "alltoall", "send", "reduce",
"recv", "ReduceOp", "wait", "get_rank", "ProbabilityEntry", "ParallelMode",
"is_initialized", "isend", "irecv", "reduce_scatter"
]
......@@ -36,6 +36,7 @@ import paddle.fluid as fluid
import paddle.fluid.core as core
from paddle import _C_ops
import paddle.fluid.dygraph_utils as dygraph_utils
import contextlib
__all__ = []
......@@ -136,6 +137,10 @@ _group_map = {}
# Dict[name, Group]
_group_map_by_name = {}
# backend map by group : the map of all backend from their groups
# Dict[group, backend]
_group_map_backend = {}
# Name of the default group for init_parallel_env
_default_group_name = "_default_pg"
......@@ -175,8 +180,7 @@ def _get_group_map_by_name():
def _get_default_group():
global _group_map_by_name
assert _default_group_name in _group_map_by_name, (
"Call paddle.distributed.init_parallel_env first "
assert is_initialized(), ("Call paddle.distributed.init_parallel_env first "
"to initialize the distributed environment.")
return _get_group_map_by_name()[_default_group_name]
......@@ -193,10 +197,29 @@ def _set_group_map_by_name(name, group):
_group_map_by_name[name] = group
def _set_group_map_backend(group, backend):
global _group_map_backend
assert group not in _group_map_backend
_group_map_backend[group] = backend
def _new_ring_id():
return len(_get_group_map()) + max(_get_global_env().nrings, 9)
def _get_reduce_op(reduce_op, func_name):
if reduce_op == ReduceOp.SUM:
return core.ReduceOp.SUM
elif reduce_op == ReduceOp.MAX:
return core.ReduceOp.MAX
elif reduce_op == ReduceOp.MIN:
return core.ReduceOp.MIN
elif reduce_op == ReduceOp.PROD:
return core.ReduceOp.PRODUCT
else:
raise ValueError("Unknown reduce_op type for {}.".format(func_name))
def get_group(id=0):
"""
......@@ -400,6 +423,7 @@ def new_group(ranks=None, backend=None):
group = Group(rank, size, id=gid, ranks=ranks, pg=pg, name=group_name)
_group_map_by_name[group_name] = group
_group_map[gid] = group
_group_map_backend[group] = backend
# TODO(shenliang03): This is a temporary solution to solve the problem of
# hang caused by tcp
......@@ -462,6 +486,75 @@ def new_group(ranks=None, backend=None):
return gp
def is_initialized():
"""
Check whether the distributed environment has been initialized
Returns (bool): `True` if distributed environment has been initialized, otherwise `False`.
Examples:
.. code-block:: python
# required: distributed
import paddle
print(paddle.distributed.is_initialized())
# False
paddle.distributed.init_parallel_env()
print(paddle.distributed.is_initialized())
# True
"""
global _group_map_by_name
return _default_group_name in _group_map_by_name
def destroy_process_group(group=None):
"""
Destroy a given group for communication
Args:
group (ProcessGroup, optional): The group to be destroyed. All of process groups, including
the default group, will be destroyed and the distributed
environment will be deinitialized.
Returns : None
Examples:
.. code-block:: python
# required: distributed
import paddle
paddle.distributed.init_parallel_env()
group = paddle.distributed.new_group([0, 1])
paddle.distributed.destroy_process_group(group)
print(paddle.distributed.is_initialized())
# True
paddle.distributed.destroy_process_group()
print(paddle.distributed.is_initialized())
# False
"""
global _group_map
global _group_map_by_name
pg = _get_default_group() if group is None else group
assert _group_map.get(pg.id, None) is not None, "Invalid group."
if group is None:
_group_map.clear()
_group_map_by_name.clear()
_group_map_backend.clear()
else:
del _group_map[pg.id]
del _group_map_by_name[pg.name]
del _group_map_backend[pg]
def wait(tensor, group=None, use_calc_stream=True):
"""
......@@ -663,16 +756,7 @@ def all_reduce(tensor, op=ReduceOp.SUM, group=None, use_calc_stream=True):
return
if in_dygraph_mode():
if op == ReduceOp.SUM:
op_type = core.ReduceOp.SUM
elif op == ReduceOp.MAX:
op_type = core.ReduceOp.MAX
elif op == ReduceOp.MIN:
op_type = core.ReduceOp.MIN
elif op == ReduceOp.PROD:
op_type = core.ReduceOp.PRODUCT
else:
raise ValueError("Unknown reduce_op type for allreduce.")
op_type = _get_reduce_op(op, "all_reduce")
group = _get_default_group() if group is None else group
task = group.process_group.allreduce(tensor, op_type)
if use_calc_stream:
......@@ -768,16 +852,7 @@ def reduce(tensor, dst, op=ReduceOp.SUM, group=None, use_calc_stream=True):
return
if in_dygraph_mode():
if op == ReduceOp.SUM:
op_type = core.ReduceOp.SUM
elif op == ReduceOp.MAX:
op_type = core.ReduceOp.MAX
elif op == ReduceOp.MIN:
op_type = core.ReduceOp.MIN
elif op == ReduceOp.PROD:
op_type = core.ReduceOp.PRODUCT
else:
raise ValueError("Unknown reduce_op type for reduce.")
op_type = _get_reduce_op(op, "reduce")
group = _get_default_group() if group is None else group
gdst = group.get_group_rank(dst)
assert gdst >= 0, ("dst rank out of group, need global rank")
......@@ -1781,10 +1856,10 @@ def alltoall(in_tensor_list, out_tensor_list, group=None, use_calc_stream=True):
Args:
in_tensor_list (list): A list of input Tensors. Every element in the list must be a Tensor whose data type
should be float16, float32, float64, int32 or int64.
out_tensor_list (Tensor): A list of output Tensors. The data type of its elements should be the same as the
out_tensor_list (list): A list of output Tensors. The data type of its elements should be the same as the
data type of the input Tensors.
group (Group, optional): The group instance return by new_group or None for global default group. Default: None.
use_calc_stream (bool, optional): Wether to use calculation stream (True) or communication stream. Default: True.
use_calc_stream (bool, optional): Whether to use calculation stream (True) or communication stream. Default: True.
Returns:
None.
......@@ -1867,6 +1942,94 @@ def alltoall(in_tensor_list, out_tensor_list, group=None, use_calc_stream=True):
out_tensor_list.extend(paddle.split(out, nranks, 0))
def alltoall_single(in_tensor,
out_tensor,
in_split_sizes=None,
out_split_sizes=None,
group=None,
use_calc_stream=True):
"""
Scatter a single input tensor to all participators and gather the received tensors in out_tensor.
.. note::
``alltoall_single`` is only supported in eager mode.
Args:
in_tensor (Tensor): Input tensor. The data type should be float16, float32, float64, int32 or int64.
out_tensor (Tensor): Output Tensor. The data type should be the same as the data type of the input Tensor.
in_split_sizes (list[int], optional): Split sizes of ``in_tensor`` for dim[0]. If not given, dim[0] of ``in_tensor``
must be divisible by group size and ``in_tensor`` will be scattered averagely to all participators. Default: None.
out_split_sizes (list[int], optional): Split sizes of ``out_tensor`` for dim[0]. If not given, dim[0] of ``out_tensor``
must be divisible by group size and ``out_tensor`` will be gathered averagely from all participators. Default: None.
group (Group, optional): The group instance return by ``new_group`` or None for global default group. Default: None.
use_calc_stream (bool, optional): Whether to use calculation stream (True) or communication stream. Default: True.
Returns:
None, if ``use_calc_stream`` is set to ``True``; ``Task`` of ``group``, if ``use_calc_stream`` is set to ``False``.
Examples:
.. code-block:: python
# required: distributed
import paddle
import paddle.distributed as dist
dist.init_parallel_env()
rank = dist.get_rank()
size = dist.get_world_size()
# case 1
input = paddle.arange(2, dtype='int64') + rank * 2
# input for rank 0: [0, 1]
# input for rank 1: [2, 3]
output = paddle.empty([2], dtype='int64')
dist.alltoall_single(input, output)
# output for rank 0: [0, 2]
# output for rank 1: [1, 3]
# case 2
in_split_sizes = [i + 1 for i in range(size)]
# in_split_sizes for rank 0: [1, 2] and for rank 1: [1, 2]
out_split_sizes = [rank + 1 for i in range(size)]
# out_split_sizes for rank 0: [1, 1] and for rank 1: [2, 2]
input = paddle.ones([sum(in_split_sizes), size], dtype='float32') * rank
# input for rank 0: [[0., 0.], [0., 0.], [0., 0.]]
# input for rank 1: [[1., 1.], [1., 1.], [1., 1.]]
output = paddle.empty([(rank + 1) * size, size], dtype='float32')
group = dist.new_group([0, 1])
task = dist.alltoall_single(input,
output,
in_split_sizes,
out_split_sizes,
use_calc_stream=False,
group=group)
task.wait()
# output for rank 0: [[0., 0.], [1., 1.]]
# output for rank 1: [[0., 0.], [0., 0.], [1., 1.], [1., 1.]]
"""
if group is not None and not group.is_member():
return
assert in_dygraph_mode(), "Only suppport alltoall_single in eager mode."
# _check_single_tensor
group = _get_default_group() if group is None else group
in_split_sizes = [] if in_split_sizes is None else in_split_sizes
out_split_sizes = [] if out_split_sizes is None else out_split_sizes
task = group.process_group.alltoall_single(in_tensor, out_tensor,
in_split_sizes, out_split_sizes)
if use_calc_stream:
task.wait()
return
else:
return task
def send(tensor, dst=0, group=None, use_calc_stream=True):
"""
Send a tensor to the receiver.
......@@ -1902,7 +2065,8 @@ def send(tensor, dst=0, group=None, use_calc_stream=True):
if in_dygraph_mode():
group = _get_default_group() if group is None else group
task = group.process_group.send(tensor, dst)
group_dst_rank = group.get_group_rank(dst)
task = group.process_group.send(tensor, group_dst_rank)
if use_calc_stream:
task.wait()
return None
......@@ -1964,7 +2128,8 @@ def recv(tensor, src=0, group=None, use_calc_stream=True):
if in_dygraph_mode():
group = _get_default_group() if group is None else group
task = group.process_group.recv(tensor, src)
group_src_rank = group.get_group_rank(src)
task = group.process_group.recv(tensor, group_src_rank)
if use_calc_stream:
task.wait()
return None
......@@ -1991,3 +2156,390 @@ def recv(tensor, src=0, group=None, use_calc_stream=True):
'dtype': tensor.dtype,
'use_calc_stream': use_calc_stream,
})
def _check_single_tensor(tensor, tensor_name):
if not isinstance(tensor, (core.eager.Tensor, paddle.Tensor)):
raise RuntimeError("Invalid function argument. Expected parameter {}"
"to be of type paddle.Tensor, but it's {}".format(
tensor_name, type(tensor)))
def _check_tensor_list(tensor_list, tensor_name):
if not isinstance(tensor_list, list) or \
not all(isinstance(t, (core.eager.Tensor, paddle.Tensor)) for t in tensor_list):
raise RuntimeError("Invalid function argument. Expected parameter {}"
"to be of type paddle.Tensor".format(tensor_name))
def isend(tensor, dst, group=None):
"""
Sends a tensor asynchronously
Args:
tensor (Tensor): The Tensor to send. Its data type
should be float16, float32, float64, int32 or int64.
dst (int): The destination rank.
group (Group, optional): The group instance return by new_group or None for global default group. Default: None.
Returns:
A distributed task object.
Warning:
This API only supports the dygraph mode.
Examples:
.. code-block:: python
# required: distributed
import paddle
import paddle.distributed as dist
dist.init_parallel_env()
rank = dist.get_rank()
world_size = dist.get_world_size()
if rank == 0:
data = paddle.to_tensor([7, 8, 9])
task = paddle.distributed.isend(data, dst=1)
else:
data = paddle.to_tensor([1, 2, 3])
task = paddle.distributed.irecv(data, src=0)
task.wait()
print(data)
# paddle.tensor([7, 8, 9]) # Rank-0
# paddle.tensor([7, 8, 9]) # Rank-1
"""
_check_single_tensor(tensor, "tensor")
if group is not None and not group.is_member():
return
if in_dygraph_mode():
group = _get_default_group() if group is None else group
group_dst_rank = group.get_group_rank(dst)
assert group_dst_rank >= 0, ("dst rank out of group, need global rank")
return group.process_group.send(tensor, group_dst_rank)
else:
raise RuntimeError("Don't support static graph mode currently.")
def irecv(tensor, src=None, group=None):
"""
Receive a tensor to the sender.
Args:
tensor (Tensor): The Tensor to receive. Its data type
should be float16, float32, float64, int32 or int64.
src (int): The source rank id.
group (Group, optional): The group instance return by new_group or None for global default group. Default: None.
Returns:
A distributed task object.
Warning:
This API only supports the dygraph mode.
Examples:
.. code-block:: python
# required: distributed
import paddle
import paddle.distributed as dist
dist.init_parallel_env()
rank = dist.get_rank()
world_size = dist.get_world_size()
if rank == 0:
data = paddle.to_tensor([7, 8, 9])
task = paddle.distributed.isend(data, dst=1)
else:
data = paddle.to_tensor([1, 2, 3])
task = paddle.distributed.irecv(data, src=0)
task.wait()
print(data)
# paddle.tensor([7, 8, 9]) # Rank-0
# paddle.tensor([7, 8, 9]) # Rank-1
"""
_check_single_tensor(tensor, "tensor")
if group is not None and not group.is_member():
return
if in_dygraph_mode():
group = _get_default_group() if group is None else group
group_src_rank = group.get_group_rank(src)
assert group_src_rank >= 0, ("src rank out of group, need global rank")
return group.process_group.recv(tensor, group_src_rank)
else:
raise RuntimeError("Don't support static graph mode currently.")
class P2POp(object):
"""
A class that makes point-to-point operations for "batch_isend_irecv".
This class creates the type of P2P operation, communication buffer, peer rank,
Group. Instances of this class will be passed to
``paddle.distributed.batch_isend_irecv`` for point-to-point communication.
Args:
op (callable): A function to send data to or receive data from a peer process.
The type of ``op`` is either ``paddle.distributed.isend`` or ``paddle.distributed.irecv``.
tensor (Tensor): Tensor to send or receive.
peer (int): The destination or source rank.
group (Group, optional): The group instance return by new_group or None for global
default group. Default: None.
"""
def __init__(self, op, tensor, peer, group=None):
if op not in [isend, irecv]:
raise RuntimeError("Invalid ``op`` function. Expected ``op`` "
"to be of type ``paddle.distributed.isend`` or "
"``paddle.distributed.irecv``.")
_check_single_tensor(tensor, "tensor")
self.op = op
self.tensor = tensor
self.peer = peer
self.group = _get_default_group() if group is None else group
@contextlib.contextmanager
def _with_batch_p2p_guard(backend):
if backend == "nccl":
core.ProcessGroupNCCL.group_start()
try:
yield
finally:
if backend == "nccl":
core.ProcessGroupNCCL.group_end()
def _check_p2p_op_list(p2p_op_list):
"""
Helper to check that the ``p2p_op_list`` is a list of P2POp instances and
all ops use the same backend.
"""
if not isinstance(p2p_op_list, list) or not all(
isinstance(p2p_op, P2POp) for p2p_op in p2p_op_list):
raise RuntimeError("Invalid ``p2p_op_list``. Each op is expected to "
"to be of type ``paddle.distributed.P2POp``.")
backend = _group_map_backend[p2p_op_list[0].group]
if not all(backend == _group_map_backend[p2p_op.group]
for p2p_op in p2p_op_list):
raise RuntimeError("All groups need to use the same backend.")
def batch_isend_irecv(p2p_op_list):
"""
Send or Receive a batch of tensors asynchronously and return a list of requests.
Process each of the point-to-point operations in ``p2p_op_list`` and return the
corresponding tasks. NCCL are currently supported.
Args:
p2p_op_list: A list of point-to-point operations(type of each operator is
``paddle.distributed.P2POp``). The order of the isend/irecv in the list
matters and it needs to match with corresponding isend/irecv on the
remote end.
Returns:
A list of distributed tasks returned by calling the corresponding
op in the op_list.
Warning:
This API only supports the dygraph mode.
Examples:
.. code-block:: python
# required: distributed
import paddle
import paddle.distributed as dist
dist.init_parallel_env()
rank = dist.get_rank()
world_size = dist.get_world_size()
send_t = paddle.arange(2) + rank
# paddle.tensor([0, 1]) # Rank-0
# paddle.tensor([1, 2]) # Rank-1
recv_t = paddle.empty(shape=[2], dtype=send_t.dtype)
send_op = dist.P2POp(dist.isend, send_t, (rank + 1) % world_size)
recv_op = dist.P2POp(dist.irecv, recv_t, (rank - 1 + world_size) % world_size)
tasks = dist.batch_isend_irecv([send_op, recv_op])
for task in tasks:
task.wait()
print(recv_t)
# paddle.tensor([1, 2]) # Rank-0
# paddle.tensor([0, 1]) # Rank-1
"""
_check_p2p_op_list(p2p_op_list)
group = p2p_op_list[0].group
if group is not None and not group.is_member():
return
if in_dygraph_mode():
group = _get_default_group() if group is None else group
backend = _group_map_backend[group]
tasks = []
with _with_batch_p2p_guard(backend):
for p2p_op in p2p_op_list:
op = p2p_op.op
tensor = p2p_op.tensor
peer = p2p_op.peer
comm_group = p2p_op.group
task = op(tensor, peer, comm_group)
if task is not None:
tasks.append(task)
return tasks
else:
raise RuntimeError("Don't support static graph mode currently.")
def reduce_scatter(tensor,
tensor_list,
op=ReduceOp.SUM,
group=None,
use_calc_stream=True):
"""
Reduces, then scatters a list of tensors to all processes in a group
Args:
tensor (Tensor): Output tensor.
tensor_list (list[Tensor]): List of tensors to reduce and scatter.
op (ReduceOp.SUM|ReduceOp.MAX|ReduceOp.Min|ReduceOp.PROD): Optional. The operation used. Default: ReduceOp.SUM.
group (Group, optional): The group instance return by new_group or None for global
default group. Default: None.
use_calc_stream (bool, optional): Whether this op should be an async op.
Returns:
Async task handle, if use_calc_stream is set to False.
None, if use_calc_stream or if not part of the group.
Warning:
This API only supports the dygraph mode.
Examples:
.. code-block:: python
# required: distributed
import paddle
import paddle.distributed as dist
dist.init_parallel_env()
rank = dist.get_rank()
world_size = dist.get_world_size()
if rank == 0:
t1 = paddle.to_tensor([0, 1])
t2 = paddle.to_tensor([2, 3])
else:
t1 = paddle.to_tensor([4, 5])
t2 = paddle.to_tensor([6, 7])
tensor_list = [t1, t2]
output = paddle.empty(shape=[2], dtype=tensor_list[0].dtype)
dist.reduce_scatter(output, tensor_list)
print(output)
# [4, 6] # Rank-0
# [8, 10] # Rank-1
"""
_check_single_tensor(tensor, "tensor")
_check_tensor_list(tensor_list, "tensor_list")
if group is not None and not group.is_member():
return
if in_dygraph_mode():
op_type = _get_reduce_op(op, "reduce_scatter")
group = _get_default_group() if group is None else group
temp = paddle.concat(tensor_list, axis=0)
task = group.process_group._reduce_scatter_base(tensor, temp, op_type)
if use_calc_stream:
task.wait()
return None
else:
return task
else:
raise RuntimeError("Don't support static graph mode currently.")
def _reduce_scatter_base(output,
input,
op=ReduceOp.SUM,
group=None,
use_calc_stream=True):
"""
Reduces, then scatters a flattened tensor to all processes in a group.
Args:
output (Tensor): Output tensor.
input (Tensor): Input tensor that is of size output tensor size times world size
op (ReduceOp.SUM|ReduceOp.MAX|ReduceOp.Min|ReduceOp.PROD): Optional. The operation used. Default: ReduceOp.SUM.
group (ProcessGroup, optional): The process group to work on. If None,
the default process group will be used.
use_calc_stream (bool, optional): Wether to use calculation stream (True) or communication stream (False).
Default to True.
Returns:
Async task handle, if use_calc_stream is set to False.
None, if use_calc_stream or if not part of the group.
Examples:
.. code-block:: python
# required: distributed
import paddle
import paddle.distributed as dist
dist.init_parallel_env()
rank = dist.get_rank()
world_size = dist.get_world_size()
input = paddle.arange(4) + rank
# [0, 1, 2, 3] # Rank-0
# [1, 2, 3, 4] # Rank-1
output = paddle.empty(shape=[2], dtype=input.dtype)
paddle.distributed.collective._reduce_scatter_base(output, input)
print(output)
# [1, 3] # Rank-0
# [5, 7] # Rank-1
"""
_check_single_tensor(output, "output")
_check_single_tensor(input, "input")
if group is not None and not group.is_member():
return
if in_dygraph_mode():
op_type = _get_reduce_op(op, "_reduce_scatter_base")
group = _get_default_group() if group is None else group
task = group.process_group._reduce_scatter_base(output, input, op_type)
if use_calc_stream:
task.wait()
return None
else:
return task
else:
raise RuntimeError("Don't support static graph mode currently.")
......@@ -42,6 +42,7 @@ from paddle.distributed.collective import _set_default_backend
from paddle.distributed.collective import _set_default_store
from paddle.distributed.collective import _new_process_group_impl
from paddle.distributed.collective import Group
from paddle.distributed.collective import _set_group_map_backend
__all__ = []
......@@ -257,6 +258,7 @@ def init_parallel_env():
name=_default_group_name)
_set_group_map_by_name(_default_group_name, group)
_set_group_map(0, group)
_set_group_map_backend(group, backend)
parallel_helper._set_parallel_ctx(True)
paddle.distributed.barrier(group=group)
......
......@@ -72,7 +72,10 @@ list(APPEND DIST_TEST_OPS test_auto_parallel_data_unshard)
list(APPEND DIST_TEST_OPS test_auto_parallel_save_load)
list(APPEND DIST_TEST_OPS test_auto_parallel_autoconvert)
list(APPEND DIST_TEST_OPS test_collective_process_group)
list(APPEND DIST_TEST_OPS test_collective_alltoall_single)
list(APPEND DIST_TEST_OPS test_eager_dist_api)
list(APPEND DIST_TEST_OPS test_collective_batch_isend_irecv)
list(APPEND DIST_TEST_OPS test_collective_reduce_scatter)
set(MIXED_DIST_TEST_OPS ${DIST_TEST_OPS})
#remove distribute unittests.
list(APPEND MIXED_DIST_TEST_OPS test_dgc_op)
......@@ -334,7 +337,11 @@ if((NOT WITH_GPU) AND (NOT WITH_ROCM))
list(REMOVE_ITEM TEST_OPS test_auto_parallel_save_load)
list(REMOVE_ITEM TEST_OPS test_auto_parallel_autoconvert)
list(REMOVE_ITEM TEST_OPS test_collective_process_group)
list(REMOVE_ITEM TEST_OPS test_collective_alltoall_single)
list(REMOVE_ITEM TEST_OPS test_eager_dist_api)
list(REMOVE_ITEM TEST_OPS test_collective_batch_isend_irecv)
list(REMOVE_ITEM TEST_OPS test_collective_reduce_scatter)
elseif(WITH_GPU)
if(${CUDNN_VERSION} VERSION_LESS 7100)
list(REMOVE_ITEM TEST_OPS test_conv2d_fusion_op)
......@@ -1569,8 +1576,10 @@ if(WITH_DISTRIBUTE
set_tests_properties(test_auto_parallel_save_load PROPERTIES TIMEOUT 120)
set_tests_properties(test_auto_parallel_autoconvert PROPERTIES TIMEOUT 120)
set_tests_properties(test_collective_process_group PROPERTIES TIMEOUT 120)
set_tests_properties(test_collective_alltoall_single PROPERTIES TIMEOUT 60)
set_tests_properties(test_eager_dist_api PROPERTIES TIMEOUT 100)
set_tests_properties(test_collective_batch_isend_irecv PROPERTIES TIMEOUT 100)
set_tests_properties(test_collective_reduce_scatter PROPERTIES TIMEOUT 100)
if(${NCCL_VERSION} VERSION_GREATER_EQUAL 2212)
set_tests_properties(test_parallel_dygraph_sparse_embedding
PROPERTIES TIMEOUT 200)
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import division
from __future__ import print_function
import unittest
import paddle
import numpy as np
import random
import paddle.distributed as dist
import paddle.fluid as fluid
import paddle.distributed.fleet as fleet
from paddle import framework
class TestCollectiveAllToAllSingle(unittest.TestCase):
def setUp(self):
assert not paddle.distributed.is_initialized(), \
"The distributed environment has not been initialized."
dist.init_parallel_env()
assert paddle.distributed.is_initialized(), \
"The distributed environment has been initialized."
paddle.fluid.set_flags({"FLAGS_retain_grad_for_all_tensor": True})
def test_collective_alltoall_single(self):
rank = dist.get_rank()
size = dist.get_world_size()
# case 1
input = paddle.ones([size, size], dtype='int64') * rank
output = paddle.empty([size, size], dtype='int64')
expected_output = paddle.concat(
[paddle.ones([1, size], dtype='int64') * i for i in range(size)])
group = dist.new_group([0, 1])
dist.alltoall_single(input, output, group=group)
np.testing.assert_allclose(output.numpy(), expected_output.numpy())
dist.destroy_process_group(group)
# case 2
in_split_sizes = [i + 1 for i in range(size)]
out_split_sizes = [rank + 1 for i in range(size)]
input = paddle.ones([sum(in_split_sizes), size], dtype='float32') * rank
output = paddle.empty([(rank + 1) * size, size], dtype='float32')
expected_output = paddle.concat([
paddle.ones([rank + 1, size], dtype='float32') * i
for i in range(size)
])
group = dist.new_group([0, 1])
task = dist.alltoall_single(input,
output,
in_split_sizes,
out_split_sizes,
use_calc_stream=False,
group=group)
task.wait()
np.testing.assert_allclose(output.numpy(), expected_output.numpy())
dist.destroy_process_group(group)
def tearDown(self):
dist.destroy_process_group()
assert not paddle.distributed.is_initialized(), \
"The distributed environment has been deinitialized."
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import division
from __future__ import print_function
import unittest
import paddle
import numpy as np
import random
import paddle.distributed as dist
import paddle.fluid as fluid
import paddle.distributed.fleet as fleet
from paddle import framework
class TestCollectiveBatchIsendIrecv(unittest.TestCase):
def setUp(self):
dist.init_parallel_env()
paddle.fluid.set_flags({"FLAGS_retain_grad_for_all_tensor": True})
def test_collective_batch_isend_irecv(self):
rank = dist.get_rank()
world_size = dist.get_world_size()
send_t = paddle.arange(2) + rank
# paddle.tensor([0, 1]) # Rank-0
# paddle.tensor([1, 2]) # Rank-1
recv_t = paddle.empty(shape=[2], dtype=send_t.dtype)
send_op = dist.P2POp(dist.isend, send_t, (rank + 1) % world_size)
recv_op = dist.P2POp(dist.irecv, recv_t,
(rank - 1 + world_size) % world_size)
tasks = dist.batch_isend_irecv([send_op, recv_op])
for task in tasks:
task.wait()
if rank == 0:
np.testing.assert_allclose(recv_t.numpy(), [1, 2])
elif rank == 1:
np.testing.assert_allclose(recv_t.numpy(), [0, 1])
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import division
from __future__ import print_function
import unittest
import paddle
import numpy as np
import random
import paddle.distributed as dist
import paddle.fluid as fluid
import paddle.distributed.fleet as fleet
from paddle import framework
class TestCollectiveReduceScatter(unittest.TestCase):
def setUp(self):
dist.init_parallel_env()
paddle.fluid.set_flags({"FLAGS_retain_grad_for_all_tensor": True})
def test_collective_reduce_scatter_sum(self):
rank = dist.get_rank()
world_size = dist.get_world_size()
if rank == 0:
t1 = paddle.to_tensor([0, 1])
t2 = paddle.to_tensor([2, 3])
else:
t1 = paddle.to_tensor([4, 5])
t2 = paddle.to_tensor([6, 7])
input_list = [t1, t2]
output = paddle.empty(shape=[2], dtype=input_list[0].dtype)
dist.reduce_scatter(output, input_list)
if rank == 0:
np.testing.assert_allclose(output.numpy(), [4, 6])
elif rank == 1:
np.testing.assert_allclose(output.numpy(), [8, 10])
def test_collective_reduce_scatter_max(self):
rank = dist.get_rank()
world_size = dist.get_world_size()
if rank == 0:
t1 = paddle.to_tensor([0, 1], dtype="float16")
t2 = paddle.to_tensor([2, 3], dtype="float16")
else:
t1 = paddle.to_tensor([4, 5], dtype="float16")
t2 = paddle.to_tensor([6, 7], dtype="float16")
input_list = [t1, t2]
output = paddle.empty(shape=[2], dtype=input_list[0].dtype)
dist.reduce_scatter(output, input_list, op=dist.ReduceOp.MAX)
if rank == 0:
np.testing.assert_allclose(output.numpy(), [4, 5])
elif rank == 1:
np.testing.assert_allclose(output.numpy(), [6, 7])
def test_collective_reduce_scatter_base(self):
rank = dist.get_rank()
world_size = dist.get_world_size()
input = paddle.arange(4) + rank
# [0, 1, 2, 3] # Rank-0
# [1, 2, 3, 4] # Rank-1
output = paddle.empty(shape=[2], dtype=input.dtype)
task = paddle.distributed.collective._reduce_scatter_base(
output, input, use_calc_stream=False)
task.wait()
if rank == 0:
np.testing.assert_allclose(output.numpy(), [1, 3])
elif rank == 1:
np.testing.assert_allclose(output.numpy(), [5, 7])
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import os
import unittest
import paddle.fluid as fluid
from test_parallel_dygraph_dataparallel import TestMultipleGpus
class TestCollectiveAllToAllSingle(TestMultipleGpus):
def test_collective_alltoall_single(self):
self.run_mnist_2gpu('collective_alltoall_single.py', eager_mode=True)
if __name__ == "__main__":
os.environ["FLAGS_enable_eager_mode"] = "1"
unittest.main()
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import os
import unittest
import paddle.fluid as fluid
from test_parallel_dygraph_dataparallel import TestMultipleGpus
class TestCollectiveBatchIsendIrecv(TestMultipleGpus):
def test_collective_batch_isend_irecv(self):
self.run_mnist_2gpu('collective_batch_isend_irecv.py', eager_mode=True)
if __name__ == "__main__":
os.environ["FLAGS_enable_eager_mode"] = "1"
unittest.main()
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import os
import unittest
import paddle.fluid as fluid
from test_parallel_dygraph_dataparallel import TestMultipleGpus
class TestCollectiveReduceScatter(TestMultipleGpus):
def test_collective_reduce_scatter(self):
self.run_mnist_2gpu('collective_reduce_scatter.py', eager_mode=True)
if __name__ == "__main__":
os.environ["FLAGS_enable_eager_mode"] = "1"
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册