From 479a443f68c4295ab7f04a925a925f75c94dc94a Mon Sep 17 00:00:00 2001 From: tangwei12 Date: Thu, 16 Aug 2018 20:27:58 +0800 Subject: [PATCH] name optimized --- .../fluid/tests/unittests/test_sum_op.py | 23 +++++++++++-------- 1 file changed, 14 insertions(+), 9 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_sum_op.py b/python/paddle/fluid/tests/unittests/test_sum_op.py index 3c4260791..81efd01d5 100644 --- a/python/paddle/fluid/tests/unittests/test_sum_op.py +++ b/python/paddle/fluid/tests/unittests/test_sum_op.py @@ -50,11 +50,16 @@ class TestSelectedRowsSumOp(OpTest): self.check_input_and_optput(scope, place, False, False, True) self.check_input_and_optput(scope, place, False, False, False) - def check_input_and_optput(self, scope, place, w1=False, w2=False, - w3=False): - W1 = self.create_selected_rows(scope, place, "W1", w1) - W2 = self.create_selected_rows(scope, place, "W2", w2) - W3 = self.create_selected_rows(scope, place, "W3", w3) + def check_input_and_optput(self, + scope, + place, + w1_has_data=False, + w2_has_data=False, + w3_has_data=False): + + self.create_selected_rows(scope, place, "W1", w1_has_data) + self.create_selected_rows(scope, place, "W2", w2_has_data) + self.create_selected_rows(scope, place, "W3", w3_has_data) # create Out Variable out = scope.var('Out').get_selected_rows() @@ -63,12 +68,12 @@ class TestSelectedRowsSumOp(OpTest): sum_op = Operator("sum", X=["W1", "W2", "W3"], Out='Out') sum_op.run(scope, place) - trues = 0 - for w in [w1, w2, w3]: + has_data_w_num = 0 + for w in [w1_has_data, w2_has_data, w3_has_data]: if not w: - trues += 1 + has_data_w_num += 1 - self.assertEqual(7 * trues, len(out.rows())) + self.assertEqual(7 * has_data_w_num, len(out.rows())) def create_selected_rows(self, scope, place, var_name, isEmpty): # create and initialize W Variable -- GitLab