提交 906f5e8a 编写于 作者: F fengjiayi

Fix unittest bugs

上级 e71b836f
......@@ -160,7 +160,7 @@ class Operator(object):
(in_proto.name, len(in_argus)))
in_argu_names = []
for argu in in_argus:
in_argu_names.append(argu.name())
in_argu_names.append(argu.name)
self.desc.set_input(in_proto.name, in_argu_names)
if outputs is not None:
......@@ -174,7 +174,7 @@ class Operator(object):
(out_proto.name, len(out_argus)))
out_argu_names = []
for argu in out_argus:
out_argu_names.append(argu.name())
out_argu_names.append(argu.name)
self.desc.set_output(out_proto.name, out_argu_names)
if attrs is not None:
......
......@@ -14,7 +14,7 @@ class TestOperator(unittest.TestCase):
err.message,
"Operator with type \"no_such_op\" has not been registered.")
def test_input_output(self):
def test_op_desc_creation(self):
block = g_program.current_block()
mul_x = block.create_var(
dtype="float32", shape=[5, 10], lod_level=0, name="mul.x")
......@@ -26,12 +26,18 @@ class TestOperator(unittest.TestCase):
type="mul",
inputs={"X": [mul_x],
"Y": mul_y},
outputs={"Out": [mul_out]})
outputs={"Out": [mul_out]},
attrs={"x_num_col_dims": 1})
self.assertEqual(mul_op.type, "mul")
self.assertEqual(mul_op.input_names, ["X", "Y"])
self.assertEqual(mul_op.input("X"), ["x"])
self.assertEqual(mul_op.input("X"), ["mul.x"])
self.assertEqual(mul_op.input("Y"), ["mul.y"])
self.assertEqual(mul_op.output_names, ["Out"])
self.assertEqual(mul_op.output("Out"), ["out"])
self.assertEqual(mul_op.output("Out"), ["mul.out"])
self.assertEqual(mul_op.attr_names, ["x_num_col_dims"])
self.assertEqual(mul_op.has_attr("x_num_col_dims"), True)
self.assertEqual(mul_op.attr_type("x_num_col_dims"), core.AttrType.INT)
self.assertEqual(mul_op.attr("x_num_col_dims"), 1)
def test_mult_input(self):
block = g_program.current_block()
......@@ -49,9 +55,9 @@ class TestOperator(unittest.TestCase):
outputs={"Out": sum_out})
self.assertEqual(sum_op.type, "sum")
self.assertEqual(sum_op.input_names, ["X"])
self.assertEqual(sum_op.input("X"), ["x1", "x2", "x3"])
self.assertEqual(sum_op.input("X"), ["sum.x1", "sum.x2", "sum.x3"])
self.assertEqual(sum_op.output_names, ["Out"])
self.assertEqual(sum_op.output("Out"), ["out"])
self.assertEqual(sum_op.output("Out"), ["sum.out"])
if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册