diff --git a/python/paddle/v2/framework/create_op_creation_methods.py b/python/paddle/v2/framework/create_op_creation_methods.py index b034efffb69030cb09e09ea545e9bff6f1744671..6fd33b366b6d376cc51ba5d663bb04d45ab8c933 100644 --- a/python/paddle/v2/framework/create_op_creation_methods.py +++ b/python/paddle/v2/framework/create_op_creation_methods.py @@ -1,7 +1,7 @@ import paddle.v2.framework.core as core import paddle.v2.framework.proto.op_proto_pb2 as op_proto_pb2 import paddle.v2.framework.proto.op_desc_pb2 as op_desc_pb2 -import paddle.v2.framework.proto.attr_type_pb2 as attr_type_pb2 +import paddle.v2.framework.proto.attribute_pb2 as attribute_pb2 import cStringIO @@ -57,7 +57,7 @@ class OpDescCreationMethod(object): op_desc.attrs.extend([out_format]) if len(tmp_index) != 0: tmp_index_attr = op_desc.attrs.add() - tmp_index_attr.type = attr_type_pb2.INTS + tmp_index_attr.type = attribute_pb2.INTS tmp_index_attr.name = "temporary_index" tmp_index_attr.ints.extend(tmp_index) @@ -73,17 +73,17 @@ class OpDescCreationMethod(object): new_attr = op_desc.attrs.add() new_attr.name = attr.name new_attr.type = attr.type - if attr.type == attr_type_pb2.INT: + if attr.type == attribute_pb2.INT: new_attr.i = user_defined_attr - elif attr.type == attr_type_pb2.FLOAT: + elif attr.type == attribute_pb2.FLOAT: new_attr.f = user_defined_attr - elif attr.type == attr_type_pb2.STRING: + elif attr.type == attribute_pb2.STRING: new_attr.s = user_defined_attr - elif attr.type == attr_type_pb2.INTS: + elif attr.type == attribute_pb2.INTS: new_attr.ints.extend(user_defined_attr) - elif attr.type == attr_type_pb2.FLOATS: + elif attr.type == attribute_pb2.FLOATS: new_attr.floats.extend(user_defined_attr) - elif attr.type == attr_type_pb2.STRINGS: + elif attr.type == attribute_pb2.STRINGS: new_attr.strings.extend(user_defined_attr) else: raise NotImplementedError("Not support attribute type " + @@ -109,7 +109,7 @@ class OpDescCreationMethod(object): retv = [] if multiple: var_format = op_desc_pb2.AttrDesc() - var_format.type = attr_type_pb2.INTS + var_format.type = attribute_pb2.INTS var_format.name = "%s_format" % in_out var_format.ints.append(0) @@ -185,17 +185,17 @@ def get_docstring_from_op_proto(op_proto): for attr in op_proto.attrs: attr_type = None - if attr.type == attr_type_pb2.INT: + if attr.type == attribute_pb2.INT: attr_type = "int" - elif attr.type == attr_type_pb2.FLOAT: + elif attr.type == attribute_pb2.FLOAT: attr_type = "float" - elif attr.type == attr_type_pb2.STRING: + elif attr.type == attribute_pb2.STRING: attr_type = "basestr" - elif attr.type == attr_type_pb2.INTS: + elif attr.type == attribute_pb2.INTS: attr_type = "list of int" - elif attr.type == attr_type_pb2.FLOATS: + elif attr.type == attribute_pb2.FLOATS: attr_type = "list of float" - elif attr.type == attr_type_pb2.STRINGS: + elif attr.type == attribute_pb2.STRINGS: attr_type = "list of basestr" if attr_type is None: diff --git a/python/paddle/v2/framework/tests/test_op_creation_methods.py b/python/paddle/v2/framework/tests/test_op_creation_methods.py index 41db7c0d535aa920b34d6cc346090a8c15bfb110..1d2ce6d0717bfb45355fe0cabc516a598492d518 100644 --- a/python/paddle/v2/framework/tests/test_op_creation_methods.py +++ b/python/paddle/v2/framework/tests/test_op_creation_methods.py @@ -3,7 +3,7 @@ import paddle.v2.framework.create_op_creation_methods as creation import paddle.v2.framework.core as core import paddle.v2.framework.proto.op_proto_pb2 as op_proto_pb2 import paddle.v2.framework.proto.op_desc_pb2 as op_desc_pb2 -import paddle.v2.framework.proto.attr_type_pb2 as attr_type_pb2 +import paddle.v2.framework.proto.attribute_pb2 as attribute_pb2 class TestGetAllProtos(unittest.TestCase): @@ -76,7 +76,7 @@ class TestOpDescCreationMethod(unittest.TestCase): expected1.type = 'fc' attr = expected1.attrs.add() attr.name = 'input_format' - attr.type = attr_type_pb2.INTS + attr.type = attribute_pb2.INTS attr.ints.extend([0, 1, 2, 3]) self.assertEqual(expected1, generated1) @@ -88,7 +88,7 @@ class TestOpDescCreationMethod(unittest.TestCase): expected2.type = 'fc' attr = expected2.attrs.add() attr.name = 'input_format' - attr.type = attr_type_pb2.INTS + attr.type = attribute_pb2.INTS attr.ints.extend([0, 3, 6, 7]) self.assertEqual(expected2, generated2) @@ -105,12 +105,12 @@ class TestOpDescCreationMethod(unittest.TestCase): attr.comment = "" attr.type = type - __add_attr__("int_attr", attr_type_pb2.INT) - __add_attr__("float_attr", attr_type_pb2.FLOAT) - __add_attr__("string_attr", attr_type_pb2.STRING) - __add_attr__("ints_attr", attr_type_pb2.INTS) - __add_attr__("floats_attr", attr_type_pb2.FLOATS) - __add_attr__("strings_attr", attr_type_pb2.STRINGS) + __add_attr__("int_attr", attribute_pb2.INT) + __add_attr__("float_attr", attribute_pb2.FLOAT) + __add_attr__("string_attr", attribute_pb2.STRING) + __add_attr__("ints_attr", attribute_pb2.INTS) + __add_attr__("floats_attr", attribute_pb2.FLOATS) + __add_attr__("strings_attr", attribute_pb2.STRINGS) op.comment = "" self.assertTrue(op.IsInitialized()) @@ -131,32 +131,32 @@ class TestOpDescCreationMethod(unittest.TestCase): expected.inputs.extend(['a']) attr = expected.attrs.add() attr.name = "int_attr" - attr.type = attr_type_pb2.INT + attr.type = attribute_pb2.INT attr.i = 10 attr = expected.attrs.add() attr.name = "float_attr" - attr.type = attr_type_pb2.FLOAT + attr.type = attribute_pb2.FLOAT attr.f = 3.2 attr = expected.attrs.add() attr.name = "string_attr" - attr.type = attr_type_pb2.STRING + attr.type = attribute_pb2.STRING attr.s = "test_str" attr = expected.attrs.add() attr.name = "ints_attr" - attr.type = attr_type_pb2.INTS + attr.type = attribute_pb2.INTS attr.ints.extend([0, 1, 2, 3, 4]) attr = expected.attrs.add() attr.name = 'floats_attr' - attr.type = attr_type_pb2.FLOATS + attr.type = attribute_pb2.FLOATS attr.floats.extend([0.2, 3.2, 4.5]) attr = expected.attrs.add() attr.name = 'strings_attr' - attr.type = attr_type_pb2.STRINGS + attr.type = attribute_pb2.STRINGS attr.strings.extend(['a', 'b', 'c']) self.assertEqual(expected, generated) @@ -185,7 +185,7 @@ class TestOpDescCreationMethod(unittest.TestCase): desc.type = "test" attr = desc.attrs.add() attr.name = "temporary_index" - attr.type = attr_type_pb2.INTS + attr.type = attribute_pb2.INTS attr.ints.append(2) self.assertEqual(generated, desc) @@ -219,7 +219,7 @@ This op is used for unit test, not a real op. test_str = op.attrs.add() test_str.name = "str_attr" - test_str.type = attr_type_pb2.STRING + test_str.type = attribute_pb2.STRING test_str.comment = "A string attribute for test op" actual = creation.get_docstring_from_op_proto(op) diff --git a/python/paddle/v2/framework/tests/test_protobuf.py b/python/paddle/v2/framework/tests/test_protobuf.py index b8702477e64203e735bff05b115eafbb2a52172d..69e98e2f250a9df23b25e7e2043af29f87c996a0 100644 --- a/python/paddle/v2/framework/tests/test_protobuf.py +++ b/python/paddle/v2/framework/tests/test_protobuf.py @@ -1,12 +1,10 @@ -import paddle.v2.framework.proto.op_proto_pb2 -import paddle.v2.framework.proto.attr_type_pb2 +import paddle.v2.framework.proto.op_proto_pb2 as op_proto_lib +import paddle.v2.framework.proto.attribute_pb2 as attr_type_lib import unittest class TestFrameworkProto(unittest.TestCase): def test_all(self): - op_proto_lib = paddle.v2.framework.proto.op_proto_pb2 - attr_type_lib = paddle.v2.framework.proto.attr_type_pb2 op_proto = op_proto_lib.OpProto() ipt0 = op_proto.inputs.add() ipt0.name = "a"