未验证 提交 2010bdc3 编写于 作者: W Wen Sun 提交者: GitHub

Fix collective APIs cannot be recognized when building docs (#46962)

上级 a1b99978
...@@ -76,7 +76,7 @@ __all__ = [ # noqa ...@@ -76,7 +76,7 @@ __all__ = [ # noqa
"gloo_release", "QueueDataset", "split", "CountFilterEntry", "gloo_release", "QueueDataset", "split", "CountFilterEntry",
"ShowClickEntry", "get_world_size", "get_group", "all_gather", "ShowClickEntry", "get_world_size", "get_group", "all_gather",
"all_gather_object", "InMemoryDataset", "barrier", "all_reduce", "alltoall", "all_gather_object", "InMemoryDataset", "barrier", "all_reduce", "alltoall",
"send", "reduce", "recv", "ReduceOp", "wait", "get_rank", "alltoall_single", "send", "reduce", "recv", "ReduceOp", "wait", "get_rank",
"ProbabilityEntry", "ParallelMode", "is_initialized", "isend", "irecv", "ProbabilityEntry", "ParallelMode", "is_initialized",
"reduce_scatter", "rpc" "destroy_process_group", "isend", "irecv", "reduce_scatter", "rpc", "stream"
] ]
...@@ -18,12 +18,12 @@ from .alltoall import alltoall ...@@ -18,12 +18,12 @@ from .alltoall import alltoall
from .alltoall_single import alltoall_single from .alltoall_single import alltoall_single
from .broadcast import broadcast from .broadcast import broadcast
from .reduce import reduce from .reduce import reduce
from .reduce_scatter import _reduce_scatter_base, reduce_scatter from .reduce_scatter import reduce_scatter
from .recv import recv from .recv import recv
from .scatter import scatter from .scatter import scatter
from .send import send from .send import send
__all__ = [ __all__ = [
"_reduce_scatter_base", "all_reduce", "alltoall", "alltoall_single", "all_gather", "all_reduce", "alltoall", "alltoall_single", "broadcast",
"broadcast", "reduce", "reduce_scatter", "recv", "scatter", "send" "reduce", "reduce_scatter", "recv", "scatter", "send"
] ]
...@@ -17,6 +17,7 @@ import numpy as np ...@@ -17,6 +17,7 @@ import numpy as np
import paddle import paddle
import paddle.distributed as dist import paddle.distributed as dist
import test_collective_api_base as test_collective_base import test_collective_api_base as test_collective_base
from paddle.distributed.communication.stream.reduce_scatter import _reduce_scatter_base
class StreamReduceScatterTestCase(): class StreamReduceScatterTestCase():
...@@ -77,11 +78,10 @@ class StreamReduceScatterTestCase(): ...@@ -77,11 +78,10 @@ class StreamReduceScatterTestCase():
# case 3: test the legacy API # case 3: test the legacy API
result_tensor = paddle.empty_like(t1) result_tensor = paddle.empty_like(t1)
task = dist.stream._reduce_scatter_base( task = _reduce_scatter_base(result_tensor,
result_tensor, tensor,
tensor, sync_op=self._sync_op,
sync_op=self._sync_op, use_calc_stream=self._use_calc_stream)
use_calc_stream=self._use_calc_stream)
if not self._sync_op: if not self._sync_op:
task.wait() task.wait()
if rank == 0: if rank == 0:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册