提交 27346b0b 编写于 作者: M Megvii Engine Team

test(opr): add scalar check for opr_test

GitOrigin-RevId: dcfd7ad5d6b8a85027df796051a0521c7be48575
上级 22504523
......@@ -11,12 +11,12 @@ from megengine.utils.network_node import VarNode
def _default_compare_fn(x, y):
if isinstance(x, np.ndarray):
np.testing.assert_allclose(x, y, rtol=1e-6)
elif isinstance(x, tensor):
np.testing.assert_allclose(x.numpy(), y, rtol=1e-6)
else:
np.testing.assert_allclose(get_var_value(x), y, rtol=1e-6)
if isinstance(x, tensor):
x = x.numpy()
elif not isinstance(x, np.ndarray):
x = get_var_value(x)
assert isinstance(x, np.ndarray)
np.testing.assert_allclose(x, y, rtol=1e-6)
def make_tensor(x, network=None, device=None):
......@@ -69,12 +69,16 @@ def opr_test(
"""
def check_results(results, expected):
def check_results(results, expected, check_shape=True):
if not isinstance(results, (tuple, list)):
results = (results,)
for r, e in zip(results, expected):
if not isinstance(r, (tensor, VarNode)):
r = tensor(r)
if check_shape:
r_shape = r.numpy().shape
e_shape = e.shape if isinstance(e, np.ndarray) else ()
assert r_shape == e_shape
compare_fn(r, e)
def get_param(cases, idx):
......@@ -127,10 +131,10 @@ def opr_test(
# assume #outputs == 1
loaded_results = list(infer_cg.run(inp_dict=inp_dict).values())[0]
check_results(loaded_results, outp)
check_results(loaded_results, outp, check_shape=False) # scalar info lost
results = func(*inp_tensor, **kwargs)
check_results(results, outp)
check_results(results, outp, check_shape=(network is None))
if len(cases) == 0:
raise ValueError("should give one case at least")
......
......@@ -39,12 +39,6 @@ def test_where():
xv1 = np.array([[1, np.inf, 2], [0, np.nan, 4], [1, 5, 7]], dtype=np.float32)
yv1 = np.array([[5, 6, 9], [2, 7, 8], [2, 1, 9]], dtype=np.float32)
cases = [
{"input": [maskv0, xv0, yv0]},
{"input": [maskv1, xv1, yv1]},
]
opr_test(cases, F.where, ref_fn=np.where, test_trace=False)
maskv2 = np.array([1, 1, 1], dtype=np.bool_)
xv2 = np.array([1, 3, 2], dtype=np.float32)
yv2 = np.array([5, 6, 9], dtype=np.float32)
......@@ -53,11 +47,18 @@ def test_where():
xv3 = np.array([1, 3, 2], dtype=np.float32)
yv3 = np.array([5, 6, 9], dtype=np.float32)
maskv4 = np.array(1, dtype=np.bool_)
xv4 = np.array(1, dtype=np.float32)
yv4 = np.array(0, dtype=np.float32)
cases = [
{"input": [maskv0, xv0, yv0]},
{"input": [maskv1, xv1, yv1]},
{"input": [maskv2, xv2, yv2]},
{"input": [maskv3, xv3, yv3]},
{"input": [maskv4, xv4, yv4]},
]
opr_test(cases, F.where, ref_fn=np.where, test_trace=False)
opr_test(cases, F.where, ref_fn=np.where, test_trace=True)
def test_dropout():
......@@ -618,12 +619,12 @@ def test_binary_cross_entropy():
np.random.seed(123)
data1 = np.random.uniform(size=data1_shape).astype(np.float32)
label1 = np.random.uniform(size=label1_shape).astype(np.float32)
expect1 = np.array([0.6361], dtype=np.float32)
expect1 = np.array(0.6361, dtype=np.float32)
np.random.seed(123)
data2 = np.random.uniform(size=data2_shape).astype(np.float32)
label2 = np.random.uniform(size=label2_shape).astype(np.float32)
expect2 = np.array([0.6750], dtype=np.float32)
expect2 = np.array(0.6750, dtype=np.float32)
cases = [
{"input": [data1, label1], "output": expect1,},
......
......@@ -335,18 +335,18 @@ def test_reshape_shape_inference(is_varnode):
source = output.shape
if isinstance(source, tensor):
source = source.numpy()
np.testing.assert_equal(source, target)
np.testing.assert_equal(source, target.shape)
def func(x, target_shape):
return x.reshape(target_shape)
cases = [
{"input": [x_shape_known, tshp_unknown], "output": [(2, 2),]},
{"input": [x_shape_unknown, tshp_unknown], "output": [(2, 2),]},
{"input": [x_shape_known, tshp_known], "output": [(2, 2),]},
{"input": [x_shape_known, tshp_known_unspec], "output": [(2, 2),]},
{"input": [x_shape_unknown, tshp_known], "output": [(2, 2),]},
{"input": [x_shape_unknown, tshp_known_unspec], "output": [(2, 2),]},
{"input": [x_shape_known, tshp_unknown], "output": [np.zeros((2, 2)),]},
{"input": [x_shape_unknown, tshp_unknown], "output": [np.zeros((2, 2)),]},
{"input": [x_shape_known, tshp_known], "output": [np.zeros((2, 2)),]},
{"input": [x_shape_known, tshp_known_unspec], "output": [np.zeros((2, 2)),]},
{"input": [x_shape_unknown, tshp_known], "output": [np.zeros((2, 2)),]},
{"input": [x_shape_unknown, tshp_known_unspec], "output": [np.zeros((2, 2)),]},
]
opr_test(cases, func, compare_fn=check_shape, test_trace=True, network=network)
if is_varnode:
......@@ -533,46 +533,30 @@ def test_flatten(is_varnode):
data0 = np.random.random(data0_shape).astype(np.float32)
data1 = np.random.random(data1_shape).astype(np.float32)
def compare_fn(x, y):
assert x._tuple_shape[0] == y
output0 = (2 * 3 * 4 * 5,)
output1 = (4 * 5 * 6 * 7,)
cases = [
{"input": data0, "output": output0},
{"input": data1, "output": output1},
{"input": data0, "output": data0.flatten()},
{"input": data1, "output": data1.flatten()},
]
opr_test(cases, F.flatten, compare_fn=compare_fn, network=network)
opr_test(cases, F.flatten, network=network)
output0 = (2, 3 * 4 * 5)
output1 = (4, 5 * 6 * 7)
cases = [
{"input": data0, "output": output0},
{"input": data1, "output": output1},
{"input": data0, "output": data0.reshape(2, -1)},
{"input": data1, "output": data1.reshape(4, -1)},
]
opr_test(cases, F.flatten, compare_fn=compare_fn, start_axis=1, network=network)
opr_test(cases, F.flatten, start_axis=1, network=network)
output0 = (2, 3, 4 * 5)
output1 = (4, 5, 6 * 7)
cases = [
{"input": data0, "output": output0},
{"input": data1, "output": output1},
{"input": data0, "output": data0.reshape(2, 3, -1)},
{"input": data1, "output": data1.reshape(4, 5, -1)},
]
opr_test(cases, F.flatten, compare_fn=compare_fn, start_axis=2, network=network)
opr_test(cases, F.flatten, start_axis=2, network=network)
output0 = (2, 3 * 4, 5)
output1 = (4, 5 * 6, 7)
cases = [
{"input": data0, "output": output0},
{"input": data1, "output": output1},
{"input": data0, "output": data0.reshape(2, -1, 5)},
{"input": data1, "output": data1.reshape(4, -1, 7)},
]
opr_test(
cases,
F.flatten,
compare_fn=compare_fn,
start_axis=1,
end_axis=2,
network=network,
cases, F.flatten, start_axis=1, end_axis=2, network=network,
)
......@@ -595,15 +579,22 @@ def test_broadcast(is_varnode):
output3_shape = (10, 10)
data3 = np.random.random(input3_shape).astype(np.float32)
def compare_fn(x, y):
assert x._tuple_shape[0] == y
cases = [
{"input": [data1, output1_shape], "output": output1_shape},
{"input": [data2, output2_shape], "output": output2_shape},
{"input": [data3, output3_shape], "output": output3_shape},
{
"input": [data1, output1_shape],
"output": np.broadcast_to(data1, output1_shape),
},
{
"input": [data2, output2_shape],
"output": np.broadcast_to(data2, output2_shape),
},
{
"input": [data3, output3_shape],
"output": np.broadcast_to(data3, output3_shape),
},
]
opr_test(cases, F.broadcast_to, compare_fn=compare_fn, network=network)
opr_test(cases, F.broadcast_to, network=network)
x = F.ones((2, 1, 3))
with pytest.raises(RuntimeError):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册