未验证 提交 595e9c5a 编写于 作者: W wawltor 提交者: GitHub

Fix the test case bug for matmul, test=develop (#23674)

Fix the bug of matmul test_case
上级 a1a95f81
...@@ -246,26 +246,26 @@ for dim in [4]: ...@@ -246,26 +246,26 @@ for dim in [4]:
class API_TestMm(unittest.TestCase): class API_TestMm(unittest.TestCase):
def test_out(self): def test_out(self):
with fluid.program_guard(fluid.Program()): with fluid.program_guard(fluid.Program()):
x = fluid.data(name="x", shape=[3, 2], dtype="float32") x = fluid.data(name="x", shape=[3, 2], dtype="float64")
y = fluid.data(name='y', shape=[2, 3], dtype='float32') y = fluid.data(name='y', shape=[2, 3], dtype='float64')
res = fluid.data(name="output", shape=[3, 3], dtype="float32") res = fluid.data(name="output", shape=[3, 3], dtype="float64")
y_1 = paddle.mm(x, y, out=res) y_1 = paddle.mm(x, y, out=res)
exe = fluid.Executor(fluid.CPUPlace()) exe = fluid.Executor(fluid.CPUPlace())
data1 = np.random.rand(3, 2).astype('float32') data1 = np.random.rand(3, 2)
data2 = np.random.rand(2, 3).astype('float32') data2 = np.random.rand(2, 3)
np_res, np_y_1 = exe.run(feed={'x': data1, np_res, np_y_1 = exe.run(feed={'x': data1,
'y': data2}, 'y': data2},
fetch_list=[res, y_1]) fetch_list=[res, y_1])
self.assertEqual((np_res == np_y_1).all(), True) self.assertEqual((np_res == np_y_1).all(), True)
with fluid.program_guard(fluid.Program()): with fluid.program_guard(fluid.Program()):
x = fluid.data(name="x", shape=[2], dtype="float32") x = fluid.data(name="x", shape=[2], dtype="float64")
y = fluid.data(name='y', shape=[2], dtype='float32') y = fluid.data(name='y', shape=[2], dtype='float64')
res = fluid.data(name="output", shape=[1], dtype="float32") res = fluid.data(name="output", shape=[1], dtype="float64")
result = paddle.mm(x, y) result = paddle.mm(x, y)
exe = fluid.Executor(fluid.CPUPlace()) exe = fluid.Executor(fluid.CPUPlace())
data1 = np.random.rand(2).astype('float32') data1 = np.random.rand(2)
data2 = np.random.rand(2).astype('float32') data2 = np.random.rand(2)
np_res = exe.run(feed={'x': data1, 'y': data2}, fetch_list=[result]) np_res = exe.run(feed={'x': data1, 'y': data2}, fetch_list=[result])
expected_result = np.matmul( expected_result = np.matmul(
data1.reshape(1, 2), data2.reshape(2, 1)) data1.reshape(1, 2), data2.reshape(2, 1))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册