提交 b4ebb3c8 编写于 作者: Y Yu Yang

Change attr_type_pb2 to attribute_pb2

Make ci pass
上级 b8ff8275
import paddle.v2.framework.core as core import paddle.v2.framework.core as core
import paddle.v2.framework.proto.op_proto_pb2 as op_proto_pb2 import paddle.v2.framework.proto.op_proto_pb2 as op_proto_pb2
import paddle.v2.framework.proto.op_desc_pb2 as op_desc_pb2 import paddle.v2.framework.proto.op_desc_pb2 as op_desc_pb2
import paddle.v2.framework.proto.attr_type_pb2 as attr_type_pb2 import paddle.v2.framework.proto.attribute_pb2 as attribute_pb2
import cStringIO import cStringIO
...@@ -57,7 +57,7 @@ class OpDescCreationMethod(object): ...@@ -57,7 +57,7 @@ class OpDescCreationMethod(object):
op_desc.attrs.extend([out_format]) op_desc.attrs.extend([out_format])
if len(tmp_index) != 0: if len(tmp_index) != 0:
tmp_index_attr = op_desc.attrs.add() tmp_index_attr = op_desc.attrs.add()
tmp_index_attr.type = attr_type_pb2.INTS tmp_index_attr.type = attribute_pb2.INTS
tmp_index_attr.name = "temporary_index" tmp_index_attr.name = "temporary_index"
tmp_index_attr.ints.extend(tmp_index) tmp_index_attr.ints.extend(tmp_index)
...@@ -73,17 +73,17 @@ class OpDescCreationMethod(object): ...@@ -73,17 +73,17 @@ class OpDescCreationMethod(object):
new_attr = op_desc.attrs.add() new_attr = op_desc.attrs.add()
new_attr.name = attr.name new_attr.name = attr.name
new_attr.type = attr.type new_attr.type = attr.type
if attr.type == attr_type_pb2.INT: if attr.type == attribute_pb2.INT:
new_attr.i = user_defined_attr new_attr.i = user_defined_attr
elif attr.type == attr_type_pb2.FLOAT: elif attr.type == attribute_pb2.FLOAT:
new_attr.f = user_defined_attr new_attr.f = user_defined_attr
elif attr.type == attr_type_pb2.STRING: elif attr.type == attribute_pb2.STRING:
new_attr.s = user_defined_attr new_attr.s = user_defined_attr
elif attr.type == attr_type_pb2.INTS: elif attr.type == attribute_pb2.INTS:
new_attr.ints.extend(user_defined_attr) new_attr.ints.extend(user_defined_attr)
elif attr.type == attr_type_pb2.FLOATS: elif attr.type == attribute_pb2.FLOATS:
new_attr.floats.extend(user_defined_attr) new_attr.floats.extend(user_defined_attr)
elif attr.type == attr_type_pb2.STRINGS: elif attr.type == attribute_pb2.STRINGS:
new_attr.strings.extend(user_defined_attr) new_attr.strings.extend(user_defined_attr)
else: else:
raise NotImplementedError("Not support attribute type " + raise NotImplementedError("Not support attribute type " +
...@@ -109,7 +109,7 @@ class OpDescCreationMethod(object): ...@@ -109,7 +109,7 @@ class OpDescCreationMethod(object):
retv = [] retv = []
if multiple: if multiple:
var_format = op_desc_pb2.AttrDesc() var_format = op_desc_pb2.AttrDesc()
var_format.type = attr_type_pb2.INTS var_format.type = attribute_pb2.INTS
var_format.name = "%s_format" % in_out var_format.name = "%s_format" % in_out
var_format.ints.append(0) var_format.ints.append(0)
...@@ -185,17 +185,17 @@ def get_docstring_from_op_proto(op_proto): ...@@ -185,17 +185,17 @@ def get_docstring_from_op_proto(op_proto):
for attr in op_proto.attrs: for attr in op_proto.attrs:
attr_type = None attr_type = None
if attr.type == attr_type_pb2.INT: if attr.type == attribute_pb2.INT:
attr_type = "int" attr_type = "int"
elif attr.type == attr_type_pb2.FLOAT: elif attr.type == attribute_pb2.FLOAT:
attr_type = "float" attr_type = "float"
elif attr.type == attr_type_pb2.STRING: elif attr.type == attribute_pb2.STRING:
attr_type = "basestr" attr_type = "basestr"
elif attr.type == attr_type_pb2.INTS: elif attr.type == attribute_pb2.INTS:
attr_type = "list of int" attr_type = "list of int"
elif attr.type == attr_type_pb2.FLOATS: elif attr.type == attribute_pb2.FLOATS:
attr_type = "list of float" attr_type = "list of float"
elif attr.type == attr_type_pb2.STRINGS: elif attr.type == attribute_pb2.STRINGS:
attr_type = "list of basestr" attr_type = "list of basestr"
if attr_type is None: if attr_type is None:
......
...@@ -3,7 +3,7 @@ import paddle.v2.framework.create_op_creation_methods as creation ...@@ -3,7 +3,7 @@ import paddle.v2.framework.create_op_creation_methods as creation
import paddle.v2.framework.core as core import paddle.v2.framework.core as core
import paddle.v2.framework.proto.op_proto_pb2 as op_proto_pb2 import paddle.v2.framework.proto.op_proto_pb2 as op_proto_pb2
import paddle.v2.framework.proto.op_desc_pb2 as op_desc_pb2 import paddle.v2.framework.proto.op_desc_pb2 as op_desc_pb2
import paddle.v2.framework.proto.attr_type_pb2 as attr_type_pb2 import paddle.v2.framework.proto.attribute_pb2 as attribute_pb2
class TestGetAllProtos(unittest.TestCase): class TestGetAllProtos(unittest.TestCase):
...@@ -76,7 +76,7 @@ class TestOpDescCreationMethod(unittest.TestCase): ...@@ -76,7 +76,7 @@ class TestOpDescCreationMethod(unittest.TestCase):
expected1.type = 'fc' expected1.type = 'fc'
attr = expected1.attrs.add() attr = expected1.attrs.add()
attr.name = 'input_format' attr.name = 'input_format'
attr.type = attr_type_pb2.INTS attr.type = attribute_pb2.INTS
attr.ints.extend([0, 1, 2, 3]) attr.ints.extend([0, 1, 2, 3])
self.assertEqual(expected1, generated1) self.assertEqual(expected1, generated1)
...@@ -88,7 +88,7 @@ class TestOpDescCreationMethod(unittest.TestCase): ...@@ -88,7 +88,7 @@ class TestOpDescCreationMethod(unittest.TestCase):
expected2.type = 'fc' expected2.type = 'fc'
attr = expected2.attrs.add() attr = expected2.attrs.add()
attr.name = 'input_format' attr.name = 'input_format'
attr.type = attr_type_pb2.INTS attr.type = attribute_pb2.INTS
attr.ints.extend([0, 3, 6, 7]) attr.ints.extend([0, 3, 6, 7])
self.assertEqual(expected2, generated2) self.assertEqual(expected2, generated2)
...@@ -105,12 +105,12 @@ class TestOpDescCreationMethod(unittest.TestCase): ...@@ -105,12 +105,12 @@ class TestOpDescCreationMethod(unittest.TestCase):
attr.comment = "" attr.comment = ""
attr.type = type attr.type = type
__add_attr__("int_attr", attr_type_pb2.INT) __add_attr__("int_attr", attribute_pb2.INT)
__add_attr__("float_attr", attr_type_pb2.FLOAT) __add_attr__("float_attr", attribute_pb2.FLOAT)
__add_attr__("string_attr", attr_type_pb2.STRING) __add_attr__("string_attr", attribute_pb2.STRING)
__add_attr__("ints_attr", attr_type_pb2.INTS) __add_attr__("ints_attr", attribute_pb2.INTS)
__add_attr__("floats_attr", attr_type_pb2.FLOATS) __add_attr__("floats_attr", attribute_pb2.FLOATS)
__add_attr__("strings_attr", attr_type_pb2.STRINGS) __add_attr__("strings_attr", attribute_pb2.STRINGS)
op.comment = "" op.comment = ""
self.assertTrue(op.IsInitialized()) self.assertTrue(op.IsInitialized())
...@@ -131,32 +131,32 @@ class TestOpDescCreationMethod(unittest.TestCase): ...@@ -131,32 +131,32 @@ class TestOpDescCreationMethod(unittest.TestCase):
expected.inputs.extend(['a']) expected.inputs.extend(['a'])
attr = expected.attrs.add() attr = expected.attrs.add()
attr.name = "int_attr" attr.name = "int_attr"
attr.type = attr_type_pb2.INT attr.type = attribute_pb2.INT
attr.i = 10 attr.i = 10
attr = expected.attrs.add() attr = expected.attrs.add()
attr.name = "float_attr" attr.name = "float_attr"
attr.type = attr_type_pb2.FLOAT attr.type = attribute_pb2.FLOAT
attr.f = 3.2 attr.f = 3.2
attr = expected.attrs.add() attr = expected.attrs.add()
attr.name = "string_attr" attr.name = "string_attr"
attr.type = attr_type_pb2.STRING attr.type = attribute_pb2.STRING
attr.s = "test_str" attr.s = "test_str"
attr = expected.attrs.add() attr = expected.attrs.add()
attr.name = "ints_attr" attr.name = "ints_attr"
attr.type = attr_type_pb2.INTS attr.type = attribute_pb2.INTS
attr.ints.extend([0, 1, 2, 3, 4]) attr.ints.extend([0, 1, 2, 3, 4])
attr = expected.attrs.add() attr = expected.attrs.add()
attr.name = 'floats_attr' attr.name = 'floats_attr'
attr.type = attr_type_pb2.FLOATS attr.type = attribute_pb2.FLOATS
attr.floats.extend([0.2, 3.2, 4.5]) attr.floats.extend([0.2, 3.2, 4.5])
attr = expected.attrs.add() attr = expected.attrs.add()
attr.name = 'strings_attr' attr.name = 'strings_attr'
attr.type = attr_type_pb2.STRINGS attr.type = attribute_pb2.STRINGS
attr.strings.extend(['a', 'b', 'c']) attr.strings.extend(['a', 'b', 'c'])
self.assertEqual(expected, generated) self.assertEqual(expected, generated)
...@@ -185,7 +185,7 @@ class TestOpDescCreationMethod(unittest.TestCase): ...@@ -185,7 +185,7 @@ class TestOpDescCreationMethod(unittest.TestCase):
desc.type = "test" desc.type = "test"
attr = desc.attrs.add() attr = desc.attrs.add()
attr.name = "temporary_index" attr.name = "temporary_index"
attr.type = attr_type_pb2.INTS attr.type = attribute_pb2.INTS
attr.ints.append(2) attr.ints.append(2)
self.assertEqual(generated, desc) self.assertEqual(generated, desc)
...@@ -219,7 +219,7 @@ This op is used for unit test, not a real op. ...@@ -219,7 +219,7 @@ This op is used for unit test, not a real op.
test_str = op.attrs.add() test_str = op.attrs.add()
test_str.name = "str_attr" test_str.name = "str_attr"
test_str.type = attr_type_pb2.STRING test_str.type = attribute_pb2.STRING
test_str.comment = "A string attribute for test op" test_str.comment = "A string attribute for test op"
actual = creation.get_docstring_from_op_proto(op) actual = creation.get_docstring_from_op_proto(op)
......
import paddle.v2.framework.proto.op_proto_pb2 import paddle.v2.framework.proto.op_proto_pb2 as op_proto_lib
import paddle.v2.framework.proto.attr_type_pb2 import paddle.v2.framework.proto.attribute_pb2 as attr_type_lib
import unittest import unittest
class TestFrameworkProto(unittest.TestCase): class TestFrameworkProto(unittest.TestCase):
def test_all(self): def test_all(self):
op_proto_lib = paddle.v2.framework.proto.op_proto_pb2
attr_type_lib = paddle.v2.framework.proto.attr_type_pb2
op_proto = op_proto_lib.OpProto() op_proto = op_proto_lib.OpProto()
ipt0 = op_proto.inputs.add() ipt0 = op_proto.inputs.add()
ipt0.name = "a" ipt0.name = "a"
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册