提交 479a443f 编写于 作者: T tangwei12

name optimized

上级 26b228e4
...@@ -50,11 +50,16 @@ class TestSelectedRowsSumOp(OpTest): ...@@ -50,11 +50,16 @@ class TestSelectedRowsSumOp(OpTest):
self.check_input_and_optput(scope, place, False, False, True) self.check_input_and_optput(scope, place, False, False, True)
self.check_input_and_optput(scope, place, False, False, False) self.check_input_and_optput(scope, place, False, False, False)
def check_input_and_optput(self, scope, place, w1=False, w2=False, def check_input_and_optput(self,
w3=False): scope,
W1 = self.create_selected_rows(scope, place, "W1", w1) place,
W2 = self.create_selected_rows(scope, place, "W2", w2) w1_has_data=False,
W3 = self.create_selected_rows(scope, place, "W3", w3) w2_has_data=False,
w3_has_data=False):
self.create_selected_rows(scope, place, "W1", w1_has_data)
self.create_selected_rows(scope, place, "W2", w2_has_data)
self.create_selected_rows(scope, place, "W3", w3_has_data)
# create Out Variable # create Out Variable
out = scope.var('Out').get_selected_rows() out = scope.var('Out').get_selected_rows()
...@@ -63,12 +68,12 @@ class TestSelectedRowsSumOp(OpTest): ...@@ -63,12 +68,12 @@ class TestSelectedRowsSumOp(OpTest):
sum_op = Operator("sum", X=["W1", "W2", "W3"], Out='Out') sum_op = Operator("sum", X=["W1", "W2", "W3"], Out='Out')
sum_op.run(scope, place) sum_op.run(scope, place)
trues = 0 has_data_w_num = 0
for w in [w1, w2, w3]: for w in [w1_has_data, w2_has_data, w3_has_data]:
if not w: if not w:
trues += 1 has_data_w_num += 1
self.assertEqual(7 * trues, len(out.rows())) self.assertEqual(7 * has_data_w_num, len(out.rows()))
def create_selected_rows(self, scope, place, var_name, isEmpty): def create_selected_rows(self, scope, place, var_name, isEmpty):
# create and initialize W Variable # create and initialize W Variable
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册