提交 b483d12e 编写于 作者: M Macrobull

add onnx2paddle

上级 7b3054cb
# Virtualenv
/.venv/
/venv/
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
# C extensions
*.so
# Distribution / packaging
/bin/
/build/
/develop-eggs/
/dist/
/eggs/
/lib/
/lib64/
/output/
/parts/
/sdist/
/var/
/*.egg-info/
/.installed.cfg
/*.egg
/.eggs
# AUTHORS and ChangeLog will be generated while packaging
/AUTHORS
/ChangeLog
# BCloud / BuildSubmitter
/build_submitter.*
/logger_client_log
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
.tox/
.coverage
.cache
.pytest_cache
nosetests.xml
coverage.xml
# Translations
*.mo
# Sphinx documentation
/docs/_build/
/examples/*/
/examples/*.gz
/examples/*.aria2
/examples/*.onnx
/examples/*.np?
Onnx2paddle
===
Inference model conversion from ONNX/PyTorch to Paddle
快速开始
---
如何构建、安装、运行
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Wed Mar 27 11:50:03 2019
@author: Macrobull
"""
import sys
import numpy as np
from collections import OrderedDict as Dict
fn = sys.argv[1]
input_names = sys.argv[2].split(':')
output_name = sys.argv[3].split(':')
data = np.load(fn)
input_data = data['inputs']
output_data = data['outputs']
inputs = Dict(zip(input_names, [input_data]))
outputs = Dict(zip(output_name, [output_data]))
np.savez(fn, inputs=inputs, outputs=outputs) # overwrite
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Wed Mar 27 11:50:03 2019
@author: Macrobull
"""
# import os, sys
import os
import sys
import numpy as np
import onnx
import onnx.numpy_helper as numpy_helper
from collections import OrderedDict as Dict
from glob import glob
data_dir = os.path.dirname(sys.argv[1])
input_names = sys.argv[2].split(':')
output_name = sys.argv[3].split(':')
# Load inputs
inputs = []
for fn in glob(os.path.join(data_dir, 'input_*.pb')):
tensor = onnx.TensorProto()
with open(fn, 'rb') as f:
tensor.ParseFromString(f.read())
inputs.append(numpy_helper.to_array(tensor))
# Load outputs
outputs = []
for fn in glob(os.path.join(data_dir, 'output_*.pb')):
tensor = onnx.TensorProto()
with open(fn, 'rb') as f:
tensor.ParseFromString(f.read())
outputs.append(numpy_helper.to_array(tensor))
inputs = Dict(zip(input_names, inputs))
outputs = Dict(zip(output_name, outputs))
np.savez(data_dir, inputs=inputs, outputs=outputs)
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Fri Mar 22 11:19:45 2019
@author: Macrobull
Not all ops in this file are supported by both Pytorch and ONNX
This only demostrates the conversion/validation workflow from Pytorch to ONNX to Paddle
"""
from __future__ import print_function
import torch
import torch.nn as nn
import torch.nn.functional as F
from onnx2paddle.torch_export_helper import export_onnx_with_validation
idx = 0
######### example: RNN ########
#
#class Model(nn.Module):
# def __init__(self):
# super(Model, self).__init__()
# self.rnn = nn.RNN(4, 6, 2)
#
# def forward(self, x):
# y = x
# y, h = self.rnn(y)
# return y
#
#
#model = Model()
#xb = torch.rand((2, 3, 4))
#yp = model(xb)
#idx += 1
#print('index: ', idx)
#export_onnx_with_validation(model, (xb, ), 't' + str(idx),
# ['x'], ['y'],
# verbose=True, training=False)
######### example: random ########
#
#class Model(nn.Module):
# def __init__(self):
# super(Model, self).__init__()
#
# def forward(self, x):
# y = torch.rand((2, 3)) # + torch.rand_like(xb)
# y = y + torch.randn((2, 3)) # + torch.randn_like(xb)
# return y
#
#
#model = Model()
#xb = torch.rand((2, 3))
#yp = model(xb)
#idx += 1
#print('index: ', idx)
#export_onnx_with_validation(model, (xb, ), 't' + str(idx),
# ['x'], ['y'],
# verbose=True, training=False)
######## example: fc ########
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.fc = nn.Linear(3, 8)
def forward(self, x):
y = x
y = self.fc(y)
return y
model = Model()
xb = torch.rand((2, 3))
yp = model(xb)
idx += 1
print('index: ', idx)
export_onnx_with_validation(model, (xb, ), 't' + str(idx),
['x'], ['y'],
verbose=True, training=False)
######## example: compare ########
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
def forward(self, x0, x1):
x0 = x0.clamp(-1, 1)
a = torch.max(x0, x1) == x1
b = x0 < x1
c = x0 > x1
return a, b, c
model = Model()
xb0 = torch.rand((2, 3))
xb1 = torch.rand((2, 3))
ya, yb, yc = model(xb0, xb1)
idx += 1
print('index: ', idx)
export_onnx_with_validation(model, (xb0, xb1), 't' + str(idx),
['x0', 'x1'], ['ya', 'yb', 'yc'],
verbose=True, training=False)
######## example: affine_grid ########
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
def forward(self, theta):
grid = F.affine_grid(theta, (2, 2, 8, 8))
return grid
model = Model()
theta = torch.rand((2, 2, 3))
grid = model(theta)
idx += 1
print('index: ', idx)
export_onnx_with_validation(model, (theta, ), 't' + str(idx),
['theta'], ['grid'],
verbose=True, training=False)
######## example: conv2d_transpose ########
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.conv = nn.ConvTranspose2d(3, 8, 3)
self.dropout = nn.Dropout2d()
def forward(self, x):
y = x
y = self.conv(y)
y = self.dropout(y)
return y
model = Model()
xb = torch.rand((2, 3, 4, 5))
yp = model(xb)
idx += 1
print('index: ', idx)
export_onnx_with_validation(model, (xb, ), 't' + str(idx),
['x'], ['y'],
verbose=True, training=False)
######## example: conv2d ########
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.conv = nn.Conv2d(3, 8, 3)
self.batch_norm = nn.BatchNorm2d(8)
self.pool = nn.AdaptiveAvgPool2d(2)
def forward(self, x):
y = x
y = self.conv(y)
y = self.batch_norm(y)
y = self.pool(y)
return y
model = Model()
xb = torch.rand((2, 3, 4, 5))
yp = model(xb)
idx += 1
print('index: ', idx)
export_onnx_with_validation(model, (xb, ), 't' + str(idx),
['x'], ['y'],
verbose=True, training=False)
######### example: conv1d ########
#
#class Model(nn.Module):
# def __init__(self):
# super(Model, self).__init__()
# self.batch_norm = nn.BatchNorm2d(3)
#
# def forward(self, x):
# y = x
# y = self.batch_norm(y)
# return y
#
#
#model = Model()
#xb = torch.rand((2, 3, 4, 5))
#yp = model(xb)
#idx += 1
#print('index: ', idx)
#export_onnx_with_validation(model, (xb, ), 't' + str(idx),
# ['x'], ['y'],
# verbose=True, training=False)
######## example: empty ########
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
def forward(self, x):
return x
model = Model()
xb = torch.rand((2, 3))
yp = model(xb)
idx += 1
print('index: ', idx)
export_onnx_with_validation(model, (xb, ), 't' + str(idx),
['y'], ['y'],
verbose=True, training=False)
#! /usr/bin/env sh
get_url="proxychains4 aria2c -c -s8 -x8"
base_url="https://s3.amazonaws.com/download.onnx/models/opset_9/"
flags="-de -o /tmp/export/"
bvlc_alexnet()
{
bn_tar="bvlc_alexnet"
fn_tar="$bn_tar.tar.gz"
fn_model="$bn_tar/model.onnx"
$get_url "$base_url$fn_tar"
echo "extracting ..."
tar xf "$fn_tar"
for npz in $bn_tar/*.npz
do
echo "converting $npz ..."
python convert_data_npz_0.py "$npz" "data_0" "prob_1"
python -m onnx2paddle $flags "$fn_model" -t $npz
done
for pb_dir in $bn_tar/*/
do
echo "converting $pb_dir ..."
python convert_data_pb_0.py "$pb_dir" "data_0" "prob_1"
python -m onnx2paddle $flags "$fn_model" -t echo $(dirname "$pb_dir/x").npz
done
}
bvlc_googlenet()
{
bn_tar="bvlc_googlenet"
fn_tar="$bn_tar.tar.gz"
fn_model="$bn_tar/model.onnx"
$get_url "$base_url$fn_tar"
echo "extracting ..."
tar xf "$fn_tar"
for pb_dir in $bn_tar/*/
do
echo "converting $pb_dir"
python convert_data_pb_0.py "$pb_dir" "data_0" "prob_1"
python -m onnx2paddle $flags "$fn_model" -t $(dirname "$pb_dir/x").npz
done
}
bvlc_reference_caffenet()
{
bn_tar="bvlc_reference_caffenet"
fn_tar="$bn_tar.tar.gz"
fn_model="$bn_tar/model.onnx"
$get_url "$base_url$fn_tar"
echo "extracting ..."
tar xf "$fn_tar"
for pb_dir in $bn_tar/*/
do
echo "converting $pb_dir"
python convert_data_pb_0.py "$pb_dir" "data_0" "prob_1"
python -m onnx2paddle $flags "$fn_model" -t $(dirname "$pb_dir/x").npz
done
}
bvlc_reference_rcnn_ilsvrc13()
{
bn_tar="bvlc_reference_rcnn_ilsvrc13"
fn_tar="$bn_tar.tar.gz"
fn_model="$bn_tar/model.onnx"
$get_url "$base_url$fn_tar"
echo "extracting ..."
tar xf "$fn_tar"
for pb_dir in $bn_tar/*/
do
echo "converting $pb_dir"
python convert_data_pb_0.py "$pb_dir" "data_0" "softmaxout_1"
python -m onnx2paddle $flags "$fn_model" -t $(dirname "$pb_dir/x").npz
done
}
inception_v1()
{
bn_tar="inception_v1"
fn_tar="$bn_tar.tar.gz"
fn_model="$bn_tar/model.onnx"
$get_url "$base_url$fn_tar"
echo "extracting ..."
tar xf "$fn_tar"
for npz in $bn_tar/*.npz
do
echo "converting $npz ..."
python convert_data_npz_0.py "$npz" "data_0" "prob_1"
python -m onnx2paddle $flags "$fn_model" -t $npz
done
for pb_dir in $bn_tar/*/
do
echo "converting $pb_dir"
python convert_data_pb_0.py "$pb_dir" "data_0" "prob_1"
python -m onnx2paddle $flags "$fn_model" -t $(dirname "$pb_dir/x").npz
done
}
inception_v2()
{
bn_tar="inception_v2"
fn_tar="$bn_tar.tar.gz"
fn_model="$bn_tar/model.onnx"
$get_url "$base_url$fn_tar"
echo "extracting ..."
tar xf "$fn_tar"
for npz in $bn_tar/*.npz
do
echo "converting $npz ..."
python convert_data_npz_0.py "$npz" "data_0" "prob_1"
python -m onnx2paddle $flags "$fn_model" -t $npz
done
for pb_dir in $bn_tar/*/
do
echo "converting $pb_dir"
python convert_data_pb_0.py "$pb_dir" "data_0" "prob_1"
python -m onnx2paddle $flags "$fn_model" -t $(dirname "$pb_dir/x").npz
done
}
resnet50()
{
bn_tar="resnet50"
fn_tar="$bn_tar.tar.gz"
fn_model="$bn_tar/model.onnx"
$get_url "$base_url$fn_tar"
echo "extracting ..."
tar xf "$fn_tar"
for npz in $bn_tar/*.npz
do
echo "converting $npz ..."
python convert_data_npz_0.py "$npz" "gpu_0/data_0" "gpu_0/softmaxout_1"
python -m onnx2paddle $flags "$fn_model" -t $npz
done
for pb_dir in $bn_tar/*/
do
echo "converting $pb_dir"
python convert_data_pb_0.py "$pb_dir" "gpu_0/data_0" "gpu_0/softmaxout_1"
python -m onnx2paddle $flags "$fn_model" -t $(dirname "$pb_dir/x").npz
done
}
shufflenet()
{
bn_tar="shufflenet"
fn_tar="$bn_tar.tar.gz"
fn_model="$bn_tar/model.onnx"
$get_url "$base_url$fn_tar"
echo "extracting ..."
tar xf "$fn_tar"
for pb_dir in $bn_tar/*/
do
echo "converting $pb_dir"
python convert_data_pb_0.py "$pb_dir" "gpu_0/data_0" "gpu_0/softmaxout_1"
python -m onnx2paddle $flags "$fn_model" -t $(dirname "$pb_dir/x").npz
done
}
squeezenet()
{
bn_tar="squeezenet"
fn_tar="$bn_tar.tar.gz"
fn_model="$bn_tar/model.onnx"
$get_url "$base_url$fn_tar"
echo "extracting ..."
tar xf "$fn_tar"
for pb_dir in $bn_tar/*/
do
echo "converting $pb_dir"
python convert_data_pb_0.py "$pb_dir" "data_0" "softmaxout_1"
python -m onnx2paddle $flags "$fn_model" -t $(dirname "$pb_dir/x").npz
done
}
tiny_yolov2()
{
bn_tar="tiny_yolov2"
fn_tar="$bn_tar.tar.gz"
fn_model="$bn_tar/model.onnx"
$get_url "https://onnxzoo.blob.core.windows.net/models/opset_8/tiny_yolov2/$fn_tar"
echo "extracting ..."
tar xf "$fn_tar"
for pb_dir in $bn_tar/*/
do
echo "converting $pb_dir"
python convert_data_pb_0.py "$pb_dir" "image" "grid"
python -m onnx2paddle $flags "$fn_model" -t $(dirname "$pb_dir/x").npz -x
done
}
vgg19()
{
bn_tar="vgg19"
fn_tar="$bn_tar.tar.gz"
fn_model="$bn_tar/model.onnx"
$get_url "$base_url$fn_tar"
echo "extracting ..."
tar xf "$fn_tar"
for pb_dir in $bn_tar/*/
do
echo "converting $pb_dir"
python convert_data_pb_0.py "$pb_dir" "data_0" "prob_1"
python -m onnx2paddle $flags "$fn_model" -t $(dirname "$pb_dir/x").npz
done
}
zfnet512()
{
bn_tar="zfnet512"
fn_tar="$bn_tar.tar.gz"
fn_model="$bn_tar/model.onnx"
$get_url "$base_url$fn_tar"
echo "extracting ..."
tar xf "$fn_tar"
for pb_dir in $bn_tar/*/
do
echo "converting $pb_dir"
python convert_data_pb_0.py "$pb_dir" "gpu_0/data_0" "gpu_0/softmax_1"
python -m onnx2paddle $flags "$fn_model" -t $(dirname "$pb_dir/x").npz
done
}
bvlc_alexnet # data error
bvlc_googlenet # desc error
bvlc_reference_caffenet
bvlc_reference_rcnn_ilsvrc13
inception_v1 ###
inception_v2 ###
resnet50 # data error
shufflenet ###
squeezenet
tiny_yolov2 # not supported
vgg19
zfnet512 # data error
# -*- coding: UTF-8 -*-
################################################################################
#
# Copyright (c) 2019 Baidu.com, Inc. All Rights Reserved
#
################################################################################
"""
本文件允许模块包以python -m onnx2paddle方式直接执行。
Authors: Macrobull
Date: 2019/02/22 10:25:46
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
# import argparse, logging, sys
import argparse
import logging
import sys
parser = argparse.ArgumentParser(description='onnx2paddle',
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument('model', nargs=1,
help='path to model.onnx',
)
parser.add_argument('--debug', '-d', action='store_true',
help='enable debug logging and checking',
)
parser.add_argument('--output-dir', '-o', type=str, default='',
help='output directory',
)
parser.add_argument('--test_data', '-t', type=str, default='',
help='I/O golden data for validation, e.g. test.npy, test.npz',
)
parser.add_argument('--embed_params', '-e', action='store_true',
help='try to embed parameters for trainable Paddle layers',
)
parser.add_argument('--pedantic', action='store_true', default=True,
help='accept and convert only standard ONNX opset',
)
parser.add_argument('--no-pedantic', '-x', action='store_false',
dest='pedantic',
help='process non-standard ONNX ops, this may lead to fails',
)
parser.add_argument('--precision', '-p', type=int, default=4,
help='assertion decimal for validation',
)
args = parser.parse_args()
logging_format = '[%(levelname)8s]%(name)s::%(funcName)s:%(lineno)04d: %(message)s'
logging_level = logging.DEBUG if args.debug else logging.INFO
logging.basicConfig(format=logging_format, level=logging_level)
try:
from . import cmdline
except ImportError:
import cmdline
# imports
main = cmdline.main
sys.exit(main(**args.__dict__))
# -*- coding: UTF-8 -*-
################################################################################
#
# Copyright (c) 2019 Baidu.com, Inc. All Rights Reserved
#
################################################################################
"""
本文件提供了命令行工具的入口逻辑。
Authors: Macrobull
Date: 2019/02/22 10:25:46
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
# import logging, shutil, zipfile
import logging
import shutil
import zipfile
__all__ = [
'main',
]
DEFAULT_ONNX_OPSET_VERSION = 9
DEFAULT_MODEL_MODULE = 'model'
DEFAULT_MODEL_FUNC = 'inference'
def main(**kwargs):
"""主程序入口"""
try:
from . import conversion
except ImportError:
import conversion
# imports
convert = conversion.convert
logger = logging.getLogger('onnx2paddle')
debug = kwargs.get('debug', False)
# prepare arguments
filename = kwargs['model'][0]
basepath, _ = shutil.os.path.splitext(filename)
save_dir = kwargs.get('output_dir', '')
# model.onnx -> model/
save_dir = shutil.os.path.dirname(save_dir) if save_dir else basepath
model_basename = DEFAULT_MODEL_MODULE + '.py'
model_func_name = DEFAULT_MODEL_FUNC
embed_params = kwargs.get('embed_params', False)
onnx_opset_version = DEFAULT_ONNX_OPSET_VERSION
onnx_opset_pedantic = kwargs.get('pedantic', True)
# convert
convert(filename, save_dir,
model_basename=model_basename,
model_func_name=model_func_name,
embed_params=embed_params,
onnx_opset_version=onnx_opset_version,
onnx_opset_pedantic=onnx_opset_pedantic,
debug=debug)
# validate
passed = True
golden_data_filename = kwargs.get('test_data', '')
if golden_data_filename:
try:
from . import validation
except ImportError:
import validation
# imports
validate = validation.validate
# in fact fluid can not fully clear the context
# continuous validation may be inaccurate
precision = 10 ** -kwargs.get('precision', 4)
logger.info('starting validation on desc ...')
passed &= validate(shutil.os.path.join(save_dir, '__model__'),
golden_data_filename,
precision=precision,
)
logger.info('starting validation on code ...')
passed &= validate(shutil.os.path.join(save_dir, model_basename),
golden_data_filename,
model_func_name=model_func_name,
precision=precision,
save_inference_model=debug, # this overwrite desc file for test
)
if not passed:
logger.error('validation failed, exit')
return
# create zip file
fn_zip = save_dir.rstrip('/') + '.zip'
logger.info('compressing file to %s ...', fn_zip)
fz = zipfile.ZipFile(fn_zip, 'w', compression=zipfile.ZIP_LZMA)
for fn in shutil.os.listdir(save_dir):
fz.write(shutil.os.path.join(save_dir, fn), arcname=fn)
fz.close()
logger.info('compressing done')
if __name__ == '__main__':
logging.basicConfig(
format='[%(levelname)8s]%(name)s::%(funcName)s:%(lineno)04d: %(message)s',
level=logging.DEBUG,
)
# main(model=['../examples/t5.onnx'],
# output_dir='/tmp/export/',
# embed_params=False,
# pedantic=False,
# test_data='../examples/t5.npz',
# debug=True)
main(model=['../examples/shufflenet/model.onnx'],
output_dir='/tmp/export/',
embed_params=True,
pedantic=False,
test_data='../examples/shufflenet/test_data_set_0.npz',
debug=True)
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Mon Feb 25 09:50:35 2019
@author: Macrobull
"""
from __future__ import division
# import logging, shutil
import logging
import shutil
__all__ = [
'convert',
]
def convert(onnx_model_filename, save_dir,
model_basename='model.py', model_func_name='inference',
embed_params=False,
onnx_opset_version=9, onnx_opset_pedantic=True,
debug=False):
"""
convert an ONNX model to Paddle Python code and desc pb
"""
import onnx
from onnx.checker import ValidationError
from onnx.checker import check_model
from onnx.utils import polish_model
from onnx.version_converter import convert_version
try:
from . import onnx_utils, writer
except ImportError:
import onnx_utils, writer
# imports
DEFAULT_OP_DOMAIN = onnx_utils.DEFAULT_OP_DOMAIN
graph_ops, graph_weights = onnx_utils.graph_ops, onnx_utils.graph_weights
inferred_model_value_info = onnx_utils.inferred_model_value_info
optimize_model_skip_op_for_inference = onnx_utils.optimize_model_skip_op_for_inference
optimize_model_strip_initializer = onnx_utils.optimize_model_strip_initializer
optimize_model_cast = onnx_utils.optimize_model_cast
optimize_model_slice = onnx_utils.optimize_model_slice
Program, Writer = writer.Program, writer.Writer
make_var_name = writer.make_var_name
logger = logging.getLogger('convert')
# prepare onnx model
logger.info('loading model: %s ...', onnx_model_filename)
onnx_model = onnx.load(onnx_model_filename)
try:
logger.info('checking model ...')
check_model(onnx_model)
logger.debug('using opset version: %d', onnx_opset_version)
if onnx_opset_pedantic: # WORKAROUND: RuntimeError: No Adapter For OP
onnx_model = convert_version(onnx_model, onnx_opset_version)
else: # TODO: add new argument for this option
logger.warning('opset conversion skipped for onnx_opset_pedantic is OFF')
onnx_model = polish_model(onnx_model)
except ValidationError as e:
if onnx_opset_pedantic:
raise e
else:
logger.warning('due to onnx_opset_pedantic is OFF')
logger.warning('the ONNX model sanity checking error is suppressed')
logger.warning('value_info inferring may be uncompleted')
# onnx model optimization
logger.info('optimizing model ...')
onnx_model = optimize_model_skip_op_for_inference(onnx_model)
onnx_model = optimize_model_strip_initializer(onnx_model)
onnx_model = optimize_model_cast(onnx_model)
onnx_model = optimize_model_slice(onnx_model)
# prepare filesystem
shutil.rmtree(save_dir, ignore_errors=True)
shutil.os.makedirs(save_dir, exist_ok=True)
logger.info('folder %s cleared', save_dir)
# DEBUG:
if debug:
model = onnx.shape_inference.infer_shapes(onnx_model)
debug_model_filename, _ = shutil.os.path.splitext(onnx_model_filename)
onnx.save(model, debug_model_filename + '.optimized_and_inffered.onnx')
# onnx.save(model, '/tmp/export/optimized_and_inffered.onnx')
# I/O instances
onnx_graph = onnx_model.graph
paddle_program = Program()
paddle_writer = Writer()
# model components
# graph_name = onnx_graph.name
graph_inputs = [value.name for value in onnx_graph.input]
graph_outputs = [value.name for value in onnx_graph.output]
graph_params = []
graph_value_infos = inferred_model_value_info(onnx_model)
# prepare additional value_info
# for weights
for name, weight in graph_weights(onnx_graph):
value_info = graph_value_infos[name]
value_info['embeded_as'] = []
value_info['get_weight'] = lambda: weight.tolist() # lazy getter
logger.info('conversion started')
# op set conversion
# topo = 'backward' if embed_params else 'forward'
topo = 'forward'
for name, domain, op_type, inputs, outputs, attrs in graph_ops(onnx_graph, topo=topo):
logger.debug('translating op %s %s::%s ...', name, domain, op_type)
if domain == DEFAULT_OP_DOMAIN:
domain = ''
try:
paddle_writer.emit_op(paddle_program, name, domain, op_type,
inputs, outputs, attrs,
graph_value_infos,
embed_params=embed_params,
)
except BaseException as e:
logger.fatal('conversion failed for:\n\t%s -> %s::%s -> %s',
inputs, domain, op_type, outputs)
raise e
op_codes = paddle_program.codes
paddle_program.codes = []
logger.info('%d ops converted', len(paddle_program.op_descs))
# weight writer
for name, weight in graph_weights(onnx_graph):
graph_params.append(name)
value_info = graph_value_infos[name]
var_names = value_info.get('embeded_as', [])
if var_names:
if len(var_names) > 1:
logger.info('weight %s is shared between ops, more disk space will be consumed', name)
logger.debug('saving weight %s with size of %d, in %d bytes, as %s ...',
name, weight.size, weight.nbytes, var_names)
for var_name in var_names: # multiple references
paddle_writer.write_weight(weight, shutil.os.path.join(save_dir, var_name))
else:
logger.debug('saving weight %s with size of %d, in %d bytes, to %s ...',
name, weight.size, weight.nbytes, make_var_name(name))
paddle_writer.write_weight(weight, shutil.os.path.join(save_dir, make_var_name(name)))
paddle_writer.emit_param(paddle_program, name, value_info)
param_codes = paddle_program.codes
paddle_program.codes = []
logger.info('%d weights converted', len(graph_params))
# input writer
external_inputs = []
for name in graph_inputs:
if name not in graph_params:
value_info = graph_value_infos[name]
assert value_info['external']
external_inputs.append(name)
paddle_writer.emit_inputs(paddle_program, external_inputs, graph_value_infos, remove_batch=False) # TODO:
input_codes = paddle_program.codes
paddle_program.codes = []
logger.info('%d inputs converted', len(external_inputs))
# output writer
external_outputs = []
for name in graph_outputs:
if name not in graph_params:
value_info = graph_value_infos[name]
assert value_info['external']
external_outputs.append(name)
paddle_writer.emit_outputs(paddle_program, external_outputs)
output_codes = [''] + paddle_program.codes # add an empty line
paddle_program.codes = []
logger.info('%d outputs converted', len(external_outputs))
# code generation
code_filename = shutil.os.path.join(save_dir, model_basename)
paddle_writer.write_code_file(code_filename, paddle_writer.header_code(model_func_name),
input_codes, param_codes, op_codes, output_codes)
logger.info('code saved to %s, factory function: %s', code_filename, model_func_name)
# desc generation
desc_filename = shutil.os.path.join(save_dir, '__model__')
paddle_writer.write_desc_file(desc_filename,
op_descs=paddle_program.op_descs,
var_descs=paddle_program.var_descs,
)
logger.info('program saved to %s', desc_filename)
logger.info('conversion finished')
# globals().update(locals())
if __name__ == '__main__':
logging.basicConfig(
format='[%(levelname)8s]%(name)s::%(funcName)s:%(lineno)04d: %(message)s',
level=logging.DEBUG,
)
model_list = [
'../examples/t1.onnx',
'../examples/t2.onnx',
'../examples/t3.onnx',
'../examples/t4.onnx',
'../examples/t5.onnx',
'../examples/t6.onnx',
# '../examples/t7.onnx',
# '../examples/t8.onnx',
]
for model in model_list:
pathname, _ = shutil.os.path.splitext(model)
convert(model, pathname,
onnx_opset_pedantic=False, debug=True)
convert(model, pathname + '.embeded',
embed_params=True, onnx_opset_pedantic=False, debug=True)
# Generated by the protocol buffer compiler. DO NOT EDIT!
# source: framework.proto
import sys
_b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1'))
from google.protobuf.internal import enum_type_wrapper
from google.protobuf import descriptor as _descriptor
from google.protobuf import message as _message
from google.protobuf import reflection as _reflection
from google.protobuf import symbol_database as _symbol_database
from google.protobuf import descriptor_pb2
# @@protoc_insertion_point(imports)
_sym_db = _symbol_database.Default()
DESCRIPTOR = _descriptor.FileDescriptor(
name='framework.proto',
package='paddle.framework.proto',
syntax='proto2',
serialized_pb=_b('\n\x0f\x66ramework.proto\x12\x16paddle.framework.proto\"\x1d\n\x07Version\x12\x12\n\x07version\x18\x01 \x01(\x03:\x01\x30\"\xec\x03\n\x06OpDesc\x12\x0c\n\x04type\x18\x03 \x02(\t\x12\x32\n\x06inputs\x18\x01 \x03(\x0b\x32\".paddle.framework.proto.OpDesc.Var\x12\x33\n\x07outputs\x18\x02 \x03(\x0b\x32\".paddle.framework.proto.OpDesc.Var\x12\x32\n\x05\x61ttrs\x18\x04 \x03(\x0b\x32#.paddle.framework.proto.OpDesc.Attr\x12\x18\n\tis_target\x18\x05 \x01(\x08:\x05\x66\x61lse\x1a\xef\x01\n\x04\x41ttr\x12\x0c\n\x04name\x18\x01 \x02(\t\x12.\n\x04type\x18\x02 \x02(\x0e\x32 .paddle.framework.proto.AttrType\x12\t\n\x01i\x18\x03 \x01(\x05\x12\t\n\x01\x66\x18\x04 \x01(\x02\x12\t\n\x01s\x18\x05 \x01(\t\x12\x0c\n\x04ints\x18\x06 \x03(\x05\x12\x0e\n\x06\x66loats\x18\x07 \x03(\x02\x12\x0f\n\x07strings\x18\x08 \x03(\t\x12\t\n\x01\x62\x18\n \x01(\x08\x12\r\n\x05\x62ools\x18\x0b \x03(\x08\x12\x11\n\tblock_idx\x18\x0c \x01(\x05\x12\t\n\x01l\x18\r \x01(\x03\x12\x12\n\nblocks_idx\x18\x0e \x03(\x05\x12\r\n\x05longs\x18\x0f \x03(\x03\x1a+\n\x03Var\x12\x11\n\tparameter\x18\x01 \x02(\t\x12\x11\n\targuments\x18\x02 \x03(\t\"\xb3\x03\n\x07OpProto\x12\x0c\n\x04type\x18\x01 \x02(\t\x12\x33\n\x06inputs\x18\x02 \x03(\x0b\x32#.paddle.framework.proto.OpProto.Var\x12\x34\n\x07outputs\x18\x03 \x03(\x0b\x32#.paddle.framework.proto.OpProto.Var\x12\x33\n\x05\x61ttrs\x18\x04 \x03(\x0b\x32$.paddle.framework.proto.OpProto.Attr\x12\x0f\n\x07\x63omment\x18\x05 \x02(\t\x1ax\n\x03Var\x12\x0c\n\x04name\x18\x01 \x02(\t\x12\x0f\n\x07\x63omment\x18\x02 \x02(\t\x12\x19\n\nduplicable\x18\x03 \x01(\x08:\x05\x66\x61lse\x12\x1b\n\x0cintermediate\x18\x04 \x01(\x08:\x05\x66\x61lse\x12\x1a\n\x0b\x64ispensable\x18\x05 \x01(\x08:\x05\x66\x61lse\x1ao\n\x04\x41ttr\x12\x0c\n\x04name\x18\x01 \x02(\t\x12.\n\x04type\x18\x02 \x02(\x0e\x32 .paddle.framework.proto.AttrType\x12\x0f\n\x07\x63omment\x18\x03 \x02(\t\x12\x18\n\tgenerated\x18\x04 \x01(\x08:\x05\x66\x61lse\"\xda\x08\n\x07VarType\x12\x32\n\x04type\x18\x01 \x02(\x0e\x32$.paddle.framework.proto.VarType.Type\x12\x41\n\rselected_rows\x18\x02 \x01(\x0b\x32*.paddle.framework.proto.VarType.TensorDesc\x12\x41\n\nlod_tensor\x18\x03 \x01(\x0b\x32-.paddle.framework.proto.VarType.LoDTensorDesc\x12H\n\x0ctensor_array\x18\x04 \x01(\x0b\x32\x32.paddle.framework.proto.VarType.LoDTensorArrayDesc\x12:\n\x06reader\x18\x05 \x01(\x0b\x32*.paddle.framework.proto.VarType.ReaderDesc\x12\x34\n\x05tuple\x18\x07 \x01(\x0b\x32%.paddle.framework.proto.VarType.Tuple\x1aS\n\nTensorDesc\x12\x37\n\tdata_type\x18\x01 \x02(\x0e\x32$.paddle.framework.proto.VarType.Type\x12\x0c\n\x04\x64ims\x18\x02 \x03(\x03\x1a\x61\n\rLoDTensorDesc\x12:\n\x06tensor\x18\x01 \x02(\x0b\x32*.paddle.framework.proto.VarType.TensorDesc\x12\x14\n\tlod_level\x18\x02 \x01(\x05:\x01\x30\x1a\x66\n\x12LoDTensorArrayDesc\x12:\n\x06tensor\x18\x01 \x02(\x0b\x32*.paddle.framework.proto.VarType.TensorDesc\x12\x14\n\tlod_level\x18\x02 \x01(\x05:\x01\x30\x1aO\n\nReaderDesc\x12\x41\n\nlod_tensor\x18\x01 \x03(\x0b\x32-.paddle.framework.proto.VarType.LoDTensorDesc\x1a\x43\n\x05Tuple\x12:\n\x0c\x65lement_type\x18\x01 \x03(\x0e\x32$.paddle.framework.proto.VarType.Type\"\xa2\x02\n\x04Type\x12\x08\n\x04\x42OOL\x10\x00\x12\t\n\x05INT16\x10\x01\x12\t\n\x05INT32\x10\x02\x12\t\n\x05INT64\x10\x03\x12\x08\n\x04\x46P16\x10\x04\x12\x08\n\x04\x46P32\x10\x05\x12\x08\n\x04\x46P64\x10\x06\x12\n\n\x06SIZE_T\x10\x13\x12\t\n\x05UINT8\x10\x14\x12\x08\n\x04INT8\x10\x15\x12\x0e\n\nLOD_TENSOR\x10\x07\x12\x11\n\rSELECTED_ROWS\x10\x08\x12\x12\n\x0e\x46\x45\x45\x44_MINIBATCH\x10\t\x12\x0e\n\nFETCH_LIST\x10\n\x12\x0f\n\x0bSTEP_SCOPES\x10\x0b\x12\x12\n\x0eLOD_RANK_TABLE\x10\x0c\x12\x14\n\x10LOD_TENSOR_ARRAY\x10\r\x12\x0e\n\nPLACE_LIST\x10\x0e\x12\n\n\x06READER\x10\x0f\x12\x07\n\x03RAW\x10\x11\x12\t\n\x05TUPLE\x10\x12\"b\n\x07VarDesc\x12\x0c\n\x04name\x18\x01 \x02(\t\x12-\n\x04type\x18\x02 \x02(\x0b\x32\x1f.paddle.framework.proto.VarType\x12\x1a\n\x0bpersistable\x18\x03 \x01(\x08:\x05\x66\x61lse\"\xa7\x01\n\tBlockDesc\x12\x0b\n\x03idx\x18\x01 \x02(\x05\x12\x12\n\nparent_idx\x18\x02 \x02(\x05\x12-\n\x04vars\x18\x03 \x03(\x0b\x32\x1f.paddle.framework.proto.VarDesc\x12+\n\x03ops\x18\x04 \x03(\x0b\x32\x1e.paddle.framework.proto.OpDesc\x12\x1d\n\x11\x66orward_block_idx\x18\x05 \x01(\x05:\x02-1\"r\n\x0bProgramDesc\x12\x31\n\x06\x62locks\x18\x01 \x03(\x0b\x32!.paddle.framework.proto.BlockDesc\x12\x30\n\x07version\x18\x02 \x01(\x0b\x32\x1f.paddle.framework.proto.Version*\x94\x01\n\x08\x41ttrType\x12\x07\n\x03INT\x10\x00\x12\t\n\x05\x46LOAT\x10\x01\x12\n\n\x06STRING\x10\x02\x12\x08\n\x04INTS\x10\x03\x12\n\n\x06\x46LOATS\x10\x04\x12\x0b\n\x07STRINGS\x10\x05\x12\x0b\n\x07\x42OOLEAN\x10\x06\x12\x0c\n\x08\x42OOLEANS\x10\x07\x12\t\n\x05\x42LOCK\x10\x08\x12\x08\n\x04LONG\x10\t\x12\n\n\x06\x42LOCKS\x10\n\x12\t\n\x05LONGS\x10\x0b\x42\x02H\x03')
)
_sym_db.RegisterFileDescriptor(DESCRIPTOR)
_ATTRTYPE = _descriptor.EnumDescriptor(
name='AttrType',
full_name='paddle.framework.proto.AttrType',
filename=None,
file=DESCRIPTOR,
values=[
_descriptor.EnumValueDescriptor(
name='INT', index=0, number=0,
options=None,
type=None),
_descriptor.EnumValueDescriptor(
name='FLOAT', index=1, number=1,
options=None,
type=None),
_descriptor.EnumValueDescriptor(
name='STRING', index=2, number=2,
options=None,
type=None),
_descriptor.EnumValueDescriptor(
name='INTS', index=3, number=3,
options=None,
type=None),
_descriptor.EnumValueDescriptor(
name='FLOATS', index=4, number=4,
options=None,
type=None),
_descriptor.EnumValueDescriptor(
name='STRINGS', index=5, number=5,
options=None,
type=None),
_descriptor.EnumValueDescriptor(
name='BOOLEAN', index=6, number=6,
options=None,
type=None),
_descriptor.EnumValueDescriptor(
name='BOOLEANS', index=7, number=7,
options=None,
type=None),
_descriptor.EnumValueDescriptor(
name='BLOCK', index=8, number=8,
options=None,
type=None),
_descriptor.EnumValueDescriptor(
name='LONG', index=9, number=9,
options=None,
type=None),
_descriptor.EnumValueDescriptor(
name='BLOCKS', index=10, number=10,
options=None,
type=None),
_descriptor.EnumValueDescriptor(
name='LONGS', index=11, number=11,
options=None,
type=None),
],
containing_type=None,
options=None,
serialized_start=2511,
serialized_end=2659,
)
_sym_db.RegisterEnumDescriptor(_ATTRTYPE)
AttrType = enum_type_wrapper.EnumTypeWrapper(_ATTRTYPE)
INT = 0
FLOAT = 1
STRING = 2
INTS = 3
FLOATS = 4
STRINGS = 5
BOOLEAN = 6
BOOLEANS = 7
BLOCK = 8
LONG = 9
BLOCKS = 10
LONGS = 11
_VARTYPE_TYPE = _descriptor.EnumDescriptor(
name='Type',
full_name='paddle.framework.proto.VarType.Type',
filename=None,
file=DESCRIPTOR,
values=[
_descriptor.EnumValueDescriptor(
name='BOOL', index=0, number=0,
options=None,
type=None),
_descriptor.EnumValueDescriptor(
name='INT16', index=1, number=1,
options=None,
type=None),
_descriptor.EnumValueDescriptor(
name='INT32', index=2, number=2,
options=None,
type=None),
_descriptor.EnumValueDescriptor(
name='INT64', index=3, number=3,
options=None,
type=None),
_descriptor.EnumValueDescriptor(
name='FP16', index=4, number=4,
options=None,
type=None),
_descriptor.EnumValueDescriptor(
name='FP32', index=5, number=5,
options=None,
type=None),
_descriptor.EnumValueDescriptor(
name='FP64', index=6, number=6,
options=None,
type=None),
_descriptor.EnumValueDescriptor(
name='SIZE_T', index=7, number=19,
options=None,
type=None),
_descriptor.EnumValueDescriptor(
name='UINT8', index=8, number=20,
options=None,
type=None),
_descriptor.EnumValueDescriptor(
name='INT8', index=9, number=21,
options=None,
type=None),
_descriptor.EnumValueDescriptor(
name='LOD_TENSOR', index=10, number=7,
options=None,
type=None),
_descriptor.EnumValueDescriptor(
name='SELECTED_ROWS', index=11, number=8,
options=None,
type=None),
_descriptor.EnumValueDescriptor(
name='FEED_MINIBATCH', index=12, number=9,
options=None,
type=None),
_descriptor.EnumValueDescriptor(
name='FETCH_LIST', index=13, number=10,
options=None,
type=None),
_descriptor.EnumValueDescriptor(
name='STEP_SCOPES', index=14, number=11,
options=None,
type=None),
_descriptor.EnumValueDescriptor(
name='LOD_RANK_TABLE', index=15, number=12,
options=None,
type=None),
_descriptor.EnumValueDescriptor(
name='LOD_TENSOR_ARRAY', index=16, number=13,
options=None,
type=None),
_descriptor.EnumValueDescriptor(
name='PLACE_LIST', index=17, number=14,
options=None,
type=None),
_descriptor.EnumValueDescriptor(
name='READER', index=18, number=15,
options=None,
type=None),
_descriptor.EnumValueDescriptor(
name='RAW', index=19, number=17,
options=None,
type=None),
_descriptor.EnumValueDescriptor(
name='TUPLE', index=20, number=18,
options=None,
type=None),
],
containing_type=None,
options=None,
serialized_start=1832,
serialized_end=2122,
)
_sym_db.RegisterEnumDescriptor(_VARTYPE_TYPE)
_VERSION = _descriptor.Descriptor(
name='Version',
full_name='paddle.framework.proto.Version',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='version', full_name='paddle.framework.proto.Version.version', index=0,
number=1, type=3, cpp_type=2, label=1,
has_default_value=True, default_value=0,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
],
extensions=[
],
nested_types=[],
enum_types=[
],
options=None,
is_extendable=False,
syntax='proto2',
extension_ranges=[],
oneofs=[
],
serialized_start=43,
serialized_end=72,
)
_OPDESC_ATTR = _descriptor.Descriptor(
name='Attr',
full_name='paddle.framework.proto.OpDesc.Attr',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='name', full_name='paddle.framework.proto.OpDesc.Attr.name', index=0,
number=1, type=9, cpp_type=9, label=2,
has_default_value=False, default_value=_b("").decode('utf-8'),
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='type', full_name='paddle.framework.proto.OpDesc.Attr.type', index=1,
number=2, type=14, cpp_type=8, label=2,
has_default_value=False, default_value=0,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='i', full_name='paddle.framework.proto.OpDesc.Attr.i', index=2,
number=3, type=5, cpp_type=1, label=1,
has_default_value=False, default_value=0,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='f', full_name='paddle.framework.proto.OpDesc.Attr.f', index=3,
number=4, type=2, cpp_type=6, label=1,
has_default_value=False, default_value=float(0),
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='s', full_name='paddle.framework.proto.OpDesc.Attr.s', index=4,
number=5, type=9, cpp_type=9, label=1,
has_default_value=False, default_value=_b("").decode('utf-8'),
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='ints', full_name='paddle.framework.proto.OpDesc.Attr.ints', index=5,
number=6, type=5, cpp_type=1, label=3,
has_default_value=False, default_value=[],
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='floats', full_name='paddle.framework.proto.OpDesc.Attr.floats', index=6,
number=7, type=2, cpp_type=6, label=3,
has_default_value=False, default_value=[],
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='strings', full_name='paddle.framework.proto.OpDesc.Attr.strings', index=7,
number=8, type=9, cpp_type=9, label=3,
has_default_value=False, default_value=[],
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='b', full_name='paddle.framework.proto.OpDesc.Attr.b', index=8,
number=10, type=8, cpp_type=7, label=1,
has_default_value=False, default_value=False,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='bools', full_name='paddle.framework.proto.OpDesc.Attr.bools', index=9,
number=11, type=8, cpp_type=7, label=3,
has_default_value=False, default_value=[],
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='block_idx', full_name='paddle.framework.proto.OpDesc.Attr.block_idx', index=10,
number=12, type=5, cpp_type=1, label=1,
has_default_value=False, default_value=0,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='l', full_name='paddle.framework.proto.OpDesc.Attr.l', index=11,
number=13, type=3, cpp_type=2, label=1,
has_default_value=False, default_value=0,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='blocks_idx', full_name='paddle.framework.proto.OpDesc.Attr.blocks_idx', index=12,
number=14, type=5, cpp_type=1, label=3,
has_default_value=False, default_value=[],
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='longs', full_name='paddle.framework.proto.OpDesc.Attr.longs', index=13,
number=15, type=3, cpp_type=2, label=3,
has_default_value=False, default_value=[],
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
],
extensions=[
],
nested_types=[],
enum_types=[
],
options=None,
is_extendable=False,
syntax='proto2',
extension_ranges=[],
oneofs=[
],
serialized_start=283,
serialized_end=522,
)
_OPDESC_VAR = _descriptor.Descriptor(
name='Var',
full_name='paddle.framework.proto.OpDesc.Var',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='parameter', full_name='paddle.framework.proto.OpDesc.Var.parameter', index=0,
number=1, type=9, cpp_type=9, label=2,
has_default_value=False, default_value=_b("").decode('utf-8'),
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='arguments', full_name='paddle.framework.proto.OpDesc.Var.arguments', index=1,
number=2, type=9, cpp_type=9, label=3,
has_default_value=False, default_value=[],
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
],
extensions=[
],
nested_types=[],
enum_types=[
],
options=None,
is_extendable=False,
syntax='proto2',
extension_ranges=[],
oneofs=[
],
serialized_start=524,
serialized_end=567,
)
_OPDESC = _descriptor.Descriptor(
name='OpDesc',
full_name='paddle.framework.proto.OpDesc',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='type', full_name='paddle.framework.proto.OpDesc.type', index=0,
number=3, type=9, cpp_type=9, label=2,
has_default_value=False, default_value=_b("").decode('utf-8'),
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='inputs', full_name='paddle.framework.proto.OpDesc.inputs', index=1,
number=1, type=11, cpp_type=10, label=3,
has_default_value=False, default_value=[],
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='outputs', full_name='paddle.framework.proto.OpDesc.outputs', index=2,
number=2, type=11, cpp_type=10, label=3,
has_default_value=False, default_value=[],
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='attrs', full_name='paddle.framework.proto.OpDesc.attrs', index=3,
number=4, type=11, cpp_type=10, label=3,
has_default_value=False, default_value=[],
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='is_target', full_name='paddle.framework.proto.OpDesc.is_target', index=4,
number=5, type=8, cpp_type=7, label=1,
has_default_value=True, default_value=False,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
],
extensions=[
],
nested_types=[_OPDESC_ATTR, _OPDESC_VAR, ],
enum_types=[
],
options=None,
is_extendable=False,
syntax='proto2',
extension_ranges=[],
oneofs=[
],
serialized_start=75,
serialized_end=567,
)
_OPPROTO_VAR = _descriptor.Descriptor(
name='Var',
full_name='paddle.framework.proto.OpProto.Var',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='name', full_name='paddle.framework.proto.OpProto.Var.name', index=0,
number=1, type=9, cpp_type=9, label=2,
has_default_value=False, default_value=_b("").decode('utf-8'),
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='comment', full_name='paddle.framework.proto.OpProto.Var.comment', index=1,
number=2, type=9, cpp_type=9, label=2,
has_default_value=False, default_value=_b("").decode('utf-8'),
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='duplicable', full_name='paddle.framework.proto.OpProto.Var.duplicable', index=2,
number=3, type=8, cpp_type=7, label=1,
has_default_value=True, default_value=False,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='intermediate', full_name='paddle.framework.proto.OpProto.Var.intermediate', index=3,
number=4, type=8, cpp_type=7, label=1,
has_default_value=True, default_value=False,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='dispensable', full_name='paddle.framework.proto.OpProto.Var.dispensable', index=4,
number=5, type=8, cpp_type=7, label=1,
has_default_value=True, default_value=False,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
],
extensions=[
],
nested_types=[],
enum_types=[
],
options=None,
is_extendable=False,
syntax='proto2',
extension_ranges=[],
oneofs=[
],
serialized_start=772,
serialized_end=892,
)
_OPPROTO_ATTR = _descriptor.Descriptor(
name='Attr',
full_name='paddle.framework.proto.OpProto.Attr',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='name', full_name='paddle.framework.proto.OpProto.Attr.name', index=0,
number=1, type=9, cpp_type=9, label=2,
has_default_value=False, default_value=_b("").decode('utf-8'),
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='type', full_name='paddle.framework.proto.OpProto.Attr.type', index=1,
number=2, type=14, cpp_type=8, label=2,
has_default_value=False, default_value=0,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='comment', full_name='paddle.framework.proto.OpProto.Attr.comment', index=2,
number=3, type=9, cpp_type=9, label=2,
has_default_value=False, default_value=_b("").decode('utf-8'),
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='generated', full_name='paddle.framework.proto.OpProto.Attr.generated', index=3,
number=4, type=8, cpp_type=7, label=1,
has_default_value=True, default_value=False,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
],
extensions=[
],
nested_types=[],
enum_types=[
],
options=None,
is_extendable=False,
syntax='proto2',
extension_ranges=[],
oneofs=[
],
serialized_start=894,
serialized_end=1005,
)
_OPPROTO = _descriptor.Descriptor(
name='OpProto',
full_name='paddle.framework.proto.OpProto',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='type', full_name='paddle.framework.proto.OpProto.type', index=0,
number=1, type=9, cpp_type=9, label=2,
has_default_value=False, default_value=_b("").decode('utf-8'),
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='inputs', full_name='paddle.framework.proto.OpProto.inputs', index=1,
number=2, type=11, cpp_type=10, label=3,
has_default_value=False, default_value=[],
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='outputs', full_name='paddle.framework.proto.OpProto.outputs', index=2,
number=3, type=11, cpp_type=10, label=3,
has_default_value=False, default_value=[],
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='attrs', full_name='paddle.framework.proto.OpProto.attrs', index=3,
number=4, type=11, cpp_type=10, label=3,
has_default_value=False, default_value=[],
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='comment', full_name='paddle.framework.proto.OpProto.comment', index=4,
number=5, type=9, cpp_type=9, label=2,
has_default_value=False, default_value=_b("").decode('utf-8'),
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
],
extensions=[
],
nested_types=[_OPPROTO_VAR, _OPPROTO_ATTR, ],
enum_types=[
],
options=None,
is_extendable=False,
syntax='proto2',
extension_ranges=[],
oneofs=[
],
serialized_start=570,
serialized_end=1005,
)
_VARTYPE_TENSORDESC = _descriptor.Descriptor(
name='TensorDesc',
full_name='paddle.framework.proto.VarType.TensorDesc',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='data_type', full_name='paddle.framework.proto.VarType.TensorDesc.data_type', index=0,
number=1, type=14, cpp_type=8, label=2,
has_default_value=False, default_value=0,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='dims', full_name='paddle.framework.proto.VarType.TensorDesc.dims', index=1,
number=2, type=3, cpp_type=2, label=3,
has_default_value=False, default_value=[],
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
],
extensions=[
],
nested_types=[],
enum_types=[
],
options=None,
is_extendable=False,
syntax='proto2',
extension_ranges=[],
oneofs=[
],
serialized_start=1393,
serialized_end=1476,
)
_VARTYPE_LODTENSORDESC = _descriptor.Descriptor(
name='LoDTensorDesc',
full_name='paddle.framework.proto.VarType.LoDTensorDesc',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='tensor', full_name='paddle.framework.proto.VarType.LoDTensorDesc.tensor', index=0,
number=1, type=11, cpp_type=10, label=2,
has_default_value=False, default_value=None,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='lod_level', full_name='paddle.framework.proto.VarType.LoDTensorDesc.lod_level', index=1,
number=2, type=5, cpp_type=1, label=1,
has_default_value=True, default_value=0,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
],
extensions=[
],
nested_types=[],
enum_types=[
],
options=None,
is_extendable=False,
syntax='proto2',
extension_ranges=[],
oneofs=[
],
serialized_start=1478,
serialized_end=1575,
)
_VARTYPE_LODTENSORARRAYDESC = _descriptor.Descriptor(
name='LoDTensorArrayDesc',
full_name='paddle.framework.proto.VarType.LoDTensorArrayDesc',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='tensor', full_name='paddle.framework.proto.VarType.LoDTensorArrayDesc.tensor', index=0,
number=1, type=11, cpp_type=10, label=2,
has_default_value=False, default_value=None,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='lod_level', full_name='paddle.framework.proto.VarType.LoDTensorArrayDesc.lod_level', index=1,
number=2, type=5, cpp_type=1, label=1,
has_default_value=True, default_value=0,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
],
extensions=[
],
nested_types=[],
enum_types=[
],
options=None,
is_extendable=False,
syntax='proto2',
extension_ranges=[],
oneofs=[
],
serialized_start=1577,
serialized_end=1679,
)
_VARTYPE_READERDESC = _descriptor.Descriptor(
name='ReaderDesc',
full_name='paddle.framework.proto.VarType.ReaderDesc',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='lod_tensor', full_name='paddle.framework.proto.VarType.ReaderDesc.lod_tensor', index=0,
number=1, type=11, cpp_type=10, label=3,
has_default_value=False, default_value=[],
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
],
extensions=[
],
nested_types=[],
enum_types=[
],
options=None,
is_extendable=False,
syntax='proto2',
extension_ranges=[],
oneofs=[
],
serialized_start=1681,
serialized_end=1760,
)
_VARTYPE_TUPLE = _descriptor.Descriptor(
name='Tuple',
full_name='paddle.framework.proto.VarType.Tuple',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='element_type', full_name='paddle.framework.proto.VarType.Tuple.element_type', index=0,
number=1, type=14, cpp_type=8, label=3,
has_default_value=False, default_value=[],
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
],
extensions=[
],
nested_types=[],
enum_types=[
],
options=None,
is_extendable=False,
syntax='proto2',
extension_ranges=[],
oneofs=[
],
serialized_start=1762,
serialized_end=1829,
)
_VARTYPE = _descriptor.Descriptor(
name='VarType',
full_name='paddle.framework.proto.VarType',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='type', full_name='paddle.framework.proto.VarType.type', index=0,
number=1, type=14, cpp_type=8, label=2,
has_default_value=False, default_value=0,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='selected_rows', full_name='paddle.framework.proto.VarType.selected_rows', index=1,
number=2, type=11, cpp_type=10, label=1,
has_default_value=False, default_value=None,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='lod_tensor', full_name='paddle.framework.proto.VarType.lod_tensor', index=2,
number=3, type=11, cpp_type=10, label=1,
has_default_value=False, default_value=None,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='tensor_array', full_name='paddle.framework.proto.VarType.tensor_array', index=3,
number=4, type=11, cpp_type=10, label=1,
has_default_value=False, default_value=None,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='reader', full_name='paddle.framework.proto.VarType.reader', index=4,
number=5, type=11, cpp_type=10, label=1,
has_default_value=False, default_value=None,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='tuple', full_name='paddle.framework.proto.VarType.tuple', index=5,
number=7, type=11, cpp_type=10, label=1,
has_default_value=False, default_value=None,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
],
extensions=[
],
nested_types=[_VARTYPE_TENSORDESC, _VARTYPE_LODTENSORDESC, _VARTYPE_LODTENSORARRAYDESC, _VARTYPE_READERDESC, _VARTYPE_TUPLE, ],
enum_types=[
_VARTYPE_TYPE,
],
options=None,
is_extendable=False,
syntax='proto2',
extension_ranges=[],
oneofs=[
],
serialized_start=1008,
serialized_end=2122,
)
_VARDESC = _descriptor.Descriptor(
name='VarDesc',
full_name='paddle.framework.proto.VarDesc',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='name', full_name='paddle.framework.proto.VarDesc.name', index=0,
number=1, type=9, cpp_type=9, label=2,
has_default_value=False, default_value=_b("").decode('utf-8'),
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='type', full_name='paddle.framework.proto.VarDesc.type', index=1,
number=2, type=11, cpp_type=10, label=2,
has_default_value=False, default_value=None,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='persistable', full_name='paddle.framework.proto.VarDesc.persistable', index=2,
number=3, type=8, cpp_type=7, label=1,
has_default_value=True, default_value=False,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
],
extensions=[
],
nested_types=[],
enum_types=[
],
options=None,
is_extendable=False,
syntax='proto2',
extension_ranges=[],
oneofs=[
],
serialized_start=2124,
serialized_end=2222,
)
_BLOCKDESC = _descriptor.Descriptor(
name='BlockDesc',
full_name='paddle.framework.proto.BlockDesc',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='idx', full_name='paddle.framework.proto.BlockDesc.idx', index=0,
number=1, type=5, cpp_type=1, label=2,
has_default_value=False, default_value=0,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='parent_idx', full_name='paddle.framework.proto.BlockDesc.parent_idx', index=1,
number=2, type=5, cpp_type=1, label=2,
has_default_value=False, default_value=0,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='vars', full_name='paddle.framework.proto.BlockDesc.vars', index=2,
number=3, type=11, cpp_type=10, label=3,
has_default_value=False, default_value=[],
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='ops', full_name='paddle.framework.proto.BlockDesc.ops', index=3,
number=4, type=11, cpp_type=10, label=3,
has_default_value=False, default_value=[],
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='forward_block_idx', full_name='paddle.framework.proto.BlockDesc.forward_block_idx', index=4,
number=5, type=5, cpp_type=1, label=1,
has_default_value=True, default_value=-1,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
],
extensions=[
],
nested_types=[],
enum_types=[
],
options=None,
is_extendable=False,
syntax='proto2',
extension_ranges=[],
oneofs=[
],
serialized_start=2225,
serialized_end=2392,
)
_PROGRAMDESC = _descriptor.Descriptor(
name='ProgramDesc',
full_name='paddle.framework.proto.ProgramDesc',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='blocks', full_name='paddle.framework.proto.ProgramDesc.blocks', index=0,
number=1, type=11, cpp_type=10, label=3,
has_default_value=False, default_value=[],
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='version', full_name='paddle.framework.proto.ProgramDesc.version', index=1,
number=2, type=11, cpp_type=10, label=1,
has_default_value=False, default_value=None,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
],
extensions=[
],
nested_types=[],
enum_types=[
],
options=None,
is_extendable=False,
syntax='proto2',
extension_ranges=[],
oneofs=[
],
serialized_start=2394,
serialized_end=2508,
)
_OPDESC_ATTR.fields_by_name['type'].enum_type = _ATTRTYPE
_OPDESC_ATTR.containing_type = _OPDESC
_OPDESC_VAR.containing_type = _OPDESC
_OPDESC.fields_by_name['inputs'].message_type = _OPDESC_VAR
_OPDESC.fields_by_name['outputs'].message_type = _OPDESC_VAR
_OPDESC.fields_by_name['attrs'].message_type = _OPDESC_ATTR
_OPPROTO_VAR.containing_type = _OPPROTO
_OPPROTO_ATTR.fields_by_name['type'].enum_type = _ATTRTYPE
_OPPROTO_ATTR.containing_type = _OPPROTO
_OPPROTO.fields_by_name['inputs'].message_type = _OPPROTO_VAR
_OPPROTO.fields_by_name['outputs'].message_type = _OPPROTO_VAR
_OPPROTO.fields_by_name['attrs'].message_type = _OPPROTO_ATTR
_VARTYPE_TENSORDESC.fields_by_name['data_type'].enum_type = _VARTYPE_TYPE
_VARTYPE_TENSORDESC.containing_type = _VARTYPE
_VARTYPE_LODTENSORDESC.fields_by_name['tensor'].message_type = _VARTYPE_TENSORDESC
_VARTYPE_LODTENSORDESC.containing_type = _VARTYPE
_VARTYPE_LODTENSORARRAYDESC.fields_by_name['tensor'].message_type = _VARTYPE_TENSORDESC
_VARTYPE_LODTENSORARRAYDESC.containing_type = _VARTYPE
_VARTYPE_READERDESC.fields_by_name['lod_tensor'].message_type = _VARTYPE_LODTENSORDESC
_VARTYPE_READERDESC.containing_type = _VARTYPE
_VARTYPE_TUPLE.fields_by_name['element_type'].enum_type = _VARTYPE_TYPE
_VARTYPE_TUPLE.containing_type = _VARTYPE
_VARTYPE.fields_by_name['type'].enum_type = _VARTYPE_TYPE
_VARTYPE.fields_by_name['selected_rows'].message_type = _VARTYPE_TENSORDESC
_VARTYPE.fields_by_name['lod_tensor'].message_type = _VARTYPE_LODTENSORDESC
_VARTYPE.fields_by_name['tensor_array'].message_type = _VARTYPE_LODTENSORARRAYDESC
_VARTYPE.fields_by_name['reader'].message_type = _VARTYPE_READERDESC
_VARTYPE.fields_by_name['tuple'].message_type = _VARTYPE_TUPLE
_VARTYPE_TYPE.containing_type = _VARTYPE
_VARDESC.fields_by_name['type'].message_type = _VARTYPE
_BLOCKDESC.fields_by_name['vars'].message_type = _VARDESC
_BLOCKDESC.fields_by_name['ops'].message_type = _OPDESC
_PROGRAMDESC.fields_by_name['blocks'].message_type = _BLOCKDESC
_PROGRAMDESC.fields_by_name['version'].message_type = _VERSION
DESCRIPTOR.message_types_by_name['Version'] = _VERSION
DESCRIPTOR.message_types_by_name['OpDesc'] = _OPDESC
DESCRIPTOR.message_types_by_name['OpProto'] = _OPPROTO
DESCRIPTOR.message_types_by_name['VarType'] = _VARTYPE
DESCRIPTOR.message_types_by_name['VarDesc'] = _VARDESC
DESCRIPTOR.message_types_by_name['BlockDesc'] = _BLOCKDESC
DESCRIPTOR.message_types_by_name['ProgramDesc'] = _PROGRAMDESC
DESCRIPTOR.enum_types_by_name['AttrType'] = _ATTRTYPE
Version = _reflection.GeneratedProtocolMessageType('Version', (_message.Message,), dict(
DESCRIPTOR = _VERSION,
__module__ = 'framework_pb2'
# @@protoc_insertion_point(class_scope:paddle.framework.proto.Version)
))
_sym_db.RegisterMessage(Version)
OpDesc = _reflection.GeneratedProtocolMessageType('OpDesc', (_message.Message,), dict(
Attr = _reflection.GeneratedProtocolMessageType('Attr', (_message.Message,), dict(
DESCRIPTOR = _OPDESC_ATTR,
__module__ = 'framework_pb2'
# @@protoc_insertion_point(class_scope:paddle.framework.proto.OpDesc.Attr)
))
,
Var = _reflection.GeneratedProtocolMessageType('Var', (_message.Message,), dict(
DESCRIPTOR = _OPDESC_VAR,
__module__ = 'framework_pb2'
# @@protoc_insertion_point(class_scope:paddle.framework.proto.OpDesc.Var)
))
,
DESCRIPTOR = _OPDESC,
__module__ = 'framework_pb2'
# @@protoc_insertion_point(class_scope:paddle.framework.proto.OpDesc)
))
_sym_db.RegisterMessage(OpDesc)
_sym_db.RegisterMessage(OpDesc.Attr)
_sym_db.RegisterMessage(OpDesc.Var)
OpProto = _reflection.GeneratedProtocolMessageType('OpProto', (_message.Message,), dict(
Var = _reflection.GeneratedProtocolMessageType('Var', (_message.Message,), dict(
DESCRIPTOR = _OPPROTO_VAR,
__module__ = 'framework_pb2'
# @@protoc_insertion_point(class_scope:paddle.framework.proto.OpProto.Var)
))
,
Attr = _reflection.GeneratedProtocolMessageType('Attr', (_message.Message,), dict(
DESCRIPTOR = _OPPROTO_ATTR,
__module__ = 'framework_pb2'
# @@protoc_insertion_point(class_scope:paddle.framework.proto.OpProto.Attr)
))
,
DESCRIPTOR = _OPPROTO,
__module__ = 'framework_pb2'
# @@protoc_insertion_point(class_scope:paddle.framework.proto.OpProto)
))
_sym_db.RegisterMessage(OpProto)
_sym_db.RegisterMessage(OpProto.Var)
_sym_db.RegisterMessage(OpProto.Attr)
VarType = _reflection.GeneratedProtocolMessageType('VarType', (_message.Message,), dict(
TensorDesc = _reflection.GeneratedProtocolMessageType('TensorDesc', (_message.Message,), dict(
DESCRIPTOR = _VARTYPE_TENSORDESC,
__module__ = 'framework_pb2'
# @@protoc_insertion_point(class_scope:paddle.framework.proto.VarType.TensorDesc)
))
,
LoDTensorDesc = _reflection.GeneratedProtocolMessageType('LoDTensorDesc', (_message.Message,), dict(
DESCRIPTOR = _VARTYPE_LODTENSORDESC,
__module__ = 'framework_pb2'
# @@protoc_insertion_point(class_scope:paddle.framework.proto.VarType.LoDTensorDesc)
))
,
LoDTensorArrayDesc = _reflection.GeneratedProtocolMessageType('LoDTensorArrayDesc', (_message.Message,), dict(
DESCRIPTOR = _VARTYPE_LODTENSORARRAYDESC,
__module__ = 'framework_pb2'
# @@protoc_insertion_point(class_scope:paddle.framework.proto.VarType.LoDTensorArrayDesc)
))
,
ReaderDesc = _reflection.GeneratedProtocolMessageType('ReaderDesc', (_message.Message,), dict(
DESCRIPTOR = _VARTYPE_READERDESC,
__module__ = 'framework_pb2'
# @@protoc_insertion_point(class_scope:paddle.framework.proto.VarType.ReaderDesc)
))
,
Tuple = _reflection.GeneratedProtocolMessageType('Tuple', (_message.Message,), dict(
DESCRIPTOR = _VARTYPE_TUPLE,
__module__ = 'framework_pb2'
# @@protoc_insertion_point(class_scope:paddle.framework.proto.VarType.Tuple)
))
,
DESCRIPTOR = _VARTYPE,
__module__ = 'framework_pb2'
# @@protoc_insertion_point(class_scope:paddle.framework.proto.VarType)
))
_sym_db.RegisterMessage(VarType)
_sym_db.RegisterMessage(VarType.TensorDesc)
_sym_db.RegisterMessage(VarType.LoDTensorDesc)
_sym_db.RegisterMessage(VarType.LoDTensorArrayDesc)
_sym_db.RegisterMessage(VarType.ReaderDesc)
_sym_db.RegisterMessage(VarType.Tuple)
VarDesc = _reflection.GeneratedProtocolMessageType('VarDesc', (_message.Message,), dict(
DESCRIPTOR = _VARDESC,
__module__ = 'framework_pb2'
# @@protoc_insertion_point(class_scope:paddle.framework.proto.VarDesc)
))
_sym_db.RegisterMessage(VarDesc)
BlockDesc = _reflection.GeneratedProtocolMessageType('BlockDesc', (_message.Message,), dict(
DESCRIPTOR = _BLOCKDESC,
__module__ = 'framework_pb2'
# @@protoc_insertion_point(class_scope:paddle.framework.proto.BlockDesc)
))
_sym_db.RegisterMessage(BlockDesc)
ProgramDesc = _reflection.GeneratedProtocolMessageType('ProgramDesc', (_message.Message,), dict(
DESCRIPTOR = _PROGRAMDESC,
__module__ = 'framework_pb2'
# @@protoc_insertion_point(class_scope:paddle.framework.proto.ProgramDesc)
))
_sym_db.RegisterMessage(ProgramDesc)
DESCRIPTOR.has_options = True
DESCRIPTOR._options = _descriptor._ParseOptions(descriptor_pb2.FileOptions(), _b('H\003'))
# @@protoc_insertion_point(module_scope)
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Sun Feb 24 17:23:09 2019
@author: Macrobull
"""
from __future__ import division
import logging
import numpy as np
import onnx
from collections import OrderedDict as Dict # as default dict
from onnx.helper import get_attribute_value, make_attribute
from onnx.mapping import TENSOR_TYPE_TO_NP_TYPE
from onnx.numpy_helper import to_array
from onnx.shape_inference import infer_shapes
logger = logging.getLogger(__name__)
__all__ = [
'print_pb_structure',
'build_value_refs',
'node_attrs', 'node_topo', 'node_iter',
'tensor_shape',
'graph_ops', 'graph_weights',
'inferred_model_value_info',
'optimize_model_skip_op_for_inference',
'optimize_model_strip_initializer',
'optimize_model_cast', 'optimize_model_slice',
]
ONNX_INT_MAX = 2 ** 63 - 1
DEFAULT_OP_DOMAIN = 'ai.onnx'
def print_pb_structure(message,
loop_iterative=False, depth=0):
"""
print pb fields in its structure
"""
if hasattr(message, 'DESCRIPTOR') and hasattr(message.DESCRIPTOR, 'fields'):
for field in message.DESCRIPTOR.fields:
print('\t' * depth + '-', field.name)
print_pb_structure(getattr(message, field.name),
loop_iterative=loop_iterative, depth=(depth + 1))
if loop_iterative and hasattr(message, 'MergeFrom') and hasattr(message, '__len__'):
for idx, item in enumerate(message):
print('\t' * depth + '-', idx)
print_pb_structure(item,
loop_iterative=loop_iterative, depth=(depth + 1))
def build_value_refs(nodes):
"""
build op reference of inputs and outputs
"""
input_refs = Dict()
output_refs = Dict()
for idx, node in enumerate(nodes):
for val_name in node.input:
input_refs.setdefault(val_name, set()).add(idx)
for val_name in node.output:
output_refs.setdefault(val_name, set()).add(idx)
return input_refs, output_refs
def get_attribute_value2(attr):
"""
get_attribute_value with tensor conversion
"""
if attr.type == onnx.AttributeProto.TENSOR:
dtype = np.dtype(TENSOR_TYPE_TO_NP_TYPE[attr.t.data_type])
data = attr.t.raw_data
value = np.frombuffer(data, dtype=dtype, count=(len(data) // dtype.itemsize))
else:
value = get_attribute_value(attr)
return value
def node_attrs(node):
"""
convert ONNX node attributes to dict
"""
return {attr.name: get_attribute_value2(attr) for attr in node.attribute} # dict
def tensor_shape(tensor):
"""
get ONNX tensor shape
"""
return [dim.dim_value for dim in tensor.type.tensor_type.shape.dim]
def node_topo(nodes, topo='default'):
"""
build indices with given topology to an ONNX node graph
"""
if topo == 'default':
return list(range(len(nodes)))
node_topo = []
node_in_degrees = [len(node.input) for node in nodes]
node_out_degrees = [len(node.output) for node in nodes]
input_refs, output_refs = build_value_refs(nodes)
if topo == 'forward':
for val_name in input_refs:
if val_name not in output_refs:
for node_idx in input_refs[val_name]:
node_in_degrees[node_idx] -= 1
queue = []
for node_idx, degree in enumerate(node_in_degrees):
if degree == 0:
queue.append(node_idx)
while len(queue) > 0:
node_idx = queue.pop(0)
node_topo.append(node_idx)
for val_name in nodes[node_idx].output:
output_refs[val_name].remove(node_idx)
if len(output_refs[val_name]) > 0:
continue
output_refs.pop(val_name)
if val_name not in input_refs:
continue
for next_idx in input_refs[val_name]:
node_in_degrees[next_idx] -= 1
if node_in_degrees[next_idx] == 0:
queue.insert(0, next_idx) # make it lazy
return node_topo
if topo == 'backward':
for val_name in output_refs:
if val_name not in input_refs:
for node_idx in output_refs[val_name]:
node_out_degrees[node_idx] -= 1
queue = []
for node_idx, degree in enumerate(node_out_degrees):
if degree == 0:
queue.append(node_idx)
while len(queue) > 0:
node_idx = queue.pop(0)
node_topo.append(node_idx)
for val_name in nodes[node_idx].input:
input_refs[val_name].remove(node_idx)
if len(input_refs[val_name]) > 0:
continue
input_refs.pop(val_name)
if val_name not in output_refs:
continue
for next_idx in output_refs[val_name]:
node_out_degrees[next_idx] -= 1
if node_out_degrees[next_idx] == 0:
queue.insert(0, next_idx) # make it lazy
return node_topo
raise ValueError('unkown given topo: {}'.format(topo))
def node_iter(nodes,
indices=None):
"""
generator for ONNX node graph with given indices
"""
if indices is None:
indices = range(len(nodes))
for index in indices:
node = nodes[index]
name = node.name
domain = node.domain
op_type = node.op_type
inputs = list(node.input)
outputs = list(node.output)
attrs = node_attrs(node)
if name == '':
name = 'op_' + str(index)
if domain == '':
domain = DEFAULT_OP_DOMAIN
yield name, domain, op_type, inputs, outputs, attrs
def graph_ops(graph,
topo='default'):
"""
generator for ONNX node graph with given topology
"""
if not isinstance(graph, onnx.GraphProto):
logger.error('graph is not a GraphProto instance')
return
return node_iter(graph.node, node_topo(graph.node, topo))
def graph_weights(graph):
"""
generator for weights of an ONNX model
"""
if not isinstance(graph, onnx.GraphProto):
logger.error('graph is not a GraphProto instance')
return
for initializer in graph.initializer:
name = initializer.name
weight = to_array(initializer)
yield name, weight
def inferred_model_value_info(model):
"""
collect value/type info for an ONNX model
"""
model = infer_shapes(model)
graph = model.graph
value_info = Dict()
for item in graph.value_info:
value_info[item.name] = dict(
dtype=TENSOR_TYPE_TO_NP_TYPE[item.type.tensor_type.elem_type],
shape=tensor_shape(item),
external=False,
)
for item in graph.input:
assert item.name not in value_info
value_info[item.name] = dict(
dtype=TENSOR_TYPE_TO_NP_TYPE[item.type.tensor_type.elem_type],
shape=tensor_shape(item),
external=True,
)
for item in graph.output:
# assert item.name not in value_info, 'bypass-model not supported'
value_info[item.name] = dict(
dtype=TENSOR_TYPE_TO_NP_TYPE[item.type.tensor_type.elem_type],
shape=tensor_shape(item),
external=True,
)
return value_info
def skip_node_forward(nodes, src_output_name, dst_input_name, input_refs):
"""
skip nodes between src_output_name -> dst_input_name and connect this pair
"""
processed = 0
for next_idx in input_refs[src_output_name]:
next_node = nodes[next_idx]
for val_idx, next_input_name in enumerate(next_node.input):
if next_input_name == src_output_name:
next_node.input[val_idx] = dst_input_name
processed += 1
return processed
def skip_node_backward(nodes, src_input_name, dst_output_name, output_refs):
"""
skip nodes between dst_output_name -> src_input_name and connect this pair
"""
processed = 0
for prev_idx in output_refs[src_input_name]:
prev_node = nodes[prev_idx]
for val_idx, prev_output_name in enumerate(prev_node.output):
if prev_output_name == src_input_name:
prev_node.output[val_idx] = dst_output_name
processed += 1
return processed
def optimize_model_skip_op_for_inference(
model,
op_list=None):
"""
skip ops can be bypassed for inference
"""
if op_list is None:
op_list = ['Dropout']
nodes = model.graph.node
input_refs, output_refs = build_value_refs(nodes)
ret = type(model)()
ret.CopyFrom(model)
ret.graph.ClearField('value_info') # WORKAROUND: onnx do not drop old value_info
ret_nodes = ret.graph.node
nodes_to_remove = []
for node_idx, node in enumerate(nodes):
if not(node.domain == DEFAULT_OP_DOMAIN or node.domain == ''):
continue
op_type = node.op_type
if not(op_type in op_list):
continue
if op_type in ['Dropout']:
input_name = node.input[0]
output_name = node.output[0]
elif not(len(node.input) == 1 and len(node.output) == 1):
logger.warning('currently only 1-input-1-output op supported, skip required %d: %s',
node_idx, node.op_type)
continue
else:
input_name = node.input[0]
output_name = node.output[0]
if output_name in input_refs:
processed = skip_node_forward(ret_nodes, output_name, input_name, input_refs)
elif input_name in output_refs:
processed = skip_node_backward(ret_nodes, input_name, output_name, output_refs)
else:
processed = -1
if processed > 0:
nodes_to_remove.append(node_idx)
logger.debug('skip op %d: %s -> %s -> %s',
node_idx, input_name, node.op_type, output_name)
elif processed == 0:
logger.warning('weird, no node processed')
else:
logger.warning('standalone op %d: %s -> %s -> %s not skipped',
node_idx, input_name, node.op_type, output_name)
nodes_to_remove.sort(reverse=True)
for node_idx in nodes_to_remove:
ret_nodes.pop(node_idx)
return ret
def optimize_model_strip_initializer(model,
keep_input_only=True):
"""
strip weights for inference
"""
nodes = model.graph.node
input_refs, output_refs = build_value_refs(nodes)
out_names = [val.name for val in model.graph.output]
ret = type(model)()
ret.CopyFrom(model)
ret.graph.ClearField('value_info') # WORKAROUND: onnx do not drop old value_info
# strip initializers
ret.graph.ClearField('initializer')
ret_initializers = ret.graph.initializer
for initializer in model.graph.initializer:
name = initializer.name
if name in input_refs:
ret_initializers.add().CopyFrom(initializer)
elif not keep_input_only and name in output_refs:
ret_initializers.add().CopyFrom(initializer)
else:
logger.debug('initializer %s(%s[%d]) stripped',
name,
TENSOR_TYPE_TO_NP_TYPE[initializer.data_type],
len(initializer.raw_data))
# strip inputs
ret.graph.ClearField('input')
ret_inputs = ret.graph.input
for item in model.graph.input:
name = item.name
if name in input_refs or name in out_names:
ret_inputs.add().CopyFrom(item)
else:
logger.debug('input %s(%s%s) stripped',
name,
TENSOR_TYPE_TO_NP_TYPE[item.type.tensor_type.elem_type],
tensor_shape(item))
return ret
def optimize_model_cast(model):
"""
strip cascade and unecessary onnx::Cast
"""
nodes = model.graph.node
input_refs, output_refs = build_value_refs(nodes)
value_info = inferred_model_value_info(model)
ret = type(model)()
ret.CopyFrom(model)
ret.graph.ClearField('value_info') # WORKAROUND: onnx do not drop old value_info
ret_nodes = ret.graph.node
nodes_to_remove = []
for node_idx, node in enumerate(nodes):
if not(node.domain == DEFAULT_OP_DOMAIN or node.domain == ''):
continue
if not(node.op_type == 'Cast'):
continue
attrs = node_attrs(node)
output_dtype = TENSOR_TYPE_TO_NP_TYPE[attrs['to']]
input_name = node.input[0]
info = value_info.get('input_name', None) # relax for un-inferrable
if info is None:
continue
input_dtype = info.get('dtype', None)
if input_dtype is None or input_dtype != output_dtype:
continue
output_name = node.output[0]
if output_name in input_refs:
processed = skip_node_forward(ret_nodes, output_name, input_name, input_refs)
elif input_name in output_refs:
processed = skip_node_backward(ret_nodes, input_name, output_name, output_refs)
else:
processed = -1
if processed > 0:
nodes_to_remove.append(node_idx)
logger.debug('skip %s: %s -> %s Cast op',
node.name, input_dtype, output_dtype)
elif processed == 0:
logger.warning('weird, no node processed')
else:
logger.debug('keep standalone %s: %s -> %s Cast op',
node.name, input_dtype, output_dtype)
nodes_to_remove.sort(reverse=True)
for node_idx in nodes_to_remove:
ret_nodes.pop(node_idx)
return ret
def optimize_model_slice(model):
"""
strip cascade and unecessary onnx::Slice
"""
nodes = model.graph.node
input_refs, output_refs = build_value_refs(nodes)
def _build_slice_node_chain(node_idx):
chain = []
while True:
node = nodes[node_idx]
if not(node.domain == DEFAULT_OP_DOMAIN or node.domain == ''):
return chain
if not node.op_type == 'Slice':
return chain
chain.append(node_idx)
output_name = node.output[0]
if output_name not in input_refs or len(input_refs[output_name]) != 1:
return chain
node_idx = list(input_refs[output_name])[0]
# axis: (start, end)
def _merge_slice(slice_chain):
merged_slice = dict()
for slice_node_idx in slice_chain:
node = nodes[slice_node_idx]
attrs = node_attrs(node)
for axis, start, end in zip(attrs['axes'], attrs['starts'], attrs['ends']):
if start == 0 and end == ONNX_INT_MAX:
continue
if axis in merged_slice:
prev_start, prev_end = merged_slice[axis]
start += prev_start if start >= 0 else 0 if prev_end == ONNX_INT_MAX else prev_end
end += prev_start if end >= 0 else 0 if prev_end == ONNX_INT_MAX else prev_end
merged_slice[axis] = (start, end)
return merged_slice
ret = type(model)()
ret.CopyFrom(model)
ret.graph.ClearField('value_info') # WORKAROUND: onnx do not drop old value_info
ret_nodes = ret.graph.node
nodes_to_remove = []
for node_idx in range(len(nodes)):
slice_chain = _build_slice_node_chain(node_idx)
if len(slice_chain) == 0:
continue
merged_slice = _merge_slice(slice_chain)
if len(merged_slice) > 0 and len(slice_chain) == 1: # no need to merge
continue
attrs = dict(axes=[], starts=[], ends=[])
for axis, (start, end) in merged_slice.items():
attrs['axes'].append(axis)
attrs['starts'].append(start)
attrs['ends'].append(end)
first_node = nodes[slice_chain[0]]
last_node = nodes[slice_chain[-1]]
input_name = first_node.input[0]
output_name = last_node.output[0]
processed = -1
if output_name in input_refs: # 0, [1...]
new_input_name = first_node.output[0] if len(merged_slice) > 0 else input_name
processed = skip_node_forward(ret_nodes, output_name, new_input_name, input_refs)
if processed > 0:
if len(merged_slice) > 0:
remain_idx = slice_chain[0]
remove_chain = slice_chain[1:]
slice_node = ret_nodes[remain_idx]
for attr in slice_node.attribute:
attr.CopyFrom(make_attribute(attr.name, attrs[attr.name]))
logger.debug('merged slice chain %s -> %s%s -> %s',
input_name, remain_idx, remove_chain, output_name)
else:
remove_chain = slice_chain
if processed < 0 and input_name in output_refs:
new_output_name = last_node.input[0] if len(merged_slice) > 0 else output_name
processed = skip_node_backward(ret_nodes, input_name, new_output_name, output_refs)
if processed > 0:
if len(merged_slice) > 0:
remain_idx = slice_chain[-1]
remove_chain = slice_chain[:-1]
slice_node = ret_nodes[remain_idx]
for attr in slice_node.attribute:
attr.CopyFrom(make_attribute(attr.name, attrs[attr.name]))
logger.debug('merged slice chain %s -> %s%s -> %s',
input_name, remove_chain, remain_idx, output_name)
else:
remove_chain = slice_chain
if processed > 0:
nodes_to_remove.extend(remove_chain)
if len(merged_slice) == 0:
logger.debug('skip slice chain %s -> %s -> %s',
input_name, slice_chain, output_name)
elif processed < 0: # NEVERFIX: not merge standalone slice chain
logger.debug('keep standalone slice chain %s -> %s -> %s',
input_name, slice_chain, output_name)
nodes_to_remove.sort(reverse=True)
for node_idx in nodes_to_remove:
ret_nodes.pop(node_idx)
return ret
if __name__ == '__main__':
logging.basicConfig(
format='[%(levelname)8s]%(name)s::%(funcName)s:%(lineno)04d: %(message)s',
level=logging.DEBUG,
)
from onnx.checker import check_model
from onnx.utils import polish_model
from onnx.version_converter import convert_version
model = onnx.load('../examples/t1.onnx')
print_pb_structure(model, loop_iterative=False)
check_model(model)
model = convert_version(model, 9)
model = optimize_model_skip_op_for_inference(model)
model = optimize_model_strip_initializer(model)
model = optimize_model_cast(model)
model = optimize_model_slice(model)
model = polish_model(model)
onnx.save(model, '/tmp/optimized.onnx')
graph = model.graph
value_info = inferred_model_value_info(model)
name = graph.name
inputs = [value.name for value in graph.input]
outputs = [value.name for value in graph.output]
weights = []
logger.info('ops:')
for name, domain, op_type, _, _, attrs in graph_ops(graph, topo='forward'):
logger.info('%s %s::%s: %s', name, domain, op_type, attrs)
logger.info('weights:')
for name, array in graph_weights(graph):
weights.append(name)
logger.info('%s: %s', name, array.shape)
logger.info('inputs:')
external_inputs = []
for name in inputs:
if name not in weights:
external_inputs.append(name)
logger.info('%s: %s', name, value_info[name]['shape'])
logger.info('outputs:')
external_outputs = []
for name in outputs:
if name not in weights:
external_outputs.append(name)
logger.info('%s: %s', name, value_info[name]['shape'])
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
ONNX to Paddle symbolic translation
Created on Mon Feb 25 09:33:43 2019
@author: Macrobull
"""
from __future__ import division
import logging as _logging
import numpy as np
from collections import OrderedDict as _dict
from onnx.mapping import TENSOR_TYPE_TO_NP_TYPE
_logger = _logging.getLogger(__name__)
ONNX_INT_MAX = 2 ** 63 - 1
DEFAULT_ONNX_OP_DOMAIN = ''
DEFAULT_PADDLE_OP_NAMESCOPE = '/'
DEFAULT_OP_MAPPING_FIELD_VALUES = _dict()
DEFAULT_OP_MAPPING_FIELD_VALUES['PADDLE_OP'] = ''
DEFAULT_OP_MAPPING_FIELD_VALUES['PADDLE_INPUT_ARGS'] = None
DEFAULT_OP_MAPPING_FIELD_VALUES['PADDLE_OUTPUT_ARGS'] = None
DEFAULT_OP_MAPPING_FIELD_VALUES['ATTR_MAPPING'] = dict() # dict(onnx_attr_from=paddle_attr_to)
DEFAULT_OP_MAPPING_FIELD_VALUES['DEFAULTS'] = dict() # dict(paddle_attr=default)
DEFAULT_OP_MAPPING_FIELD_VALUES['INPUT_PERM'] = None # sampler: [idx_onnx_arg...]
DEFAULT_OP_MAPPING_FIELD_VALUES['OUTPUT_PERM'] = None # sampler: [idx_onnx_arg...]
DEFAULT_OP_MAPPING_FIELD_VALUES['FILL_NAME_FIELD'] = True
DEFAULT_OP_MAPPING = {
## nil ops ##
'RandomUniform':
['uniform_random', [], ['Out'], dict(high='max', low='min'),
dict(), None, None, False],
'RandomNormal':
['gaussian_random', [], ['Out'], dict(scale='std'),
dict(), None, None, False],
## unary ops ##
'Abs': ['abs', ['X'], ['Out']],
'ArgMax': ['argmax', ['X'], ['Out'], dict(keepdims='')],
'ArgMin': ['argmin', ['X'], ['Out'], dict(keepdims='')],
'Ceil': ['ceil', ['X'], ['Out']],
'Clip': ['clip', ['X'], ['Out']], # attrs bypassed
'Cos': ['cos', ['X'], ['Out']],
'Elu': ['elu', ['X'], ['Out']],
'Exp': ['exp', ['X'], ['Out']],
'Flatten': ['flatten', ['X'], ['Out']], # attrs bypassed, FIXME: emit flatten2
'Floor': ['floor', ['X'], ['Out']],
'Gather': ['gather', ['X'], ['Out'], dict(axis='')],
'LeakyRelu': ['leaky_relu', ['X'], ['Out']],
'Log': ['log', ['X'], ['Out']],
'LRN': ['lrn', ['X'], ['Out', 'MidOut'], dict(size='n', bias='k')], #
'Reciprocal': ['reciprocal', ['X'], ['Out']],
'Relu': ['relu', ['X'], ['Out']],
'Selu': ['selu', ['X'], ['Out'], dict(gamma='scale')],
'Shape': ['shape', ['X'], ['Out']], # FIXME: out is int64 - int32
'Shrink': ['softshrink', ['X'], ['Out'], dict(bias='', labmd='')],
'Sigmoid': ['sigmoid', ['X'], ['Out']],
'Sin': ['sin', ['X'], ['Out']],
'Squeeze': ['squeeze', ['X'], ['Out']], # attrs bypassed, FIXME: emit squeeze2
'Softplus': ['softplus', ['X'], ['Out']],
'Softmax': ['softmax', ['X'], ['Out'], dict(axis='')],
'Softsign': ['softsign', ['X'], ['Out']],
'Sqrt': ['sqrt', ['X'], ['Out']],
'Tanh': ['tanh', ['X'], ['Out']],
'ThresholdedRelu': ['thresholded_relu', ['X'], ['Out'], dict(alpha='threshold')],
'Transpose': ['transpose', ['X'], ['Out']], # FIXME: emit transpose2
'Unsqueeze': ['unsqueeze', ['X'], ['Out']], # attrs bypassed, FIXME: emit unsqueeze2
## binary ops ##
# FIXME: axis=-1 in Paddle is broken, refer it in specialization
'Add': ['elementwise_add', ['X', 'Y'], ['Out'], dict(), dict(axis=0)],
# 'AffineGrid': ['affine_grid', ['Theta'], ['Output'], dict(size='out_shape')],
'And': ['logical_and', ['X', 'Y'], ['Out']],
'Div': ['elementwise_div', ['X', 'Y'], ['Out'], dict(), dict(axis=0)],
'Equal': ['equal', ['X', 'Y'], ['Out'], dict(), dict(), None, None, False],
'Greater': ['less_than', ['X', 'Y'], ['Out'], dict(), dict(), None, None, False],
'Less': ['less_than', ['X', 'Y'], ['Out'], dict(), dict(), None, None, False],
'MatMul': ['matmul', ['X', 'Y'], ['Out']], # defaults excluded for transpose_x - transpose_X
'Max': ['elementwise_max', ['X', 'Y'], ['Out'], dict(), dict(axis=0)],
'Min': ['elementwise_min', ['X', 'Y'], ['Out'], dict(), dict(axis=0)],
'Mul': ['elementwise_mul', ['X', 'Y'], ['Out'], dict(), dict(axis=0)],
'Not': ['logical_not', ['X', 'Y'], ['Out']],
'OneHot': # assuming values=[0, 1], axis=-1 and drop them
['one_hot', ['Input', 'Depth'], ['Out'], dict(axis=''), dict(),
[0, 1], None, False],
'Or': ['logical_or', ['X', 'Y'], ['Out']],
'Pow': ['elementwise_pow', ['X', 'Y'], ['Out'], dict(), dict(axis=0)], # TODO: pow for scalar exponent
'Sub': ['elementwise_sub', ['X', 'Y'], ['Out'], dict(), dict(axis=0)],
'Xor': ['logical_xor', ['X', 'Y'], ['Out']],
# reduce ops
'ReduceMax': ['reduce_max', ['X'], ['Out'], dict(axes='dim', keepdims='keep_dim')],
'ReduceMean': ['reduce_mean', ['X'], ['Out'], dict(axes='dim', keepdims='keep_dim')],
'ReduceMin': ['reduce_min', ['X'], ['Out'], dict(axes='dim', keepdims='keep_dim')],
'ReduceProd': ['reduce_prod', ['X'], ['Out'], dict(axes='dim', keepdims='keep_dim')],
'ReduceSum': ['reduce_sum', ['X'], ['Out'], dict(axes='dim', keepdims='keep_dim')],
# other ops
'Scatter': ['scatter', ['X', 'Index', 'Updates'], ['Out']],
'TopK': ['topk', ['X', 'K'], ['Out', 'Indices']],
}
DEFAULT_IOA_CONSTRAINT = {
'ArgMax':
[(lambda i, o, a: a.get('keepdims', 1) == 1, 'only keepdims = 0 is supported'),
],
'ArgMin':
[(lambda i, o, a: a.get('keepdims', 1) == 1, 'only keepdims = 0 is supported'),
],
'Gather':
[(lambda i, o, a: a.get('axis', 0) == 0, 'only axis = 0 is supported'),
],
'Shrink':
[(lambda i, o, a: a.get('bias', 0) == a.get('lambd', 0.5), 'only SoftShrink with bias = lambd is supported'),
],
# 'Softmax':
# [(lambda i, o, a: a.get('axis', 1) == -2, 'Paddle Softmax works on dim -2 only'),
# ],
'OneHot':
[(lambda i, o, a: a.get('axis', -1) == -1, 'only axis = -1 is supported'),
],
'Scatter':
[(lambda i, o, a: a.get('axis', 0) == 0, 'only axis = 0 is supported'),
],
'TopK':
[(lambda i, o, a: a.get('axis', -1) == -1, 'only axis = -1 is supported'),
],
}
def _make_var_name(name):
"""
make a valid variable name in Python code
"""
if name == '':
return '_'
if name[0].isdigit():
return 'var_' + name
for s in ' *?\/-:':
name = name.replace(s, '_')
if name.startswith('_'):
name = 'var' + name
return name
#def _value_info_or_none(value_infos, val_name):
# return value_infos.get(val_name, None)
def _dtype(value_infos, val_name):
return np.dtype(value_infos[val_name]['dtype'])
def _dtype_or_none(value_infos, val_name):
if val_name not in value_infos:
return None
value_info = value_infos[val_name]
if 'dtype' not in value_info:
return None
return np.dtype(value_info['dtype'])
def _shape(value_infos, val_name):
return list(value_infos[val_name]['shape'])
def _shape_or_none(value_infos, val_name):
if val_name not in value_infos:
return None
value_info = value_infos[val_name]
if 'shape' not in value_info:
return None
return list(value_info['shape'])
#def _maybe_const_value(value_infos, val_name):
# var_name = _make_var_name(val_name)
# if val_name not in value_infos:
# return var_name
# value_info = value_infos[val_name]
# assert value_info.get('remove_batch', False) == False, 'const value should not have batch dim'
# return value_info.get('const_value', var_name)
def _default(prog, op_type, inputs, outputs, attrs,
*args,
name='',
**kwargs):
info = DEFAULT_OP_MAPPING[op_type]
info.extend(list(DEFAULT_OP_MAPPING_FIELD_VALUES.values())[len(info):])
(paddle_op,
paddle_input_args, paddle_output_args,
attr_mapping, default_attrs,
input_perm, output_perm,
fill_name_field,
) = info
if paddle_op in DEFAULT_IOA_CONSTRAINT:
for predicate, message in DEFAULT_IOA_CONSTRAINT[paddle_op]:
assert predicate(inputs, outputs, attrs), message
# bypass if key absent, drop if mapped key is '' or '_'
mapped_attrs = {attr_mapping.get(key, key): value for key, value in attrs.items()}
if '' in mapped_attrs:
mapped_attrs.pop('')
if '_' in mapped_attrs:
mapped_attrs.pop('_')
paddle_attrs = default_attrs.copy()
paddle_attrs.update(mapped_attrs) # as new attrs
val_inps = inputs if input_perm is None else map(lambda i: inputs[i], input_perm)
val_outs = outputs if output_perm is None else map(lambda i: outputs[i], output_perm)
var_inps = [_make_var_name(val) for val in val_inps]
var_outs = [_make_var_name(val) for val in val_outs]
arg_name = ', name={}'.format(repr(name)) if fill_name_field and name else ''
arg_attrs = [', {}={}'.format(key, value) for key, value in paddle_attrs.items()]
prog.Code('{} = layers.{}({}{}{})'
.format(', '.join(var_outs),
paddle_op,
', '.join(var_inps),
''.join(arg_attrs),
arg_name,
))
for val_out, var_out in zip(val_outs, var_outs):
prog.VarDesc(var_out)
# dummy var_out
num_vars = len(var_outs)
num_args = len(paddle_output_args)
if num_vars < num_args:
assert fill_name_field, 'name required to naming dummy output variable'
for idx_out in range(num_vars, num_args):
var_out = _make_var_name(name + '.' + paddle_output_args[idx_out].lower())
var_outs.append(var_out)
prog.VarDesc(var_out)
prog.OpDesc(paddle_op,
(var_inps, *paddle_input_args),
(var_outs, *paddle_output_args),
paddle_attrs)
def _assign(prog, attrs):
mapping = attrs['mapping'] # additional
paddle_op = 'assign'
for val_dst, val_src in mapping.items():
var_dst = _make_var_name(val_dst)
var_src = _make_var_name(val_src)
prog.Code('{} = {}'.format(var_dst, var_src))
# prog.Code('{} = layers.{}({})'
# .format(var_dst,
# paddle_op,
# var_src,
# ))
prog.VarDesc(var_dst)
prog.OpDesc(paddle_op,
([var_src], 'X'),
([var_dst], 'Out'),
dict(),
)
def _pad_if_asymmetric(prog, pads, val_name, value_infos): # pads: SSEE
assert len(pads) & 1 == 0
ndims = len(pads) // 2
symmetric = True
for idx_dim in range(ndims):
if pads[idx_dim] != pads[ndims + idx_dim]:
symmetric = False
break
if symmetric:
return pads[:ndims], None
val_padded = val_name + '_padded'
prog.Op('', 'Pad',
[val_name],
[val_padded], # val
dict(mode='constant',
value=0.,
pads=pads,
),
value_infos=value_infos,
name=val_padded,
)
return [0] * ndims, val_padded
def _adaptive_pool(prog, pool_type, inputs, outputs, attrs, value_infos,
name=''):
# I/O
val_x, = inputs
val_y, = outputs[:1]
var_x = _make_var_name(val_x)
var_y = _make_var_name(val_y)
has_indices = len(outputs) > 1
if has_indices:
val_indices = outputs[1]
var_indices = _make_var_name(val_indices)
# interpretation
pool_size = attrs['output_size'] # required
output_shape = _shape_or_none(value_infos, val_y)
if output_shape is not None:
assert pool_size == output_shape[2:], 'pool_size unmatches shape of Y' # NC...
poolnd = len(pool_size)
assert 2 <= poolnd <= 3, 'only pool2d and pool3d supported'
paddle_op = 'adaptive_pool{}d'.format(poolnd)
name_attr = ', name={}'.format(repr(name)) if name else ''
# generation
prog.Code('{}{} = layers.{}({}'
', require_index={}'
', pool_size={}'
', pool_type={}'
'{})'
.format(var_y, ', {}'.format(var_indices) if has_indices else '',
paddle_op,
var_x,
# attrs
has_indices,
pool_size,
repr(pool_type),
name_attr,
))
paddle_op = 'pool{}d'.format(poolnd)
prog.VarDesc(var_y)
if has_indices:
prog.VarDesc(var_indices)
prog.OpDesc(paddle_op,
([var_x], 'X'),
([var_y] + ([var_indices] if has_indices else []), 'Out', 'Indices'),
dict(global_pooling=False,
adaptive=True,
exclusive=True,
require_index=has_indices,
pooling_type=pool_type,
ksize=pool_size,
),
)
def _global_pool(prog, pool_type, inputs, outputs, attrs, value_infos,
name=''):
# I/O
val_x, = inputs
val_y, = outputs
var_x = _make_var_name(val_x)
var_y = _make_var_name(val_y)
# interpretation
input_shape = _shape_or_none(value_infos, val_x)
output_shape = _shape_or_none(value_infos, val_y)
assert input_shape is not None or output_shape is not None, 'poolnd not inferred' # NC...
if input_shape:
poolnd = len(input_shape) - 2 # NC...
elif output_shape:
poolnd = len(output_shape) - 2 # NC...
assert 2 <= poolnd <= 3, 'only pool2d and pool3d supported'
paddle_op = 'pool{}d'.format(poolnd)
name_attr = ', name={}'.format(repr(name)) if name else ''
# generation
prog.Code('{} = layers.{}({}, global_pooling=True'
', pool_type={}'
'{})'
.format(var_y,
paddle_op,
var_x,
# attrs
repr(pool_type),
name_attr,
))
prog.VarDesc(var_y)
prog.OpDesc(paddle_op,
([var_x], 'X'),
([var_y], 'Out'),
dict(global_pooling=True,
adaptive=False,
pooling_type=pool_type,
),
)
def _pool(prog, pool_type, inputs, outputs, attrs, value_infos,
name=''):
# I/O
val_x, = inputs
val_y, = outputs[:1]
var_y = _make_var_name(val_y)
has_indices = len(outputs) > 1
if has_indices:
val_indices = outputs[1]
var_indices = _make_var_name(val_indices)
# interpretation
assert attrs.get('auto_pad', 'NOTSET') == 'NOTSET', 'only auto_pad = NOTSET supported' # optional
pool_size = attrs['kernel_shape'] # required
poolnd = len(pool_size)
assert 2 <= poolnd <= 3, 'only pool2d and pool3d supported'
paddle_op = 'pool{}d'.format(poolnd)
strides = attrs.get('strides', [1] * poolnd) # optional
pads = attrs.get('pads', [0] * len(pool_size * 2)) # optional
paddings, val_x_padded = _pad_if_asymmetric(prog, pads, val_x, value_infos)
if val_x_padded:
val_x = val_x_padded
ceil_mode = bool(attrs.get('ceil_mode', 0)) # optional
var_x = _make_var_name(val_x)
name_attr = ', name={}'.format(repr(name)) if name else ''
# generation
prog.Code('{}{} = layers.{}({}, exclusive=True'
', pool_size={}'
', pool_type={}'
', pool_stride={}'
', pool_padding={}'
', ceil_mode={}'
'{})'
.format(var_y, ', {}'.format(var_indices) if has_indices else '',
paddle_op,
var_x,
# attrs
pool_size,
repr(pool_type),
strides,
paddings,
ceil_mode,
name_attr,
))
prog.VarDesc(var_y)
if has_indices:
prog.VarDesc(var_indices)
prog.OpDesc(paddle_op,
([var_x], 'X'),
([var_y] + ([var_indices] if has_indices else []), 'Out', 'Indices'),
dict(global_pooling=False,
adaptive=False,
exclusive=True,
require_index=has_indices,
pooling_type=pool_type,
ksize=pool_size,
strides=strides,
pool_padding=paddings,
ceil_mode=ceil_mode,
),
)
def _roi_pool(prog, paddle_op, inputs, outputs, attrs, value_infos, name):
# I/O
val_x, val_rois = inputs
val_y, = outputs
var_x = _make_var_name(val_x)
var_rois = _make_var_name(val_rois)
var_y = _make_var_name(val_y)
# interpretation
spatial_scale=attrs['spatial_scale'] # required
pooled_height, pooled_width = attrs['pooled_shape'] # required
od_attrs = dict(
spatial_scale=spatial_scale,
pooled_height=pooled_height,
pooled_width=pooled_width,
)
feature_attr = ''
is_max_pool = paddle_op == 'roi_pool'
if 'sampling_ratio' in attrs:
sampling_ratio = attrs['sampling_ratio']
od_attrs['sampling_ratio'] = sampling_ratio
feature_attr += ', sampling_ratio={}'.format(sampling_ratio)
if 'output_channels' in attrs:
output_channels = attrs['output_channels']
od_attrs['output_channels'] = output_channels
feature_attr += ', output_channels={}'.format(output_channels)
# generation
prog.Code('{} = layers.{}({} {}'
', spatial_scale={}'
', pooled_height={}'
', pooled_width={}'
'{})'
.format(var_y,
paddle_op,
val_x, var_rois,
# attrs
spatial_scale,
pooled_height,
pooled_width,
feature_attr,
))
prog.VarDesc(var_y)
if is_max_pool:
var_argmax = _make_var_name(name + '.argmax') # implicit variable
prog.VarDesc(var_argmax)
prog.OpDesc(paddle_op,
([var_x, var_rois], 'X', 'Rois'),
([var_y] + ([var_argmax] if is_max_pool else []), 'Out', 'Argmax'),
od_attrs,
)
def _zeros_like(prog, val_ref, val_out, value_infos):
prog.Op('', 'Sub',
[val_ref, val_ref],
[val_out], # val
dict(axis=0),
value_infos,
)
def AdaptiveAveragePool(
prog, inputs, outputs, attrs, value_infos,
name='',
*args, **kwargs):
"""
aten::adaptive_avg_poolnd
"""
return _adaptive_pool(prog, 'avg', inputs, outputs, attrs, value_infos,
name=name)
def AdaptiveMaxPool(
prog, inputs, outputs, attrs, value_infos,
name='',
*args, **kwargs):
"""
aten::adaptive_max_poolnd
"""
return _adaptive_pool(prog, 'max', inputs, outputs, attrs, value_infos,
name=name)
def AveragePool(
prog, inputs, outputs, attrs, value_infos,
name='',
*args, **kwargs):
"""
onnx::AveragePool-10:
"""
return _pool(prog, 'avg', inputs, outputs, attrs, value_infos,
name=name)
def AffineGrid(
prog, inputs, outputs, attrs,
*args,
name='',
**kwargs):
"""
aten::affine_grid
"""
# I/O
val_theta, = inputs
val_grid, = outputs
var_theta = _make_var_name(val_theta)
var_grid = _make_var_name(val_grid)
# interpretation
paddle_op = 'affine_grid'
size = attrs['size'] # required
name_attr = ', name={}'.format(repr(name)) if name else ''
# generation
prog.Code('{} = layers.{}({}'
', out_shape={}'
'{})'
.format(var_grid,
paddle_op,
var_theta,
# attrs
size,
name_attr,
))
prog.VarDesc(var_grid)
prog.OpDesc(paddle_op,
([var_theta], 'Theta'),
([var_grid], 'Output'),
dict(output_shape=size), # f**k you API
)
def BatchNormalization(
prog, inputs, outputs, attrs, value_infos,
name='', embed_params=False,
*args, **kwargs):
"""
onnx::BatchNormalization-9:
"""
# I/O
val_x, val_scale, val_b, val_mean, val_var = inputs
val_y, = outputs
var_x = _make_var_name(val_x)
var_y = _make_var_name(val_y)
# interpretation
paddle_op = 'batch_norm'
momentum = attrs.get('momentum', .9) # optional
epsilon = attrs.get('epsilon', 1e-5) # optional
name_attr = ', name={}'.format(repr(name)) if name else ''
if embed_params:
assert name != ''
var_scale = '{}.w_0'.format(name)
var_b = '{}.b_0'.format(name)
var_mean = '{}.w_1'.format(name)
var_var = '{}.w_2'.format(name)
value_infos[val_scale].setdefault('embeded_as', []).append(var_scale)
value_infos[val_b].setdefault('embeded_as', []).append(var_b)
value_infos[val_mean].setdefault('embeded_as', []).append(var_mean)
value_infos[val_var].setdefault('embeded_as', []).append(var_var)
param_attr = ''
else:
var_scale = _make_var_name(val_scale)
var_b = _make_var_name(val_b)
var_mean = _make_var_name(val_mean)
var_var = _make_var_name(val_var)
param_attr = (', param_attr={}, bias_attr={}'
', moving_mean_name={}, moving_variance_name={}'
).format(repr(var_scale), repr(var_b), repr(var_mean), repr(var_var))
var_saved_mean = '{}.saved_mean'.format(name) # dropped var
var_saved_variance = '{}.saved_variance'.format(name) # dropped var
# generationvalue_infos
prog.Code('{} = layers.{}({}, is_test=True, data_layout="NCHW"'
', momentum={}'
', epsilon={}'
'{}{})'
.format(var_y,
paddle_op,
var_x,
# attrs
momentum,
epsilon,
param_attr, name_attr,
))
prog.VarDesc(var_y)
prog.VarDesc(var_saved_mean)
prog.VarDesc(var_saved_variance)
prog.OpDesc(paddle_op,
([var_x, var_scale, var_b, var_mean, var_var],
'X', 'Scale', 'Bias', 'Mean', 'Variance'),
([var_y, var_mean, var_saved_mean, var_saved_variance, var_var],
'Y', 'MeanOut', 'SavedMean', 'SavedVariance', 'VarianceOut'),
dict(is_test=1,
data_layout='NCHW',
use_global_stats=False,
momentum=momentum,
epsilon=epsilon),
)
def Cast(
prog, inputs, outputs, attrs, value_infos,
*args, **kwargs):
"""
onnx::Cast-9:
"""
# I/O
val_input, = inputs
val_output, = outputs
var_input = _make_var_name(val_input)
var_output = _make_var_name(val_output)
# interpretation
dtype = attrs['to']
if not isinstance(dtype, np.dtype):
dtype = TENSOR_TYPE_TO_NP_TYPE[dtype] # required
output_dtype = _dtype_or_none(value_infos, val_output)
if output_dtype:
assert dtype == output_dtype, 'dtype of to unmatches output'
paddle_op = 'cast'
# generation
prog.Code('{} = layers.{}({}'
', dtype={}'
')'
.format(var_output,
paddle_op,
var_input,
# attrs
repr(dtype.name),
))
prog.VarDesc(var_output)
prog.OpDesc(paddle_op,
([var_input], 'X'),
([var_output], 'Out'),
dict(in_dtype=prog.Dtype(_dtype(value_infos, val_input)), # holy, required
out_dtype=prog.Dtype(dtype),
)
)
def Concat(
prog, inputs, outputs, attrs,
*args,
name='',
**kwargs):
"""
onnx::Concat-4:
"""
# I/O
val_concat_result, = outputs
var_inps = [_make_var_name(val) for val in inputs]
var_concat_result = _make_var_name(val_concat_result)
# interpretation
paddle_op = 'concat'
axis = attrs['axis'] # required
name_attr = ', name={}'.format(repr(name)) if name else ''
# generation
prog.Code('{} = layers.{}({}'
', axis={}'
'{})'
.format(var_concat_result,
paddle_op,
'[' + ', '.join(var_inps) + ']',
# attrs
axis,
name_attr,
))
prog.VarDesc(var_concat_result)
prog.OpDesc(paddle_op,
(var_inps, *(['X'] * len(var_inps))),
([var_concat_result], 'Out'),
dict(axis=axis),
)
def Constant(
prog, inputs, outputs, attrs, value_infos,
*args, **kwargs):
"""
onnx::Constant-9:
"""
# I/O
assert len(inputs) == 0
val_output, = outputs
var_output = _make_var_name(val_output)
# interpretation
value = attrs['value'] # required
dtype = np.dtype(value.dtype)
output_dtype = _dtype_or_none(value_infos, val_output)
if output_dtype:
assert dtype == output_dtype, 'tensor dtype unmatches storage dtype'
# dtype = np.dtype('float32') # force to float32
shape = attrs.get('shape', None) # additional, maybe var_name
if shape is None:
shape = _shape_or_none(value_infos, val_output)
if shape is None:
shape = list(value.shape)
_logger.warning('shape of %s not inferred, using value as 1-D tensor may lead to fails', val_output)
# generation
if value.size == 1: # scalar
paddle_op = 'fill_constant'
prog.Code('{} = layers.{}(shape={}, dtype={}, value={})'
.format(var_output,
paddle_op,
# attrs
shape, repr(dtype.name), value[0], # shape can be list or var_name
))
value_infos[val_output]['const_value'] = value[0]
prog.VarDesc(var_output)
else: # list parameter -> const_value
prog.Code('{} = {}'
.format(var_output,
value.tolist(),
))
value_infos[val_output]['const_value'] = value.tolist()
def ConstantOfShape(
prog, inputs, outputs, attrs, value_infos,
*args, **kwargs):
"""
onnx::ConstantOfShape-9:
"""
# I/O
val_input, = inputs
is_const_shape = 'const_value' in value_infos[val_input]
if is_const_shape:
shape = _make_var_name(val_input)
else:
shape = value_infos[val_input]['get_weight']()
dtype = attrs['value'].dtype
attrs = attrs.copy()
attrs.update(dict(shape=shape, dtype=dtype)) # pass var_name
Constant(prog, [], outputs, attrs, value_infos)
def Conv(
prog, inputs, outputs, attrs, value_infos,
name='', embed_params=False,
*args, **kwargs):
"""
onnx::ConstantOfShape-1:
"""
# I/O
val_x, val_w = inputs[:2]
val_y, = outputs
var_y = _make_var_name(val_y)
has_bias = len(inputs) == 3
if has_bias:
val_b, = inputs[2:]
# interpretation
assert attrs.get('auto_pad', 'NOTSET') == 'NOTSET', 'only auto_pad == NOTSET supported' # optional
kernel_shape = _shape(value_infos, val_w)[2:] # OI...
assert kernel_shape == attrs['kernel_shape'], 'kernel_shape in attr unmatches value_info' # HW
convnd = len(kernel_shape)
assert 2 <= convnd <= 3, 'only conv2d and conv3d supported'
num_out_channels = _shape(value_infos, val_w)[0] # OI...
paddle_op = 'conv{}d'.format(convnd)
strides = attrs.get('strides', [1] * convnd) # optional
pads = attrs.get('pads', [0] * convnd * 2) # optional
paddings, val_x_padded = _pad_if_asymmetric(prog, pads, val_x, value_infos)
if val_x_padded:
val_x = val_x_padded
dilations = attrs.get('dilations', [1] * convnd) # optional
num_groups = attrs.get('group', 1) # optional
var_x = _make_var_name(val_x)
name_attr = ', name={}'.format(repr(name)) if name else ''
if embed_params:
assert name != ''
var_w = '{}.w_0'.format(name)
value_infos[val_w].setdefault('embeded_as', []).append(var_w)
if has_bias:
var_b = '{}.b_0'.format(name)
value_infos[val_b].setdefault('embeded_as', []).append(var_b)
param_attr = ''
else:
param_attr = ', bias_attr=False'
else:
var_w = _make_var_name(val_w)
var_b = _make_var_name(val_b) if has_bias else False
param_attr = ', param_attr={}, bias_attr={}'.format(
repr(var_w), repr(var_b) if var_b else False)
# generation
prog.Code('{} = layers.{}({}'
', num_filters={}'
', filter_size={}'
', stride={}'
', padding={}'
', dilation={}'
', groups={}'
'{}{})'
.format(var_y,
paddle_op,
var_x,
# attrs
num_out_channels,
kernel_shape,
strides,
paddings,
dilations,
num_groups,
param_attr, name_attr,
))
var_conv = _make_var_name(name + '.conv') # hidden variable
prog.OpDesc(paddle_op,
([var_x, var_w], 'Input', 'Filter'), # , 'Bias', 'ResidualData'
([var_conv if has_bias else var_y], 'Output'),
dict(strides=strides,
paddings=paddings,
dilations=dilations,
groups=num_groups,
))
if has_bias:
prog.VarDesc(var_conv)
prog.IntermediateOp(
'', 'Add',
[var_conv, var_b],
[var_y], # var
dict(axis=1),
value_infos=value_infos,
name=(name + '.bias'),
)
else:
prog.VarDesc(var_y)
def ConvTranspose(
prog, inputs, outputs, attrs, value_infos,
name='', embed_params=False,
*args, **kwargs):
"""
onnx::ConvTranspose-1:
"""
# I/O
val_x, val_w = inputs[:2]
val_y, = outputs
var_y = _make_var_name(val_y)
has_bias = len(inputs) == 3
if has_bias:
val_b, = inputs[2:]
# interpretation
assert attrs.get('auto_pad', 'NOTSET') == 'NOTSET', 'only auto_pad == NOTSET supported' # optional
assert sum(attrs.get('output_padding', [])) == 0, 'only zero output_padding supported' # optional ?
kernel_shape = _shape(value_infos, val_w)[2:] # IO...
assert kernel_shape == attrs['kernel_shape'], 'kernel_shape in attr unmatches value_info' # HW
convnd = len(kernel_shape)
assert 2 <= convnd <= 3, 'only conv2d_transpose and conv3d_transpose supported'
num_out_channels = _shape(value_infos, val_w)[1] # IO...
paddle_op = 'conv{}d_transpose'.format(convnd)
strides = attrs.get('strides', [1] * convnd) # optional
pads = attrs.get('pads', [0] * convnd * 2) # optional
paddings, val_x_padded = _pad_if_asymmetric(prog, pads, val_x, value_infos)
if val_x_padded:
val_x = val_x_padded
dilations = attrs.get('dilations', [1] * convnd) # optional
num_groups = attrs.get('group', 1) # optional
var_x = _make_var_name(val_x)
name_attr = ', name={}'.format(repr(name)) if name else ''
if embed_params:
assert name != ''
var_w = '{}.w_0'.format(name)
value_infos[val_w].setdefault('embeded_as', []).append(var_w)
if has_bias:
var_b = '{}.b_0'.format(name)
value_infos[val_b].setdefault('embeded_as', []).append(var_b)
param_attr = ''
else:
param_attr = ', bias_attr=False'
else:
var_w = _make_var_name(val_w)
var_b = _make_var_name(val_b) if has_bias else False
param_attr = ', param_attr={}, bias_attr={}'.format(
repr(var_w), repr(var_b) if var_b else False)
# generation
prog.Code('{} = layers.{}({}'
', num_filters={}'
# ', output_size={}'
', filter_size={}'
', padding={}'
', stride={}'
', dilation={}'
', groups={}'
'{}{})'
.format(var_y,
paddle_op,
var_x,
# attrs
num_out_channels,
kernel_shape,
paddings,
strides,
dilations,
num_groups,
param_attr, name_attr,
))
var_conv = _make_var_name(name + '.conv') # hidden variable
prog.OpDesc(paddle_op,
([var_x, var_w], 'Input', 'Filter'), # , 'Bias', 'ResidualData'
([var_conv if has_bias else var_y], 'Output'),
dict(strides=strides,
paddings=paddings,
dilations=dilations,
# output_size=output_size,
groups=num_groups,
))
if has_bias:
prog.VarDesc(var_conv)
prog.IntermediateOp(
'', 'Add',
[var_conv, var_b],
[var_y], # var
dict(axis=1),
value_infos=value_infos,
name=(name + '.bias'),
)
else:
prog.VarDesc(var_y)
# should not appears
#def Dropout(
# prog, inputs, outputs, value_infos,
# *args, **kwargs):
# """
# onnx::Dropout-7:9
# """
#
# val_data, = inputs
# val_output, = outputs[:1]
#
# _assign(prog,
# dict(mapping=dict([(val_output, val_data)])),
# value_infos,
# )
def Gemm(
prog, inputs, outputs, attrs, value_infos, name,
*args, **kwargs):
"""
onnx::Gemm-9:
"""
# due to paddle fc don't support transposed weight, we use matmul + ew_add
val_a, val_b, val_c = inputs
val_y, = outputs
alpha = attrs.get('alpha', 1.) # optional
beta = attrs.get('beta', 1.) # optional
trans_a = bool(attrs.get('transA', 0)) # optional
trans_b = bool(attrs.get('transB', 0)) # optional
val_mm = name + '_mm' # explicit variable
prog.Op('', 'MatMul',
[val_a, val_b],
[val_mm], # val
dict(transpose_x=trans_a,
transpose_y=trans_b,
alpha=alpha,
),
value_infos=value_infos,
name=val_mm,
)
prog.op_descs[-1].attrs.extend(prog.OpDescAttrs(dict(
transpose_X=trans_a,
transpose_Y=trans_b,
))) # f**k you API
if beta != 0:
if beta == 1.: # exactly
prog.Op('', 'Add',
[val_mm, val_c],
[val_y], # val
dict(axis=1),
value_infos=value_infos,
name=(name + '_beta'),
)
else:
val_beta = name + '_beta' # explicit variable
val_vm = name + '_vm' # explicit variable
vm_dtype = _dtype_or_none(value_infos, val_c)
if vm_dtype is None:
vm_dtype = np.dtype('float32')
beta = np.dtype(vm_dtype).type(beta)
prog.Op('', 'Constant',
[],
[val_beta], # val
dict(value=beta),
value_infos=value_infos,
name=val_beta,
)
prog.Op('', 'Mul',
[val_c, val_beta],
[val_vm], # val
dict(),
value_infos=value_infos,
name=(name + '_scale'),
)
prog.Op('', 'Add',
[val_mm, val_vm],
[val_y], # val
dict(axis=1),
name=(name + '_bias'),
)
def GlobalAveragePool(
prog, inputs, outputs, attrs, value_infos,
name='',
*args, **kwargs):
"""
onnx::GlobalAveragePool-1:
"""
return _global_pool(prog, 'avg', inputs, outputs, attrs, value_infos,
name=name)
def GlobalMaxPool(
prog, inputs, outputs, attrs, value_infos,
name='',
*args, **kwargs):
"""
onnx::GlobalMaxPool-1:
"""
return _global_pool(prog, 'max', inputs, outputs, attrs, value_infos,
name=name)
#def LRN(
# prog, inputs, outputs, attrs, value_infos, name, # name required
# *args, **kwargs):
# """
# onnx::LRN-1:
# """
#
# # I/O
# val_x, = inputs
# val_y, = outputs
# var_x = _make_var_name(val_x)
# var_y = _make_var_name(val_y)
#
# # interpretation
# paddle_op = 'lrn'
# size = attrs['size'] # required
# alpha = attrs.get('alpha', 0.0001) # optional
# beta = attrs.get('beta', 0.75) # optional
# bias = attrs.get('bias', 1.0) # optional
# name_attr = ', name={}'.format(repr(name)) if name else ''
#
# # generation
# prog.Code('{} = layers.{}({}'
# ', n={}'
# ', k={}'
# ', alpha={}'
# ', beta={}'
# '{})'
# .format(var_y,
# paddle_op,
# var_x,
# # attrs
# size,
# bias,
# alpha,
# beta,
# name_attr,
# ))
# var_mid = name + '.mid' # hidden variable
# prog.VarDesc(var_y)
# prog.VarDesc(var_mid)
# prog.OpDesc(paddle_op,
# ([var_x], 'X'),
# ([var_y, var_mid], 'Out', 'MidOut'),
# dict(n=size,
# k=bias,
# alpha=alpha,
# beta=beta,
# ),
# )
def MaxPool(
prog, inputs, outputs, attrs, value_infos,
name='',
*args, **kwargs):
"""
onnx::MaxPool-10:
"""
return _pool(prog, 'max', inputs, outputs, attrs, value_infos,
name=name)
def MaxRoiPool(
prog, inputs, outputs, attrs, value_infos, name,
*args, **kwargs):
"""
onnx::MaxRoiPool-1:
"""
_roi_pool(prog, 'roi_pool', inputs, outputs, attrs, value_infos, name)
def RoiAlign(
prog, inputs, outputs, attrs, value_infos, name,
*args, **kwargs):
"""
caffe2::RoiAlign
"""
_roi_pool(prog, 'roi_align', inputs, outputs, attrs, value_infos, name)
def Pad(
prog, inputs, outputs, attrs, value_infos,
name='',
*args, **kwargs):
"""
onnx::Pad-2:
"""
# I/O
val_data, = inputs
val_output, = outputs
var_data = _make_var_name(val_data)
var_output = _make_var_name(val_output)
# interpretation
pads = attrs['pads'] # required
mode = attrs.get('mode', 'constant') # optional
value = attrs.get('value', 0.) # optional
data_shape = _shape_or_none(value_infos, val_data)
output_shape = _shape_or_none(value_infos, val_output)
assume_pad2d = False
if len(pads) == 4:
assume_pad2d |= mode != 'constant'
if data_shape:
assume_pad2d |= data_shape and len(data_shape) == 4 # NCHW
if output_shape:
assume_pad2d |= output_shape and len(output_shape) == 4 # NCHW
od_attrs = dict(pad_value=value)
if assume_pad2d:
paddle_op = 'pad2d'
pad2d_attr = ', mode={}, data_format="NCHW"'.format(repr(mode))
od_attrs['mode'] = mode
else:
assert mode == 'constant', 'mode {} is supported only in pad2d'.format(mode)
paddle_op = 'pad'
pad2d_attr = ''
paddings = np.array(pads).reshape((-1, 2)).transpose().flatten().tolist() # SSEE -> SESE
od_attrs['paddings'] = paddings
name_attr = ', name={}'.format(repr(name)) if name else ''
# generation
prog.Code('{} = layers.{}({}'
', paddings={}'
', pad_value={}'
'{}{})'
.format(var_output,
paddle_op,
var_data,
# attrs
paddings,
value,
pad2d_attr, name_attr,
))
prog.VarDesc(var_output)
prog.OpDesc(paddle_op,
([var_data], 'X'),
([var_output], 'Out'),
od_attrs,
)
def PRelu(
prog, inputs, outputs, attrs, value_infos,
name='', embed_params=False,
*args, **kwargs):
"""
onnx::PRelu-9:
"""
# I/O
val_x, val_slope = inputs
val_y, = outputs
var_x = _make_var_name(val_x)
var_y = _make_var_name(val_y)
# interpretation
paddle_op = 'prelu'
name_attr = ', name={}'.format(repr(name)) if name else ''
if embed_params:
assert name != ''
var_slope = '{}.w_0'.format(val_slope)
value_infos[val_slope].setdefault('embeded_as', []).append(var_slope)
param_attr = ''
else:
var_slope = _make_var_name(val_slope)
param_attr = ', param_attr={}'.format(repr(var_slope))
# generation
prog.Code('{} = layers.{}({}, mode="all"'
'{}{})'
.format(var_y,
paddle_op,
var_x,
# attrs
param_attr, name_attr,
))
prog.VarDesc(var_y)
prog.OpDesc(paddle_op,
([var_x], 'X'),
([var_y], 'Out'),
dict(mode='all'),
)
def PsRoiPool(
prog, inputs, outputs, attrs, value_infos, name,
*args, **kwargs):
"""
caffe2::PsRoiPool
"""
_roi_pool(prog, 'psroi_pool', inputs, outputs, attrs, value_infos, name)
def Reshape(
prog, inputs, outputs, attrs, value_infos, name,
*args, **kwargs):
"""
onnx::Reshape-5:
"""
# I/O
val_data, val_shape = inputs
val_reshaped, = outputs
var_data = _make_var_name(val_data)
var_reshaped = _make_var_name(val_reshaped)
# interpretation
paddle_op = 'reshape'
is_const_shape = 'const_value' in value_infos[val_shape]
var_shape = _make_var_name(val_shape) # for code
if is_const_shape:
shape = value_infos[val_shape]['const_value'] # for desc
else:
shape = value_infos[val_shape]['get_weight']() # for desc
name_attr = ', name={}'.format(repr(name)) if name else ''
# generation
if is_const_shape:
prog.Code('{} = layers.{}({}'
', shape={}'
'{})'
.format(var_reshaped,
paddle_op,
var_data,
# attrs
var_shape,
name_attr,
))
else:
var_shape_int32 = var_shape + '_int32'
prog.Op('', 'Cast',
[var_shape],
[var_shape_int32], # var
dict(to=np.dtype('int32')),
value_infos=value_infos,
name=(name + '_cast'),
)
prog.Code('{} = layers.{}({}'
', shape={}'
', actual_shape={}'
'{})'
.format(var_reshaped,
paddle_op,
var_data,
# attrs
shape,
var_shape_int32,
name_attr,
))
paddle_op = 'reshape2'
var_xshape = _make_var_name(name + '.xshape')
prog.VarDesc(var_reshaped)
prog.VarDesc(var_xshape)
if is_const_shape:
prog.OpDesc(paddle_op,
([var_data], 'X'),
([var_reshaped, var_xshape], 'Out', 'XShape'),
dict(shape=shape),
)
else:
prog.OpDesc(paddle_op,
([var_data, var_shape_int32], 'X', 'Shape'),
([var_reshaped, var_xshape], 'Out', 'XShape'),
dict(shape=shape),
)
def Slice(
prog, inputs, outputs, attrs, value_infos,
*args, **kwargs):
"""
onnx::Slice-1:9
"""
# I/O
val_data, = inputs
val_output, = outputs
var_data = _make_var_name(val_data)
var_output = _make_var_name(val_output)
# interpretation
paddle_op = 'slice'
axes = attrs['axes'] # required
starts = attrs['starts'] # required
ends = attrs['ends'] # required
shape = _shape_or_none(value_infos, val_data)
if shape:
ndims = len(shape)
for idx, value in enumerate(axes):
if value > ONNX_INT_MAX // 2:
axes[idx] = ndims + value - ONNX_INT_MAX - 1
# HINT: Paddle 1.3 Doc: '对于未知大小维度的末尾进行切片,则建议传入 INT_MAX' not works ?
for idx, value in enumerate(starts):
if value > ONNX_INT_MAX // 2:
value = value - ONNX_INT_MAX - 1
starts[idx] = shape[axes[idx]] + value
for idx, value in enumerate(ends):
if value > ONNX_INT_MAX // 2:
value = value - ONNX_INT_MAX - 1
ends[idx] = shape[axes[idx]] + value
# generation
prog.Code('{} = layers.{}({}'
', axes={}'
', starts={}'
', ends={}'
')'
.format(var_output,
paddle_op,
var_data,
# attrs
axes,
starts,
ends,
))
prog.VarDesc(var_output)
prog.OpDesc(paddle_op,
([var_data], 'X'),
([var_output], 'Out'),
dict(axes=axes,
starts=starts,
ends=ends,
),
)
def Sum(
prog, inputs, outputs,
*args, **kwargs):
"""
onnx::Sum-8:
"""
# I/O
val_sum, = outputs
var_inps = [_make_var_name(val) for val in inputs]
var_sum = _make_var_name(val_sum)
# interpretation
paddle_op = 'sums'
# generation
prog.Code('{} = layers.{}({})'
.format(var_sum,
paddle_op,
'[' + ', '.join(var_inps) + ']',
# attrs
))
prog.VarDesc(var_sum)
prog.OpDesc(paddle_op,
(var_inps, *(['X'] * len(var_inps))),
([var_sum], 'Out'),
dict(),
)
def Tile(
prog, inputs, outputs, attrs, value_infos,
name='',
*args, **kwargs):
"""
onnx::ConstantOfShape-6:
"""
# I/O
val_input, val_repeats = inputs
val_output, = outputs
var_input = _make_var_name(val_input)
var_output = _make_var_name(val_output)
# interpretation
paddle_op = 'expand'
is_const_repeats = 'const_value' in value_infos[val_repeats]
if is_const_repeats:
code_repeats = _make_var_name(val_repeats) # for code
repeats = value_infos[val_repeats]['const_value'] # for desc
else:
repeats = value_infos[val_input]['get_weight']() # for desc
code_repeats = repeats # for code
name_attr = ', name={}'.format(repr(name)) if name else ''
# generation
prog.Code('{} = layers.{}({}'
', expand_times={}'
'{})'
.format(var_output,
paddle_op,
var_input,
# attrs
code_repeats,
name_attr,
))
prog.VarDesc(var_output)
prog.OpDesc(paddle_op,
([var_input], 'X'),
([var_output], 'Out'),
dict(expand_times=repeats),
)
#def Shape(
# prog, inputs, outputs, attrs, value_infos,
# *args, **kwargs):
# """
# onnx::ConstantOfShape-1:
# """
#
# # I/O
# val_data, = inputs
# val_shape, = outputs
# var_data = _make_var_name(val_data)
# var_shape = _make_var_name(val_shape)
#
# # interpretation
# paddle_op = 'shape'
## value_infos[val_shape]['remove_batch'] = False
#
# # generation
# prog.Code('{} = layers.{}({})'
# .format(var_shape,
# paddle_op,
# var_data,
# # attrs
# ))
# prog.VarDesc(var_shape) # , _value_info_or_none(value_infos, val_shape))
# prog.OpDesc(paddle_op,
# ([var_data], 'X'),
# ([var_shape], 'Out'),
# dict(),
# )
def Split(
prog, inputs, outputs, attrs,
*args,
name='',
**kwargs):
"""
onnx::Split-2:
"""
# I/O
val_input, = inputs
var_outs = [_make_var_name(val) for val in outputs]
var_input = _make_var_name(val_input)
# interpretation
paddle_op = 'split'
split = attrs['split'] # required
axis = attrs.get('axis', 0) # optional
name_attr = ', name={}'.format(repr(name)) if name else ''
# generation
prog.Code('{} = layers.{}({}, {}'
', dim={}'
'{})'
.format(', '.join(var_outs),
paddle_op,
var_input,
split,
# attrs
axis,
name_attr,
))
for val_out, var_out in zip(outputs, var_outs):
prog.VarDesc(var_out)
prog.OpDesc(paddle_op,
(var_input, 'X'),
([var_outs], *(['Out'] * len(var_outs))),
dict(axis=axis,
sections=split,
),
)
if __name__ == '__main__':
_logging.basicConfig(
format='[%(levelname)8s]%(name)s::%(funcName)s:%(lineno)04d: %(message)s',
level=_logging.DEBUG,
)
logger = _logging.getLogger('symbolic_test')
from writer import Program
prog = Program()
AdaptiveAveragePool(prog, ['X'], ['Y'],
dict(output_size=[3, 3]),
dict(Y=dict(shape=(2, 3, 3, 3), dtype=np.float32)),
name='AdaptiveAveragePool2d',
)
logger.info('AdaptiveAveragePool2d program:\n%s', prog)
prog = Program()
AdaptiveAveragePool(prog, ['X'], ['Y'],
dict(output_size=[3, 3, 3]),
dict(Y=dict(shape=(2, 3, 3, 3, 3), dtype=np.float32)),
name='AdaptiveAveragePool3d',
)
logger.info('AdaptiveAveragePool3d program:\n%s', prog)
prog = Program()
AffineGrid(prog, ['Theta'], ['Grid'],
dict(size=[2, 2, 8, 8]),
dict(Grid=dict(shape=(2, 8, 8, 2), dtype=np.float32)),
)
logger.info('AffineGrid program:\n%s', prog)
prog = Program()
BatchNormalization(prog, ['X', 'scale', 'B', 'mean', 'var'], ['Y'],
dict(epsilon=1e-5,
momentum=.9,
),
dict(scale=dict(shape=(3, ), dtype=np.float32),
B=dict(shape=(3, ), dtype=np.float32),
mean=dict(shape=(3, ), dtype=np.float32),
var=dict(shape=(3, ), dtype=np.float32),
Y=dict(shape=(2, 3), dtype=np.float32),
),
name='BatchNormalization',
embed_params=True,
)
logger.info('BatchNormalization program:\n%s', prog)
prog = Program()
Cast(prog, ['input'], ['output'],
dict(to=2), # TensorProto.UINT8
dict(input=dict(shape=(2, 3), dtype=np.float32),
output=dict(shape=(2, 3), dtype=np.uint8)),
)
logger.info('Cast program:\n%s', prog)
prog = Program()
_default(prog, 'Clip', ['input'], ['output'],
dict(min=-1., max=1.),
dict(output=dict(shape=(2, 3), dtype=np.float32)),
)
logger.info('Clip program:\n%s', prog)
prog = Program()
Conv(prog, ['X', 'W'], ['Y'],
dict(auto_pad='NOTSET',
dilations=[1, 1],
group=1,
kernel_shape=[3, 3],
pads=[1, 1, 1, 1],
strides=[1, 1],
),
dict(W=dict(shape=(2, 3, 3, 3), dtype=np.float32),
Y=dict(shape=(2, 2, 4, 6), dtype=np.float32),
),
name='ConvNoBias2d',
embed_params=True,
)
logger.info('ConvNoBias2d program:\n%s', prog)
prog = Program()
Conv(prog, ['X', 'W', 'B'], ['Y'],
dict(auto_pad='NOTSET',
dilations=[1, 1],
group=1,
kernel_shape=[3, 3],
pads=[1, 1, 1, 1],
strides=[1, 1],
),
dict(W=dict(shape=(2, 3, 3, 3), dtype=np.float32),
B=dict(shape=(2), dtype=np.float32),
Y=dict(shape=(2, 2, 4, 6), dtype=np.float32),
),
name='Conv2d',
embed_params=True,
)
logger.info('Conv2d program:\n%s', prog)
prog = Program()
ConvTranspose(prog, ['X', 'W', 'B'], ['Y'],
dict(auto_pad='NOTSET',
dilations=[1, 1],
group=1,
kernel_shape=[3, 3],
# output_padding=[1, 1, 1, 1],
# output_shape=[6, 8],
pads=[1, 1, 1, 1],
strides=[1, 1],
),
dict(W=dict(shape=(2, 3, 3, 3), dtype=np.float32),
B=dict(shape=(2), dtype=np.float32),
Y=dict(shape=(2, 2, 6, 8), dtype=np.float32),
),
name='ConvTransposed2d',
embed_params=True,
)
logger.info('ConvTransposed2d program:\n%s', prog)
prog = Program()
Conv(prog, ['X', 'W'], ['Y'],
dict(auto_pad='NOTSET',
dilations=[1, 1, 1],
group=1,
kernel_shape=[3, 3, 3],
pads=[1, 1, 1, 1, 1, 1],
strides=[1, 1, 1],
),
dict(W=dict(shape=(2, 3, 3, 3, 3), dtype=np.float32),
Y=dict(shape=(2, 2, 4, 6, 8), dtype=np.float32),
),
name='ConvNoBias3d',
embed_params=True,
)
logger.info('ConvNoBias3d program:\n%s', prog)
prog = Program()
Conv(prog, ['X', 'W', 'B'], ['Y'],
dict(auto_pad='NOTSET',
dilations=[1, 1, 1],
group=1,
kernel_shape=[3, 3, 3],
pads=[1, 1, 1, 1, 1, 1],
strides=[1, 1, 1],
),
dict(W=dict(shape=(2, 3, 3, 3, 3), dtype=np.float32),
B=dict(shape=(2), dtype=np.float32),
Y=dict(shape=(2, 2, 4, 6, 8), dtype=np.float32),
),
name='Conv3d',
embed_params=True,
)
logger.info('Conv3d program:\n%s', prog)
prog = Program()
ConvTranspose(prog, ['X', 'W', 'B'], ['Y'],
dict(auto_pad='NOTSET',
dilations=[1, 1, 1],
group=1,
kernel_shape=[3, 3, 3],
# output_padding=[1, 1, 1, 1],
# output_shape=[6, 8],
pads=[1, 1, 1, 1, 1, 1],
strides=[1, 1, 1],
),
dict(W=dict(shape=(2, 3, 3, 3, 3), dtype=np.float32),
B=dict(shape=(2), dtype=np.float32),
Y=dict(shape=(2, 2, 6, 8, 9), dtype=np.float32),
),
name='ConvTransposed3d',
embed_params=True,
)
logger.info('ConvTransposed3d program:\n%s', prog)
prog = Program()
_default(prog, 'Equal', ['A', 'B'], ['C'],
dict(),
dict(C=dict(shape=(2, 3), dtype=np.bool)),
)
logger.info('Equal program:\n%s', prog)
prog = Program()
Gemm(prog, ['A', 'B', 'C'], ['Y'],
dict(alpha=1.,
beta=1.,
transA=0,
transB=1,
),
dict(B=dict(shape=(8, 3), dtype=np.float32),
Y=dict(shape=(2, 8), dtype=np.float32),
),
name='Gemm',
)
logger.info('Gemm program:\n%s', prog)
prog = Program()
_default(prog, 'Less', ['A', 'B'], ['C'],
dict(),
dict(C=dict(shape=(2, 3), dtype=np.bool)),
)
logger.info('Less program:\n%s', prog)
prog = Program()
_default(prog, 'MatMul', ['A', 'B'], ['Y'],
dict(),
dict(Y=dict(shape=(2, 8), dtype=np.float32)),
name='MatMul'
)
logger.info('MatMul program:\n%s', prog)
prog = Program()
_default(prog, 'OneHot', ['indices', 'depth', 'values'], ['output'],
dict(axis=-1),
dict(output=dict(shape=(2, 8), dtype=np.float32)),
)
logger.info('OneHot program:\n%s', prog)
prog = Program()
Pad(prog, ['data'], ['output'],
dict(mode='constant',
pads=[0, 1],
value=0.,
),
dict(data=dict(shape=(2, 7), dtype=np.float32),
output=dict(shape=(2, 8), dtype=np.float32),
),
name='Pad',
)
logger.info('Pad program:\n%s', prog)
prog = Program()
Pad(prog, ['data'], ['output'],
dict(mode='reflect',
pads=[0, 1, 2, 3],
value=0.,
),
dict(data=dict(shape=(2, 3, 3, 3), dtype=np.float32),
output=dict(shape=(2, 3, 5, 7), dtype=np.float32),
),
name='Pad2d',
)
logger.info('Pad2d program:\n%s', prog)
prog = Program()
PRelu(prog, ['X', 'slope'], ['Y'],
dict(),
dict(Y=dict(shape=(2, 3), dtype=np.float32)),
name='PRelu',
)
logger.info('PRelu program:\n%s', prog)
prog = Program()
Tile(prog, ['input', 'repeats'], ['output'],
dict(),
dict(repeats=dict(const_value=[1, 2]),
output=dict(shape=(2, 2, 4), dtype=np.float32)
),
name='Tile'
)
logger.info('Tile program:\n%s', prog)
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Fri Mar 22 11:22:46 2019
@author: Macrobull
"""
import numpy as np
import torch
from collections import OrderedDict as Dict
def _ensure_list(obj):
if isinstance(obj, (list, set, tuple)):
return list(obj)
return [obj]
def _ensure_tuple(obj):
if isinstance(obj, (list, set, tuple)):
return tuple(obj)
return (obj, )
def _flatten_list(obj,
out=None):
assert isinstance(obj, list)
if out is None:
out = type(obj)()
for item in obj:
if isinstance(item, list):
_flatten_list(item, out)
else:
out.append(item)
return out
def export_data(state_dict,
prefix=''):
"""
export binary data with meta text for raw C++ inference engines
"""
def _str(obj):
if isinstance(obj, (tuple, list)):
return str(obj)[1:-1].replace(' ', '')
return str(obj)
prefix_ = prefix + ('_' if prefix else '')
fp = open('{}.txt'.format(prefix if prefix else 'meta'), 'w')
for key, value in state_dict.items():
data = None
if torch and torch.is_tensor(value):
data = value.data.cpu().numpy()
elif np and isinstance(value, np.ndarray):
data = value
if data is not None:
data.tofile('{}{}.bin'.format(prefix_, key))
fp.write('{}.dtype={}\n'.format(key, _str(data.dtype.name)))
fp.write('{}.shape={}\n'.format(key, _str(data.shape)))
else:
fp.write('{}={}\n'.format(key, _str(value)))
fp.close()
def export_onnx_with_validation(model, inputs, export_basepath,
input_names=None, output_names=None,
use_npz=True,
*args, **kwargs):
"""
export PyTorch model to ONNX model and export sample inputs and outputs in a Numpy file
"""
is_list_or_tuple = lambda x: isinstance(x, (list, tuple))
def _tensors_to_arrays(tensors):
if torch.is_tensor(tensors):
return tensors.data.cpu().numpy()
arrays = []
for tensor in tensors:
arrays.append(_tensors_to_arrays(tensor))
return arrays
def _zip_dict(keys, values):
ret = Dict()
for idx, (key, value) in enumerate(zip(keys, values)):
is_key_list = is_list_or_tuple(key)
is_value_list = is_list_or_tuple(value)
assert is_key_list == is_value_list, 'keys and values mismatch'
if is_value_list:
ret[str(idx)] = _zip_dict(key, value)
else:
ret[key] = value
return ret
torch_inputs = _ensure_tuple(inputs) # WORKAROUND: for torch.onnx
outputs = torch.onnx.export(model, torch_inputs, export_basepath + '.onnx',
input_names=_flatten_list(input_names),
output_names=_flatten_list(output_names),
*args, **kwargs)
if outputs is None: # WORKAROUND: for torch.onnx
outputs = model(*inputs)
torch_outputs = _ensure_tuple(outputs)
inputs = _zip_dict(input_names, _tensors_to_arrays(torch_inputs))
outputs = _zip_dict(output_names, _tensors_to_arrays(torch_outputs))
if use_npz:
np.savez(export_basepath + '.npz', inputs=inputs, outputs=outputs)
else:
np.save(export_basepath + '.npy',
np.array(Dict(inputs=inputs, outputs=outputs)))
return torch_outputs
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Fri Mar 22 12:17:19 2019
@author: Macrobull
"""
# import importlib, logging, os, sys
import importlib
import logging
import os
import sys
def _flatten_dict(obj,
out=None):
assert isinstance(obj, dict)
if out is None:
out = type(obj)()
for key, value in obj.items():
if isinstance(value, dict):
_flatten_dict(value, out)
else:
assert key not in out
out[key] = value
return out
def _ensure_list(obj):
for cls in [list, set, tuple]:
if isinstance(obj, cls):
return list(obj)
return [obj]
def validate(paddle_model_filename, golden_data_filename,
model_func_name='inference',
precision=1e-4,
save_inference_model=False):
"""
inferece the converted Paddle model, validate with given golden data
"""
import numpy as np
import paddle.fluid as fluid
logger = logging.getLogger('validate')
place = fluid.CPUPlace()
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
# load model
paddle_model_dir, basename = os.path.split(paddle_model_filename)
if basename == '__model__': # is desc model
logger.debug('using desc file %s', basename)
prog, in_names, var_outs = fluid.io.load_inference_model(paddle_model_dir, exe)
out_names = var_outs # HINT: pass var if fetch ops already created
logger.info('model load passed')
elif basename.endswith('.py'): # is python code
logger.debug('using python code file %s', basename)
module_name, _ = os.path.splitext(basename)
sys_path = sys.path.copy()
sys.path.append(paddle_model_dir)
try:
module = importlib.import_module(module_name)
func = getattr(module, model_func_name)
except AttributeError:
module_name = module_name + '.' + module_name
module = importlib.import_module(module_name)
func = getattr(module, model_func_name)
sys.path = sys_path
logger.debug('from %s imported %s: %s', module_name, model_func_name, func)
var_outs = func()
var_outs = _ensure_list(var_outs)
out_names = [var.name for var in var_outs] # HINT: pass string to create fetch ops
logger.info('import passed')
prog = fluid.default_main_program()
fluid.io.load_persistables(executor=exe, dirname=paddle_model_dir, main_program=prog)
logger.info('weight load passed')
else:
raise ValueError('unsupported Paddle model')
# load data
logger.info('using golden data %s', golden_data_filename)
if golden_data_filename.endswith('.npz'):
test_data = np.load(golden_data_filename)
input_data = test_data['inputs'].tolist()
output_data = test_data['outputs'].tolist()
else:
test_data = np.load(golden_data_filename).tolist()
input_data = input_data['inputs']
output_data = output_data['outputs']
input_data = _flatten_dict(input_data)
output_data = _flatten_dict(output_data)
logger.info('found %d I/O golden data, starting test ...', len(test_data))
# DEBUG: reload test for python code
if basename.endswith('.py') and save_inference_model:
fluid.io.save_inference_model(paddle_model_dir, input_data.keys(), var_outs, exe,
main_program=prog, export_for_deployment=True)
logger.info('model re-save passed')
fluid.io.load_inference_model(paddle_model_dir, exe)
logger.info('model re-load passed')
# execute
outputs = exe.run(prog, feed=input_data, fetch_list=out_names)
logger.info('execution passed')
# validate
passed = True
for (name, truth), output in zip(output_data.items(), outputs):
logger.info('testing output {} ...'.format(name))
try:
np.testing.assert_almost_equal(output, truth, decimal=precision)
except AssertionError as e:
passed = False
logger.error('failed: %s\n', e)
if passed:
logger.info('accuracy passed')
else:
logger.info('accuracy not passed')
# globals().update(locals())
return passed
if __name__ == '__main__':
logging.basicConfig(
format='[%(levelname)8s]%(name)s::%(funcName)s:%(lineno)04d: %(message)s',
level=logging.DEBUG,
)
logger = logging.getLogger('validation_test')
model_rc_list = [
'../examples/t{}/model.py',
'../examples/t{}/__model__',
'../examples/t{}.embeded/model.py',
'../examples/t{}.embeded/__model__',
]
import numpy as np
idx_model = np.random.randint(1, 7)
model = np.random.choice(model_rc_list).format(idx_model)
precision = 10 ** (np.random.rand() * -4 - 2)
debug = False
model = '/tmp/export/model.py'
# model = '../examples/t1/__model__'
# model = '../examples/t1.embeded/model.py'
# model = '../examples/t1.embeded/__model__'
debug = True
logger.info('args: %s %.6f', model, precision)
data_dir, dir_name = os.path.split(os.path.split(model)[0])
data_pathname = os.path.splitext(dir_name)[0]
# proto debug test
from framework_pb2 import ProgramDesc
pd = ProgramDesc()
pd.ParseFromString(open(os.path.join(data_dir, dir_name, '__model__'), 'rb').read())
# validate
# validate(model, os.path.join(data_dir, data_pathname + '.npz'),
# precision=precision, save_inference_model=debug)
validate(model, '../examples/bvlc_alexnet/test_data_0.npz',
precision=precision, save_inference_model=debug)
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Sun Feb 24 20:44:43 2019
@author: Macrobull
"""
from __future__ import division
# import logging, os
import logging
import os
import numpy as np
logger = logging.getLogger(__name__)
try:
from . import symbolic
except ImportError:
import symbolic
# imports
make_var_name = symbolic._make_var_name
try:
import paddle.fluid.proto.framework_pb2 as framework_pb2
except ImportError:
try:
from . import framework_pb2
except ImportError:
import framework_pb2
logger.warning('importing paddle.fluid.proto.framework_pb2d failed,'
'using fallback framework_pb2')
__all__ = [
'Program',
'Writer',
]
def _irepr(obj,
to='_'):
"""inline repr"""
s = repr(obj)
for c in '\r\n':
s = s.replace(c, to)
if len(s) > 78:
s = s[:75] + '...'
return s
def _flatten_list(obj,
out=None):
if out is None:
out = type(obj)()
for item in obj:
if isinstance(item, list):
_flatten_list(item, out)
else:
out.append(item)
return out
def make_attr_name(name):
"""
make a valid code name for ParamAttr
"""
if name == '':
raise ValueError('name should not be empty')
for s in ' *?\/-:': #
name = name.replace(s, '_')
if not name.startswith('_'):
name = '_' + name
return 'attr' + name
class Program(object):
"""
fluid Python code and ProgramDesc wrapper
"""
DTYPE_TO_FRAMEWORK_DTYPE = {
'bool': framework_pb2.VarType.BOOL,
'int8': framework_pb2.VarType.INT8,
'uint8': framework_pb2.VarType.UINT8,
'int16': framework_pb2.VarType.INT16,
'int32': framework_pb2.VarType.INT32,
'int64': framework_pb2.VarType.INT64,
'float16': framework_pb2.VarType.FP16,
'float32': framework_pb2.VarType.FP32,
'float64': framework_pb2.VarType.FP64
}
@staticmethod
def Dtype(dtype):
"""
convert dtype to fulid framework dtype
"""
dtype = np.dtype(dtype).name
return Program.DTYPE_TO_FRAMEWORK_DTYPE[dtype]
@staticmethod
def OpDescVars(vals, *keys):
"""
make (OpDesc.Var)s
"""
od_vars = []
for idx, key in enumerate(keys):
od_var = framework_pb2.OpDesc.Var()
od_var.parameter = key
if idx < len(vals):
od_var.arguments.append(vals[idx]) #
od_vars.append(od_var)
return od_vars
@staticmethod
def OpDescAttrs(attrs):
"""
make (OpDesc.Attr)s
"""
od_attrs = []
for key, value in attrs.items():
od_attr = framework_pb2.OpDesc.Attr()
od_attr.name = key
if isinstance(value, bool): # bool.mro() = [bool, int, object]
od_attr.type = framework_pb2.BOOLEAN
od_attr.b = value
elif isinstance(value, int): # only cast to int32
od_attr.type = framework_pb2.INT
od_attr.i = value
elif isinstance(value, float):
od_attr.type = framework_pb2.FLOAT
od_attr.f = value
elif isinstance(value, str):
od_attr.type = framework_pb2.STRING
od_attr.s = value
elif isinstance(value, list) and len(value) > 0:
if isinstance(value, bool): # bool.mro() = [bool, int, object]
od_attr.type = framework_pb2.BOOLEANS
od_attr.bools.extend(value)
elif isinstance(value[0], int): # only cast to int32 list
od_attr.type = framework_pb2.INTS
od_attr.ints.extend(value)
elif isinstance(value[0], float):
od_attr.type = framework_pb2.FLOATS
od_attr.floats.extend(value)
elif isinstance(value[0], str):
od_attr.type = framework_pb2.STRINGS
od_attr.strings.extend(value)
od_attrs.append(od_attr)
return od_attrs
def __init__(self):
self.code_mutable = True
self.codes = []
self.op_descs = []
self.var_descs = []
def __str__(self):
return ('Program(code mutable: {}) with:\n'
'codes: {}\n'
'op_descs: {}\n'
'var_descs: {}\n').format(
self.code_mutable,
self.codes,
self.op_descs,
self.var_descs)
def __repr__(self):
return self.__str__()
def Code(self, code):
"""
add Python code
"""
if self.code_mutable:
self.codes.append(code)
def OpDesc(self, name,
input_val_keys=None, output_val_keys=None, attrs=None):
"""
add OpDesc
"""
desc = framework_pb2.OpDesc()
desc.type = name
if input_val_keys is not None:
desc.inputs.extend(self.OpDescVars(*input_val_keys))
if output_val_keys is not None:
desc.outputs.extend(self.OpDescVars(*output_val_keys))
if attrs is not None:
desc.attrs.extend(self.OpDescAttrs(attrs))
self.op_descs.append(desc)
return desc
def VarDesc(self, name,
persistable=False, value_info=None, remove_batch=None):
"""
add VarDesc
"""
var_desc = framework_pb2.VarDesc()
var_desc.name = name
var_desc.persistable = persistable
var_desc.type.type = framework_pb2.VarType.LOD_TENSOR
if value_info and 'dtype' in value_info:
tensor_desc = var_desc.type.lod_tensor.tensor
tensor_desc.data_type = self.Dtype(value_info['dtype']) # required
if 'shape' in value_info:
tensor_desc.dims.extend(value_info['shape'])
if len(value_info['shape']) > 0: # skip scalars
if remove_batch is None:
remove_batch = value_info.get('remove_batch', not persistable)
if remove_batch:
tensor_desc.dims[0] = -1
self.var_descs.append(var_desc)
def Op(self, domain, op_type, *args, **kwargs):
"""
convert an ONNX op and add it to program
"""
if domain != '': # TODO: symbolic file routing by domain
raise ValueError('only default domain supported')
if op_type in symbolic.DEFAULT_OP_MAPPING:
symbolic._default(self, op_type, *args, **kwargs)
elif hasattr(symbolic, op_type):
fn = getattr(symbolic, op_type)
fn(self, *args, **kwargs)
else:
raise ValueError('conversion for {}::{} not supported'
.format(domain, op_type))
def IntermediateOp(self, domain, op_type, *args, **kwargs):
"""
convert an intermediate ONNX op declaring just desc only
"""
code_mutable = self.code_mutable
self.code_mutable = False
try:
self.Op(domain, op_type, *args, **kwargs)
except BaseException as e:
self.code_mutable = code_mutable
raise e
else:
self.code_mutable = code_mutable
class Writer(object):
"""
fluid code and desc writter
"""
CODE_INDENT = ' ' * 4
@staticmethod
def header_code(func_name):
"""
Python header codes
"""
codes = list()
codes.append('"""')
codes.append('This code is generated by onnx2paddle.')
codes.append('"""')
codes.append('')
codes.append('from __future__ import division')
codes.append('')
codes.append('from paddle.fluid import ParamAttr')
codes.append('from paddle.fluid import initializer, layers')
codes.append('')
codes.append('')
codes.append('def {}():'.format(func_name))
return codes
@staticmethod
def emit_op(prog, name, domain, op_type, inputs, outputs, attrs, value_infos, *args, **kwargs):
"""
emit an ONNX op into program
"""
prog.Code('# {}, {}::{}: {} -> {}, {}'
.format(name, domain, op_type, inputs, outputs, _irepr(attrs, to=', ')))
prog.Op(domain, op_type, inputs, outputs, attrs,
value_infos=value_infos, name=name,
*args, **kwargs)
@staticmethod
def emit_param(prog, name, value_info):
"""
emit an ONNX weight into program
"""
if value_info.get('embeded_as', []):
var_names = value_info['embeded_as']
prog.Code('# parameter {} embeded as {}'.format(name, var_names))
for var_name in var_names:
prog.VarDesc(var_name, persistable=True, value_info=value_info)
else:
var_name = make_var_name(name)
attr_name = make_attr_name(name)
prog.Code('# parameter: {}'.format(name))
prog.Code('{} = ParamAttr(name={})' # , trainable=True
.format(attr_name, repr(var_name)))
prog.Code('{} = layers.create_parameter(shape={}, dtype={}, name={}, attr={}'
', default_initializer=initializer.Constant(0))' #, is_bias={}
.format(var_name,
value_info['shape'], repr(value_info['dtype'].name),
repr(name), attr_name)) #, value_info.get('is_bias', False)))
prog.VarDesc(var_name, persistable=True, value_info=value_info)
@staticmethod
def emit_inputs(prog, names, value_infos,
remove_batch=None):
"""
emit ONNX inputs into program
"""
for idx, name in enumerate(names):
var_name = make_var_name(name)
value_info = value_infos[name]
shape = value_info['shape']
if remove_batch is None:
remove_batch = value_info.get('remove_batch', True) # HINT: True by default ?
if remove_batch:
shape = shape[1:]
prog.Code('# input: {}'.format(name))
prog.Code(('{} = layers.data(name={}, shape={}, dtype={}, '
'append_batch_size={})' # , stop_gradient=True
).format(var_name, repr(name),
shape,
repr(value_info['dtype'].name),
remove_batch,
))
prog.OpDesc('feed',
(['feed'], 'X'),
([var_name], 'Out'),
dict(col=idx),
)
prog.VarDesc(var_name, value_info=value_info, remove_batch=remove_batch)
@staticmethod
def emit_outputs(prog, names): #, value_infos
"""
emit ONNX outputs into program
"""
code = 'return '
for idx, name in enumerate(names):
var_name = make_var_name(name)
code += var_name + ', '
prog.OpDesc('fetch',
([var_name], 'X'),
(['fetch'], 'Out'),
dict(col=idx),
)
# var is emitted over ops
prog.Code(code)
@staticmethod
def add_codes(codes, others, indent):
"""
flatten codes in program
"""
for code in _flatten_list(others):
codes.append(Writer.CODE_INDENT * indent + code)
return codes
@staticmethod
def write_weight(weight, filename):
"""
write single weight in fluid desc
"""
if not isinstance(weight, np.ndarray):
raise TypeError('weight is not an ndarray')
tensor_desc = framework_pb2.VarType.TensorDesc()
tensor_desc.data_type = Program.Dtype(weight.dtype)
tensor_desc.dims.extend(weight.shape)
fp = open(filename, 'wb')
np.array([0], dtype=np.int32).tofile(fp) # version
np.array([0], dtype=np.int64).tofile(fp) # LOD level
np.array([0], dtype=np.int32).tofile(fp) # tensor version
np.array([tensor_desc.ByteSize()], dtype=np.int32).tofile(fp)
fp.write(tensor_desc.SerializeToString())
weight.tofile(fp)
fp.close()
@staticmethod
def write_weights(weights, save_dir):
"""
write multiple weights in each fluid desc
"""
for name, weight in weights.items():
if not isinstance(weights, dict):
raise TypeError('dict type weights required')
var_name = make_var_name(name)
filename = os.path.join(save_dir, var_name)
Writer.write_weight(weight, filename)
logger.debug('saved weight %s to %s', name, filename)
@staticmethod
def write_code_file(filename, header_code, *body_codes):
"""
write Python code to file
"""
codes = []
Writer.add_codes(codes, header_code, 0)
for body_code in body_codes:
Writer.add_codes(codes, body_code, 1)
fp = open(filename, 'w')
for code in _flatten_list(codes):
fp.write(code)
fp.write('\n')
fp.close()
logger.debug('saved codes to %s', filename)
@staticmethod
def write_desc_file(filename, op_descs, var_descs):
"""
write desc program to file
"""
prog_desc = framework_pb2.ProgramDesc()
block_desc = prog_desc.blocks.add()
block_desc.idx = 0
block_desc.parent_idx = -1
block_desc.ops.extend(op_descs)
block_desc.vars.extend(var_descs)
# add feed-fetch on vars
feed_var_desc = block_desc.vars.add()
feed_var_desc.name = 'feed'
feed_var_desc.type.type = framework_pb2.VarType.FEED_MINIBATCH
feed_var_desc.persistable = True
fetch_var_desc = block_desc.vars.add()
fetch_var_desc.name = 'fetch'
fetch_var_desc.type.type = framework_pb2.VarType.FETCH_LIST
fetch_var_desc.persistable = True
fp = open(filename, 'wb')
fp.write(prog_desc.SerializeToString())
fp.close()
logger.debug('saved descs to %s', filename)
\ No newline at end of file
-e .
onnx>=1.4.0
paddlepaddle
\ No newline at end of file
# setup.cfg相关文档可参考如下链接
# https://setuptools.readthedocs.io/en/latest/setuptools.html#configuring-setup-using-setup-cfg-files
[metadata]
# 项目名称,发布、安装时以此作为包名
name = onnx2paddle
# 作者姓名和邮箱地址
author = Macrobull
# author_email = .Github@github.com
# 项目版本号,1.0以上版本才视为正式版
version = 0.1.0
# 项目概要描述信息,一句话让用户明白项目概要,不支持中文
description = Inference model conversion from ONNX/PyTorch to Paddle
# 项目的详细描述内容和格式,包括readme和changelog等,通常使用md或rst等格式
long_description = file: README.md, CHANGELOG.md
long_description_content_type = text/markdown
# 开源授权协议,非对外开源的项目无需关注
license = MIT
# 项目类别,非对外开源的项目无需关注
# 从PyPI官方给出的列表中选择符合的内容进行填写
# https://pypi.org/pypi?%3Aaction=list_classifiers
classifier =
Private :: Do Not Upload
Programming Language :: Python
Programming Language :: Python :: 3
Programming Language :: Python :: 3.5
# 关键字,用于检索,方便用户搜索到你的项目
keywords =
onnx paddle
[options]
# 包名称,find:表示自动寻找,可在options.packages.find中进行详细配置
packages = find:
# 依赖管理,包含项目运行时所需要的所有依赖库
# 每行一个依赖库,只写直接依赖,通常无需考虑间接依赖
# 在这里指定的版本限制应当尽量抽象,通常只要指定最低版本和大版本号即可
install_requires =
onnx >= 1.4
# 测试依赖,包含项目测试时所需要的额外的依赖库,格式与install_requires一致
# 可以使用内置的unittest,也可以使用更简单的pytest或nose等单测框架
# python3自带mock库,而python2没有,如果需要使用则必须写入测试依赖中
#tests_require =
# pytest
# mock
# 单测代码目录
#test_suite = onnx2paddle.tests
# 自动添加被版本控制的数据文件
include_package_data = True
# 项目是纯py项目,可以直接执行zip源码包
zip_safe = False
# 可以通过以下配置将指定的函数变成命令行工具,允许用户直接执行
#[options.entry_points]
#console_scripts =
# onnx2paddle = onnx2paddle.cmdline:main
# 可以通过以下配置向包中添加conf或data等非py文件,安装时会一同安装到site-packages目录下
# 仅支持文件,不支持目录,但可以使用通配
#[options.package_data]
#onnx2paddle =
# conf/*
# data/*
[sdist]
dist_dir = output/dist
[bdist_wheel]
# 如果项目可以一份代码同时运行在python2和python3上,则设置universal为1
#universal=1
dist_dir = output/dist
[easy_install]
#!/usr/bin/env python
# -*- coding: UTF-8 -*-
################################################################################
#
# Copyright (c) 2019 Baidu.com, Inc. All Rights Reserved
#
################################################################################
"""
Setup script.
Authors: Macrobull
Date: 2019/02/22 10:25:46
"""
import setuptools
setuptools.setup()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册