提交 f551d44e 编写于 作者: M Megvii Engine Team

feat(mge): rename arange paramter end -> stop

GitOrigin-RevId: ea9064e351985e3b8a43d620be5986b1bd5830d1
上级 9c5ea689
...@@ -929,15 +929,15 @@ def linspace( ...@@ -929,15 +929,15 @@ def linspace(
def arange( def arange(
start: Union[int, float, Tensor] = 0, start: Union[int, float, Tensor] = 0,
end: Optional[Union[int, float, Tensor]] = None, stop: Optional[Union[int, float, Tensor]] = None,
step: Union[int, float, Tensor] = 1, step: Union[int, float, Tensor] = 1,
dtype="float32", dtype="float32",
device: Optional[CompNode] = None, device: Optional[CompNode] = None,
) -> Tensor: ) -> Tensor:
r"""Returns a tensor with values from start to end with adjacent interval step. r"""Returns a tensor with values from start to stop with adjacent interval step.
:param start: starting value of the squence, shoule be scalar. :param start: starting value of the squence, shoule be scalar.
:param end: ending value of the squence, shoule be scalar. :param stop: ending value of the squence, shoule be scalar.
:param step: gap between each pair of adjacent values. Default: 1 :param step: gap between each pair of adjacent values. Default: 1
:param dtype: result data type. :param dtype: result data type.
:return: generated tensor. :return: generated tensor.
...@@ -961,16 +961,16 @@ def arange( ...@@ -961,16 +961,16 @@ def arange(
[0. 1. 2. 3. 4.] [0. 1. 2. 3. 4.]
""" """
if end is None: if stop is None:
start, end = 0, start start, stop = 0, start
if isinstance(start, Tensor): if isinstance(start, Tensor):
start = start.astype("float32") start = start.astype("float32")
if isinstance(end, Tensor): if isinstance(stop, Tensor):
end = end.astype("float32") stop = stop.astype("float32")
if isinstance(step, Tensor): if isinstance(step, Tensor):
step = step.astype("float32") step = step.astype("float32")
num = ceil(Tensor((end - start) / step, device=device)) num = ceil(Tensor((stop - start) / step, device=device))
stop = start + step * (num - 1) stop = start + step * (num - 1)
result = linspace(start, stop, num, device=device) result = linspace(start, stop, num, device=device)
if np.dtype(dtype) == np.int32: if np.dtype(dtype) == np.int32:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册