提交 b4ebb3c8 编写于 作者: Y Yu Yang

Change attr_type_pb2 to attribute_pb2

Make ci pass
上级 b8ff8275
import paddle.v2.framework.core as core
import paddle.v2.framework.proto.op_proto_pb2 as op_proto_pb2
import paddle.v2.framework.proto.op_desc_pb2 as op_desc_pb2
import paddle.v2.framework.proto.attr_type_pb2 as attr_type_pb2
import paddle.v2.framework.proto.attribute_pb2 as attribute_pb2
import cStringIO
......@@ -57,7 +57,7 @@ class OpDescCreationMethod(object):
op_desc.attrs.extend([out_format])
if len(tmp_index) != 0:
tmp_index_attr = op_desc.attrs.add()
tmp_index_attr.type = attr_type_pb2.INTS
tmp_index_attr.type = attribute_pb2.INTS
tmp_index_attr.name = "temporary_index"
tmp_index_attr.ints.extend(tmp_index)
......@@ -73,17 +73,17 @@ class OpDescCreationMethod(object):
new_attr = op_desc.attrs.add()
new_attr.name = attr.name
new_attr.type = attr.type
if attr.type == attr_type_pb2.INT:
if attr.type == attribute_pb2.INT:
new_attr.i = user_defined_attr
elif attr.type == attr_type_pb2.FLOAT:
elif attr.type == attribute_pb2.FLOAT:
new_attr.f = user_defined_attr
elif attr.type == attr_type_pb2.STRING:
elif attr.type == attribute_pb2.STRING:
new_attr.s = user_defined_attr
elif attr.type == attr_type_pb2.INTS:
elif attr.type == attribute_pb2.INTS:
new_attr.ints.extend(user_defined_attr)
elif attr.type == attr_type_pb2.FLOATS:
elif attr.type == attribute_pb2.FLOATS:
new_attr.floats.extend(user_defined_attr)
elif attr.type == attr_type_pb2.STRINGS:
elif attr.type == attribute_pb2.STRINGS:
new_attr.strings.extend(user_defined_attr)
else:
raise NotImplementedError("Not support attribute type " +
......@@ -109,7 +109,7 @@ class OpDescCreationMethod(object):
retv = []
if multiple:
var_format = op_desc_pb2.AttrDesc()
var_format.type = attr_type_pb2.INTS
var_format.type = attribute_pb2.INTS
var_format.name = "%s_format" % in_out
var_format.ints.append(0)
......@@ -185,17 +185,17 @@ def get_docstring_from_op_proto(op_proto):
for attr in op_proto.attrs:
attr_type = None
if attr.type == attr_type_pb2.INT:
if attr.type == attribute_pb2.INT:
attr_type = "int"
elif attr.type == attr_type_pb2.FLOAT:
elif attr.type == attribute_pb2.FLOAT:
attr_type = "float"
elif attr.type == attr_type_pb2.STRING:
elif attr.type == attribute_pb2.STRING:
attr_type = "basestr"
elif attr.type == attr_type_pb2.INTS:
elif attr.type == attribute_pb2.INTS:
attr_type = "list of int"
elif attr.type == attr_type_pb2.FLOATS:
elif attr.type == attribute_pb2.FLOATS:
attr_type = "list of float"
elif attr.type == attr_type_pb2.STRINGS:
elif attr.type == attribute_pb2.STRINGS:
attr_type = "list of basestr"
if attr_type is None:
......
......@@ -3,7 +3,7 @@ import paddle.v2.framework.create_op_creation_methods as creation
import paddle.v2.framework.core as core
import paddle.v2.framework.proto.op_proto_pb2 as op_proto_pb2
import paddle.v2.framework.proto.op_desc_pb2 as op_desc_pb2
import paddle.v2.framework.proto.attr_type_pb2 as attr_type_pb2
import paddle.v2.framework.proto.attribute_pb2 as attribute_pb2
class TestGetAllProtos(unittest.TestCase):
......@@ -76,7 +76,7 @@ class TestOpDescCreationMethod(unittest.TestCase):
expected1.type = 'fc'
attr = expected1.attrs.add()
attr.name = 'input_format'
attr.type = attr_type_pb2.INTS
attr.type = attribute_pb2.INTS
attr.ints.extend([0, 1, 2, 3])
self.assertEqual(expected1, generated1)
......@@ -88,7 +88,7 @@ class TestOpDescCreationMethod(unittest.TestCase):
expected2.type = 'fc'
attr = expected2.attrs.add()
attr.name = 'input_format'
attr.type = attr_type_pb2.INTS
attr.type = attribute_pb2.INTS
attr.ints.extend([0, 3, 6, 7])
self.assertEqual(expected2, generated2)
......@@ -105,12 +105,12 @@ class TestOpDescCreationMethod(unittest.TestCase):
attr.comment = ""
attr.type = type
__add_attr__("int_attr", attr_type_pb2.INT)
__add_attr__("float_attr", attr_type_pb2.FLOAT)
__add_attr__("string_attr", attr_type_pb2.STRING)
__add_attr__("ints_attr", attr_type_pb2.INTS)
__add_attr__("floats_attr", attr_type_pb2.FLOATS)
__add_attr__("strings_attr", attr_type_pb2.STRINGS)
__add_attr__("int_attr", attribute_pb2.INT)
__add_attr__("float_attr", attribute_pb2.FLOAT)
__add_attr__("string_attr", attribute_pb2.STRING)
__add_attr__("ints_attr", attribute_pb2.INTS)
__add_attr__("floats_attr", attribute_pb2.FLOATS)
__add_attr__("strings_attr", attribute_pb2.STRINGS)
op.comment = ""
self.assertTrue(op.IsInitialized())
......@@ -131,32 +131,32 @@ class TestOpDescCreationMethod(unittest.TestCase):
expected.inputs.extend(['a'])
attr = expected.attrs.add()
attr.name = "int_attr"
attr.type = attr_type_pb2.INT
attr.type = attribute_pb2.INT
attr.i = 10
attr = expected.attrs.add()
attr.name = "float_attr"
attr.type = attr_type_pb2.FLOAT
attr.type = attribute_pb2.FLOAT
attr.f = 3.2
attr = expected.attrs.add()
attr.name = "string_attr"
attr.type = attr_type_pb2.STRING
attr.type = attribute_pb2.STRING
attr.s = "test_str"
attr = expected.attrs.add()
attr.name = "ints_attr"
attr.type = attr_type_pb2.INTS
attr.type = attribute_pb2.INTS
attr.ints.extend([0, 1, 2, 3, 4])
attr = expected.attrs.add()
attr.name = 'floats_attr'
attr.type = attr_type_pb2.FLOATS
attr.type = attribute_pb2.FLOATS
attr.floats.extend([0.2, 3.2, 4.5])
attr = expected.attrs.add()
attr.name = 'strings_attr'
attr.type = attr_type_pb2.STRINGS
attr.type = attribute_pb2.STRINGS
attr.strings.extend(['a', 'b', 'c'])
self.assertEqual(expected, generated)
......@@ -185,7 +185,7 @@ class TestOpDescCreationMethod(unittest.TestCase):
desc.type = "test"
attr = desc.attrs.add()
attr.name = "temporary_index"
attr.type = attr_type_pb2.INTS
attr.type = attribute_pb2.INTS
attr.ints.append(2)
self.assertEqual(generated, desc)
......@@ -219,7 +219,7 @@ This op is used for unit test, not a real op.
test_str = op.attrs.add()
test_str.name = "str_attr"
test_str.type = attr_type_pb2.STRING
test_str.type = attribute_pb2.STRING
test_str.comment = "A string attribute for test op"
actual = creation.get_docstring_from_op_proto(op)
......
import paddle.v2.framework.proto.op_proto_pb2
import paddle.v2.framework.proto.attr_type_pb2
import paddle.v2.framework.proto.op_proto_pb2 as op_proto_lib
import paddle.v2.framework.proto.attribute_pb2 as attr_type_lib
import unittest
class TestFrameworkProto(unittest.TestCase):
def test_all(self):
op_proto_lib = paddle.v2.framework.proto.op_proto_pb2
attr_type_lib = paddle.v2.framework.proto.attr_type_pb2
op_proto = op_proto_lib.OpProto()
ipt0 = op_proto.inputs.add()
ipt0.name = "a"
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册