提交 4485e780 编写于 作者: M Megvii Engine Team

fix(mge/functional): fix linspace device and open other trace tests

GitOrigin-RevId: 4667c4adec0867f09cd4c030c457e0174fa6f908
上级 334eda87
...@@ -910,7 +910,7 @@ def linspace( ...@@ -910,7 +910,7 @@ def linspace(
import numpy as np import numpy as np
import megengine.functional as F import megengine.functional as F
a = F.linspace(3,10,5) a = F.linspace(3, 10, 5)
print(a.numpy()) print(a.numpy())
Outputs: Outputs:
...@@ -920,8 +920,19 @@ def linspace( ...@@ -920,8 +920,19 @@ def linspace(
[ 3. 4.75 6.5 8.25 10. ] [ 3. 4.75 6.5 8.25 10. ]
""" """
for item in (start, stop, num):
cur_device = getattr(item, "device", None)
if device is None:
device = cur_device
else:
if not (cur_device is None or device == cur_device):
raise ("ambiguous device for linspace opr")
if not isinstance(start, Tensor):
start = Tensor(start, device=device) start = Tensor(start, device=device)
if not isinstance(stop, Tensor):
stop = Tensor(stop, device=device) stop = Tensor(stop, device=device)
if not isinstance(num, Tensor):
num = Tensor(num, device=device) num = Tensor(num, device=device)
op = builtin.Linspace(comp_node=device) op = builtin.Linspace(comp_node=device)
......
...@@ -114,18 +114,12 @@ def test_matmul(): ...@@ -114,18 +114,12 @@ def test_matmul():
{"input": [data3, data4]}, {"input": [data3, data4]},
{"input": [data4, data5]}, {"input": [data4, data5]},
] ]
for _ in range(0, batch_size): opr_test(cases, F.matmul, ref_fn=np.matmul)
# FIXME: remove test_trace=False in the future
opr_test(
cases, F.matmul, test_trace=False, ref_fn=np.matmul,
)
# FIXME: remove test_trace=False in the future
opr_test( opr_test(
[{"input": [data1, data4]}], [{"input": [data1, data4]}],
F.matmul, F.matmul,
ref_fn=lambda x, y: np.matmul(x, y.transpose(0, 1, 3, 2)), ref_fn=lambda x, y: np.matmul(x, y.transpose(0, 1, 3, 2)),
test_trace=False,
transpose_b=True, transpose_b=True,
) )
......
...@@ -162,24 +162,30 @@ def test_linspace(): ...@@ -162,24 +162,30 @@ def test_linspace():
{"input": [1, 9, 9]}, {"input": [1, 9, 9]},
{"input": [3, 10, 8]}, {"input": [3, 10, 8]},
] ]
# FIXME: remove test_trace=False in the future
opr_test( opr_test(
cases, cases,
F.linspace, F.linspace,
ref_fn=lambda start, end, step: np.linspace(start, end, step, dtype=np.float32), ref_fn=lambda start, end, step: np.linspace(start, end, step, dtype=np.float32),
test_trace=False,
) )
cases = [ cases = [
{"input": [9, 1, 9]}, {"input": [9, 1, 9]},
{"input": [10, 3, 8]}, {"input": [10, 3, 8]},
] ]
# FIXME: remove test_trace=False in the future
opr_test( opr_test(
cases, cases,
F.linspace, F.linspace,
ref_fn=lambda start, end, step: np.linspace(start, end, step, dtype=np.float32), ref_fn=lambda start, end, step: np.linspace(start, end, step, dtype=np.float32),
test_trace=False, )
cases = [
{"input": [1, tensor(9), 9]},
{"input": [tensor(1), 9, tensor(9)]},
]
opr_test(
cases,
F.linspace,
ref_fn=lambda start, end, step: np.linspace(1, 9, 9, dtype=np.float32),
) )
...@@ -188,36 +194,30 @@ def test_arange(): ...@@ -188,36 +194,30 @@ def test_arange():
{"input": [1, 9, 1]}, {"input": [1, 9, 1]},
{"input": [2, 10, 2]}, {"input": [2, 10, 2]},
] ]
# FIXME: remove test_trace=False in the future
opr_test( opr_test(
cases, cases,
F.arange, F.arange,
ref_fn=lambda start, end, step: np.arange(start, end, step, dtype=np.float32), ref_fn=lambda start, end, step: np.arange(start, end, step, dtype=np.float32),
test_trace=False,
) )
cases = [ cases = [
{"input": [9, 1, -1]}, {"input": [9, 1, -1]},
{"input": [10, 2, -2]}, {"input": [10, 2, -2]},
] ]
# FIXME: remove test_trace=False in the future
opr_test( opr_test(
cases, cases,
F.arange, F.arange,
ref_fn=lambda start, end, step: np.arange(start, end, step, dtype=np.float32), ref_fn=lambda start, end, step: np.arange(start, end, step, dtype=np.float32),
test_trace=False,
) )
cases = [ cases = [
{"input": [9.3, 1.2, -0.5]}, {"input": [9.3, 1.2, -0.5]},
{"input": [10.3, 2.1, -1.7]}, {"input": [10.3, 2.1, -1.7]},
] ]
# FIXME: remove test_trace=False in the future
opr_test( opr_test(
cases, cases,
F.arange, F.arange,
ref_fn=lambda start, end, step: np.arange(start, end, step, dtype=np.float32), ref_fn=lambda start, end, step: np.arange(start, end, step, dtype=np.float32),
test_trace=False,
) )
...@@ -289,8 +289,7 @@ def test_broadcast(): ...@@ -289,8 +289,7 @@ def test_broadcast():
{"input": [data1, output1_shape], "output": output1_shape}, {"input": [data1, output1_shape], "output": output1_shape},
{"input": [data2, output2_shape], "output": output2_shape}, {"input": [data2, output2_shape], "output": output2_shape},
] ]
# FIXME: remove test_trace=False in the future opr_test(cases, F.broadcast_to, compare_fn=compare_fn)
opr_test(cases, F.broadcast_to, compare_fn=compare_fn, test_trace=False)
x = F.ones((2, 1, 3)) x = F.ones((2, 1, 3))
with pytest.raises(RuntimeError): with pytest.raises(RuntimeError):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册