提交 4485e780 编写于 作者: M Megvii Engine Team

fix(mge/functional): fix linspace device and open other trace tests

GitOrigin-RevId: 4667c4adec0867f09cd4c030c457e0174fa6f908
上级 334eda87
......@@ -910,7 +910,7 @@ def linspace(
import numpy as np
import megengine.functional as F
a = F.linspace(3,10,5)
a = F.linspace(3, 10, 5)
print(a.numpy())
Outputs:
......@@ -920,9 +920,20 @@ def linspace(
[ 3. 4.75 6.5 8.25 10. ]
"""
start = Tensor(start, device=device)
stop = Tensor(stop, device=device)
num = Tensor(num, device=device)
for item in (start, stop, num):
cur_device = getattr(item, "device", None)
if device is None:
device = cur_device
else:
if not (cur_device is None or device == cur_device):
raise ("ambiguous device for linspace opr")
if not isinstance(start, Tensor):
start = Tensor(start, device=device)
if not isinstance(stop, Tensor):
stop = Tensor(stop, device=device)
if not isinstance(num, Tensor):
num = Tensor(num, device=device)
op = builtin.Linspace(comp_node=device)
(result,) = apply(op, start, stop, num)
......
......@@ -114,18 +114,12 @@ def test_matmul():
{"input": [data3, data4]},
{"input": [data4, data5]},
]
for _ in range(0, batch_size):
# FIXME: remove test_trace=False in the future
opr_test(
cases, F.matmul, test_trace=False, ref_fn=np.matmul,
)
opr_test(cases, F.matmul, ref_fn=np.matmul)
# FIXME: remove test_trace=False in the future
opr_test(
[{"input": [data1, data4]}],
F.matmul,
ref_fn=lambda x, y: np.matmul(x, y.transpose(0, 1, 3, 2)),
test_trace=False,
transpose_b=True,
)
......
......@@ -162,24 +162,30 @@ def test_linspace():
{"input": [1, 9, 9]},
{"input": [3, 10, 8]},
]
# FIXME: remove test_trace=False in the future
opr_test(
cases,
F.linspace,
ref_fn=lambda start, end, step: np.linspace(start, end, step, dtype=np.float32),
test_trace=False,
)
cases = [
{"input": [9, 1, 9]},
{"input": [10, 3, 8]},
]
# FIXME: remove test_trace=False in the future
opr_test(
cases,
F.linspace,
ref_fn=lambda start, end, step: np.linspace(start, end, step, dtype=np.float32),
test_trace=False,
)
cases = [
{"input": [1, tensor(9), 9]},
{"input": [tensor(1), 9, tensor(9)]},
]
opr_test(
cases,
F.linspace,
ref_fn=lambda start, end, step: np.linspace(1, 9, 9, dtype=np.float32),
)
......@@ -188,36 +194,30 @@ def test_arange():
{"input": [1, 9, 1]},
{"input": [2, 10, 2]},
]
# FIXME: remove test_trace=False in the future
opr_test(
cases,
F.arange,
ref_fn=lambda start, end, step: np.arange(start, end, step, dtype=np.float32),
test_trace=False,
)
cases = [
{"input": [9, 1, -1]},
{"input": [10, 2, -2]},
]
# FIXME: remove test_trace=False in the future
opr_test(
cases,
F.arange,
ref_fn=lambda start, end, step: np.arange(start, end, step, dtype=np.float32),
test_trace=False,
)
cases = [
{"input": [9.3, 1.2, -0.5]},
{"input": [10.3, 2.1, -1.7]},
]
# FIXME: remove test_trace=False in the future
opr_test(
cases,
F.arange,
ref_fn=lambda start, end, step: np.arange(start, end, step, dtype=np.float32),
test_trace=False,
)
......@@ -289,8 +289,7 @@ def test_broadcast():
{"input": [data1, output1_shape], "output": output1_shape},
{"input": [data2, output2_shape], "output": output2_shape},
]
# FIXME: remove test_trace=False in the future
opr_test(cases, F.broadcast_to, compare_fn=compare_fn, test_trace=False)
opr_test(cases, F.broadcast_to, compare_fn=compare_fn)
x = F.ones((2, 1, 3))
with pytest.raises(RuntimeError):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册