diff --git a/python/paddle/fluid/tests/unittests/test_sum_op.py b/python/paddle/fluid/tests/unittests/test_sum_op.py index 3c42607918209af7bba9c2ba94aab5d521944588..81efd01d5073184477dbeb9b9ed6b3cff741bef4 100644 --- a/python/paddle/fluid/tests/unittests/test_sum_op.py +++ b/python/paddle/fluid/tests/unittests/test_sum_op.py @@ -50,11 +50,16 @@ class TestSelectedRowsSumOp(OpTest): self.check_input_and_optput(scope, place, False, False, True) self.check_input_and_optput(scope, place, False, False, False) - def check_input_and_optput(self, scope, place, w1=False, w2=False, - w3=False): - W1 = self.create_selected_rows(scope, place, "W1", w1) - W2 = self.create_selected_rows(scope, place, "W2", w2) - W3 = self.create_selected_rows(scope, place, "W3", w3) + def check_input_and_optput(self, + scope, + place, + w1_has_data=False, + w2_has_data=False, + w3_has_data=False): + + self.create_selected_rows(scope, place, "W1", w1_has_data) + self.create_selected_rows(scope, place, "W2", w2_has_data) + self.create_selected_rows(scope, place, "W3", w3_has_data) # create Out Variable out = scope.var('Out').get_selected_rows() @@ -63,12 +68,12 @@ class TestSelectedRowsSumOp(OpTest): sum_op = Operator("sum", X=["W1", "W2", "W3"], Out='Out') sum_op.run(scope, place) - trues = 0 - for w in [w1, w2, w3]: + has_data_w_num = 0 + for w in [w1_has_data, w2_has_data, w3_has_data]: if not w: - trues += 1 + has_data_w_num += 1 - self.assertEqual(7 * trues, len(out.rows())) + self.assertEqual(7 * has_data_w_num, len(out.rows())) def create_selected_rows(self, scope, place, var_name, isEmpty): # create and initialize W Variable