提交 479a443f 编写于 作者: T tangwei12

name optimized

上级 26b228e4
......@@ -50,11 +50,16 @@ class TestSelectedRowsSumOp(OpTest):
self.check_input_and_optput(scope, place, False, False, True)
self.check_input_and_optput(scope, place, False, False, False)
def check_input_and_optput(self, scope, place, w1=False, w2=False,
w3=False):
W1 = self.create_selected_rows(scope, place, "W1", w1)
W2 = self.create_selected_rows(scope, place, "W2", w2)
W3 = self.create_selected_rows(scope, place, "W3", w3)
def check_input_and_optput(self,
scope,
place,
w1_has_data=False,
w2_has_data=False,
w3_has_data=False):
self.create_selected_rows(scope, place, "W1", w1_has_data)
self.create_selected_rows(scope, place, "W2", w2_has_data)
self.create_selected_rows(scope, place, "W3", w3_has_data)
# create Out Variable
out = scope.var('Out').get_selected_rows()
......@@ -63,12 +68,12 @@ class TestSelectedRowsSumOp(OpTest):
sum_op = Operator("sum", X=["W1", "W2", "W3"], Out='Out')
sum_op.run(scope, place)
trues = 0
for w in [w1, w2, w3]:
has_data_w_num = 0
for w in [w1_has_data, w2_has_data, w3_has_data]:
if not w:
trues += 1
has_data_w_num += 1
self.assertEqual(7 * trues, len(out.rows()))
self.assertEqual(7 * has_data_w_num, len(out.rows()))
def create_selected_rows(self, scope, place, var_name, isEmpty):
# create and initialize W Variable
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册