diff --git a/python/paddle/v2/framework/graph.py b/python/paddle/v2/framework/graph.py index 0fbb373f2dc131020d54093008458063626e74b4..c53869c888267eeadaad0be317eeef5e9fad0cab 100644 --- a/python/paddle/v2/framework/graph.py +++ b/python/paddle/v2/framework/graph.py @@ -160,7 +160,7 @@ class Operator(object): (in_proto.name, len(in_argus))) in_argu_names = [] for argu in in_argus: - in_argu_names.append(argu.name()) + in_argu_names.append(argu.name) self.desc.set_input(in_proto.name, in_argu_names) if outputs is not None: @@ -174,7 +174,7 @@ class Operator(object): (out_proto.name, len(out_argus))) out_argu_names = [] for argu in out_argus: - out_argu_names.append(argu.name()) + out_argu_names.append(argu.name) self.desc.set_output(out_proto.name, out_argu_names) if attrs is not None: diff --git a/python/paddle/v2/framework/tests/test_operator_desc.py b/python/paddle/v2/framework/tests/test_operator_desc.py index 5ee1409c8f1abdf79873f5c1bf9860cac16fee8a..62f3a05d15d2a34bd061a4d05c07f377eddc9a43 100644 --- a/python/paddle/v2/framework/tests/test_operator_desc.py +++ b/python/paddle/v2/framework/tests/test_operator_desc.py @@ -14,7 +14,7 @@ class TestOperator(unittest.TestCase): err.message, "Operator with type \"no_such_op\" has not been registered.") - def test_input_output(self): + def test_op_desc_creation(self): block = g_program.current_block() mul_x = block.create_var( dtype="float32", shape=[5, 10], lod_level=0, name="mul.x") @@ -26,12 +26,18 @@ class TestOperator(unittest.TestCase): type="mul", inputs={"X": [mul_x], "Y": mul_y}, - outputs={"Out": [mul_out]}) + outputs={"Out": [mul_out]}, + attrs={"x_num_col_dims": 1}) self.assertEqual(mul_op.type, "mul") self.assertEqual(mul_op.input_names, ["X", "Y"]) - self.assertEqual(mul_op.input("X"), ["x"]) + self.assertEqual(mul_op.input("X"), ["mul.x"]) + self.assertEqual(mul_op.input("Y"), ["mul.y"]) self.assertEqual(mul_op.output_names, ["Out"]) - self.assertEqual(mul_op.output("Out"), ["out"]) + self.assertEqual(mul_op.output("Out"), ["mul.out"]) + self.assertEqual(mul_op.attr_names, ["x_num_col_dims"]) + self.assertEqual(mul_op.has_attr("x_num_col_dims"), True) + self.assertEqual(mul_op.attr_type("x_num_col_dims"), core.AttrType.INT) + self.assertEqual(mul_op.attr("x_num_col_dims"), 1) def test_mult_input(self): block = g_program.current_block() @@ -49,9 +55,9 @@ class TestOperator(unittest.TestCase): outputs={"Out": sum_out}) self.assertEqual(sum_op.type, "sum") self.assertEqual(sum_op.input_names, ["X"]) - self.assertEqual(sum_op.input("X"), ["x1", "x2", "x3"]) + self.assertEqual(sum_op.input("X"), ["sum.x1", "sum.x2", "sum.x3"]) self.assertEqual(sum_op.output_names, ["Out"]) - self.assertEqual(sum_op.output("Out"), ["out"]) + self.assertEqual(sum_op.output("Out"), ["sum.out"]) if __name__ == '__main__':