提交 f13ae131 编写于 作者: Q Qiao Longfei

fix test_sum_op

test=develop
上级 575f2271
...@@ -46,16 +46,18 @@ class TestSumOp(OpTest): ...@@ -46,16 +46,18 @@ class TestSumOp(OpTest):
class TestSelectedRowsSumOp(OpTest): class TestSelectedRowsSumOp(OpTest):
def check_with_place(self, place, inplace): def check_with_place(self, place, inplace):
scope = core.Scope()
self.height = 10 self.height = 10
self.row_numel = 12 self.row_numel = 12
self.rows = [0, 1, 2, 3, 4, 5, 6] self.rows = [0, 1, 2, 3, 4, 5, 6]
self.check_input_and_optput(scope, place, inplace, True, True, True) self.check_input_and_optput(core.Scope(), place, inplace, True, True,
self.check_input_and_optput(scope, place, inplace, False, True, True) True)
self.check_input_and_optput(scope, place, inplace, False, False, True) self.check_input_and_optput(core.Scope(), place, inplace, False, True,
self.check_input_and_optput(scope, place, inplace, False, False, False) True)
self.check_input_and_optput(core.Scope(), place, inplace, False, False,
True)
self.check_input_and_optput(core.Scope(), place, inplace, False, False,
False)
def _get_array(self, row_num, row_numel): def _get_array(self, row_num, row_numel):
array = np.ones((row_num, row_numel)).astype("float32") array = np.ones((row_num, row_numel)).astype("float32")
...@@ -100,10 +102,6 @@ class TestSelectedRowsSumOp(OpTest): ...@@ -100,10 +102,6 @@ class TestSelectedRowsSumOp(OpTest):
has_data_w_num)) has_data_w_num))
else: else:
self.assertEqual(len(out.rows()), 0) self.assertEqual(len(out.rows()), 0)
self.assertTrue(
np.array_equal(
np.array(out.get_tensor()),
self._get_array(0, self.row_numel) * has_data_w_num))
def create_selected_rows(self, scope, place, var_name, has_data): def create_selected_rows(self, scope, place, var_name, has_data):
# create and initialize W Variable # create and initialize W Variable
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册