提交 3224f40b 编写于 作者: J jiangjiajun

remove TensorFlow2Paddle

上级 1bb181d3
### Warning
> **TensorFlow2Paddle is not stable and lots of tensorflow operations are not supported yet**
> **Only tested on vgg_16/resnet_v1_50/inception_v3 with is_training=False**
> **Tensorflow2Paddle目前还处于开发状态中,只支持有限的TensorFlow Op**
> **目前仅在vgg_16/resnet_v1_50/inception_v3上, is_training参数设为False的情况下通过无diff测试**
### Dependency
> 1. python = 2.7
> 2. PaddlePaddle >= 1.2.0
> 3. TensorFlow >= 1.12.0
**Notice:You can install PaddlePaddle and Tensorflow in different virtual environment since there's dependency conflict between PaddlePaddle and TensorFlow**
**提示:运行Tensorflow2Paddle,依赖Tensorflow环境,目前经测试,Tensorflow与PaddlePaddle的安装彼此存在依赖冲突,因此建议将PaddlePaddle和Tensorflow分别安装在不同的虚拟环境中**
### Usage
> **1. Model file: Tensorflow checkpoint directory**
> **2. input tensors' and output tensors' name**
### Demo: How to transform tensorflow resnet_v1_50 pretrained model to PaddlePaddle model for inference
#### 1. Get pretrained_model
```
git clone https://github.com/PaddlePaddle/X2Paddle.git
cd X2Paddle/TensorFlow2Paddle
wget http://download.tensorflow.org/models/resnet_v1_50_2016_08_28.tar.gz
tar xzvf resnet_v2_50_2017_04_14.tar.gz
```
#### 2. Change model to ckpt model with meta file
```
python demo/save_resnet_ckpt_model.py resnet_v1_50.ckpt ./new_ckpt_model
```
#### 3. Export PaddlePaddle model
```
python demo/export_resnet_to_paddle_model.py new_ckpt_model/resnet.meta new_ckpt_model fluid_model
```
#### 4. Test PaddlePaddle model
```python
from fluid_model.mymodel import KitModel
import paddle.fluid as fluid
import numpy
import os
result = KitModel()
exe = fluid.Executor(fluid.CPUPlace())
exe.run(fluid.default_startup_program())
var_list = list()
for f in os.listdir('./fluid_model'):
f = f.split('/fluid')[-1]
if f.startswith("param_"):
var_list.append(fluid.default_main_program().global_block().var(f))
fluid.io.load_vars(exe, './fluid_model', vars=var_list)
test_data = numpy.random.rand(1, 3, 224, 224)
test_data = numpy.array(test_data, dtype='float32')
result = exe.run(fluid.default_main_program(),
feed={'input_0':numpy.array(img_data, dtype='float32')},
fetch_list=[result])
print(result)
```
### Link
[MMdnn-Tensorflow](https://github.com/Microsoft/MMdnn/tree/master/mmdnn/conversion/tensorflow)
import sys
sys.path.append(".")
from transformer import Transformer
meta_file = sys.argv[1]
ckpt_dir = sys.argv[2]
export_dir = sys.argv[3]
transformer = Transformer(meta_file, ckpt_dir, ['resnet_v1_50/pool5'],
(224, 224, 3), ['inputs'])
transformer.run(export_dir)
open(export_dir + "/__init__.py", "w").close()
from tensorflow.contrib.slim.nets import resnet_v1 as resnet_v1
import tensorflow.contrib.slim as slim
import tensorflow as tf
import sys
def load_model(ckpt_file):
img_size = resnet_v1.resnet_v1.default_image_size
img = tf.placeholder(
tf.float32, shape=[None, img_size, img_size, 3], name='inputs')
with slim.arg_scope(resnet_v1.resnet_arg_scope()):
net, endpoint = resnet_v1.resnet_v1_50(
img, num_classes=None, is_training=False)
sess = tf.Session()
load_model = tf.contrib.slim.assign_from_checkpoint_fn(
ckpt_file, tf.contrib.slim.get_model_variables("resnet_v1_50"))
load_model(sess)
return sess
def save_checkpoint(sess, save_dir):
saver = tf.train.Saver()
saver.save(sess, save_dir + "/resnet")
if __name__ == "__main__":
ckpt_file = sys.argv[1]
save_dir = sys.argv[2]
sess = load_model(ckpt_file)
save_checkpoint(sess, save_dir)
# Generated by the protocol buffer compiler. DO NOT EDIT!
# source: framework.proto
import sys
_b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1'))
from google.protobuf.internal import enum_type_wrapper
from google.protobuf import descriptor as _descriptor
from google.protobuf import message as _message
from google.protobuf import reflection as _reflection
from google.protobuf import symbol_database as _symbol_database
from google.protobuf import descriptor_pb2
# @@protoc_insertion_point(imports)
_sym_db = _symbol_database.Default()
DESCRIPTOR = _descriptor.FileDescriptor(
name='framework.proto',
package='paddle.framework.proto',
syntax='proto2',
serialized_pb=_b('\n\x0f\x66ramework.proto\x12\x16paddle.framework.proto\"\x1d\n\x07Version\x12\x12\n\x07version\x18\x01 \x01(\x03:\x01\x30\"\xec\x03\n\x06OpDesc\x12\x0c\n\x04type\x18\x03 \x02(\t\x12\x32\n\x06inputs\x18\x01 \x03(\x0b\x32\".paddle.framework.proto.OpDesc.Var\x12\x33\n\x07outputs\x18\x02 \x03(\x0b\x32\".paddle.framework.proto.OpDesc.Var\x12\x32\n\x05\x61ttrs\x18\x04 \x03(\x0b\x32#.paddle.framework.proto.OpDesc.Attr\x12\x18\n\tis_target\x18\x05 \x01(\x08:\x05\x66\x61lse\x1a\xef\x01\n\x04\x41ttr\x12\x0c\n\x04name\x18\x01 \x02(\t\x12.\n\x04type\x18\x02 \x02(\x0e\x32 .paddle.framework.proto.AttrType\x12\t\n\x01i\x18\x03 \x01(\x05\x12\t\n\x01\x66\x18\x04 \x01(\x02\x12\t\n\x01s\x18\x05 \x01(\t\x12\x0c\n\x04ints\x18\x06 \x03(\x05\x12\x0e\n\x06\x66loats\x18\x07 \x03(\x02\x12\x0f\n\x07strings\x18\x08 \x03(\t\x12\t\n\x01\x62\x18\n \x01(\x08\x12\r\n\x05\x62ools\x18\x0b \x03(\x08\x12\x11\n\tblock_idx\x18\x0c \x01(\x05\x12\t\n\x01l\x18\r \x01(\x03\x12\x12\n\nblocks_idx\x18\x0e \x03(\x05\x12\r\n\x05longs\x18\x0f \x03(\x03\x1a+\n\x03Var\x12\x11\n\tparameter\x18\x01 \x02(\t\x12\x11\n\targuments\x18\x02 \x03(\t\"\xb3\x03\n\x07OpProto\x12\x0c\n\x04type\x18\x01 \x02(\t\x12\x33\n\x06inputs\x18\x02 \x03(\x0b\x32#.paddle.framework.proto.OpProto.Var\x12\x34\n\x07outputs\x18\x03 \x03(\x0b\x32#.paddle.framework.proto.OpProto.Var\x12\x33\n\x05\x61ttrs\x18\x04 \x03(\x0b\x32$.paddle.framework.proto.OpProto.Attr\x12\x0f\n\x07\x63omment\x18\x05 \x02(\t\x1ax\n\x03Var\x12\x0c\n\x04name\x18\x01 \x02(\t\x12\x0f\n\x07\x63omment\x18\x02 \x02(\t\x12\x19\n\nduplicable\x18\x03 \x01(\x08:\x05\x66\x61lse\x12\x1b\n\x0cintermediate\x18\x04 \x01(\x08:\x05\x66\x61lse\x12\x1a\n\x0b\x64ispensable\x18\x05 \x01(\x08:\x05\x66\x61lse\x1ao\n\x04\x41ttr\x12\x0c\n\x04name\x18\x01 \x02(\t\x12.\n\x04type\x18\x02 \x02(\x0e\x32 .paddle.framework.proto.AttrType\x12\x0f\n\x07\x63omment\x18\x03 \x02(\t\x12\x18\n\tgenerated\x18\x04 \x01(\x08:\x05\x66\x61lse\"\xda\x08\n\x07VarType\x12\x32\n\x04type\x18\x01 \x02(\x0e\x32$.paddle.framework.proto.VarType.Type\x12\x41\n\rselected_rows\x18\x02 \x01(\x0b\x32*.paddle.framework.proto.VarType.TensorDesc\x12\x41\n\nlod_tensor\x18\x03 \x01(\x0b\x32-.paddle.framework.proto.VarType.LoDTensorDesc\x12H\n\x0ctensor_array\x18\x04 \x01(\x0b\x32\x32.paddle.framework.proto.VarType.LoDTensorArrayDesc\x12:\n\x06reader\x18\x05 \x01(\x0b\x32*.paddle.framework.proto.VarType.ReaderDesc\x12\x34\n\x05tuple\x18\x07 \x01(\x0b\x32%.paddle.framework.proto.VarType.Tuple\x1aS\n\nTensorDesc\x12\x37\n\tdata_type\x18\x01 \x02(\x0e\x32$.paddle.framework.proto.VarType.Type\x12\x0c\n\x04\x64ims\x18\x02 \x03(\x03\x1a\x61\n\rLoDTensorDesc\x12:\n\x06tensor\x18\x01 \x02(\x0b\x32*.paddle.framework.proto.VarType.TensorDesc\x12\x14\n\tlod_level\x18\x02 \x01(\x05:\x01\x30\x1a\x66\n\x12LoDTensorArrayDesc\x12:\n\x06tensor\x18\x01 \x02(\x0b\x32*.paddle.framework.proto.VarType.TensorDesc\x12\x14\n\tlod_level\x18\x02 \x01(\x05:\x01\x30\x1aO\n\nReaderDesc\x12\x41\n\nlod_tensor\x18\x01 \x03(\x0b\x32-.paddle.framework.proto.VarType.LoDTensorDesc\x1a\x43\n\x05Tuple\x12:\n\x0c\x65lement_type\x18\x01 \x03(\x0e\x32$.paddle.framework.proto.VarType.Type\"\xa2\x02\n\x04Type\x12\x08\n\x04\x42OOL\x10\x00\x12\t\n\x05INT16\x10\x01\x12\t\n\x05INT32\x10\x02\x12\t\n\x05INT64\x10\x03\x12\x08\n\x04\x46P16\x10\x04\x12\x08\n\x04\x46P32\x10\x05\x12\x08\n\x04\x46P64\x10\x06\x12\n\n\x06SIZE_T\x10\x13\x12\t\n\x05UINT8\x10\x14\x12\x08\n\x04INT8\x10\x15\x12\x0e\n\nLOD_TENSOR\x10\x07\x12\x11\n\rSELECTED_ROWS\x10\x08\x12\x12\n\x0e\x46\x45\x45\x44_MINIBATCH\x10\t\x12\x0e\n\nFETCH_LIST\x10\n\x12\x0f\n\x0bSTEP_SCOPES\x10\x0b\x12\x12\n\x0eLOD_RANK_TABLE\x10\x0c\x12\x14\n\x10LOD_TENSOR_ARRAY\x10\r\x12\x0e\n\nPLACE_LIST\x10\x0e\x12\n\n\x06READER\x10\x0f\x12\x07\n\x03RAW\x10\x11\x12\t\n\x05TUPLE\x10\x12\"b\n\x07VarDesc\x12\x0c\n\x04name\x18\x01 \x02(\t\x12-\n\x04type\x18\x02 \x02(\x0b\x32\x1f.paddle.framework.proto.VarType\x12\x1a\n\x0bpersistable\x18\x03 \x01(\x08:\x05\x66\x61lse\"\xa7\x01\n\tBlockDesc\x12\x0b\n\x03idx\x18\x01 \x02(\x05\x12\x12\n\nparent_idx\x18\x02 \x02(\x05\x12-\n\x04vars\x18\x03 \x03(\x0b\x32\x1f.paddle.framework.proto.VarDesc\x12+\n\x03ops\x18\x04 \x03(\x0b\x32\x1e.paddle.framework.proto.OpDesc\x12\x1d\n\x11\x66orward_block_idx\x18\x05 \x01(\x05:\x02-1\"r\n\x0bProgramDesc\x12\x31\n\x06\x62locks\x18\x01 \x03(\x0b\x32!.paddle.framework.proto.BlockDesc\x12\x30\n\x07version\x18\x02 \x01(\x0b\x32\x1f.paddle.framework.proto.Version*\x94\x01\n\x08\x41ttrType\x12\x07\n\x03INT\x10\x00\x12\t\n\x05\x46LOAT\x10\x01\x12\n\n\x06STRING\x10\x02\x12\x08\n\x04INTS\x10\x03\x12\n\n\x06\x46LOATS\x10\x04\x12\x0b\n\x07STRINGS\x10\x05\x12\x0b\n\x07\x42OOLEAN\x10\x06\x12\x0c\n\x08\x42OOLEANS\x10\x07\x12\t\n\x05\x42LOCK\x10\x08\x12\x08\n\x04LONG\x10\t\x12\n\n\x06\x42LOCKS\x10\n\x12\t\n\x05LONGS\x10\x0b\x42\x02H\x03')
)
_sym_db.RegisterFileDescriptor(DESCRIPTOR)
_ATTRTYPE = _descriptor.EnumDescriptor(
name='AttrType',
full_name='paddle.framework.proto.AttrType',
filename=None,
file=DESCRIPTOR,
values=[
_descriptor.EnumValueDescriptor(
name='INT', index=0, number=0,
options=None,
type=None),
_descriptor.EnumValueDescriptor(
name='FLOAT', index=1, number=1,
options=None,
type=None),
_descriptor.EnumValueDescriptor(
name='STRING', index=2, number=2,
options=None,
type=None),
_descriptor.EnumValueDescriptor(
name='INTS', index=3, number=3,
options=None,
type=None),
_descriptor.EnumValueDescriptor(
name='FLOATS', index=4, number=4,
options=None,
type=None),
_descriptor.EnumValueDescriptor(
name='STRINGS', index=5, number=5,
options=None,
type=None),
_descriptor.EnumValueDescriptor(
name='BOOLEAN', index=6, number=6,
options=None,
type=None),
_descriptor.EnumValueDescriptor(
name='BOOLEANS', index=7, number=7,
options=None,
type=None),
_descriptor.EnumValueDescriptor(
name='BLOCK', index=8, number=8,
options=None,
type=None),
_descriptor.EnumValueDescriptor(
name='LONG', index=9, number=9,
options=None,
type=None),
_descriptor.EnumValueDescriptor(
name='BLOCKS', index=10, number=10,
options=None,
type=None),
_descriptor.EnumValueDescriptor(
name='LONGS', index=11, number=11,
options=None,
type=None),
],
containing_type=None,
options=None,
serialized_start=2511,
serialized_end=2659,
)
_sym_db.RegisterEnumDescriptor(_ATTRTYPE)
AttrType = enum_type_wrapper.EnumTypeWrapper(_ATTRTYPE)
INT = 0
FLOAT = 1
STRING = 2
INTS = 3
FLOATS = 4
STRINGS = 5
BOOLEAN = 6
BOOLEANS = 7
BLOCK = 8
LONG = 9
BLOCKS = 10
LONGS = 11
_VARTYPE_TYPE = _descriptor.EnumDescriptor(
name='Type',
full_name='paddle.framework.proto.VarType.Type',
filename=None,
file=DESCRIPTOR,
values=[
_descriptor.EnumValueDescriptor(
name='BOOL', index=0, number=0,
options=None,
type=None),
_descriptor.EnumValueDescriptor(
name='INT16', index=1, number=1,
options=None,
type=None),
_descriptor.EnumValueDescriptor(
name='INT32', index=2, number=2,
options=None,
type=None),
_descriptor.EnumValueDescriptor(
name='INT64', index=3, number=3,
options=None,
type=None),
_descriptor.EnumValueDescriptor(
name='FP16', index=4, number=4,
options=None,
type=None),
_descriptor.EnumValueDescriptor(
name='FP32', index=5, number=5,
options=None,
type=None),
_descriptor.EnumValueDescriptor(
name='FP64', index=6, number=6,
options=None,
type=None),
_descriptor.EnumValueDescriptor(
name='SIZE_T', index=7, number=19,
options=None,
type=None),
_descriptor.EnumValueDescriptor(
name='UINT8', index=8, number=20,
options=None,
type=None),
_descriptor.EnumValueDescriptor(
name='INT8', index=9, number=21,
options=None,
type=None),
_descriptor.EnumValueDescriptor(
name='LOD_TENSOR', index=10, number=7,
options=None,
type=None),
_descriptor.EnumValueDescriptor(
name='SELECTED_ROWS', index=11, number=8,
options=None,
type=None),
_descriptor.EnumValueDescriptor(
name='FEED_MINIBATCH', index=12, number=9,
options=None,
type=None),
_descriptor.EnumValueDescriptor(
name='FETCH_LIST', index=13, number=10,
options=None,
type=None),
_descriptor.EnumValueDescriptor(
name='STEP_SCOPES', index=14, number=11,
options=None,
type=None),
_descriptor.EnumValueDescriptor(
name='LOD_RANK_TABLE', index=15, number=12,
options=None,
type=None),
_descriptor.EnumValueDescriptor(
name='LOD_TENSOR_ARRAY', index=16, number=13,
options=None,
type=None),
_descriptor.EnumValueDescriptor(
name='PLACE_LIST', index=17, number=14,
options=None,
type=None),
_descriptor.EnumValueDescriptor(
name='READER', index=18, number=15,
options=None,
type=None),
_descriptor.EnumValueDescriptor(
name='RAW', index=19, number=17,
options=None,
type=None),
_descriptor.EnumValueDescriptor(
name='TUPLE', index=20, number=18,
options=None,
type=None),
],
containing_type=None,
options=None,
serialized_start=1832,
serialized_end=2122,
)
_sym_db.RegisterEnumDescriptor(_VARTYPE_TYPE)
_VERSION = _descriptor.Descriptor(
name='Version',
full_name='paddle.framework.proto.Version',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='version', full_name='paddle.framework.proto.Version.version', index=0,
number=1, type=3, cpp_type=2, label=1,
has_default_value=True, default_value=0,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
],
extensions=[
],
nested_types=[],
enum_types=[
],
options=None,
is_extendable=False,
syntax='proto2',
extension_ranges=[],
oneofs=[
],
serialized_start=43,
serialized_end=72,
)
_OPDESC_ATTR = _descriptor.Descriptor(
name='Attr',
full_name='paddle.framework.proto.OpDesc.Attr',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='name', full_name='paddle.framework.proto.OpDesc.Attr.name', index=0,
number=1, type=9, cpp_type=9, label=2,
has_default_value=False, default_value=_b("").decode('utf-8'),
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='type', full_name='paddle.framework.proto.OpDesc.Attr.type', index=1,
number=2, type=14, cpp_type=8, label=2,
has_default_value=False, default_value=0,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='i', full_name='paddle.framework.proto.OpDesc.Attr.i', index=2,
number=3, type=5, cpp_type=1, label=1,
has_default_value=False, default_value=0,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='f', full_name='paddle.framework.proto.OpDesc.Attr.f', index=3,
number=4, type=2, cpp_type=6, label=1,
has_default_value=False, default_value=float(0),
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='s', full_name='paddle.framework.proto.OpDesc.Attr.s', index=4,
number=5, type=9, cpp_type=9, label=1,
has_default_value=False, default_value=_b("").decode('utf-8'),
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='ints', full_name='paddle.framework.proto.OpDesc.Attr.ints', index=5,
number=6, type=5, cpp_type=1, label=3,
has_default_value=False, default_value=[],
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='floats', full_name='paddle.framework.proto.OpDesc.Attr.floats', index=6,
number=7, type=2, cpp_type=6, label=3,
has_default_value=False, default_value=[],
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='strings', full_name='paddle.framework.proto.OpDesc.Attr.strings', index=7,
number=8, type=9, cpp_type=9, label=3,
has_default_value=False, default_value=[],
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='b', full_name='paddle.framework.proto.OpDesc.Attr.b', index=8,
number=10, type=8, cpp_type=7, label=1,
has_default_value=False, default_value=False,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='bools', full_name='paddle.framework.proto.OpDesc.Attr.bools', index=9,
number=11, type=8, cpp_type=7, label=3,
has_default_value=False, default_value=[],
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='block_idx', full_name='paddle.framework.proto.OpDesc.Attr.block_idx', index=10,
number=12, type=5, cpp_type=1, label=1,
has_default_value=False, default_value=0,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='l', full_name='paddle.framework.proto.OpDesc.Attr.l', index=11,
number=13, type=3, cpp_type=2, label=1,
has_default_value=False, default_value=0,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='blocks_idx', full_name='paddle.framework.proto.OpDesc.Attr.blocks_idx', index=12,
number=14, type=5, cpp_type=1, label=3,
has_default_value=False, default_value=[],
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='longs', full_name='paddle.framework.proto.OpDesc.Attr.longs', index=13,
number=15, type=3, cpp_type=2, label=3,
has_default_value=False, default_value=[],
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
],
extensions=[
],
nested_types=[],
enum_types=[
],
options=None,
is_extendable=False,
syntax='proto2',
extension_ranges=[],
oneofs=[
],
serialized_start=283,
serialized_end=522,
)
_OPDESC_VAR = _descriptor.Descriptor(
name='Var',
full_name='paddle.framework.proto.OpDesc.Var',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='parameter', full_name='paddle.framework.proto.OpDesc.Var.parameter', index=0,
number=1, type=9, cpp_type=9, label=2,
has_default_value=False, default_value=_b("").decode('utf-8'),
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='arguments', full_name='paddle.framework.proto.OpDesc.Var.arguments', index=1,
number=2, type=9, cpp_type=9, label=3,
has_default_value=False, default_value=[],
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
],
extensions=[
],
nested_types=[],
enum_types=[
],
options=None,
is_extendable=False,
syntax='proto2',
extension_ranges=[],
oneofs=[
],
serialized_start=524,
serialized_end=567,
)
_OPDESC = _descriptor.Descriptor(
name='OpDesc',
full_name='paddle.framework.proto.OpDesc',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='type', full_name='paddle.framework.proto.OpDesc.type', index=0,
number=3, type=9, cpp_type=9, label=2,
has_default_value=False, default_value=_b("").decode('utf-8'),
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='inputs', full_name='paddle.framework.proto.OpDesc.inputs', index=1,
number=1, type=11, cpp_type=10, label=3,
has_default_value=False, default_value=[],
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='outputs', full_name='paddle.framework.proto.OpDesc.outputs', index=2,
number=2, type=11, cpp_type=10, label=3,
has_default_value=False, default_value=[],
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='attrs', full_name='paddle.framework.proto.OpDesc.attrs', index=3,
number=4, type=11, cpp_type=10, label=3,
has_default_value=False, default_value=[],
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='is_target', full_name='paddle.framework.proto.OpDesc.is_target', index=4,
number=5, type=8, cpp_type=7, label=1,
has_default_value=True, default_value=False,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
],
extensions=[
],
nested_types=[_OPDESC_ATTR, _OPDESC_VAR, ],
enum_types=[
],
options=None,
is_extendable=False,
syntax='proto2',
extension_ranges=[],
oneofs=[
],
serialized_start=75,
serialized_end=567,
)
_OPPROTO_VAR = _descriptor.Descriptor(
name='Var',
full_name='paddle.framework.proto.OpProto.Var',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='name', full_name='paddle.framework.proto.OpProto.Var.name', index=0,
number=1, type=9, cpp_type=9, label=2,
has_default_value=False, default_value=_b("").decode('utf-8'),
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='comment', full_name='paddle.framework.proto.OpProto.Var.comment', index=1,
number=2, type=9, cpp_type=9, label=2,
has_default_value=False, default_value=_b("").decode('utf-8'),
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='duplicable', full_name='paddle.framework.proto.OpProto.Var.duplicable', index=2,
number=3, type=8, cpp_type=7, label=1,
has_default_value=True, default_value=False,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='intermediate', full_name='paddle.framework.proto.OpProto.Var.intermediate', index=3,
number=4, type=8, cpp_type=7, label=1,
has_default_value=True, default_value=False,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='dispensable', full_name='paddle.framework.proto.OpProto.Var.dispensable', index=4,
number=5, type=8, cpp_type=7, label=1,
has_default_value=True, default_value=False,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
],
extensions=[
],
nested_types=[],
enum_types=[
],
options=None,
is_extendable=False,
syntax='proto2',
extension_ranges=[],
oneofs=[
],
serialized_start=772,
serialized_end=892,
)
_OPPROTO_ATTR = _descriptor.Descriptor(
name='Attr',
full_name='paddle.framework.proto.OpProto.Attr',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='name', full_name='paddle.framework.proto.OpProto.Attr.name', index=0,
number=1, type=9, cpp_type=9, label=2,
has_default_value=False, default_value=_b("").decode('utf-8'),
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='type', full_name='paddle.framework.proto.OpProto.Attr.type', index=1,
number=2, type=14, cpp_type=8, label=2,
has_default_value=False, default_value=0,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='comment', full_name='paddle.framework.proto.OpProto.Attr.comment', index=2,
number=3, type=9, cpp_type=9, label=2,
has_default_value=False, default_value=_b("").decode('utf-8'),
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='generated', full_name='paddle.framework.proto.OpProto.Attr.generated', index=3,
number=4, type=8, cpp_type=7, label=1,
has_default_value=True, default_value=False,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
],
extensions=[
],
nested_types=[],
enum_types=[
],
options=None,
is_extendable=False,
syntax='proto2',
extension_ranges=[],
oneofs=[
],
serialized_start=894,
serialized_end=1005,
)
_OPPROTO = _descriptor.Descriptor(
name='OpProto',
full_name='paddle.framework.proto.OpProto',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='type', full_name='paddle.framework.proto.OpProto.type', index=0,
number=1, type=9, cpp_type=9, label=2,
has_default_value=False, default_value=_b("").decode('utf-8'),
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='inputs', full_name='paddle.framework.proto.OpProto.inputs', index=1,
number=2, type=11, cpp_type=10, label=3,
has_default_value=False, default_value=[],
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='outputs', full_name='paddle.framework.proto.OpProto.outputs', index=2,
number=3, type=11, cpp_type=10, label=3,
has_default_value=False, default_value=[],
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='attrs', full_name='paddle.framework.proto.OpProto.attrs', index=3,
number=4, type=11, cpp_type=10, label=3,
has_default_value=False, default_value=[],
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='comment', full_name='paddle.framework.proto.OpProto.comment', index=4,
number=5, type=9, cpp_type=9, label=2,
has_default_value=False, default_value=_b("").decode('utf-8'),
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
],
extensions=[
],
nested_types=[_OPPROTO_VAR, _OPPROTO_ATTR, ],
enum_types=[
],
options=None,
is_extendable=False,
syntax='proto2',
extension_ranges=[],
oneofs=[
],
serialized_start=570,
serialized_end=1005,
)
_VARTYPE_TENSORDESC = _descriptor.Descriptor(
name='TensorDesc',
full_name='paddle.framework.proto.VarType.TensorDesc',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='data_type', full_name='paddle.framework.proto.VarType.TensorDesc.data_type', index=0,
number=1, type=14, cpp_type=8, label=2,
has_default_value=False, default_value=0,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='dims', full_name='paddle.framework.proto.VarType.TensorDesc.dims', index=1,
number=2, type=3, cpp_type=2, label=3,
has_default_value=False, default_value=[],
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
],
extensions=[
],
nested_types=[],
enum_types=[
],
options=None,
is_extendable=False,
syntax='proto2',
extension_ranges=[],
oneofs=[
],
serialized_start=1393,
serialized_end=1476,
)
_VARTYPE_LODTENSORDESC = _descriptor.Descriptor(
name='LoDTensorDesc',
full_name='paddle.framework.proto.VarType.LoDTensorDesc',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='tensor', full_name='paddle.framework.proto.VarType.LoDTensorDesc.tensor', index=0,
number=1, type=11, cpp_type=10, label=2,
has_default_value=False, default_value=None,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='lod_level', full_name='paddle.framework.proto.VarType.LoDTensorDesc.lod_level', index=1,
number=2, type=5, cpp_type=1, label=1,
has_default_value=True, default_value=0,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
],
extensions=[
],
nested_types=[],
enum_types=[
],
options=None,
is_extendable=False,
syntax='proto2',
extension_ranges=[],
oneofs=[
],
serialized_start=1478,
serialized_end=1575,
)
_VARTYPE_LODTENSORARRAYDESC = _descriptor.Descriptor(
name='LoDTensorArrayDesc',
full_name='paddle.framework.proto.VarType.LoDTensorArrayDesc',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='tensor', full_name='paddle.framework.proto.VarType.LoDTensorArrayDesc.tensor', index=0,
number=1, type=11, cpp_type=10, label=2,
has_default_value=False, default_value=None,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='lod_level', full_name='paddle.framework.proto.VarType.LoDTensorArrayDesc.lod_level', index=1,
number=2, type=5, cpp_type=1, label=1,
has_default_value=True, default_value=0,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
],
extensions=[
],
nested_types=[],
enum_types=[
],
options=None,
is_extendable=False,
syntax='proto2',
extension_ranges=[],
oneofs=[
],
serialized_start=1577,
serialized_end=1679,
)
_VARTYPE_READERDESC = _descriptor.Descriptor(
name='ReaderDesc',
full_name='paddle.framework.proto.VarType.ReaderDesc',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='lod_tensor', full_name='paddle.framework.proto.VarType.ReaderDesc.lod_tensor', index=0,
number=1, type=11, cpp_type=10, label=3,
has_default_value=False, default_value=[],
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
],
extensions=[
],
nested_types=[],
enum_types=[
],
options=None,
is_extendable=False,
syntax='proto2',
extension_ranges=[],
oneofs=[
],
serialized_start=1681,
serialized_end=1760,
)
_VARTYPE_TUPLE = _descriptor.Descriptor(
name='Tuple',
full_name='paddle.framework.proto.VarType.Tuple',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='element_type', full_name='paddle.framework.proto.VarType.Tuple.element_type', index=0,
number=1, type=14, cpp_type=8, label=3,
has_default_value=False, default_value=[],
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
],
extensions=[
],
nested_types=[],
enum_types=[
],
options=None,
is_extendable=False,
syntax='proto2',
extension_ranges=[],
oneofs=[
],
serialized_start=1762,
serialized_end=1829,
)
_VARTYPE = _descriptor.Descriptor(
name='VarType',
full_name='paddle.framework.proto.VarType',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='type', full_name='paddle.framework.proto.VarType.type', index=0,
number=1, type=14, cpp_type=8, label=2,
has_default_value=False, default_value=0,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='selected_rows', full_name='paddle.framework.proto.VarType.selected_rows', index=1,
number=2, type=11, cpp_type=10, label=1,
has_default_value=False, default_value=None,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='lod_tensor', full_name='paddle.framework.proto.VarType.lod_tensor', index=2,
number=3, type=11, cpp_type=10, label=1,
has_default_value=False, default_value=None,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='tensor_array', full_name='paddle.framework.proto.VarType.tensor_array', index=3,
number=4, type=11, cpp_type=10, label=1,
has_default_value=False, default_value=None,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='reader', full_name='paddle.framework.proto.VarType.reader', index=4,
number=5, type=11, cpp_type=10, label=1,
has_default_value=False, default_value=None,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='tuple', full_name='paddle.framework.proto.VarType.tuple', index=5,
number=7, type=11, cpp_type=10, label=1,
has_default_value=False, default_value=None,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
],
extensions=[
],
nested_types=[_VARTYPE_TENSORDESC, _VARTYPE_LODTENSORDESC, _VARTYPE_LODTENSORARRAYDESC, _VARTYPE_READERDESC, _VARTYPE_TUPLE, ],
enum_types=[
_VARTYPE_TYPE,
],
options=None,
is_extendable=False,
syntax='proto2',
extension_ranges=[],
oneofs=[
],
serialized_start=1008,
serialized_end=2122,
)
_VARDESC = _descriptor.Descriptor(
name='VarDesc',
full_name='paddle.framework.proto.VarDesc',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='name', full_name='paddle.framework.proto.VarDesc.name', index=0,
number=1, type=9, cpp_type=9, label=2,
has_default_value=False, default_value=_b("").decode('utf-8'),
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='type', full_name='paddle.framework.proto.VarDesc.type', index=1,
number=2, type=11, cpp_type=10, label=2,
has_default_value=False, default_value=None,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='persistable', full_name='paddle.framework.proto.VarDesc.persistable', index=2,
number=3, type=8, cpp_type=7, label=1,
has_default_value=True, default_value=False,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
],
extensions=[
],
nested_types=[],
enum_types=[
],
options=None,
is_extendable=False,
syntax='proto2',
extension_ranges=[],
oneofs=[
],
serialized_start=2124,
serialized_end=2222,
)
_BLOCKDESC = _descriptor.Descriptor(
name='BlockDesc',
full_name='paddle.framework.proto.BlockDesc',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='idx', full_name='paddle.framework.proto.BlockDesc.idx', index=0,
number=1, type=5, cpp_type=1, label=2,
has_default_value=False, default_value=0,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='parent_idx', full_name='paddle.framework.proto.BlockDesc.parent_idx', index=1,
number=2, type=5, cpp_type=1, label=2,
has_default_value=False, default_value=0,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='vars', full_name='paddle.framework.proto.BlockDesc.vars', index=2,
number=3, type=11, cpp_type=10, label=3,
has_default_value=False, default_value=[],
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='ops', full_name='paddle.framework.proto.BlockDesc.ops', index=3,
number=4, type=11, cpp_type=10, label=3,
has_default_value=False, default_value=[],
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='forward_block_idx', full_name='paddle.framework.proto.BlockDesc.forward_block_idx', index=4,
number=5, type=5, cpp_type=1, label=1,
has_default_value=True, default_value=-1,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
],
extensions=[
],
nested_types=[],
enum_types=[
],
options=None,
is_extendable=False,
syntax='proto2',
extension_ranges=[],
oneofs=[
],
serialized_start=2225,
serialized_end=2392,
)
_PROGRAMDESC = _descriptor.Descriptor(
name='ProgramDesc',
full_name='paddle.framework.proto.ProgramDesc',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='blocks', full_name='paddle.framework.proto.ProgramDesc.blocks', index=0,
number=1, type=11, cpp_type=10, label=3,
has_default_value=False, default_value=[],
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='version', full_name='paddle.framework.proto.ProgramDesc.version', index=1,
number=2, type=11, cpp_type=10, label=1,
has_default_value=False, default_value=None,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
],
extensions=[
],
nested_types=[],
enum_types=[
],
options=None,
is_extendable=False,
syntax='proto2',
extension_ranges=[],
oneofs=[
],
serialized_start=2394,
serialized_end=2508,
)
_OPDESC_ATTR.fields_by_name['type'].enum_type = _ATTRTYPE
_OPDESC_ATTR.containing_type = _OPDESC
_OPDESC_VAR.containing_type = _OPDESC
_OPDESC.fields_by_name['inputs'].message_type = _OPDESC_VAR
_OPDESC.fields_by_name['outputs'].message_type = _OPDESC_VAR
_OPDESC.fields_by_name['attrs'].message_type = _OPDESC_ATTR
_OPPROTO_VAR.containing_type = _OPPROTO
_OPPROTO_ATTR.fields_by_name['type'].enum_type = _ATTRTYPE
_OPPROTO_ATTR.containing_type = _OPPROTO
_OPPROTO.fields_by_name['inputs'].message_type = _OPPROTO_VAR
_OPPROTO.fields_by_name['outputs'].message_type = _OPPROTO_VAR
_OPPROTO.fields_by_name['attrs'].message_type = _OPPROTO_ATTR
_VARTYPE_TENSORDESC.fields_by_name['data_type'].enum_type = _VARTYPE_TYPE
_VARTYPE_TENSORDESC.containing_type = _VARTYPE
_VARTYPE_LODTENSORDESC.fields_by_name['tensor'].message_type = _VARTYPE_TENSORDESC
_VARTYPE_LODTENSORDESC.containing_type = _VARTYPE
_VARTYPE_LODTENSORARRAYDESC.fields_by_name['tensor'].message_type = _VARTYPE_TENSORDESC
_VARTYPE_LODTENSORARRAYDESC.containing_type = _VARTYPE
_VARTYPE_READERDESC.fields_by_name['lod_tensor'].message_type = _VARTYPE_LODTENSORDESC
_VARTYPE_READERDESC.containing_type = _VARTYPE
_VARTYPE_TUPLE.fields_by_name['element_type'].enum_type = _VARTYPE_TYPE
_VARTYPE_TUPLE.containing_type = _VARTYPE
_VARTYPE.fields_by_name['type'].enum_type = _VARTYPE_TYPE
_VARTYPE.fields_by_name['selected_rows'].message_type = _VARTYPE_TENSORDESC
_VARTYPE.fields_by_name['lod_tensor'].message_type = _VARTYPE_LODTENSORDESC
_VARTYPE.fields_by_name['tensor_array'].message_type = _VARTYPE_LODTENSORARRAYDESC
_VARTYPE.fields_by_name['reader'].message_type = _VARTYPE_READERDESC
_VARTYPE.fields_by_name['tuple'].message_type = _VARTYPE_TUPLE
_VARTYPE_TYPE.containing_type = _VARTYPE
_VARDESC.fields_by_name['type'].message_type = _VARTYPE
_BLOCKDESC.fields_by_name['vars'].message_type = _VARDESC
_BLOCKDESC.fields_by_name['ops'].message_type = _OPDESC
_PROGRAMDESC.fields_by_name['blocks'].message_type = _BLOCKDESC
_PROGRAMDESC.fields_by_name['version'].message_type = _VERSION
DESCRIPTOR.message_types_by_name['Version'] = _VERSION
DESCRIPTOR.message_types_by_name['OpDesc'] = _OPDESC
DESCRIPTOR.message_types_by_name['OpProto'] = _OPPROTO
DESCRIPTOR.message_types_by_name['VarType'] = _VARTYPE
DESCRIPTOR.message_types_by_name['VarDesc'] = _VARDESC
DESCRIPTOR.message_types_by_name['BlockDesc'] = _BLOCKDESC
DESCRIPTOR.message_types_by_name['ProgramDesc'] = _PROGRAMDESC
DESCRIPTOR.enum_types_by_name['AttrType'] = _ATTRTYPE
Version = _reflection.GeneratedProtocolMessageType('Version', (_message.Message,), dict(
DESCRIPTOR = _VERSION,
__module__ = 'framework_pb2'
# @@protoc_insertion_point(class_scope:paddle.framework.proto.Version)
))
_sym_db.RegisterMessage(Version)
OpDesc = _reflection.GeneratedProtocolMessageType('OpDesc', (_message.Message,), dict(
Attr = _reflection.GeneratedProtocolMessageType('Attr', (_message.Message,), dict(
DESCRIPTOR = _OPDESC_ATTR,
__module__ = 'framework_pb2'
# @@protoc_insertion_point(class_scope:paddle.framework.proto.OpDesc.Attr)
))
,
Var = _reflection.GeneratedProtocolMessageType('Var', (_message.Message,), dict(
DESCRIPTOR = _OPDESC_VAR,
__module__ = 'framework_pb2'
# @@protoc_insertion_point(class_scope:paddle.framework.proto.OpDesc.Var)
))
,
DESCRIPTOR = _OPDESC,
__module__ = 'framework_pb2'
# @@protoc_insertion_point(class_scope:paddle.framework.proto.OpDesc)
))
_sym_db.RegisterMessage(OpDesc)
_sym_db.RegisterMessage(OpDesc.Attr)
_sym_db.RegisterMessage(OpDesc.Var)
OpProto = _reflection.GeneratedProtocolMessageType('OpProto', (_message.Message,), dict(
Var = _reflection.GeneratedProtocolMessageType('Var', (_message.Message,), dict(
DESCRIPTOR = _OPPROTO_VAR,
__module__ = 'framework_pb2'
# @@protoc_insertion_point(class_scope:paddle.framework.proto.OpProto.Var)
))
,
Attr = _reflection.GeneratedProtocolMessageType('Attr', (_message.Message,), dict(
DESCRIPTOR = _OPPROTO_ATTR,
__module__ = 'framework_pb2'
# @@protoc_insertion_point(class_scope:paddle.framework.proto.OpProto.Attr)
))
,
DESCRIPTOR = _OPPROTO,
__module__ = 'framework_pb2'
# @@protoc_insertion_point(class_scope:paddle.framework.proto.OpProto)
))
_sym_db.RegisterMessage(OpProto)
_sym_db.RegisterMessage(OpProto.Var)
_sym_db.RegisterMessage(OpProto.Attr)
VarType = _reflection.GeneratedProtocolMessageType('VarType', (_message.Message,), dict(
TensorDesc = _reflection.GeneratedProtocolMessageType('TensorDesc', (_message.Message,), dict(
DESCRIPTOR = _VARTYPE_TENSORDESC,
__module__ = 'framework_pb2'
# @@protoc_insertion_point(class_scope:paddle.framework.proto.VarType.TensorDesc)
))
,
LoDTensorDesc = _reflection.GeneratedProtocolMessageType('LoDTensorDesc', (_message.Message,), dict(
DESCRIPTOR = _VARTYPE_LODTENSORDESC,
__module__ = 'framework_pb2'
# @@protoc_insertion_point(class_scope:paddle.framework.proto.VarType.LoDTensorDesc)
))
,
LoDTensorArrayDesc = _reflection.GeneratedProtocolMessageType('LoDTensorArrayDesc', (_message.Message,), dict(
DESCRIPTOR = _VARTYPE_LODTENSORARRAYDESC,
__module__ = 'framework_pb2'
# @@protoc_insertion_point(class_scope:paddle.framework.proto.VarType.LoDTensorArrayDesc)
))
,
ReaderDesc = _reflection.GeneratedProtocolMessageType('ReaderDesc', (_message.Message,), dict(
DESCRIPTOR = _VARTYPE_READERDESC,
__module__ = 'framework_pb2'
# @@protoc_insertion_point(class_scope:paddle.framework.proto.VarType.ReaderDesc)
))
,
Tuple = _reflection.GeneratedProtocolMessageType('Tuple', (_message.Message,), dict(
DESCRIPTOR = _VARTYPE_TUPLE,
__module__ = 'framework_pb2'
# @@protoc_insertion_point(class_scope:paddle.framework.proto.VarType.Tuple)
))
,
DESCRIPTOR = _VARTYPE,
__module__ = 'framework_pb2'
# @@protoc_insertion_point(class_scope:paddle.framework.proto.VarType)
))
_sym_db.RegisterMessage(VarType)
_sym_db.RegisterMessage(VarType.TensorDesc)
_sym_db.RegisterMessage(VarType.LoDTensorDesc)
_sym_db.RegisterMessage(VarType.LoDTensorArrayDesc)
_sym_db.RegisterMessage(VarType.ReaderDesc)
_sym_db.RegisterMessage(VarType.Tuple)
VarDesc = _reflection.GeneratedProtocolMessageType('VarDesc', (_message.Message,), dict(
DESCRIPTOR = _VARDESC,
__module__ = 'framework_pb2'
# @@protoc_insertion_point(class_scope:paddle.framework.proto.VarDesc)
))
_sym_db.RegisterMessage(VarDesc)
BlockDesc = _reflection.GeneratedProtocolMessageType('BlockDesc', (_message.Message,), dict(
DESCRIPTOR = _BLOCKDESC,
__module__ = 'framework_pb2'
# @@protoc_insertion_point(class_scope:paddle.framework.proto.BlockDesc)
))
_sym_db.RegisterMessage(BlockDesc)
ProgramDesc = _reflection.GeneratedProtocolMessageType('ProgramDesc', (_message.Message,), dict(
DESCRIPTOR = _PROGRAMDESC,
__module__ = 'framework_pb2'
# @@protoc_insertion_point(class_scope:paddle.framework.proto.ProgramDesc)
))
_sym_db.RegisterMessage(ProgramDesc)
DESCRIPTOR.has_options = True
DESCRIPTOR._options = _descriptor._ParseOptions(descriptor_pb2.FileOptions(), _b('H\003'))
# @@protoc_insertion_point(module_scope)
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from name_generator import NameGenerator
class GraphNode(object):
def __init__(self, layer):
self.in_edges = list()
self.out_edges = list()
self.layer = layer
self.ref_name = None
class Graph(object):
def __init__(self, model):
self.node_map = dict()
self.input_nodes = list()
self.output_nodes = list()
self.topological_sort = list()
self.model = model
self.name_generator = NameGenerator()
def build(self):
self._make_input_nodes()
self._make_output_nodes()
self._get_topological_sort()
self._gen_newname_for_nodes()
def _make_input_nodes(self):
for name, node in self.node_map.items():
node.left_in_edges = len(node.in_edges)
if len(node.in_edges) == 0:
self.input_nodes.append(name)
def _make_output_nodes(self):
for name, node in self.node_map.items():
if len(node.out_edges) == 0:
self.output_nodes.append(name)
def _get_topological_sort(self):
self.topological_sort = self.input_nodes[:]
idx = 0
while idx < len(self.topological_sort):
current_node = self.node_map[self.topological_sort[idx]]
for next_node in current_node.out_edges:
next_node_info = self.node_map[next_node]
next_node_info.left_in_edges -= 1
if next_node_info.left_in_edges == 0:
self.topological_sort.append(next_node)
idx += 1
def _gen_newname_for_nodes(self):
for node_name in self.topological_sort:
node = self.node_map[node_name]
ref_name = self.name_generator.get_name(node)
self.node_map[node.layer.name].ref_name = ref_name
def get_node(self, name):
if name not in self.node_map:
raise Exception("Graph doesn't have node [%s]." % name)
else:
return self.node_map[name]
def _make_connection(self, src, dst):
if src == dst or src not in self.node_map or dst not in self.node_map:
raise Exception('Warning: Node not exist or there is a self-loop')
if src not in self.node_map[dst].in_edges:
self.node_map[dst].in_edges.append(src)
if dst not in self.node_map[src].out_edges:
self.node_map[src].out_edges.append(dst)
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
class NameGenerator(object):
def __init__(self):
self.param_index = 0
self.input_index = 0
self.net_index = 0
self.const_index = 0
self.names = dict()
def get_name(self, node):
ref_name = None
op_name = node.type
if node.layer.name in self.names:
return self.names[node.layer.name]
if op_name == "variablev2":
ref_name = "param_" + str(self.param_index)
self.param_index += 1
elif op_name == "placeholder":
ref_name = "input_" + str(self.input_index)
self.input_index += 1
elif op_name == "const":
ref_name = "const_" + str(self.const_index)
self.const_index += 1
elif op_name.lower() == "identity":
ref_name = self.names[node.layer.input[0]]
else:
ref_name = "net_" + str(self.net_index)
self.net_index += 1
self.names[node.layer.name] = ref_name
return ref_name
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from tensorflow.core.framework import attr_value_pb2
from tensorflow.python.framework import tensor_util
from six import string_types as _string_types
class PaddleEmitter(object):
skip_op = set(['const', 'identity'])
dtype_map = {1: "float32", 3: "int32", 9: "int64"}
def __init__(self, graph):
self.graph = graph
self.body_code = ""
self.tab = " " * 4
@staticmethod
def tensor_shape_to_list(shapes):
if isinstance(shapes, attr_value_pb2.AttrValue):
return [dim.size for dim in shapes.shape.dim]
else:
ret = []
for shape in shapes:
this_one = [dim.size for dim in shape.dim]
ret.append(this_one)
return ret
@property
def header_code(self):
code = ["import paddle.fluid as fluid", "", "def KitModel():"]
return code
def add_codes(self, indent, codes):
if isinstance(codes, _string_types):
codes = codes.strip().split("\n")
for code in codes:
self.body_code += (self.tab * indent) + code + "\n"
def gen_code(self):
self.add_codes(0, self.header_code)
for node in self.graph.topological_sort:
current_node = self.graph.get_node(node)
op = current_node.type
if op in self.skip_op:
continue
if hasattr(self, "emit_" + op):
func = getattr(self, "emit_" + op)
codes = func(current_node)
if not isinstance(codes, list):
codes = [codes]
self.graph.get_node(node).codes = codes
else:
raise Exception("Unknow node op: {}".format(op))
for node in self.graph.topological_sort:
codes = self.graph.get_node(node).codes
self.add_codes(1, codes)
outs = []
for node in self.graph.output_nodes:
outs.append(self.graph.get_node(node).ref_name)
self.add_codes(1, "return {}".format(", ".join(outs)))
return self.body_code
def gen_weight(self, weight_dict, dirname):
import struct
import framework_pb2 as framework
import numpy
for var_name, var in weight_dict.items():
if var_name not in self.graph.node_map:
continue
shape = var.shape
paddle_var_name = self.graph.get_node(var_name).ref_name
paddle_var = var
dataformat = self.graph.get_node(var_name).dataformat
filew = open(dirname + '/' + paddle_var_name, 'wb')
filew.write(struct.pack('i', 0))
filew.write(struct.pack('L', 0))
filew.write(struct.pack('i', 0))
tensor_desc = framework.VarType.TensorDesc()
tensor_desc.data_type = framework.VarType.FP32
if len(shape) == 4 and dataformat == "NHWC":
paddle_var = numpy.transpose(var, (3, 2, 0, 1))
shape = paddle_var.shape
tensor_desc.dims.extend(shape)
desc_size = tensor_desc.ByteSize()
filew.write(struct.pack('i', desc_size))
filew.write(tensor_desc.SerializeToString())
tensor_size = reduce(lambda x, y: x * y, shape)
paddle_var = paddle_var.flatten()
for i in range(0, tensor_size):
filew.write(struct.pack('f', paddle_var[i]))
filew.close()
def emit_variablev2(self, node):
shape = self.tensor_shape_to_list(node.get_attr('_output_shapes'))[0]
if node.dataformat == 'NHWC' and len(shape) == 4:
shape = [shape[3], shape[2], shape[0], shape[1]]
dtype = node.get_attr('dtype')
if dtype in self.dtype_map:
dtype = self.dtype_map[dtype]
else:
raise Exception('Unknow dtype : {}'.format(dtype))
code = ["# variable[{}]:\t{}".format(node.ref_name, node.name)]
code.append(
'{} = fluid.layers.create_parameter(name=\'{}\', shape={}, dtype=\'{}\')'
.format(node.ref_name, node.ref_name, shape, dtype))
return code
def emit_placeholder(self, node):
shape = self.tensor_shape_to_list(node.get_attr('shape'))[0]
if node.dataformat == 'NHWC' and len(shape) == 4:
shape = [shape[0], shape[3], shape[1], shape[2]]
if shape[0] < 0:
shape = shape[1:]
dtype = node.get_attr('dtype')
if dtype in self.dtype_map:
dtype = self.dtype_map[dtype]
else:
raise Exception('Unknow dtype : {}'.format(dtype))
code = ["# placeholder[{}]:\t{}".format(node.ref_name, node.name)]
code.append(
'{} = fluid.layers.data(name=\'{}\', shape={}, dtype=\'{}\')'.
format(node.ref_name, node.ref_name, shape, dtype))
return code
def emit_conv2d(self, node):
inputs = self.graph.get_node(node.layer.input[0])
kernel = self.graph.get_node(node.layer.input[1])
dataformat = node.dataformat
padding_mode = node.get_attr('padding')
strides = node.get_attr('strides')[1:3]
k_shape = self.tensor_shape_to_list(
kernel.get_attr('_output_shapes'))[0]
kernel_num, channel, height, width = k_shape
if dataformat == "NHWC":
height, width, channle, kernel_num = k_shape
padding = [0, 0]
if padding_mode == 'SAME':
padding = map(int, [(height - 1) / 2, (width - 1) / 2])
code = '{} = fluid.layers.conv2d({}, {}, {}, padding={}, stride={}, param_attr=\'{}\', bias_attr=False)'.format(
node.ref_name, inputs.ref_name, kernel_num, [height, width],
padding, strides, kernel.ref_name)
return code
def emit_biasadd(self, node):
inputs = self.graph.get_node(node.layer.input[0])
bias = self.graph.get_node(node.layer.input[1])
# TODO more validations
axis = 1
code = '{} = fluid.layers.elementwise_add({}, {}, axis={})'.format(
node.ref_name, inputs.ref_name, bias.ref_name, axis)
return code
def emit_relu(self, node):
inputs = self.graph.get_node(node.layer.input[0])
code = '{} = fluid.layers.relu({})'.format(node.ref_name,
inputs.ref_name)
return code
def emit_maxpool(self, node):
inputs = self.graph.get_node(node.layer.input[0])
padding_mode = node.get_attr('padding')
strides = node.get_attr('strides')[1:3]
pool_size = node.get_attr('ksize')[1:3]
padding = [0, 0]
if padding_mode == 'SAME':
pad_0 = (pool_size[0] - 1) / 2
pad_1 = (pool_size[1] - 1) / 2
padding = [0, pad_0 * 2, 0, pad_1 * 2]
code = [
'pad_net = fluid.layers.pad2d({}, paddings={})'.format(
inputs.ref_name, padding)
]
code.append(
'{} = fluid.layers.pool2d(pad_net, {}, \'max\', {})'.format(
node.ref_name, pool_size, strides))
else:
code = '{} = fluid.layers.pool2d({}, {}, \'max\', {})'.format(
node.ref_name, inputs.ref_name, pool_size, strides)
return code
def emit_pad(self, node):
inputs = self.graph.get_node(node.layer.input[0])
padding = self.graph.get_node(node.layer.input[1])
assert padding.type == 'const'
padding = padding.layer.attr['value'].tensor
padding = tensor_util.MakeNdarray(padding)
if node.dataformat == "NHWC" and padding.shape[0] == 4:
padding = padding[[0, 3, 1, 2]]
code = '{} = fluid.layers.pad({}, {})'.format(node.ref_name,
inputs.ref_name,
list(padding.flatten()))
return code
def emit_fusedbatchnorm(self, node):
inputs = self.graph.get_node(node.layer.input[0])
gamma = self.graph.get_node(node.layer.input[1])
beta = self.graph.get_node(node.layer.input[2])
mv_mean = self.graph.get_node(node.layer.input[3])
mv_variance = self.graph.get_node(node.layer.input[4])
is_training = node.get_attr("is_training")
if is_training:
raise Exception(
"FusedBatchNorm: is_training=True, not support yet, please set is_training=False in your tensorflow code, then dump model again"
)
epsilon = round(node.get_attr('epsilon'), 6)
if gamma.type == 'const':
value = gamma.get_attr('value')
shape = value.tensor_shape
assert len(shape.dim) == 1
shape = shape.dim[0].size
assert len(value.float_val) == 1
value = value.float_val[0]
code = "{} = fluid.layers.batch_norm({}, epsilon={}, param_attr=fluid.ParamAttr(\'{}\', fluid.initializer.Constant({})), bias_attr=\'{}\', moving_mean_name=\'{}\', moving_variance_name=\'{}\', is_test=True)".format(
node.ref_name, inputs.ref_name, epsilon, gamma.ref_name, value,
beta.ref_name, mv_mean.ref_name, mv_variance.ref_name)
else:
code = '{} = fluid.layers.batch_norm({}, epsilon={}, param_attr=\'{}\', bias_attr=\'{}\', moving_mean_name=\'{}\', moving_variance_name=\'{}\', is_test=True)'.format(
node.ref_name, inputs.ref_name, epsilon, gamma.ref_name,
beta.ref_name, mv_mean.ref_name, mv_variance.ref_name)
return code
def emit_assign(self, node):
ref = self.graph.get_node(node.layer.input[0])
value = self.graph.get_node(node.layer.input[1])
code = 'fluid.layers.assign(input={}, output={})'.format(
value.ref_name, ref.ref_name)
return code
def emit_add(self, node):
input1 = self.graph.get_node(node.layer.input[0])
input2 = self.graph.get_node(node.layer.input[1])
code = '{} = fluid.layers.elementwise_add({}, {})'.format(
node.ref_name, input1.ref_name, input2.ref_name)
return code
def emit_mean(self, node):
inputs = self.graph.get_node(node.layer.input[0])
reduce_idx = self.graph.get_node(node.layer.input[1])
idxs = reduce_idx.layer.attr['value'].tensor
idxs = tensor_util.MakeNdarray(idxs)
shape = idxs.shape
if len(shape) != 1:
raise Exception('Unexpected situation[mean_op]')
input_shape = self.tensor_shape_to_list(
inputs.get_attr('_output_shapes'))[0]
if node.dataformat == "NHWC" and len(input_shape) == 4:
for i in range(0, shape[0]):
if idxs[i] == 1:
idxs[i] = 2
elif idxs[i] == 2:
idxs[i] = 3
elif idxs[i] == 3:
idxs[i] = 1
code = '{} = fluid.layers.reduce_mean({}, {}, keep_dim=True)'.format(
node.ref_name, inputs.ref_name, list(idxs))
return code
def emit_squeeze(self, node):
inputs = self.graph.get_node(node.layer.input[0])
axis = node.get_attr('squeeze_dims')
input_shape = self.tensor_shape_to_list(
inputs.get_attr('_output_shapes'))[0]
if node.dataformat == "NHWC" and len(input_shape) == 4:
for i in range(0, len(axis)):
if axis[i] == 1:
axis[i] = 2
elif axis[i] == 2:
axis[i] = 3
elif axis[i] == 3:
axis[i] = 1
code = '{} = fluid.layers.squeeze({}, {})'.format(
node.ref_name, inputs.ref_name, axis)
return code
def emit_const(self, node):
shape = self.tensor_shape_to_list(node.get_attr('_output_shapes'))[0]
dtype = node.get_attr('dtype')
# TODO dtype need more validation
value = node.layer.attr['value'].tensor.int_val[0]
if dtype in self.dtype_map:
dtype = self.dtype_map[dtype]
else:
raise Exception('Unknow dtype : {}'.format(dtype))
if node.dataformat == 'NHWC':
raise Exception("Const: NHWC format not support yet")
code = "{} = fluid.layers.fill_constant({}, \'{}\', {})".format(
node.ref_name, shape, dtype, value)
return code
def emit_concatv2(self, node):
inputs = node.layer.input
inputs_vars = []
code = []
for i in range(0, len(inputs) - 1):
tmp = self.graph.get_node(inputs[i])
if tmp.type == 'const':
code.append(self.emit_const(tmp))
inputs_vars.append(tmp.ref_name)
axis = self.graph.get_node(
inputs[-1]).layer.attr['value'].tensor.int_val[0]
output_shape = self.tensor_shape_to_list(
node.get_attr('_output_shapes'))[0]
if node.dataformat == "NHWC" and len(output_shape) == 4:
if axis == 1:
axis = 2
elif axis == 2:
axis = 3
elif axis == 3:
axis = 1
code.append('{} = fluid.layers.concat([{}], {})'.format(
node.ref_name, ', '.join(inputs_vars), axis))
return code
def emit_avgpool(self, node):
inputs = self.graph.get_node(node.layer.input[0])
padding_mode = node.get_attr('padding')
# TODO need more validation in nlp
strides = node.get_attr('strides')[1:3]
pool_size = node.get_attr('ksize')[1:3]
padding = [0, 0]
if padding_mode == 'SAME':
pad_0 = (pool_size[0] - 1) / 2
pad_1 = (pool_size[1] - 1) / 2
padding = [pad_0, pad_1]
code = '{} = fluid.layers.pool2d({}, {}, \'avg\', {}, {})'.format(
node.ref_name, inputs.ref_name, pool_size, strides, padding)
return code
def emit_sub(self, node):
input1 = self.graph.get_node(node.layer.input[0])
input2 = self.graph.get_node(node.layer.input[1])
code = '{} = fluid.layers.elementwise_sub({}, {})'.format(
node.ref_name, input1.ref_name, input2.ref_name)
return code
def emit_mul(self, node):
input1 = self.graph.get_node(node.layer.input[0])
input2 = self.graph.get_node(node.layer.input[1])
code = '{} = fluid.layers.elementwise_mul({}, {})'.format(
node.ref_name, input1.ref_name, input2.ref_name)
return code
def emit_floor(self, node):
inputs = self.graph.get_node(node.layer.input[0])
code = '{} = fluid.layers.floor({})'.format(node.ref_name,
inputs.ref_name)
return code
def emit_realdiv(self, node):
input1 = self.graph.get_node(node.layer.input[0])
input2 = self.graph.get_node(node.layer.input[1])
code = '{} = fluid.layers.elementwise_div({})'.format(
node.ref_name, input1.ref_name, input2.ref_name)
return code
def emit_shape(self, node):
inputs = self.graph.get_node(node.layer.input[0])
if "num_split" in inputs.layer.attr:
code = '{} = fluid.layers.shape({}[0])'.format(
node.ref_name, inputs.ref_name)
else:
code = '{} = fluid.layers.shape({})'.format(
node.ref_name, inputs.ref_name)
# TODO there's dtype problem of PaddlePaddle's OP[fluid.layers.shape]
# https://github.com/PaddlePaddle/Paddle/issues/15267
# tensorflow2paddle fix problem temporary
code = [code]
code.append("{} = fluid.layers.cast({}, dtype='int32')".format(
node.ref_name, node.ref_name))
return code
def emit_stridedslice(self, node):
inputs = self.graph.get_node(node.layer.input[0])
begin = self.graph.get_node(node.layer.input[1])
end = self.graph.get_node(node.layer.input[2])
strides = self.graph.get_node(node.layer.input[3])
begin = list(tensor_util.MakeNdarray(begin.layer.attr['value'].tensor))
end = list(tensor_util.MakeNdarray(end.layer.attr['value'].tensor))
strides = list(
tensor_util.MakeNdarray(strides.layer.attr['value'].tensor))
if len(begin) != len(strides) or len(end) != len(strides):
raise Exception("length of begin/end/strides must be equl")
for i in range(0, len(strides)):
if strides[i] != 1:
raise Exception(
"strides must be 1 for all axis, other situation not supported yet"
)
code = "{} = fluid.layers.slice({}, axes={}, starts={}, ends={})".format(
node.ref_name, inputs.ref_name, [i for i in range(0, len(begin))],
begin, end)
return code
def emit_gather(self, node):
embedding = self.graph.get_node(node.layer.input[0])
idxs = self.graph.get_node(node.layer.input[1])
idxs_shape = self.tensor_shape_to_list(
idxs.get_attr('_output_shapes'))[0]
embedding_shape = self.tensor_shape_to_list(
embedding.get_attr('_output_shapes'))[0]
if len(embedding_shape) != 2:
raise Exception("rank of input[0] must be equal to 2 in Gather OP")
code = []
if idxs_shape[-1] != 1:
code.append(
"reshape = fluid.layers.reshape({}, shape=[-1, 1])".format(
idxs.ref_name))
code.append("gather = fluid.layers.gather({}, reshape)".format(
embedding.ref_name))
code.append("{} = fluid.layers.reshape(gather,{})".format(
node.ref_name, idxs_shape + [embedding_shape[-1]]))
else:
code = "{} = fluid.layers.gather({}, {})".format(
node.ref_name, embedding.ref_name, idxs.ref_name)
return code
def emit_transpose(self, node):
inputs = self.graph.get_node(node.layer.input[0])
perm = self.graph.get_node(node.layer.input[1])
assert perm.type == "const"
perm = perm.layer.attr['value'].tensor
perm = tensor_util.MakeNdarray(perm)
# TODO
if node.dataformat == "NHWC" and perm.shape[0] == 4:
raise Exception(
"Unsupported situation for op Transpose, NHWC not supported yet"
)
perm = list(perm)
code = "{} = fluid.layers.transpose({}, {})".format(
node.ref_name, inputs.ref_name, perm)
return code
def emit_reshape(self, node):
inputs = self.graph.get_node(node.layer.input[0])
shape = self.graph.get_node(node.layer.input[1])
assert shape.type == "const"
# TODO
if node.dataformat == "NHWC":
raise Exception(
"Unsupported situation for reshape, NHWC not supported yet")
shape = shape.layer.attr['value'].tensor
shape = list(tensor_util.MakeNdarray(shape))
code = "{} = fluid.layers.reshape({}, {})".format(
node.ref_name, inputs.ref_name, shape)
return code
def emit_split(self, node):
inputs = self.graph.get_node(node.layer.input[1])
split_dim = self.graph.get_node(node.layer.input[0])
inputs_shape = self.tensor_shape_to_list(
inputs.get_attr('_output_shapes'))[0]
assert split_dim.type == 'const' and len(inputs_shape) > 1
axis = split_dim.layer.attr['value'].tensor.int_val[0]
num_split = node.get_attr('num_split')
code = list()
if inputs_shape[axis] < 0:
tmp_shape = [-1, num_split
] + inputs_shape[:axis] + inputs_shape[axis + 1:]
code.append("reshape = fluid.layers.reshape({}, {})".format(
inputs.ref_name, tmp_shape))
code.append(
"split = fluid.layers.split(reshape, {}, 1)".format(num_split))
code.append(
"{} = [fluid.layers.squeeze(s, [1]) for s in split]".format(
node.ref_name))
else:
code = "{} = fluid.layers.split({}, {}, {})".format(
inputs.ref_name, num_split, axis)
return code
def emit_expanddims(self, node):
inputs = self.graph.get_node(node.layer.input[0])
dim = self.graph.get_node(node.layer.input[1])
dim = tensor_util.MakeNdarray(dim.layer.attr['value'].tensor)
inputs_shape = self.tensor_shape_to_list(
inputs.get_attr('_output_shapes'))[0]
inputs_shape.insert(dim, 1)
code = "{} = fluid.layers.reshape({}, {})".format(
node.ref_name, inputs.ref_name, inputs_shape)
return code
def emit_fill(self, node):
value = self.graph.get_node(node.layer.input[1])
value = value.layer.attr['value'].tensor.float_val[0]
dtype = node.layer.attr['T'].type
if dtype in self.dtype_map:
dtype = self.dtype_map[dtype]
else:
raise Exception('Unknow dtype : {}'.format(dtype))
output_shape = self.tensor_shape_to_list(
node.get_attr('_output_shapes'))[0]
code = "{} = fluid.layers.create_parameter({}, {}, default_initializer=fluid.initializer.Constant({}))".format(
node.ref_name, output_shape, dtype, value)
return code
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from graph import GraphNode, Graph
from tensorflow.core.framework import attr_value_pb2
class TensorflowGraphNode(GraphNode):
def __init__(self, layer):
super(TensorflowGraphNode, self).__init__(layer)
self.codes = list()
self.dataformat = 'NCHW'
@property
def type(self):
return self.layer.op.lower()
@property
def name(self):
return self.layer.name
# TODO
def get_attr(self, name, default_value=None):
if name in self.layer.attr:
attr = self.layer.attr[name]
field = attr.WhichOneof('value')
val = getattr(attr, field) if field else default_value
if isinstance(val, attr_value_pb2.AttrValue.ListValue):
return list(val.ListFields()[0][1])
else:
return val.decode('utf-8') if isinstance(val, bytes) else val
else:
return default_value
class TensorflowGraph(Graph):
def __init__(self, model):
super(TensorflowGraph, self).__init__(model)
self.model = model
def build(self):
for i, layer in enumerate(self.model.node):
self.node_map[layer.name] = TensorflowGraphNode(layer)
for pred in layer.input:
if pred not in self.node_map:
raise Exception('input: {} not in node_map'.format(pred))
self._make_connection(pred, layer.name)
super(TensorflowGraph, self).build()
self._check_dataformat()
# check the dataformat of network
def _check_dataformat(self):
ss = list()
for i in range(0, len(self.topological_sort)):
current_node = self.node_map[self.topological_sort[i]]
if current_node.type == 'conv2d':
s = current_node.layer.attr['data_format'].s
if s != 'NHWC' and s != 'NCHW':
raise Exception('Unkown dataformat {}'.format(s))
ss.append(s)
if len(set(ss)) > 1:
raise Exception("Two type of dataformat exist in this model")
if len(set(ss)) == 0:
return
for i in range(0, len(self.topological_sort)):
current_node = self.node_map[self.topological_sort[i]]
current_node.dataformat = ss[0]
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import tensorflow
from tensorflow_graph import TensorflowGraph
class TensorflowParser(object):
def __init__(self,
meta_file,
checkpoint_file,
dest_nodes,
input_shape=None,
in_nodes=None):
graph_def = None
self.weights = dict()
with tensorflow.Session() as sess:
if meta_file is None:
raise Exception("meta_file must be provided")
new_saver = tensorflow.train.import_meta_graph(meta_file)
if checkpoint_file is not None:
new_saver.restore(
sess, tensorflow.train.latest_checkpoint(checkpoint_file))
for var in tensorflow.global_variables():
value = var.eval(sess)
self.weights[var.name.split(':')[0]] = value
graph_def, ver = tensorflow.get_default_graph()._as_graph_def(
add_shapes=True)
if in_nodes is not None and input_shape is not None:
from tensorflow.python.tools import strip_unused_lib
from tensorflow.python.framework import dtypes
graph_def = strip_unused_lib.strip_unused(
input_graph_def=graph_def,
input_node_names=in_nodes,
output_node_names=dest_nodes,
placeholder_type_enum=dtypes.float32.as_datatype_enum)
input_list = [None]
for i in range(len(input_shape)):
input_list.append(tensorflow.Dimension(input_shape[i]))
tensor_input = tensorflow.TensorShape(input_list)
self.tf_graph = TensorflowGraph(graph_def)
for node in self.tf_graph.model.node:
if node.name in in_nodes:
node.attr['shape'].list.shape.extend(
[tensor_input.as_proto()])
node.attr['_output_shapes'].list.shape.pop()
node.attr['_output_shapes'].list.shape.extend(
[tensor_input.as_proto()])
else:
raise Exception('in_nodes and output_nodes need be provided')
self.tf_graph.build()
from paddle_emitter import PaddleEmitter
from tensorflow_parser import TensorflowParser
class Transformer(object):
def __init__(self, meta_file, ckpt_file, out_nodes, in_shape, in_nodes):
self.parser = TensorflowParser(meta_file, ckpt_file, out_nodes,
in_shape, in_nodes)
self.emitter = PaddleEmitter(self.parser.tf_graph)
def transform_code(self, out_py_file):
filew = open(out_py_file, 'w')
codes = self.emitter.gen_code()
filew.write(codes)
filew.close()
def transform_weight(self, out_dir):
self.emitter.gen_weight(self.parser.weights, out_dir)
def run(self, dst_dir):
import os
if os.path.isdir(dst_dir) or os.path.isfile(dst_dir):
print("{} already exists, set a new directory")
return
if not os.path.isdir(dst_dir):
os.mkdir(dst_dir)
self.transform_code(dst_dir + "/mymodel.py")
if (len(self.parser.weights) == 0):
print("There is no tensorflow model weight translate to paddle")
else:
self.transform_weight(dst_dir)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册