diff --git a/python_module/megengine/functional/tensor.py b/python_module/megengine/functional/tensor.py index 31aedc270df6deff739936760ad656aaf8efe03c..2cc58a57edf3174189ec9294e48ff30c0d64b9a2 100644 --- a/python_module/megengine/functional/tensor.py +++ b/python_module/megengine/functional/tensor.py @@ -515,7 +515,7 @@ def remove_axis(inp: Tensor, axis: Union[int, Iterable[int]]) -> Tensor: def linspace( start: Union[int, float, Tensor], stop: Union[int, float, Tensor], - num: int = 100, + num: Union[int, Tensor], dtype=np.float32, device: Optional[CompNode] = None, comp_graph: Optional[CompGraph] = None, diff --git a/python_module/test/unit/functional/test_functional.py b/python_module/test/unit/functional/test_functional.py index 3b88cd8c15b7c7e57bbc7f6b43108083da62af75..b9f0cebf466c0214089cba4da3ae1d0b0d2d7251 100644 --- a/python_module/test/unit/functional/test_functional.py +++ b/python_module/test/unit/functional/test_functional.py @@ -157,6 +157,28 @@ def test_broadcast_to(): opr_test(cases, F.broadcast_to, compare_fn=compare_fn) +def test_linspace(): + cases = [ + {"input": [1, 9, 9]}, + {"input": [3, 10, 8]}, + ] + opr_test( + cases, + F.linspace, + ref_fn=lambda start, end, step: np.linspace(start, end, step, dtype=np.float32), + ) + + cases = [ + {"input": [9, 1, 9]}, + {"input": [10, 3, 8]}, + ] + opr_test( + cases, + F.linspace, + ref_fn=lambda start, end, step: np.linspace(start, end, step, dtype=np.float32), + ) + + def test_arange(): cases = [ {"input": [1, 9, 1]},