提交 bb92ee26 编写于 作者: M Megvii Engine Team 提交者: Xinran Xu

fix(mge/functional): refine doc and add test for linspace

GitOrigin-RevId: d6c7eea2c23ea37fa3cda668132aa6360fb24e51
上级 ce30d045
......@@ -515,7 +515,7 @@ def remove_axis(inp: Tensor, axis: Union[int, Iterable[int]]) -> Tensor:
def linspace(
start: Union[int, float, Tensor],
stop: Union[int, float, Tensor],
num: int = 100,
num: Union[int, Tensor],
dtype=np.float32,
device: Optional[CompNode] = None,
comp_graph: Optional[CompGraph] = None,
......
......@@ -157,6 +157,28 @@ def test_broadcast_to():
opr_test(cases, F.broadcast_to, compare_fn=compare_fn)
def test_linspace():
cases = [
{"input": [1, 9, 9]},
{"input": [3, 10, 8]},
]
opr_test(
cases,
F.linspace,
ref_fn=lambda start, end, step: np.linspace(start, end, step, dtype=np.float32),
)
cases = [
{"input": [9, 1, 9]},
{"input": [10, 3, 8]},
]
opr_test(
cases,
F.linspace,
ref_fn=lambda start, end, step: np.linspace(start, end, step, dtype=np.float32),
)
def test_arange():
cases = [
{"input": [1, 9, 1]},
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册