From 906f5e8a269144150e6132da0ac100f49df22980 Mon Sep 17 00:00:00 2001 From: fengjiayi Date: Tue, 10 Oct 2017 16:49:48 -0700 Subject: [PATCH] Fix unittest bugs --- python/paddle/v2/framework/graph.py | 4 ++-- .../v2/framework/tests/test_operator_desc.py | 18 ++++++++++++------ 2 files changed, 14 insertions(+), 8 deletions(-) diff --git a/python/paddle/v2/framework/graph.py b/python/paddle/v2/framework/graph.py index 0fbb373f2d..c53869c888 100644 --- a/python/paddle/v2/framework/graph.py +++ b/python/paddle/v2/framework/graph.py @@ -160,7 +160,7 @@ class Operator(object): (in_proto.name, len(in_argus))) in_argu_names = [] for argu in in_argus: - in_argu_names.append(argu.name()) + in_argu_names.append(argu.name) self.desc.set_input(in_proto.name, in_argu_names) if outputs is not None: @@ -174,7 +174,7 @@ class Operator(object): (out_proto.name, len(out_argus))) out_argu_names = [] for argu in out_argus: - out_argu_names.append(argu.name()) + out_argu_names.append(argu.name) self.desc.set_output(out_proto.name, out_argu_names) if attrs is not None: diff --git a/python/paddle/v2/framework/tests/test_operator_desc.py b/python/paddle/v2/framework/tests/test_operator_desc.py index 5ee1409c8f..62f3a05d15 100644 --- a/python/paddle/v2/framework/tests/test_operator_desc.py +++ b/python/paddle/v2/framework/tests/test_operator_desc.py @@ -14,7 +14,7 @@ class TestOperator(unittest.TestCase): err.message, "Operator with type \"no_such_op\" has not been registered.") - def test_input_output(self): + def test_op_desc_creation(self): block = g_program.current_block() mul_x = block.create_var( dtype="float32", shape=[5, 10], lod_level=0, name="mul.x") @@ -26,12 +26,18 @@ class TestOperator(unittest.TestCase): type="mul", inputs={"X": [mul_x], "Y": mul_y}, - outputs={"Out": [mul_out]}) + outputs={"Out": [mul_out]}, + attrs={"x_num_col_dims": 1}) self.assertEqual(mul_op.type, "mul") self.assertEqual(mul_op.input_names, ["X", "Y"]) - self.assertEqual(mul_op.input("X"), ["x"]) + self.assertEqual(mul_op.input("X"), ["mul.x"]) + self.assertEqual(mul_op.input("Y"), ["mul.y"]) self.assertEqual(mul_op.output_names, ["Out"]) - self.assertEqual(mul_op.output("Out"), ["out"]) + self.assertEqual(mul_op.output("Out"), ["mul.out"]) + self.assertEqual(mul_op.attr_names, ["x_num_col_dims"]) + self.assertEqual(mul_op.has_attr("x_num_col_dims"), True) + self.assertEqual(mul_op.attr_type("x_num_col_dims"), core.AttrType.INT) + self.assertEqual(mul_op.attr("x_num_col_dims"), 1) def test_mult_input(self): block = g_program.current_block() @@ -49,9 +55,9 @@ class TestOperator(unittest.TestCase): outputs={"Out": sum_out}) self.assertEqual(sum_op.type, "sum") self.assertEqual(sum_op.input_names, ["X"]) - self.assertEqual(sum_op.input("X"), ["x1", "x2", "x3"]) + self.assertEqual(sum_op.input("X"), ["sum.x1", "sum.x2", "sum.x3"]) self.assertEqual(sum_op.output_names, ["Out"]) - self.assertEqual(sum_op.output("Out"), ["out"]) + self.assertEqual(sum_op.output("Out"), ["sum.out"]) if __name__ == '__main__': -- GitLab