提交 64406706 编写于 作者: Q Qiao Longfei

update test_split_selected_rows_op.py

上级 bd2b6d7f
...@@ -99,7 +99,6 @@ class TestSpliteSelectedRows(unittest.TestCase): ...@@ -99,7 +99,6 @@ class TestSpliteSelectedRows(unittest.TestCase):
out0_grad.set_height(height) out0_grad.set_height(height)
out0_grad_tensor = out0_grad.get_tensor() out0_grad_tensor = out0_grad.get_tensor()
np_array = np.ones((len(rows0), row_numel)).astype("float32") np_array = np.ones((len(rows0), row_numel)).astype("float32")
np_array[0, 0] = 2.0
out0_grad_tensor.set(np_array, place) out0_grad_tensor.set(np_array, place)
out1_grad = scope.var("out1@GRAD").get_selected_rows() out1_grad = scope.var("out1@GRAD").get_selected_rows()
...@@ -108,7 +107,6 @@ class TestSpliteSelectedRows(unittest.TestCase): ...@@ -108,7 +107,6 @@ class TestSpliteSelectedRows(unittest.TestCase):
out1_grad.set_height(height) out1_grad.set_height(height)
out1_grad_tensor = out1_grad.get_tensor() out1_grad_tensor = out1_grad.get_tensor()
np_array = np.ones((len(rows1), row_numel)).astype("float32") np_array = np.ones((len(rows1), row_numel)).astype("float32")
np_array[0, 1] = 4.0
out1_grad_tensor.set(np_array, place) out1_grad_tensor.set(np_array, place)
x_grad = scope.var("X@GRAD").get_selected_rows() x_grad = scope.var("X@GRAD").get_selected_rows()
...@@ -121,11 +119,13 @@ class TestSpliteSelectedRows(unittest.TestCase): ...@@ -121,11 +119,13 @@ class TestSpliteSelectedRows(unittest.TestCase):
grad_op.run(scope, place) grad_op.run(scope, place)
self.assertEqual(x_grad.rows(), rows0 + rows1) merged_rows = set(rows0 + rows1)
self.assertEqual(set(x_grad.rows()), set(rows0 + rows1))
self.assertEqual(x_grad.height(), height) self.assertEqual(x_grad.height(), height)
print(np.array(x_grad.get_tensor()))
self.assertAlmostEqual(2.0, np.array(x_grad.get_tensor())[0, 0]) self.assertAlmostEqual(2.0, np.array(x_grad.get_tensor())[0, 0])
self.assertAlmostEqual(4.0, np.array(x_grad.get_tensor())[2, 1]) self.assertAlmostEqual(1.0, np.array(x_grad.get_tensor())[2, 1])
if __name__ == "__main__": if __name__ == "__main__":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册