diff --git a/imperative/python/megengine/functional/tensor.py b/imperative/python/megengine/functional/tensor.py index ba729c68818c4be951ffb70f12f59c99bb1d8d60..0266da1336153325fe66ff44b8e0e2b4f7feab8a 100644 --- a/imperative/python/megengine/functional/tensor.py +++ b/imperative/python/megengine/functional/tensor.py @@ -910,7 +910,7 @@ def linspace( import numpy as np import megengine.functional as F - a = F.linspace(3,10,5) + a = F.linspace(3, 10, 5) print(a.numpy()) Outputs: @@ -920,9 +920,20 @@ def linspace( [ 3. 4.75 6.5 8.25 10. ] """ - start = Tensor(start, device=device) - stop = Tensor(stop, device=device) - num = Tensor(num, device=device) + for item in (start, stop, num): + cur_device = getattr(item, "device", None) + if device is None: + device = cur_device + else: + if not (cur_device is None or device == cur_device): + raise ("ambiguous device for linspace opr") + + if not isinstance(start, Tensor): + start = Tensor(start, device=device) + if not isinstance(stop, Tensor): + stop = Tensor(stop, device=device) + if not isinstance(num, Tensor): + num = Tensor(num, device=device) op = builtin.Linspace(comp_node=device) (result,) = apply(op, start, stop, num) diff --git a/imperative/python/test/unit/functional/test_functional.py b/imperative/python/test/unit/functional/test_functional.py index 21d22c777cc3fa1336f3321dec9b446a0140554a..d6df37361c5194f04b6298a54a065e0475e4c3fb 100644 --- a/imperative/python/test/unit/functional/test_functional.py +++ b/imperative/python/test/unit/functional/test_functional.py @@ -114,18 +114,12 @@ def test_matmul(): {"input": [data3, data4]}, {"input": [data4, data5]}, ] - for _ in range(0, batch_size): - # FIXME: remove test_trace=False in the future - opr_test( - cases, F.matmul, test_trace=False, ref_fn=np.matmul, - ) + opr_test(cases, F.matmul, ref_fn=np.matmul) - # FIXME: remove test_trace=False in the future opr_test( [{"input": [data1, data4]}], F.matmul, ref_fn=lambda x, y: np.matmul(x, y.transpose(0, 1, 3, 2)), - test_trace=False, transpose_b=True, ) diff --git a/imperative/python/test/unit/functional/test_tensor.py b/imperative/python/test/unit/functional/test_tensor.py index c24cafc37bcebc286c44d92a2eda960b99d9d75c..b4039fc7c73aeb99d02d5c4f97c85ca5b1bc9e95 100644 --- a/imperative/python/test/unit/functional/test_tensor.py +++ b/imperative/python/test/unit/functional/test_tensor.py @@ -162,24 +162,30 @@ def test_linspace(): {"input": [1, 9, 9]}, {"input": [3, 10, 8]}, ] - # FIXME: remove test_trace=False in the future opr_test( cases, F.linspace, ref_fn=lambda start, end, step: np.linspace(start, end, step, dtype=np.float32), - test_trace=False, ) cases = [ {"input": [9, 1, 9]}, {"input": [10, 3, 8]}, ] - # FIXME: remove test_trace=False in the future opr_test( cases, F.linspace, ref_fn=lambda start, end, step: np.linspace(start, end, step, dtype=np.float32), - test_trace=False, + ) + + cases = [ + {"input": [1, tensor(9), 9]}, + {"input": [tensor(1), 9, tensor(9)]}, + ] + opr_test( + cases, + F.linspace, + ref_fn=lambda start, end, step: np.linspace(1, 9, 9, dtype=np.float32), ) @@ -188,36 +194,30 @@ def test_arange(): {"input": [1, 9, 1]}, {"input": [2, 10, 2]}, ] - # FIXME: remove test_trace=False in the future opr_test( cases, F.arange, ref_fn=lambda start, end, step: np.arange(start, end, step, dtype=np.float32), - test_trace=False, ) cases = [ {"input": [9, 1, -1]}, {"input": [10, 2, -2]}, ] - # FIXME: remove test_trace=False in the future opr_test( cases, F.arange, ref_fn=lambda start, end, step: np.arange(start, end, step, dtype=np.float32), - test_trace=False, ) cases = [ {"input": [9.3, 1.2, -0.5]}, {"input": [10.3, 2.1, -1.7]}, ] - # FIXME: remove test_trace=False in the future opr_test( cases, F.arange, ref_fn=lambda start, end, step: np.arange(start, end, step, dtype=np.float32), - test_trace=False, ) @@ -289,8 +289,7 @@ def test_broadcast(): {"input": [data1, output1_shape], "output": output1_shape}, {"input": [data2, output2_shape], "output": output2_shape}, ] - # FIXME: remove test_trace=False in the future - opr_test(cases, F.broadcast_to, compare_fn=compare_fn, test_trace=False) + opr_test(cases, F.broadcast_to, compare_fn=compare_fn) x = F.ones((2, 1, 3)) with pytest.raises(RuntimeError):