提交 e8d0f9db 编写于 作者: M Megvii Engine Team

docs(example): replace dist.help testcode with doctest format

GitOrigin-RevId: 79a7b72ca3b33d6566d7b169b040fe73b540223a
上级 f5f9249a
......@@ -28,14 +28,14 @@ from .group import WORLD, Group, group_barrier, is_distributed, override_backend
def param_pack_split(inp: Tensor, offsets: list, shapes: list):
r"""Returns split tensor to tensor list as offsets and shapes described,
r"""Returns split tensor to list of tensors as offsets and shapes described,
only used for ``parampack``.
Args:
inp: input tensor.
offsets: offsets of outputs, length of `2 * n`,
while n is tensor nums you want to split,
format `[begin0, end0, begin1, end1]`.
offsets: offsets of outputs, length of ``2 * n``,
where ``n`` is the number of tensor you want to split,
format ``[begin0, end0, begin1, end1]``.
shapes: tensor shapes of outputs.
Returns:
......@@ -43,25 +43,14 @@ def param_pack_split(inp: Tensor, offsets: list, shapes: list):
Examples:
.. testcode::
import numpy as np
from megengine import tensor
from megengine.distributed.helper import param_pack_split
a = tensor(np.ones((10,), np.int32))
b, c = param_pack_split(a, [0, 1, 1, 10], [(1,), (3, 3)])
print(b.numpy())
print(c.numpy())
Outputs:
.. testoutput::
[1]
[[1 1 1]
[1 1 1]
[1 1 1]]
>>> a = F.ones(10)
>>> b, c = dist.helper.param_pack_split(a, [0, 1, 1, 10], [(1,), (3, 3)])
>>> b
Tensor([1.], device=xpux:0)
>>> c
Tensor([[1. 1. 1.]
[1. 1. 1.]
[1. 1. 1.]], device=xpux:0)
"""
op = ParamPackSplit()
op.offsets = offsets
......@@ -74,34 +63,22 @@ def param_pack_concat(inps: list, offsets: Tensor, offsets_val: list):
r"""Returns concated tensor, only used for ``parampack``.
Args:
inps: input tensors.
inps: list of input tensors.
offsets: device value of offsets.
offsets_val: offsets of inputs, length of `2 * n`,
format `[begin0, end0, begin1, end1]`.
offsets_val: offsets of inputs, length of ``2 * n``,
format ``[begin0, end0, begin1, end1]``.
Returns:
concated tensor.
Examples:
.. testcode::
import numpy as np
from megengine import tensor
from megengine.distributed.helper import param_pack_concat
a = tensor(np.ones((1,), np.int32))
b = tensor(np.ones((3, 3), np.int32))
offsets_val = [0, 1, 1, 10]
offsets = tensor(offsets_val, np.int32)
c = param_pack_concat([a, b], offsets, offsets_val)
print(c.numpy())
Outputs:
.. testoutput::
[1 1 1 1 1 1 1 1 1 1]
>>> a = F.ones(1)
>>> b = F.ones((3, 3))
>>> offsets_val = [0, 1, 1, 10]
>>> offsets = Tensor(offsets_val)
>>> c = dist.helper.param_pack_concat([a, b], offsets, offsets_val) # doctest: +SKIP
Tensor([1. 1. 1. 1. 1. 1. 1. 1. 1. 1.], device=xpux:0)
"""
op = ParamPackConcat()
op.offsets = offsets_val
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册