diff --git a/TensorFlow2Paddle/README.md b/TensorFlow2Paddle/README.md new file mode 100644 index 0000000000000000000000000000000000000000..2f3901aefe7a49bba014c47d4eb35cee4f05b2ee --- /dev/null +++ b/TensorFlow2Paddle/README.md @@ -0,0 +1 @@ +Warning: TensorFlow2Paddle is not stable yet diff --git a/TensorFlow2Paddle/framework_pb2.py b/TensorFlow2Paddle/framework_pb2.py new file mode 100644 index 0000000000000000000000000000000000000000..35115d775d19ff28327c65eb78fc50369f5b35a1 --- /dev/null +++ b/TensorFlow2Paddle/framework_pb2.py @@ -0,0 +1,1165 @@ +# Generated by the protocol buffer compiler. DO NOT EDIT! +# source: framework.proto + +import sys +_b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) +from google.protobuf.internal import enum_type_wrapper +from google.protobuf import descriptor as _descriptor +from google.protobuf import message as _message +from google.protobuf import reflection as _reflection +from google.protobuf import symbol_database as _symbol_database +from google.protobuf import descriptor_pb2 +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + + + +DESCRIPTOR = _descriptor.FileDescriptor( + name='framework.proto', + package='paddle.framework.proto', + syntax='proto2', + serialized_pb=_b('\n\x0f\x66ramework.proto\x12\x16paddle.framework.proto\"\x1d\n\x07Version\x12\x12\n\x07version\x18\x01 \x01(\x03:\x01\x30\"\xec\x03\n\x06OpDesc\x12\x0c\n\x04type\x18\x03 \x02(\t\x12\x32\n\x06inputs\x18\x01 \x03(\x0b\x32\".paddle.framework.proto.OpDesc.Var\x12\x33\n\x07outputs\x18\x02 \x03(\x0b\x32\".paddle.framework.proto.OpDesc.Var\x12\x32\n\x05\x61ttrs\x18\x04 \x03(\x0b\x32#.paddle.framework.proto.OpDesc.Attr\x12\x18\n\tis_target\x18\x05 \x01(\x08:\x05\x66\x61lse\x1a\xef\x01\n\x04\x41ttr\x12\x0c\n\x04name\x18\x01 \x02(\t\x12.\n\x04type\x18\x02 \x02(\x0e\x32 .paddle.framework.proto.AttrType\x12\t\n\x01i\x18\x03 \x01(\x05\x12\t\n\x01\x66\x18\x04 \x01(\x02\x12\t\n\x01s\x18\x05 \x01(\t\x12\x0c\n\x04ints\x18\x06 \x03(\x05\x12\x0e\n\x06\x66loats\x18\x07 \x03(\x02\x12\x0f\n\x07strings\x18\x08 \x03(\t\x12\t\n\x01\x62\x18\n \x01(\x08\x12\r\n\x05\x62ools\x18\x0b \x03(\x08\x12\x11\n\tblock_idx\x18\x0c \x01(\x05\x12\t\n\x01l\x18\r \x01(\x03\x12\x12\n\nblocks_idx\x18\x0e \x03(\x05\x12\r\n\x05longs\x18\x0f \x03(\x03\x1a+\n\x03Var\x12\x11\n\tparameter\x18\x01 \x02(\t\x12\x11\n\targuments\x18\x02 \x03(\t\"\xb3\x03\n\x07OpProto\x12\x0c\n\x04type\x18\x01 \x02(\t\x12\x33\n\x06inputs\x18\x02 \x03(\x0b\x32#.paddle.framework.proto.OpProto.Var\x12\x34\n\x07outputs\x18\x03 \x03(\x0b\x32#.paddle.framework.proto.OpProto.Var\x12\x33\n\x05\x61ttrs\x18\x04 \x03(\x0b\x32$.paddle.framework.proto.OpProto.Attr\x12\x0f\n\x07\x63omment\x18\x05 \x02(\t\x1ax\n\x03Var\x12\x0c\n\x04name\x18\x01 \x02(\t\x12\x0f\n\x07\x63omment\x18\x02 \x02(\t\x12\x19\n\nduplicable\x18\x03 \x01(\x08:\x05\x66\x61lse\x12\x1b\n\x0cintermediate\x18\x04 \x01(\x08:\x05\x66\x61lse\x12\x1a\n\x0b\x64ispensable\x18\x05 \x01(\x08:\x05\x66\x61lse\x1ao\n\x04\x41ttr\x12\x0c\n\x04name\x18\x01 \x02(\t\x12.\n\x04type\x18\x02 \x02(\x0e\x32 .paddle.framework.proto.AttrType\x12\x0f\n\x07\x63omment\x18\x03 \x02(\t\x12\x18\n\tgenerated\x18\x04 \x01(\x08:\x05\x66\x61lse\"\xda\x08\n\x07VarType\x12\x32\n\x04type\x18\x01 \x02(\x0e\x32$.paddle.framework.proto.VarType.Type\x12\x41\n\rselected_rows\x18\x02 \x01(\x0b\x32*.paddle.framework.proto.VarType.TensorDesc\x12\x41\n\nlod_tensor\x18\x03 \x01(\x0b\x32-.paddle.framework.proto.VarType.LoDTensorDesc\x12H\n\x0ctensor_array\x18\x04 \x01(\x0b\x32\x32.paddle.framework.proto.VarType.LoDTensorArrayDesc\x12:\n\x06reader\x18\x05 \x01(\x0b\x32*.paddle.framework.proto.VarType.ReaderDesc\x12\x34\n\x05tuple\x18\x07 \x01(\x0b\x32%.paddle.framework.proto.VarType.Tuple\x1aS\n\nTensorDesc\x12\x37\n\tdata_type\x18\x01 \x02(\x0e\x32$.paddle.framework.proto.VarType.Type\x12\x0c\n\x04\x64ims\x18\x02 \x03(\x03\x1a\x61\n\rLoDTensorDesc\x12:\n\x06tensor\x18\x01 \x02(\x0b\x32*.paddle.framework.proto.VarType.TensorDesc\x12\x14\n\tlod_level\x18\x02 \x01(\x05:\x01\x30\x1a\x66\n\x12LoDTensorArrayDesc\x12:\n\x06tensor\x18\x01 \x02(\x0b\x32*.paddle.framework.proto.VarType.TensorDesc\x12\x14\n\tlod_level\x18\x02 \x01(\x05:\x01\x30\x1aO\n\nReaderDesc\x12\x41\n\nlod_tensor\x18\x01 \x03(\x0b\x32-.paddle.framework.proto.VarType.LoDTensorDesc\x1a\x43\n\x05Tuple\x12:\n\x0c\x65lement_type\x18\x01 \x03(\x0e\x32$.paddle.framework.proto.VarType.Type\"\xa2\x02\n\x04Type\x12\x08\n\x04\x42OOL\x10\x00\x12\t\n\x05INT16\x10\x01\x12\t\n\x05INT32\x10\x02\x12\t\n\x05INT64\x10\x03\x12\x08\n\x04\x46P16\x10\x04\x12\x08\n\x04\x46P32\x10\x05\x12\x08\n\x04\x46P64\x10\x06\x12\n\n\x06SIZE_T\x10\x13\x12\t\n\x05UINT8\x10\x14\x12\x08\n\x04INT8\x10\x15\x12\x0e\n\nLOD_TENSOR\x10\x07\x12\x11\n\rSELECTED_ROWS\x10\x08\x12\x12\n\x0e\x46\x45\x45\x44_MINIBATCH\x10\t\x12\x0e\n\nFETCH_LIST\x10\n\x12\x0f\n\x0bSTEP_SCOPES\x10\x0b\x12\x12\n\x0eLOD_RANK_TABLE\x10\x0c\x12\x14\n\x10LOD_TENSOR_ARRAY\x10\r\x12\x0e\n\nPLACE_LIST\x10\x0e\x12\n\n\x06READER\x10\x0f\x12\x07\n\x03RAW\x10\x11\x12\t\n\x05TUPLE\x10\x12\"b\n\x07VarDesc\x12\x0c\n\x04name\x18\x01 \x02(\t\x12-\n\x04type\x18\x02 \x02(\x0b\x32\x1f.paddle.framework.proto.VarType\x12\x1a\n\x0bpersistable\x18\x03 \x01(\x08:\x05\x66\x61lse\"\xa7\x01\n\tBlockDesc\x12\x0b\n\x03idx\x18\x01 \x02(\x05\x12\x12\n\nparent_idx\x18\x02 \x02(\x05\x12-\n\x04vars\x18\x03 \x03(\x0b\x32\x1f.paddle.framework.proto.VarDesc\x12+\n\x03ops\x18\x04 \x03(\x0b\x32\x1e.paddle.framework.proto.OpDesc\x12\x1d\n\x11\x66orward_block_idx\x18\x05 \x01(\x05:\x02-1\"r\n\x0bProgramDesc\x12\x31\n\x06\x62locks\x18\x01 \x03(\x0b\x32!.paddle.framework.proto.BlockDesc\x12\x30\n\x07version\x18\x02 \x01(\x0b\x32\x1f.paddle.framework.proto.Version*\x94\x01\n\x08\x41ttrType\x12\x07\n\x03INT\x10\x00\x12\t\n\x05\x46LOAT\x10\x01\x12\n\n\x06STRING\x10\x02\x12\x08\n\x04INTS\x10\x03\x12\n\n\x06\x46LOATS\x10\x04\x12\x0b\n\x07STRINGS\x10\x05\x12\x0b\n\x07\x42OOLEAN\x10\x06\x12\x0c\n\x08\x42OOLEANS\x10\x07\x12\t\n\x05\x42LOCK\x10\x08\x12\x08\n\x04LONG\x10\t\x12\n\n\x06\x42LOCKS\x10\n\x12\t\n\x05LONGS\x10\x0b\x42\x02H\x03') +) +_sym_db.RegisterFileDescriptor(DESCRIPTOR) + +_ATTRTYPE = _descriptor.EnumDescriptor( + name='AttrType', + full_name='paddle.framework.proto.AttrType', + filename=None, + file=DESCRIPTOR, + values=[ + _descriptor.EnumValueDescriptor( + name='INT', index=0, number=0, + options=None, + type=None), + _descriptor.EnumValueDescriptor( + name='FLOAT', index=1, number=1, + options=None, + type=None), + _descriptor.EnumValueDescriptor( + name='STRING', index=2, number=2, + options=None, + type=None), + _descriptor.EnumValueDescriptor( + name='INTS', index=3, number=3, + options=None, + type=None), + _descriptor.EnumValueDescriptor( + name='FLOATS', index=4, number=4, + options=None, + type=None), + _descriptor.EnumValueDescriptor( + name='STRINGS', index=5, number=5, + options=None, + type=None), + _descriptor.EnumValueDescriptor( + name='BOOLEAN', index=6, number=6, + options=None, + type=None), + _descriptor.EnumValueDescriptor( + name='BOOLEANS', index=7, number=7, + options=None, + type=None), + _descriptor.EnumValueDescriptor( + name='BLOCK', index=8, number=8, + options=None, + type=None), + _descriptor.EnumValueDescriptor( + name='LONG', index=9, number=9, + options=None, + type=None), + _descriptor.EnumValueDescriptor( + name='BLOCKS', index=10, number=10, + options=None, + type=None), + _descriptor.EnumValueDescriptor( + name='LONGS', index=11, number=11, + options=None, + type=None), + ], + containing_type=None, + options=None, + serialized_start=2511, + serialized_end=2659, +) +_sym_db.RegisterEnumDescriptor(_ATTRTYPE) + +AttrType = enum_type_wrapper.EnumTypeWrapper(_ATTRTYPE) +INT = 0 +FLOAT = 1 +STRING = 2 +INTS = 3 +FLOATS = 4 +STRINGS = 5 +BOOLEAN = 6 +BOOLEANS = 7 +BLOCK = 8 +LONG = 9 +BLOCKS = 10 +LONGS = 11 + + +_VARTYPE_TYPE = _descriptor.EnumDescriptor( + name='Type', + full_name='paddle.framework.proto.VarType.Type', + filename=None, + file=DESCRIPTOR, + values=[ + _descriptor.EnumValueDescriptor( + name='BOOL', index=0, number=0, + options=None, + type=None), + _descriptor.EnumValueDescriptor( + name='INT16', index=1, number=1, + options=None, + type=None), + _descriptor.EnumValueDescriptor( + name='INT32', index=2, number=2, + options=None, + type=None), + _descriptor.EnumValueDescriptor( + name='INT64', index=3, number=3, + options=None, + type=None), + _descriptor.EnumValueDescriptor( + name='FP16', index=4, number=4, + options=None, + type=None), + _descriptor.EnumValueDescriptor( + name='FP32', index=5, number=5, + options=None, + type=None), + _descriptor.EnumValueDescriptor( + name='FP64', index=6, number=6, + options=None, + type=None), + _descriptor.EnumValueDescriptor( + name='SIZE_T', index=7, number=19, + options=None, + type=None), + _descriptor.EnumValueDescriptor( + name='UINT8', index=8, number=20, + options=None, + type=None), + _descriptor.EnumValueDescriptor( + name='INT8', index=9, number=21, + options=None, + type=None), + _descriptor.EnumValueDescriptor( + name='LOD_TENSOR', index=10, number=7, + options=None, + type=None), + _descriptor.EnumValueDescriptor( + name='SELECTED_ROWS', index=11, number=8, + options=None, + type=None), + _descriptor.EnumValueDescriptor( + name='FEED_MINIBATCH', index=12, number=9, + options=None, + type=None), + _descriptor.EnumValueDescriptor( + name='FETCH_LIST', index=13, number=10, + options=None, + type=None), + _descriptor.EnumValueDescriptor( + name='STEP_SCOPES', index=14, number=11, + options=None, + type=None), + _descriptor.EnumValueDescriptor( + name='LOD_RANK_TABLE', index=15, number=12, + options=None, + type=None), + _descriptor.EnumValueDescriptor( + name='LOD_TENSOR_ARRAY', index=16, number=13, + options=None, + type=None), + _descriptor.EnumValueDescriptor( + name='PLACE_LIST', index=17, number=14, + options=None, + type=None), + _descriptor.EnumValueDescriptor( + name='READER', index=18, number=15, + options=None, + type=None), + _descriptor.EnumValueDescriptor( + name='RAW', index=19, number=17, + options=None, + type=None), + _descriptor.EnumValueDescriptor( + name='TUPLE', index=20, number=18, + options=None, + type=None), + ], + containing_type=None, + options=None, + serialized_start=1832, + serialized_end=2122, +) +_sym_db.RegisterEnumDescriptor(_VARTYPE_TYPE) + + +_VERSION = _descriptor.Descriptor( + name='Version', + full_name='paddle.framework.proto.Version', + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name='version', full_name='paddle.framework.proto.Version.version', index=0, + number=1, type=3, cpp_type=2, label=1, + has_default_value=True, default_value=0, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None), + ], + extensions=[ + ], + nested_types=[], + enum_types=[ + ], + options=None, + is_extendable=False, + syntax='proto2', + extension_ranges=[], + oneofs=[ + ], + serialized_start=43, + serialized_end=72, +) + + +_OPDESC_ATTR = _descriptor.Descriptor( + name='Attr', + full_name='paddle.framework.proto.OpDesc.Attr', + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name='name', full_name='paddle.framework.proto.OpDesc.Attr.name', index=0, + number=1, type=9, cpp_type=9, label=2, + has_default_value=False, default_value=_b("").decode('utf-8'), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None), + _descriptor.FieldDescriptor( + name='type', full_name='paddle.framework.proto.OpDesc.Attr.type', index=1, + number=2, type=14, cpp_type=8, label=2, + has_default_value=False, default_value=0, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None), + _descriptor.FieldDescriptor( + name='i', full_name='paddle.framework.proto.OpDesc.Attr.i', index=2, + number=3, type=5, cpp_type=1, label=1, + has_default_value=False, default_value=0, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None), + _descriptor.FieldDescriptor( + name='f', full_name='paddle.framework.proto.OpDesc.Attr.f', index=3, + number=4, type=2, cpp_type=6, label=1, + has_default_value=False, default_value=float(0), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None), + _descriptor.FieldDescriptor( + name='s', full_name='paddle.framework.proto.OpDesc.Attr.s', index=4, + number=5, type=9, cpp_type=9, label=1, + has_default_value=False, default_value=_b("").decode('utf-8'), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None), + _descriptor.FieldDescriptor( + name='ints', full_name='paddle.framework.proto.OpDesc.Attr.ints', index=5, + number=6, type=5, cpp_type=1, label=3, + has_default_value=False, default_value=[], + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None), + _descriptor.FieldDescriptor( + name='floats', full_name='paddle.framework.proto.OpDesc.Attr.floats', index=6, + number=7, type=2, cpp_type=6, label=3, + has_default_value=False, default_value=[], + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None), + _descriptor.FieldDescriptor( + name='strings', full_name='paddle.framework.proto.OpDesc.Attr.strings', index=7, + number=8, type=9, cpp_type=9, label=3, + has_default_value=False, default_value=[], + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None), + _descriptor.FieldDescriptor( + name='b', full_name='paddle.framework.proto.OpDesc.Attr.b', index=8, + number=10, type=8, cpp_type=7, label=1, + has_default_value=False, default_value=False, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None), + _descriptor.FieldDescriptor( + name='bools', full_name='paddle.framework.proto.OpDesc.Attr.bools', index=9, + number=11, type=8, cpp_type=7, label=3, + has_default_value=False, default_value=[], + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None), + _descriptor.FieldDescriptor( + name='block_idx', full_name='paddle.framework.proto.OpDesc.Attr.block_idx', index=10, + number=12, type=5, cpp_type=1, label=1, + has_default_value=False, default_value=0, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None), + _descriptor.FieldDescriptor( + name='l', full_name='paddle.framework.proto.OpDesc.Attr.l', index=11, + number=13, type=3, cpp_type=2, label=1, + has_default_value=False, default_value=0, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None), + _descriptor.FieldDescriptor( + name='blocks_idx', full_name='paddle.framework.proto.OpDesc.Attr.blocks_idx', index=12, + number=14, type=5, cpp_type=1, label=3, + has_default_value=False, default_value=[], + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None), + _descriptor.FieldDescriptor( + name='longs', full_name='paddle.framework.proto.OpDesc.Attr.longs', index=13, + number=15, type=3, cpp_type=2, label=3, + has_default_value=False, default_value=[], + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None), + ], + extensions=[ + ], + nested_types=[], + enum_types=[ + ], + options=None, + is_extendable=False, + syntax='proto2', + extension_ranges=[], + oneofs=[ + ], + serialized_start=283, + serialized_end=522, +) + +_OPDESC_VAR = _descriptor.Descriptor( + name='Var', + full_name='paddle.framework.proto.OpDesc.Var', + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name='parameter', full_name='paddle.framework.proto.OpDesc.Var.parameter', index=0, + number=1, type=9, cpp_type=9, label=2, + has_default_value=False, default_value=_b("").decode('utf-8'), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None), + _descriptor.FieldDescriptor( + name='arguments', full_name='paddle.framework.proto.OpDesc.Var.arguments', index=1, + number=2, type=9, cpp_type=9, label=3, + has_default_value=False, default_value=[], + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None), + ], + extensions=[ + ], + nested_types=[], + enum_types=[ + ], + options=None, + is_extendable=False, + syntax='proto2', + extension_ranges=[], + oneofs=[ + ], + serialized_start=524, + serialized_end=567, +) + +_OPDESC = _descriptor.Descriptor( + name='OpDesc', + full_name='paddle.framework.proto.OpDesc', + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name='type', full_name='paddle.framework.proto.OpDesc.type', index=0, + number=3, type=9, cpp_type=9, label=2, + has_default_value=False, default_value=_b("").decode('utf-8'), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None), + _descriptor.FieldDescriptor( + name='inputs', full_name='paddle.framework.proto.OpDesc.inputs', index=1, + number=1, type=11, cpp_type=10, label=3, + has_default_value=False, default_value=[], + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None), + _descriptor.FieldDescriptor( + name='outputs', full_name='paddle.framework.proto.OpDesc.outputs', index=2, + number=2, type=11, cpp_type=10, label=3, + has_default_value=False, default_value=[], + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None), + _descriptor.FieldDescriptor( + name='attrs', full_name='paddle.framework.proto.OpDesc.attrs', index=3, + number=4, type=11, cpp_type=10, label=3, + has_default_value=False, default_value=[], + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None), + _descriptor.FieldDescriptor( + name='is_target', full_name='paddle.framework.proto.OpDesc.is_target', index=4, + number=5, type=8, cpp_type=7, label=1, + has_default_value=True, default_value=False, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None), + ], + extensions=[ + ], + nested_types=[_OPDESC_ATTR, _OPDESC_VAR, ], + enum_types=[ + ], + options=None, + is_extendable=False, + syntax='proto2', + extension_ranges=[], + oneofs=[ + ], + serialized_start=75, + serialized_end=567, +) + + +_OPPROTO_VAR = _descriptor.Descriptor( + name='Var', + full_name='paddle.framework.proto.OpProto.Var', + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name='name', full_name='paddle.framework.proto.OpProto.Var.name', index=0, + number=1, type=9, cpp_type=9, label=2, + has_default_value=False, default_value=_b("").decode('utf-8'), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None), + _descriptor.FieldDescriptor( + name='comment', full_name='paddle.framework.proto.OpProto.Var.comment', index=1, + number=2, type=9, cpp_type=9, label=2, + has_default_value=False, default_value=_b("").decode('utf-8'), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None), + _descriptor.FieldDescriptor( + name='duplicable', full_name='paddle.framework.proto.OpProto.Var.duplicable', index=2, + number=3, type=8, cpp_type=7, label=1, + has_default_value=True, default_value=False, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None), + _descriptor.FieldDescriptor( + name='intermediate', full_name='paddle.framework.proto.OpProto.Var.intermediate', index=3, + number=4, type=8, cpp_type=7, label=1, + has_default_value=True, default_value=False, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None), + _descriptor.FieldDescriptor( + name='dispensable', full_name='paddle.framework.proto.OpProto.Var.dispensable', index=4, + number=5, type=8, cpp_type=7, label=1, + has_default_value=True, default_value=False, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None), + ], + extensions=[ + ], + nested_types=[], + enum_types=[ + ], + options=None, + is_extendable=False, + syntax='proto2', + extension_ranges=[], + oneofs=[ + ], + serialized_start=772, + serialized_end=892, +) + +_OPPROTO_ATTR = _descriptor.Descriptor( + name='Attr', + full_name='paddle.framework.proto.OpProto.Attr', + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name='name', full_name='paddle.framework.proto.OpProto.Attr.name', index=0, + number=1, type=9, cpp_type=9, label=2, + has_default_value=False, default_value=_b("").decode('utf-8'), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None), + _descriptor.FieldDescriptor( + name='type', full_name='paddle.framework.proto.OpProto.Attr.type', index=1, + number=2, type=14, cpp_type=8, label=2, + has_default_value=False, default_value=0, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None), + _descriptor.FieldDescriptor( + name='comment', full_name='paddle.framework.proto.OpProto.Attr.comment', index=2, + number=3, type=9, cpp_type=9, label=2, + has_default_value=False, default_value=_b("").decode('utf-8'), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None), + _descriptor.FieldDescriptor( + name='generated', full_name='paddle.framework.proto.OpProto.Attr.generated', index=3, + number=4, type=8, cpp_type=7, label=1, + has_default_value=True, default_value=False, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None), + ], + extensions=[ + ], + nested_types=[], + enum_types=[ + ], + options=None, + is_extendable=False, + syntax='proto2', + extension_ranges=[], + oneofs=[ + ], + serialized_start=894, + serialized_end=1005, +) + +_OPPROTO = _descriptor.Descriptor( + name='OpProto', + full_name='paddle.framework.proto.OpProto', + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name='type', full_name='paddle.framework.proto.OpProto.type', index=0, + number=1, type=9, cpp_type=9, label=2, + has_default_value=False, default_value=_b("").decode('utf-8'), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None), + _descriptor.FieldDescriptor( + name='inputs', full_name='paddle.framework.proto.OpProto.inputs', index=1, + number=2, type=11, cpp_type=10, label=3, + has_default_value=False, default_value=[], + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None), + _descriptor.FieldDescriptor( + name='outputs', full_name='paddle.framework.proto.OpProto.outputs', index=2, + number=3, type=11, cpp_type=10, label=3, + has_default_value=False, default_value=[], + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None), + _descriptor.FieldDescriptor( + name='attrs', full_name='paddle.framework.proto.OpProto.attrs', index=3, + number=4, type=11, cpp_type=10, label=3, + has_default_value=False, default_value=[], + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None), + _descriptor.FieldDescriptor( + name='comment', full_name='paddle.framework.proto.OpProto.comment', index=4, + number=5, type=9, cpp_type=9, label=2, + has_default_value=False, default_value=_b("").decode('utf-8'), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None), + ], + extensions=[ + ], + nested_types=[_OPPROTO_VAR, _OPPROTO_ATTR, ], + enum_types=[ + ], + options=None, + is_extendable=False, + syntax='proto2', + extension_ranges=[], + oneofs=[ + ], + serialized_start=570, + serialized_end=1005, +) + + +_VARTYPE_TENSORDESC = _descriptor.Descriptor( + name='TensorDesc', + full_name='paddle.framework.proto.VarType.TensorDesc', + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name='data_type', full_name='paddle.framework.proto.VarType.TensorDesc.data_type', index=0, + number=1, type=14, cpp_type=8, label=2, + has_default_value=False, default_value=0, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None), + _descriptor.FieldDescriptor( + name='dims', full_name='paddle.framework.proto.VarType.TensorDesc.dims', index=1, + number=2, type=3, cpp_type=2, label=3, + has_default_value=False, default_value=[], + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None), + ], + extensions=[ + ], + nested_types=[], + enum_types=[ + ], + options=None, + is_extendable=False, + syntax='proto2', + extension_ranges=[], + oneofs=[ + ], + serialized_start=1393, + serialized_end=1476, +) + +_VARTYPE_LODTENSORDESC = _descriptor.Descriptor( + name='LoDTensorDesc', + full_name='paddle.framework.proto.VarType.LoDTensorDesc', + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name='tensor', full_name='paddle.framework.proto.VarType.LoDTensorDesc.tensor', index=0, + number=1, type=11, cpp_type=10, label=2, + has_default_value=False, default_value=None, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None), + _descriptor.FieldDescriptor( + name='lod_level', full_name='paddle.framework.proto.VarType.LoDTensorDesc.lod_level', index=1, + number=2, type=5, cpp_type=1, label=1, + has_default_value=True, default_value=0, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None), + ], + extensions=[ + ], + nested_types=[], + enum_types=[ + ], + options=None, + is_extendable=False, + syntax='proto2', + extension_ranges=[], + oneofs=[ + ], + serialized_start=1478, + serialized_end=1575, +) + +_VARTYPE_LODTENSORARRAYDESC = _descriptor.Descriptor( + name='LoDTensorArrayDesc', + full_name='paddle.framework.proto.VarType.LoDTensorArrayDesc', + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name='tensor', full_name='paddle.framework.proto.VarType.LoDTensorArrayDesc.tensor', index=0, + number=1, type=11, cpp_type=10, label=2, + has_default_value=False, default_value=None, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None), + _descriptor.FieldDescriptor( + name='lod_level', full_name='paddle.framework.proto.VarType.LoDTensorArrayDesc.lod_level', index=1, + number=2, type=5, cpp_type=1, label=1, + has_default_value=True, default_value=0, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None), + ], + extensions=[ + ], + nested_types=[], + enum_types=[ + ], + options=None, + is_extendable=False, + syntax='proto2', + extension_ranges=[], + oneofs=[ + ], + serialized_start=1577, + serialized_end=1679, +) + +_VARTYPE_READERDESC = _descriptor.Descriptor( + name='ReaderDesc', + full_name='paddle.framework.proto.VarType.ReaderDesc', + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name='lod_tensor', full_name='paddle.framework.proto.VarType.ReaderDesc.lod_tensor', index=0, + number=1, type=11, cpp_type=10, label=3, + has_default_value=False, default_value=[], + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None), + ], + extensions=[ + ], + nested_types=[], + enum_types=[ + ], + options=None, + is_extendable=False, + syntax='proto2', + extension_ranges=[], + oneofs=[ + ], + serialized_start=1681, + serialized_end=1760, +) + +_VARTYPE_TUPLE = _descriptor.Descriptor( + name='Tuple', + full_name='paddle.framework.proto.VarType.Tuple', + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name='element_type', full_name='paddle.framework.proto.VarType.Tuple.element_type', index=0, + number=1, type=14, cpp_type=8, label=3, + has_default_value=False, default_value=[], + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None), + ], + extensions=[ + ], + nested_types=[], + enum_types=[ + ], + options=None, + is_extendable=False, + syntax='proto2', + extension_ranges=[], + oneofs=[ + ], + serialized_start=1762, + serialized_end=1829, +) + +_VARTYPE = _descriptor.Descriptor( + name='VarType', + full_name='paddle.framework.proto.VarType', + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name='type', full_name='paddle.framework.proto.VarType.type', index=0, + number=1, type=14, cpp_type=8, label=2, + has_default_value=False, default_value=0, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None), + _descriptor.FieldDescriptor( + name='selected_rows', full_name='paddle.framework.proto.VarType.selected_rows', index=1, + number=2, type=11, cpp_type=10, label=1, + has_default_value=False, default_value=None, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None), + _descriptor.FieldDescriptor( + name='lod_tensor', full_name='paddle.framework.proto.VarType.lod_tensor', index=2, + number=3, type=11, cpp_type=10, label=1, + has_default_value=False, default_value=None, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None), + _descriptor.FieldDescriptor( + name='tensor_array', full_name='paddle.framework.proto.VarType.tensor_array', index=3, + number=4, type=11, cpp_type=10, label=1, + has_default_value=False, default_value=None, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None), + _descriptor.FieldDescriptor( + name='reader', full_name='paddle.framework.proto.VarType.reader', index=4, + number=5, type=11, cpp_type=10, label=1, + has_default_value=False, default_value=None, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None), + _descriptor.FieldDescriptor( + name='tuple', full_name='paddle.framework.proto.VarType.tuple', index=5, + number=7, type=11, cpp_type=10, label=1, + has_default_value=False, default_value=None, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None), + ], + extensions=[ + ], + nested_types=[_VARTYPE_TENSORDESC, _VARTYPE_LODTENSORDESC, _VARTYPE_LODTENSORARRAYDESC, _VARTYPE_READERDESC, _VARTYPE_TUPLE, ], + enum_types=[ + _VARTYPE_TYPE, + ], + options=None, + is_extendable=False, + syntax='proto2', + extension_ranges=[], + oneofs=[ + ], + serialized_start=1008, + serialized_end=2122, +) + + +_VARDESC = _descriptor.Descriptor( + name='VarDesc', + full_name='paddle.framework.proto.VarDesc', + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name='name', full_name='paddle.framework.proto.VarDesc.name', index=0, + number=1, type=9, cpp_type=9, label=2, + has_default_value=False, default_value=_b("").decode('utf-8'), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None), + _descriptor.FieldDescriptor( + name='type', full_name='paddle.framework.proto.VarDesc.type', index=1, + number=2, type=11, cpp_type=10, label=2, + has_default_value=False, default_value=None, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None), + _descriptor.FieldDescriptor( + name='persistable', full_name='paddle.framework.proto.VarDesc.persistable', index=2, + number=3, type=8, cpp_type=7, label=1, + has_default_value=True, default_value=False, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None), + ], + extensions=[ + ], + nested_types=[], + enum_types=[ + ], + options=None, + is_extendable=False, + syntax='proto2', + extension_ranges=[], + oneofs=[ + ], + serialized_start=2124, + serialized_end=2222, +) + + +_BLOCKDESC = _descriptor.Descriptor( + name='BlockDesc', + full_name='paddle.framework.proto.BlockDesc', + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name='idx', full_name='paddle.framework.proto.BlockDesc.idx', index=0, + number=1, type=5, cpp_type=1, label=2, + has_default_value=False, default_value=0, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None), + _descriptor.FieldDescriptor( + name='parent_idx', full_name='paddle.framework.proto.BlockDesc.parent_idx', index=1, + number=2, type=5, cpp_type=1, label=2, + has_default_value=False, default_value=0, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None), + _descriptor.FieldDescriptor( + name='vars', full_name='paddle.framework.proto.BlockDesc.vars', index=2, + number=3, type=11, cpp_type=10, label=3, + has_default_value=False, default_value=[], + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None), + _descriptor.FieldDescriptor( + name='ops', full_name='paddle.framework.proto.BlockDesc.ops', index=3, + number=4, type=11, cpp_type=10, label=3, + has_default_value=False, default_value=[], + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None), + _descriptor.FieldDescriptor( + name='forward_block_idx', full_name='paddle.framework.proto.BlockDesc.forward_block_idx', index=4, + number=5, type=5, cpp_type=1, label=1, + has_default_value=True, default_value=-1, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None), + ], + extensions=[ + ], + nested_types=[], + enum_types=[ + ], + options=None, + is_extendable=False, + syntax='proto2', + extension_ranges=[], + oneofs=[ + ], + serialized_start=2225, + serialized_end=2392, +) + + +_PROGRAMDESC = _descriptor.Descriptor( + name='ProgramDesc', + full_name='paddle.framework.proto.ProgramDesc', + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name='blocks', full_name='paddle.framework.proto.ProgramDesc.blocks', index=0, + number=1, type=11, cpp_type=10, label=3, + has_default_value=False, default_value=[], + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None), + _descriptor.FieldDescriptor( + name='version', full_name='paddle.framework.proto.ProgramDesc.version', index=1, + number=2, type=11, cpp_type=10, label=1, + has_default_value=False, default_value=None, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None), + ], + extensions=[ + ], + nested_types=[], + enum_types=[ + ], + options=None, + is_extendable=False, + syntax='proto2', + extension_ranges=[], + oneofs=[ + ], + serialized_start=2394, + serialized_end=2508, +) + +_OPDESC_ATTR.fields_by_name['type'].enum_type = _ATTRTYPE +_OPDESC_ATTR.containing_type = _OPDESC +_OPDESC_VAR.containing_type = _OPDESC +_OPDESC.fields_by_name['inputs'].message_type = _OPDESC_VAR +_OPDESC.fields_by_name['outputs'].message_type = _OPDESC_VAR +_OPDESC.fields_by_name['attrs'].message_type = _OPDESC_ATTR +_OPPROTO_VAR.containing_type = _OPPROTO +_OPPROTO_ATTR.fields_by_name['type'].enum_type = _ATTRTYPE +_OPPROTO_ATTR.containing_type = _OPPROTO +_OPPROTO.fields_by_name['inputs'].message_type = _OPPROTO_VAR +_OPPROTO.fields_by_name['outputs'].message_type = _OPPROTO_VAR +_OPPROTO.fields_by_name['attrs'].message_type = _OPPROTO_ATTR +_VARTYPE_TENSORDESC.fields_by_name['data_type'].enum_type = _VARTYPE_TYPE +_VARTYPE_TENSORDESC.containing_type = _VARTYPE +_VARTYPE_LODTENSORDESC.fields_by_name['tensor'].message_type = _VARTYPE_TENSORDESC +_VARTYPE_LODTENSORDESC.containing_type = _VARTYPE +_VARTYPE_LODTENSORARRAYDESC.fields_by_name['tensor'].message_type = _VARTYPE_TENSORDESC +_VARTYPE_LODTENSORARRAYDESC.containing_type = _VARTYPE +_VARTYPE_READERDESC.fields_by_name['lod_tensor'].message_type = _VARTYPE_LODTENSORDESC +_VARTYPE_READERDESC.containing_type = _VARTYPE +_VARTYPE_TUPLE.fields_by_name['element_type'].enum_type = _VARTYPE_TYPE +_VARTYPE_TUPLE.containing_type = _VARTYPE +_VARTYPE.fields_by_name['type'].enum_type = _VARTYPE_TYPE +_VARTYPE.fields_by_name['selected_rows'].message_type = _VARTYPE_TENSORDESC +_VARTYPE.fields_by_name['lod_tensor'].message_type = _VARTYPE_LODTENSORDESC +_VARTYPE.fields_by_name['tensor_array'].message_type = _VARTYPE_LODTENSORARRAYDESC +_VARTYPE.fields_by_name['reader'].message_type = _VARTYPE_READERDESC +_VARTYPE.fields_by_name['tuple'].message_type = _VARTYPE_TUPLE +_VARTYPE_TYPE.containing_type = _VARTYPE +_VARDESC.fields_by_name['type'].message_type = _VARTYPE +_BLOCKDESC.fields_by_name['vars'].message_type = _VARDESC +_BLOCKDESC.fields_by_name['ops'].message_type = _OPDESC +_PROGRAMDESC.fields_by_name['blocks'].message_type = _BLOCKDESC +_PROGRAMDESC.fields_by_name['version'].message_type = _VERSION +DESCRIPTOR.message_types_by_name['Version'] = _VERSION +DESCRIPTOR.message_types_by_name['OpDesc'] = _OPDESC +DESCRIPTOR.message_types_by_name['OpProto'] = _OPPROTO +DESCRIPTOR.message_types_by_name['VarType'] = _VARTYPE +DESCRIPTOR.message_types_by_name['VarDesc'] = _VARDESC +DESCRIPTOR.message_types_by_name['BlockDesc'] = _BLOCKDESC +DESCRIPTOR.message_types_by_name['ProgramDesc'] = _PROGRAMDESC +DESCRIPTOR.enum_types_by_name['AttrType'] = _ATTRTYPE + +Version = _reflection.GeneratedProtocolMessageType('Version', (_message.Message,), dict( + DESCRIPTOR = _VERSION, + __module__ = 'framework_pb2' + # @@protoc_insertion_point(class_scope:paddle.framework.proto.Version) + )) +_sym_db.RegisterMessage(Version) + +OpDesc = _reflection.GeneratedProtocolMessageType('OpDesc', (_message.Message,), dict( + + Attr = _reflection.GeneratedProtocolMessageType('Attr', (_message.Message,), dict( + DESCRIPTOR = _OPDESC_ATTR, + __module__ = 'framework_pb2' + # @@protoc_insertion_point(class_scope:paddle.framework.proto.OpDesc.Attr) + )) + , + + Var = _reflection.GeneratedProtocolMessageType('Var', (_message.Message,), dict( + DESCRIPTOR = _OPDESC_VAR, + __module__ = 'framework_pb2' + # @@protoc_insertion_point(class_scope:paddle.framework.proto.OpDesc.Var) + )) + , + DESCRIPTOR = _OPDESC, + __module__ = 'framework_pb2' + # @@protoc_insertion_point(class_scope:paddle.framework.proto.OpDesc) + )) +_sym_db.RegisterMessage(OpDesc) +_sym_db.RegisterMessage(OpDesc.Attr) +_sym_db.RegisterMessage(OpDesc.Var) + +OpProto = _reflection.GeneratedProtocolMessageType('OpProto', (_message.Message,), dict( + + Var = _reflection.GeneratedProtocolMessageType('Var', (_message.Message,), dict( + DESCRIPTOR = _OPPROTO_VAR, + __module__ = 'framework_pb2' + # @@protoc_insertion_point(class_scope:paddle.framework.proto.OpProto.Var) + )) + , + + Attr = _reflection.GeneratedProtocolMessageType('Attr', (_message.Message,), dict( + DESCRIPTOR = _OPPROTO_ATTR, + __module__ = 'framework_pb2' + # @@protoc_insertion_point(class_scope:paddle.framework.proto.OpProto.Attr) + )) + , + DESCRIPTOR = _OPPROTO, + __module__ = 'framework_pb2' + # @@protoc_insertion_point(class_scope:paddle.framework.proto.OpProto) + )) +_sym_db.RegisterMessage(OpProto) +_sym_db.RegisterMessage(OpProto.Var) +_sym_db.RegisterMessage(OpProto.Attr) + +VarType = _reflection.GeneratedProtocolMessageType('VarType', (_message.Message,), dict( + + TensorDesc = _reflection.GeneratedProtocolMessageType('TensorDesc', (_message.Message,), dict( + DESCRIPTOR = _VARTYPE_TENSORDESC, + __module__ = 'framework_pb2' + # @@protoc_insertion_point(class_scope:paddle.framework.proto.VarType.TensorDesc) + )) + , + + LoDTensorDesc = _reflection.GeneratedProtocolMessageType('LoDTensorDesc', (_message.Message,), dict( + DESCRIPTOR = _VARTYPE_LODTENSORDESC, + __module__ = 'framework_pb2' + # @@protoc_insertion_point(class_scope:paddle.framework.proto.VarType.LoDTensorDesc) + )) + , + + LoDTensorArrayDesc = _reflection.GeneratedProtocolMessageType('LoDTensorArrayDesc', (_message.Message,), dict( + DESCRIPTOR = _VARTYPE_LODTENSORARRAYDESC, + __module__ = 'framework_pb2' + # @@protoc_insertion_point(class_scope:paddle.framework.proto.VarType.LoDTensorArrayDesc) + )) + , + + ReaderDesc = _reflection.GeneratedProtocolMessageType('ReaderDesc', (_message.Message,), dict( + DESCRIPTOR = _VARTYPE_READERDESC, + __module__ = 'framework_pb2' + # @@protoc_insertion_point(class_scope:paddle.framework.proto.VarType.ReaderDesc) + )) + , + + Tuple = _reflection.GeneratedProtocolMessageType('Tuple', (_message.Message,), dict( + DESCRIPTOR = _VARTYPE_TUPLE, + __module__ = 'framework_pb2' + # @@protoc_insertion_point(class_scope:paddle.framework.proto.VarType.Tuple) + )) + , + DESCRIPTOR = _VARTYPE, + __module__ = 'framework_pb2' + # @@protoc_insertion_point(class_scope:paddle.framework.proto.VarType) + )) +_sym_db.RegisterMessage(VarType) +_sym_db.RegisterMessage(VarType.TensorDesc) +_sym_db.RegisterMessage(VarType.LoDTensorDesc) +_sym_db.RegisterMessage(VarType.LoDTensorArrayDesc) +_sym_db.RegisterMessage(VarType.ReaderDesc) +_sym_db.RegisterMessage(VarType.Tuple) + +VarDesc = _reflection.GeneratedProtocolMessageType('VarDesc', (_message.Message,), dict( + DESCRIPTOR = _VARDESC, + __module__ = 'framework_pb2' + # @@protoc_insertion_point(class_scope:paddle.framework.proto.VarDesc) + )) +_sym_db.RegisterMessage(VarDesc) + +BlockDesc = _reflection.GeneratedProtocolMessageType('BlockDesc', (_message.Message,), dict( + DESCRIPTOR = _BLOCKDESC, + __module__ = 'framework_pb2' + # @@protoc_insertion_point(class_scope:paddle.framework.proto.BlockDesc) + )) +_sym_db.RegisterMessage(BlockDesc) + +ProgramDesc = _reflection.GeneratedProtocolMessageType('ProgramDesc', (_message.Message,), dict( + DESCRIPTOR = _PROGRAMDESC, + __module__ = 'framework_pb2' + # @@protoc_insertion_point(class_scope:paddle.framework.proto.ProgramDesc) + )) +_sym_db.RegisterMessage(ProgramDesc) + + +DESCRIPTOR.has_options = True +DESCRIPTOR._options = _descriptor._ParseOptions(descriptor_pb2.FileOptions(), _b('H\003')) +# @@protoc_insertion_point(module_scope) diff --git a/TensorFlow2Paddle/graph.py b/TensorFlow2Paddle/graph.py new file mode 100644 index 0000000000000000000000000000000000000000..b47737124634a4c9d2e986ab462525d36dd3ca03 --- /dev/null +++ b/TensorFlow2Paddle/graph.py @@ -0,0 +1,82 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from name_generator import NameGenerator + + +class GraphNode(object): + def __init__(self, layer): + self.in_edges = list() + self.out_edges = list() + self.layer = layer + self.ref_name = None + + +class Graph(object): + def __init__(self, model): + self.node_map = dict() + self.input_nodes = list() + self.output_nodes = list() + self.topological_sort = list() + self.model = model + self.name_generator = NameGenerator() + + def build(self): + self._make_input_nodes() + self._make_output_nodes() + self._get_topological_sort() + self._gen_newname_for_nodes() + + def _make_input_nodes(self): + for name, node in self.node_map.items(): + node.left_in_edges = len(node.in_edges) + if len(node.in_edges) == 0: + self.input_nodes.append(name) + + def _make_output_nodes(self): + for name, node in self.node_map.items(): + if len(node.out_edges) == 0: + self.output_nodes.append(name) + + def _get_topological_sort(self): + self.topological_sort = self.input_nodes[:] + idx = 0 + while idx < len(self.topological_sort): + current_node = self.node_map[self.topological_sort[idx]] + for next_node in current_node.out_edges: + next_node_info = self.node_map[next_node] + next_node_info.left_in_edges -= 1 + if next_node_info.left_in_edges == 0: + self.topological_sort.append(next_node) + idx += 1 + + def _gen_newname_for_nodes(self): + for node_name in self.topological_sort: + node = self.node_map[node_name] + ref_name = self.name_generator.get_name(node) + self.node_map[node.layer.name].ref_name = ref_name + + def get_node(self, name): + if name not in self.node_map: + raise Exception("Graph doesn't have node [%s]." % name) + else: + return self.node_map[name] + + def _make_connection(self, src, dst): + if src == dst or src not in self.node_map or dst not in self.node_map: + raise Exception('Warning: Node not exist or there is a self-loop') + if src not in self.node_map[dst].in_edges: + self.node_map[dst].in_edges.append(src) + if dst not in self.node_map[src].out_edges: + self.node_map[src].out_edges.append(dst) diff --git a/TensorFlow2Paddle/name_generator.py b/TensorFlow2Paddle/name_generator.py new file mode 100644 index 0000000000000000000000000000000000000000..b521d20559c3a8b88a59bfcad39ddd06d41db070 --- /dev/null +++ b/TensorFlow2Paddle/name_generator.py @@ -0,0 +1,46 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +class NameGenerator(object): + def __init__(self): + self.param_index = 0 + self.input_index = 0 + self.net_index = 0 + self.const_index = 0 + self.names = dict() + + def get_name(self, node): + ref_name = None + op_name = node.type + + if node.layer.name in self.names: + return self.names[node.layer.name] + + if op_name == "variablev2": + ref_name = "param_" + str(self.param_index) + self.param_index += 1 + elif op_name == "placeholder": + ref_name = "input_" + str(self.input_index) + self.input_index += 1 + elif op_name == "const": + ref_name = "const_" + str(self.const_index) + self.const_index += 1 + elif op_name.lower() == "identity": + ref_name = self.names[node.layer.input[0]] + else: + ref_name = "net_" + str(self.net_index) + self.net_index += 1 + self.names[node.layer.name] = ref_name + return ref_name diff --git a/TensorFlow2Paddle/paddle_emitter.py b/TensorFlow2Paddle/paddle_emitter.py new file mode 100644 index 0000000000000000000000000000000000000000..429826f64522b6e2807f92ff50d6c1571e593a4f --- /dev/null +++ b/TensorFlow2Paddle/paddle_emitter.py @@ -0,0 +1,548 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from tensorflow.core.framework import attr_value_pb2 +from tensorflow.python.framework import tensor_util +from six import string_types as _string_types + + +class PaddleEmitter(object): + skip_op = set(['const', 'identity']) + dtype_map = {1: "float32", 3: "int32", 9: "int64"} + + def __init__(self, graph): + self.graph = graph + self.body_code = "" + self.tab = " " * 4 + + @staticmethod + def tensor_shape_to_list(shapes): + if isinstance(shapes, attr_value_pb2.AttrValue): + return [dim.size for dim in shapes.shape.dim] + else: + ret = [] + for shape in shapes: + this_one = [dim.size for dim in shape.dim] + ret.append(this_one) + return ret + + @property + def header_code(self): + code = ["import paddle.fluid as fluid", "", "def KitModel():"] + return code + + def add_codes(self, indent, codes): + if isinstance(codes, _string_types): + codes = codes.strip().split("\n") + for code in codes: + self.body_code += (self.tab * indent) + code + "\n" + + def gen_code(self): + self.add_codes(0, self.header_code) + + for node in self.graph.topological_sort: + current_node = self.graph.get_node(node) + op = current_node.type + if op in self.skip_op: + continue + if hasattr(self, "emit_" + op): + func = getattr(self, "emit_" + op) + codes = func(current_node) + if not isinstance(codes, list): + codes = [codes] + self.graph.get_node(node).codes = codes + else: + raise Exception("Unknow node op: {}".format(op)) + + for node in self.graph.topological_sort: + codes = self.graph.get_node(node).codes + self.add_codes(1, codes) + + outs = [] + for node in self.graph.output_nodes: + outs.append(self.graph.get_node(node).ref_name) + self.add_codes(1, "return {}".format(", ".join(outs))) + + return self.body_code + + def gen_weight(self, weight_dict, dirname): + import struct + import framework_pb2 as framework + import numpy + + for var_name, var in weight_dict.items(): + if var_name not in self.graph.node_map: + continue + shape = var.shape + paddle_var_name = self.graph.get_node(var_name).ref_name + paddle_var = var + dataformat = self.graph.get_node(var_name).dataformat + + filew = open(dirname + '/' + paddle_var_name, 'wb') + filew.write(struct.pack('i', 0)) + filew.write(struct.pack('L', 0)) + filew.write(struct.pack('i', 0)) + tensor_desc = framework.VarType.TensorDesc() + tensor_desc.data_type = framework.VarType.FP32 + + if len(shape) == 4 and dataformat == "NHWC": + paddle_var = numpy.transpose(var, (3, 2, 0, 1)) + shape = paddle_var.shape + + tensor_desc.dims.extend(shape) + desc_size = tensor_desc.ByteSize() + filew.write(struct.pack('i', desc_size)) + filew.write(tensor_desc.SerializeToString()) + + tensor_size = reduce(lambda x, y: x * y, shape) + paddle_var = paddle_var.flatten() + for i in range(0, tensor_size): + filew.write(struct.pack('f', paddle_var[i])) + filew.close() + + def emit_variablev2(self, node): + shape = self.tensor_shape_to_list(node.get_attr('_output_shapes'))[0] + if node.dataformat == 'NHWC' and len(shape) == 4: + shape = [shape[3], shape[2], shape[0], shape[1]] + + dtype = node.get_attr('dtype') + if dtype in self.dtype_map: + dtype = self.dtype_map[dtype] + else: + raise Exception('Unknow dtype : {}'.format(dtype)) + + code = ["# variable[{}]:\t{}".format(node.ref_name, node.name)] + code.append( + '{} = fluid.layers.create_parameter(name=\'{}\', shape={}, dtype=\'{}\')' + .format(node.ref_name, node.ref_name, shape, dtype)) + return code + + def emit_placeholder(self, node): + shape = self.tensor_shape_to_list(node.get_attr('shape'))[0] + + if node.dataformat == 'NHWC' and len(shape) == 4: + shape = [shape[0], shape[3], shape[1], shape[2]] + if shape[0] < 0: + shape = shape[1:] + + dtype = node.get_attr('dtype') + if dtype in self.dtype_map: + dtype = self.dtype_map[dtype] + else: + raise Exception('Unknow dtype : {}'.format(dtype)) + + code = ["# placeholder[{}]:\t{}".format(node.ref_name, node.name)] + code.append( + '{} = fluid.layers.data(name=\'{}\', shape={}, dtype=\'{}\')'. + format(node.ref_name, node.ref_name, shape, dtype)) + return code + + def emit_conv2d(self, node): + inputs = self.graph.get_node(node.layer.input[0]) + kernel = self.graph.get_node(node.layer.input[1]) + + dataformat = node.dataformat + padding_mode = node.get_attr('padding') + strides = node.get_attr('strides')[1:3] + k_shape = self.tensor_shape_to_list( + kernel.get_attr('_output_shapes'))[0] + + kernel_num, channel, height, width = k_shape + if dataformat == "NHWC": + height, width, channle, kernel_num = k_shape + + padding = [0, 0] + if padding_mode == 'SAME': + padding = map(int, [(height - 1) / 2, (width - 1) / 2]) + + code = '{} = fluid.layers.conv2d({}, {}, {}, padding={}, stride={}, param_attr=\'{}\', bias_attr=False)'.format( + node.ref_name, inputs.ref_name, kernel_num, [height, width], + padding, strides, kernel.ref_name) + + return code + + def emit_biasadd(self, node): + inputs = self.graph.get_node(node.layer.input[0]) + bias = self.graph.get_node(node.layer.input[1]) + # TODO more validations + axis = 1 + code = '{} = fluid.layers.elementwise_add({}, {}, axis={})'.format( + node.ref_name, inputs.ref_name, bias.ref_name, axis) + return code + + def emit_relu(self, node): + inputs = self.graph.get_node(node.layer.input[0]) + code = '{} = fluid.layers.relu({})'.format(node.ref_name, + inputs.ref_name) + return code + + def emit_maxpool(self, node): + inputs = self.graph.get_node(node.layer.input[0]) + padding_mode = node.get_attr('padding') + strides = node.get_attr('strides')[1:3] + pool_size = node.get_attr('ksize')[1:3] + padding = [0, 0] + if padding_mode == 'SAME': + pad_0 = (pool_size[0] - 1) / 2 + pad_1 = (pool_size[1] - 1) / 2 + padding = [0, pad_0 * 2, 0, pad_1 * 2] + code = [ + 'pad_net = fluid.layers.pad2d({}, paddings={})'.format( + inputs.ref_name, padding) + ] + code.append( + '{} = fluid.layers.pool2d(pad_net, {}, \'max\', {})'.format( + node.ref_name, pool_size, strides)) + else: + code = '{} = fluid.layers.pool2d({}, {}, \'max\', {})'.format( + node.ref_name, inputs.ref_name, pool_size, strides) + return code + + def emit_pad(self, node): + inputs = self.graph.get_node(node.layer.input[0]) + padding = self.graph.get_node(node.layer.input[1]) + assert padding.type == 'const' + padding = padding.layer.attr['value'].tensor + padding = tensor_util.MakeNdarray(padding) + if node.dataformat == "NHWC" and padding.shape[0] == 4: + padding = padding[[0, 3, 1, 2]] + code = '{} = fluid.layers.pad({}, {})'.format(node.ref_name, + inputs.ref_name, + list(padding.flatten())) + return code + + def emit_fusedbatchnorm(self, node): + inputs = self.graph.get_node(node.layer.input[0]) + gamma = self.graph.get_node(node.layer.input[1]) + beta = self.graph.get_node(node.layer.input[2]) + mv_mean = self.graph.get_node(node.layer.input[3]) + mv_variance = self.graph.get_node(node.layer.input[4]) + + is_training = node.get_attr("is_training") + if is_training: + raise Exception( + "FusedBatchNorm: is_training=True, not support yet, please set is_training=False in your tensorflow code, then dump model again" + ) + epsilon = round(node.get_attr('epsilon'), 6) + + if gamma.type == 'const': + value = gamma.get_attr('value') + shape = value.tensor_shape + assert len(shape.dim) == 1 + shape = shape.dim[0].size + + assert len(value.float_val) == 1 + value = value.float_val[0] + code = "{} = fluid.layers.batch_norm({}, epsilon={}, param_attr=fluid.ParamAttr(\'{}\', fluid.initializer.Constant({})), bias_attr=\'{}\', moving_mean_name=\'{}\', moving_variance_name=\'{}\', is_test=True)".format( + node.ref_name, inputs.ref_name, epsilon, gamma.ref_name, value, + beta.ref_name, mv_mean.ref_name, mv_variance.ref_name) + else: + code = '{} = fluid.layers.batch_norm({}, epsilon={}, param_attr=\'{}\', bias_attr=\'{}\', moving_mean_name=\'{}\', moving_variance_name=\'{}\', is_test=True)'.format( + node.ref_name, inputs.ref_name, epsilon, gamma.ref_name, + beta.ref_name, mv_mean.ref_name, mv_variance.ref_name) + return code + + def emit_assign(self, node): + ref = self.graph.get_node(node.layer.input[0]) + value = self.graph.get_node(node.layer.input[1]) + code = 'fluid.layers.assign(input={}, output={})'.format( + value.ref_name, ref.ref_name) + return code + + def emit_add(self, node): + input1 = self.graph.get_node(node.layer.input[0]) + input2 = self.graph.get_node(node.layer.input[1]) + code = '{} = fluid.layers.elementwise_add({}, {})'.format( + node.ref_name, input1.ref_name, input2.ref_name) + return code + + def emit_mean(self, node): + inputs = self.graph.get_node(node.layer.input[0]) + reduce_idx = self.graph.get_node(node.layer.input[1]) + idxs = reduce_idx.layer.attr['value'].tensor + idxs = tensor_util.MakeNdarray(idxs) + shape = idxs.shape + if len(shape) != 1: + raise Exception('Unexpected situation[mean_op]') + + input_shape = self.tensor_shape_to_list( + inputs.get_attr('_output_shapes'))[0] + if node.dataformat == "NHWC" and len(input_shape) == 4: + for i in range(0, shape[0]): + if idxs[i] == 1: + idxs[i] = 2 + elif idxs[i] == 2: + idxs[i] = 3 + elif idxs[i] == 3: + idxs[i] = 1 + + code = '{} = fluid.layers.reduce_mean({}, {}, keep_dim=True)'.format( + node.ref_name, inputs.ref_name, list(idxs)) + return code + + def emit_squeeze(self, node): + inputs = self.graph.get_node(node.layer.input[0]) + axis = node.get_attr('squeeze_dims') + input_shape = self.tensor_shape_to_list( + inputs.get_attr('_output_shapes'))[0] + + if node.dataformat == "NHWC" and len(input_shape) == 4: + for i in range(0, len(axis)): + if axis[i] == 1: + axis[i] = 2 + elif axis[i] == 2: + axis[i] = 3 + elif axis[i] == 3: + axis[i] = 1 + code = '{} = fluid.layers.squeeze({}, {})'.format( + node.ref_name, inputs.ref_name, axis) + return code + + def emit_const(self, node): + shape = self.tensor_shape_to_list(node.get_attr('_output_shapes'))[0] + dtype = node.get_attr('dtype') + # TODO dtype need more validation + value = node.layer.attr['value'].tensor.int_val[0] + if dtype in self.dtype_map: + dtype = self.dtype_map[dtype] + else: + raise Exception('Unknow dtype : {}'.format(dtype)) + + if node.dataformat == 'NHWC': + raise Exception("Const: NHWC format not support yet") + + code = "{} = fluid.layers.fill_constant({}, \'{}\', {})".format( + node.ref_name, shape, dtype, value) + return code + + def emit_concatv2(self, node): + inputs = node.layer.input + inputs_vars = [] + code = [] + for i in range(0, len(inputs) - 1): + tmp = self.graph.get_node(inputs[i]) + if tmp.type == 'const': + code.append(self.emit_const(tmp)) + inputs_vars.append(tmp.ref_name) + axis = self.graph.get_node( + inputs[-1]).layer.attr['value'].tensor.int_val[0] + + output_shape = self.tensor_shape_to_list( + node.get_attr('_output_shapes'))[0] + if node.dataformat == "NHWC" and len(output_shape) == 4: + if axis == 1: + axis = 2 + elif axis == 2: + axis = 3 + elif axis == 3: + axis = 1 + code.append('{} = fluid.layers.concat([{}], {})'.format( + node.ref_name, ', '.join(inputs_vars), axis)) + return code + + def emit_avgpool(self, node): + inputs = self.graph.get_node(node.layer.input[0]) + padding_mode = node.get_attr('padding') + # TODO need more validation in nlp + strides = node.get_attr('strides')[1:3] + pool_size = node.get_attr('ksize')[1:3] + padding = [0, 0] + if padding_mode == 'SAME': + pad_0 = (pool_size[0] - 1) / 2 + pad_1 = (pool_size[1] - 1) / 2 + padding = [pad_0, pad_1] + code = '{} = fluid.layers.pool2d({}, {}, \'avg\', {}, {})'.format( + node.ref_name, inputs.ref_name, pool_size, strides, padding) + return code + + def emit_sub(self, node): + input1 = self.graph.get_node(node.layer.input[0]) + input2 = self.graph.get_node(node.layer.input[1]) + code = '{} = fluid.layers.elementwise_sub({}, {})'.format( + node.ref_name, input1.ref_name, input2.ref_name) + return code + + def emit_mul(self, node): + input1 = self.graph.get_node(node.layer.input[0]) + input2 = self.graph.get_node(node.layer.input[1]) + code = '{} = fluid.layers.elementwise_mul({}, {})'.format( + node.ref_name, input1.ref_name, input2.ref_name) + return code + + def emit_floor(self, node): + inputs = self.graph.get_node(node.layer.input[0]) + code = '{} = fluid.layers.floor({})'.format(node.ref_name, + inputs.ref_name) + return code + + def emit_realdiv(self, node): + input1 = self.graph.get_node(node.layer.input[0]) + input2 = self.graph.get_node(node.layer.input[1]) + code = '{} = fluid.layers.elementwise_div({})'.format( + node.ref_name, input1.ref_name, input2.ref_name) + return code + + def emit_shape(self, node): + inputs = self.graph.get_node(node.layer.input[0]) + if "num_split" in inputs.layer.attr: + code = '{} = fluid.layers.shape({}[0])'.format( + node.ref_name, inputs.ref_name) + else: + code = '{} = fluid.layers.shape({})'.format( + node.ref_name, inputs.ref_name) + # TODO there's dtype problem of PaddlePaddle's OP[fluid.layers.shape] + # https://github.com/PaddlePaddle/Paddle/issues/15267 + # tensorflow2paddle fix problem temporary + code = [code] + code.append("{} = fluid.layers.cast({}, dtype='int32')".format( + node.ref_name, node.ref_name)) + return code + + def emit_stridedslice(self, node): + inputs = self.graph.get_node(node.layer.input[0]) + begin = self.graph.get_node(node.layer.input[1]) + end = self.graph.get_node(node.layer.input[2]) + strides = self.graph.get_node(node.layer.input[3]) + + begin = list(tensor_util.MakeNdarray(begin.layer.attr['value'].tensor)) + end = list(tensor_util.MakeNdarray(end.layer.attr['value'].tensor)) + strides = list( + tensor_util.MakeNdarray(strides.layer.attr['value'].tensor)) + + if len(begin) != len(strides) or len(end) != len(strides): + raise Exception("length of begin/end/strides must be equl") + + for i in range(0, len(strides)): + if strides[i] != 1: + raise Exception( + "strides must be 1 for all axis, other situation not supported yet" + ) + + code = "{} = fluid.layers.slice({}, axes={}, starts={}, ends={})".format( + node.ref_name, inputs.ref_name, [i for i in range(0, len(begin))], + begin, end) + return code + + def emit_gather(self, node): + embedding = self.graph.get_node(node.layer.input[0]) + idxs = self.graph.get_node(node.layer.input[1]) + + idxs_shape = self.tensor_shape_to_list( + idxs.get_attr('_output_shapes'))[0] + embedding_shape = self.tensor_shape_to_list( + embedding.get_attr('_output_shapes'))[0] + + if len(embedding_shape) != 2: + raise Exception("rank of input[0] must be equal to 2 in Gather OP") + + code = [] + if idxs_shape[-1] != 1: + code.append( + "reshape = fluid.layers.reshape({}, shape=[-1, 1])".format( + idxs.ref_name)) + code.append("gather = fluid.layers.gather({}, reshape)".format( + embedding.ref_name)) + code.append("{} = fluid.layers.reshape(gather,{})".format( + node.ref_name, idxs_shape + [embedding_shape[-1]])) + else: + code = "{} = fluid.layers.gather({}, {})".format( + node.ref_name, embedding.ref_name, idxs.ref_name) + return code + + def emit_transpose(self, node): + inputs = self.graph.get_node(node.layer.input[0]) + perm = self.graph.get_node(node.layer.input[1]) + assert perm.type == "const" + perm = perm.layer.attr['value'].tensor + perm = tensor_util.MakeNdarray(perm) + + # TODO + if node.dataformat == "NHWC" and perm.shape[0] == 4: + raise Exception( + "Unsupported situation for op Transpose, NHWC not supported yet" + ) + + perm = list(perm) + code = "{} = fluid.layers.transpose({}, {})".format( + node.ref_name, inputs.ref_name, perm) + return code + + def emit_reshape(self, node): + inputs = self.graph.get_node(node.layer.input[0]) + shape = self.graph.get_node(node.layer.input[1]) + assert shape.type == "const" + + # TODO + if node.dataformat == "NHWC": + raise Exception( + "Unsupported situation for reshape, NHWC not supported yet") + + shape = shape.layer.attr['value'].tensor + shape = list(tensor_util.MakeNdarray(shape)) + code = "{} = fluid.layers.reshape({}, {})".format( + node.ref_name, inputs.ref_name, shape) + return code + + def emit_split(self, node): + inputs = self.graph.get_node(node.layer.input[1]) + split_dim = self.graph.get_node(node.layer.input[0]) + inputs_shape = self.tensor_shape_to_list( + inputs.get_attr('_output_shapes'))[0] + assert split_dim.type == 'const' and len(inputs_shape) > 1 + axis = split_dim.layer.attr['value'].tensor.int_val[0] + num_split = node.get_attr('num_split') + + code = list() + if inputs_shape[axis] < 0: + tmp_shape = [-1, num_split + ] + inputs_shape[:axis] + inputs_shape[axis + 1:] + code.append("reshape = fluid.layers.reshape({}, {})".format( + inputs.ref_name, tmp_shape)) + code.append( + "split = fluid.layers.split(reshape, {}, 1)".format(num_split)) + code.append( + "{} = [fluid.layers.squeeze(s, [1]) for s in split]".format( + node.ref_name)) + else: + code = "{} = fluid.layers.split({}, {}, {})".format( + inputs.ref_name, num_split, axis) + + return code + + def emit_expanddims(self, node): + inputs = self.graph.get_node(node.layer.input[0]) + dim = self.graph.get_node(node.layer.input[1]) + dim = tensor_util.MakeNdarray(dim.layer.attr['value'].tensor) + inputs_shape = self.tensor_shape_to_list( + inputs.get_attr('_output_shapes'))[0] + + inputs_shape.insert(dim, 1) + code = "{} = fluid.layers.reshape({}, {})".format( + node.ref_name, inputs.ref_name, inputs_shape) + return code + + def emit_fill(self, node): + value = self.graph.get_node(node.layer.input[1]) + value = value.layer.attr['value'].tensor.float_val[0] + dtype = node.layer.attr['T'].type + if dtype in self.dtype_map: + dtype = self.dtype_map[dtype] + else: + raise Exception('Unknow dtype : {}'.format(dtype)) + + output_shape = self.tensor_shape_to_list( + node.get_attr('_output_shapes'))[0] + code = "{} = fluid.layers.create_parameter({}, {}, default_initializer=fluid.initializer.Constant({}))".format( + node.ref_name, output_shape, dtype, value) + return code diff --git a/TensorFlow2Paddle/tensorflow_graph.py b/TensorFlow2Paddle/tensorflow_graph.py new file mode 100644 index 0000000000000000000000000000000000000000..236267c27d91aa78fff7dd78f1338956ba71a26d --- /dev/null +++ b/TensorFlow2Paddle/tensorflow_graph.py @@ -0,0 +1,83 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from graph import GraphNode, Graph +from tensorflow.core.framework import attr_value_pb2 + + +class TensorflowGraphNode(GraphNode): + def __init__(self, layer): + super(TensorflowGraphNode, self).__init__(layer) + self.codes = list() + self.dataformat = 'NCHW' + + @property + def type(self): + return self.layer.op.lower() + + @property + def name(self): + return self.layer.name + + # TODO + def get_attr(self, name, default_value=None): + if name in self.layer.attr: + attr = self.layer.attr[name] + field = attr.WhichOneof('value') + val = getattr(attr, field) if field else default_value + if isinstance(val, attr_value_pb2.AttrValue.ListValue): + return list(val.ListFields()[0][1]) + else: + return val.decode('utf-8') if isinstance(val, bytes) else val + else: + return default_value + + +class TensorflowGraph(Graph): + def __init__(self, model): + super(TensorflowGraph, self).__init__(model) + self.model = model + + def build(self): + for i, layer in enumerate(self.model.node): + self.node_map[layer.name] = TensorflowGraphNode(layer) + for pred in layer.input: + if pred not in self.node_map: + raise Exception('input: {} not in node_map'.format(pred)) + + self._make_connection(pred, layer.name) + + super(TensorflowGraph, self).build() + self._check_dataformat() + + # check the dataformat of network + def _check_dataformat(self): + ss = list() + for i in range(0, len(self.topological_sort)): + current_node = self.node_map[self.topological_sort[i]] + if current_node.type == 'conv2d': + s = current_node.layer.attr['data_format'].s + if s != 'NHWC' and s != 'NCHW': + raise Exception('Unkown dataformat {}'.format(s)) + ss.append(s) + + if len(set(ss)) > 1: + raise Exception("Two type of dataformat exist in this model") + + if len(set(ss)) == 0: + return + + for i in range(0, len(self.topological_sort)): + current_node = self.node_map[self.topological_sort[i]] + current_node.dataformat = ss[0] diff --git a/TensorFlow2Paddle/tensorflow_parser.py b/TensorFlow2Paddle/tensorflow_parser.py new file mode 100644 index 0000000000000000000000000000000000000000..718f30689fd6a49840f1d183d937a503301ae49b --- /dev/null +++ b/TensorFlow2Paddle/tensorflow_parser.py @@ -0,0 +1,66 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import tensorflow +from tensorflow_graph import TensorflowGraph + + +class TensorflowParser(object): + def __init__(self, + meta_file, + checkpoint_file, + dest_nodes, + input_shape=None, + in_nodes=None): + graph_def = None + self.weights = dict() + with tensorflow.Session() as sess: + if meta_file is None: + raise Exception("meta_file must be provided") + new_saver = tensorflow.train.import_meta_graph(meta_file) + if checkpoint_file is not None: + new_saver.restore( + sess, tensorflow.train.latest_checkpoint(checkpoint_file)) + for var in tensorflow.global_variables(): + value = var.eval(sess) + self.weights[var.name.split(':')[0]] = value + + graph_def, ver = tensorflow.get_default_graph()._as_graph_def( + add_shapes=True) + + if in_nodes is not None and input_shape is not None: + from tensorflow.python.tools import strip_unused_lib + from tensorflow.python.framework import dtypes + graph_def = strip_unused_lib.strip_unused( + input_graph_def=graph_def, + input_node_names=in_nodes, + output_node_names=dest_nodes, + placeholder_type_enum=dtypes.float32.as_datatype_enum) + + input_list = [None] + for i in range(len(input_shape)): + input_list.append(tensorflow.Dimension(input_shape[i])) + tensor_input = tensorflow.TensorShape(input_list) + + self.tf_graph = TensorflowGraph(graph_def) + for node in self.tf_graph.model.node: + if node.name in in_nodes: + node.attr['shape'].list.shape.extend( + [tensor_input.as_proto()]) + node.attr['_output_shapes'].list.shape.pop() + node.attr['_output_shapes'].list.shape.extend( + [tensor_input.as_proto()]) + else: + raise Exception('in_nodes and output_nodes need be provided') + self.tf_graph.build() diff --git a/TensorFlow2Paddle/transformer.py b/TensorFlow2Paddle/transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..0db75de200668e6b9e6b23adc97c126d6f0f6254 --- /dev/null +++ b/TensorFlow2Paddle/transformer.py @@ -0,0 +1,31 @@ +from paddle_emitter import PaddleEmitter +from tensorflow_parser import TensorflowParser + + +class Transformer(object): + def __init__(self, meta_file, ckpt_file, out_nodes, in_shape, in_nodes): + self.parser = TensorflowParser(meta_file, ckpt_file, out_nodes, + in_shape, in_nodes) + self.emitter = PaddleEmitter(self.parser.tf_graph) + + def transform_code(self, out_py_file): + filew = open(out_py_file, 'w') + codes = self.emitter.gen_code() + filew.write(codes) + filew.close() + + def transform_weight(self, out_dir): + self.emitter.gen_weight(self.parser.weights, out_dir) + + def run(self, dst_dir): + import os + if os.path.isdir(dst_dir) or os.path.isfile(dst_dir): + print("{} already exists, set a new directory") + return + if not os.path.isdir(dst_dir): + os.mkdir(dst_dir) + self.transform_code(dst_dir + "/mymodel.py") + if (len(self.parser.weights) == 0): + print("There is no tensorflow model weight translate to paddle") + else: + self.transform_weight(dst_dir)