提交 6bd66e79 编写于 作者: H Hoai Linh Tran

Fix memcpy calls. Add ut tests for arithmetic_simplify. Split long...

Fix memcpy calls. Add ut tests for arithmetic_simplify. Split long arithmetic_simplify.h to arithmetic_simplify.cc

Code checking
上级 bf699955
此差异已折叠。
......@@ -549,6 +549,122 @@ def test_zeros():
assert res == Tensor(np.zeros([2, 3]).astype(np.int32))
@ms_function
def arithmetic_simplify_01(x, y):
""" arithmetic_simplify_01 """
return C.zeros_like(x) * y
def test_arithmetic_simplify_01():
""" test_arithmetic_simplify_01 """
x = Tensor(np.ones([2, 3]).astype(np.int32))
y = Tensor(np.array([[1, 2, 3], [4, 5, 6]]).astype(np.int32))
res = arithmetic_simplify_01(x, y)
expect = np.zeros([2, 3]).astype(np.int32)
assert np.all(res.asnumpy() == expect)
@ms_function
def arithmetic_simplify_02(x, y):
""" arithmetic_simplify_02 """
return C.ones_like(x) * y
def test_arithmetic_simplify_02():
""" test_arithmetic_simplify_02 """
x = Tensor(np.ones([2, 3]).astype(np.int32))
y = Tensor(np.array([[1, 2, 3], [4, 5, 6]]).astype(np.int32))
res = arithmetic_simplify_02(x, y)
expect = np.array([[1, 2, 3], [4, 5, 6]]).astype(np.int32)
assert np.all(res.asnumpy() == expect)
@ms_function
def arithmetic_simplify_03(x, y):
""" arithmetic_simplify_03 """
return x * C.ones_like(y)
def test_arithmetic_simplify_03():
""" test_arithmetic_simplify_03 """
x = Tensor(np.ones([2, 3]).astype(np.int32))
y = Tensor(np.array([[1, 2, 3], [4, 5, 6]]).astype(np.int32))
res = arithmetic_simplify_03(x, y)
expect = np.ones([2, 3]).astype(np.int32)
assert np.all(res.asnumpy() == expect)
@ms_function
def arithmetic_simplify_04(x):
""" arithmetic_simplify_04 """
return x + 0
def test_arithmetic_simplify_04():
""" test_arithmetic_simplify_04 """
x = Tensor(np.array([[1, 2, 3], [4, 5, 6]]).astype(np.int32))
res = arithmetic_simplify_04(x)
expect = np.array([[1, 2, 3], [4, 5, 6]]).astype(np.int32)
assert np.all(res.asnumpy() == expect)
@ms_function
def arithmetic_simplify_05(x):
""" arithmetic_simplify_05 """
return x * 1
def test_arithmetic_simplify_05():
""" test_arithmetic_simplify_05 """
x = Tensor(np.array([[1, 2, 3], [4, 5, 6]]).astype(np.int32))
res = arithmetic_simplify_05(x)
expect = np.array([[1, 2, 3], [4, 5, 6]]).astype(np.int32)
assert np.all(res.asnumpy() == expect)
@ms_function
def arithmetic_simplify_06(x):
""" arithmetic_simplify_06 """
return x * 2 * 5
def test_arithmetic_simplify_06():
""" test_arithmetic_simplify_06 """
x = Tensor(np.array([[1, 2, 3], [4, 5, 6]]).astype(np.int32))
res = arithmetic_simplify_06(x)
expect = np.array([[10, 20, 30], [40, 50, 60]]).astype(np.int32)
assert np.all(res.asnumpy() == expect)
@ms_function
def arithmetic_simplify_07(x):
""" arithmetic_simplify_07 """
return (x + 1) * 2 * 5
def test_arithmetic_simplify_07():
""" test_arithmetic_simplify_07 """
x = Tensor(np.array([[1, 2, 3], [4, 5, 6]]).astype(np.int32))
res = arithmetic_simplify_07(x)
expect = np.array([[20, 30, 40], [50, 60, 70]]).astype(np.int32)
assert np.all(res.asnumpy() == expect)
@ms_function
def arithmetic_simplify_08(x, y):
""" arithmetic_simplify_08 """
return 1 * x * 1 * 1 + 1 * 0 * 1 + 0 + y * 1
def test_arithmetic_simplify_08():
""" test_arithmetic_simplify_08 """
x = Tensor(np.array([[1, 2, 3], [4, 5, 6]]).astype(np.int32))
y = Tensor(np.ones([2, 3]).astype(np.int32))
res = arithmetic_simplify_08(x, y)
expect = np.array([[2, 3, 4], [5, 6, 7]]).astype(np.int32)
assert np.all(res.asnumpy() == expect)
def test_ScalarGradChecker():
""" test_ScalarGradChecker """
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册