未验证 提交 2228423e 编写于 作者: J Jason 提交者: GitHub

Merge pull request #16 from MacroBull/master

Rename onnx2paddle to onnx2fluid
...@@ -57,3 +57,4 @@ coverage.xml ...@@ -57,3 +57,4 @@ coverage.xml
/examples/*.aria2 /examples/*.aria2
/examples/*.onnx /examples/*.onnx
/examples/*.np? /examples/*.np?
**/.*
Onnx2paddle Onnx2Fluid
=== ===
Inference model conversion from ONNX/PyTorch to Paddle Inference model conversion from ONNX/PyTorch to Paddle fluid
快速开始 快速开始
--- ---
......
...@@ -6,7 +6,7 @@ Created on Fri Mar 22 11:19:45 2019 ...@@ -6,7 +6,7 @@ Created on Fri Mar 22 11:19:45 2019
@author: Macrobull @author: Macrobull
Not all ops in this file are supported by both Pytorch and ONNX Not all ops in this file are supported by both Pytorch and ONNX
This only demostrates the conversion/validation workflow from Pytorch to ONNX to Paddle This only demostrates the conversion/validation workflow from Pytorch to ONNX to Paddle fluid
""" """
...@@ -16,12 +16,10 @@ import torch ...@@ -16,12 +16,10 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from onnx2paddle.torch_export_helper import export_onnx_with_validation from onnx2fluid.torch_export_helper import export_onnx_with_validation
idx = 0 idx = 0
######### example: RNN ######## ######### example: RNN ########
# #
#class Model(nn.Module): #class Model(nn.Module):
...@@ -44,7 +42,6 @@ idx = 0 ...@@ -44,7 +42,6 @@ idx = 0
# ['x'], ['y'], # ['x'], ['y'],
# verbose=True, training=False) # verbose=True, training=False)
######### example: random ######## ######### example: random ########
# #
#class Model(nn.Module): #class Model(nn.Module):
...@@ -66,9 +63,9 @@ idx = 0 ...@@ -66,9 +63,9 @@ idx = 0
# ['x'], ['y'], # ['x'], ['y'],
# verbose=True, training=False) # verbose=True, training=False)
######## example: fc ######## ######## example: fc ########
class Model(nn.Module): class Model(nn.Module):
def __init__(self): def __init__(self):
super(Model, self).__init__() super(Model, self).__init__()
...@@ -85,13 +82,12 @@ xb = torch.rand((2, 3)) ...@@ -85,13 +82,12 @@ xb = torch.rand((2, 3))
yp = model(xb) yp = model(xb)
idx += 1 idx += 1
print('index: ', idx) print('index: ', idx)
export_onnx_with_validation(model, (xb, ), 't' + str(idx), export_onnx_with_validation(
['x'], ['y'], model, (xb, ), 't' + str(idx), ['x'], ['y'], verbose=True, training=False)
verbose=True, training=False)
######## example: compare ######## ######## example: compare ########
class Model(nn.Module): class Model(nn.Module):
def __init__(self): def __init__(self):
super(Model, self).__init__() super(Model, self).__init__()
...@@ -110,12 +106,15 @@ xb1 = torch.rand((2, 3)) ...@@ -110,12 +106,15 @@ xb1 = torch.rand((2, 3))
ya, yb, yc = model(xb0, xb1) ya, yb, yc = model(xb0, xb1)
idx += 1 idx += 1
print('index: ', idx) print('index: ', idx)
export_onnx_with_validation(model, (xb0, xb1), 't' + str(idx), export_onnx_with_validation(
['x0', 'x1'], ['ya', 'yb', 'yc'], model, (xb0, xb1),
verbose=True, training=False) 't' + str(idx), ['x0', 'x1'], ['ya', 'yb', 'yc'],
verbose=True,
training=False)
######## example: affine_grid ######## ######## example: affine_grid ########
class Model(nn.Module): class Model(nn.Module):
def __init__(self): def __init__(self):
super(Model, self).__init__() super(Model, self).__init__()
...@@ -130,13 +129,15 @@ theta = torch.rand((2, 2, 3)) ...@@ -130,13 +129,15 @@ theta = torch.rand((2, 2, 3))
grid = model(theta) grid = model(theta)
idx += 1 idx += 1
print('index: ', idx) print('index: ', idx)
export_onnx_with_validation(model, (theta, ), 't' + str(idx), export_onnx_with_validation(
['theta'], ['grid'], model, (theta, ),
verbose=True, training=False) 't' + str(idx), ['theta'], ['grid'],
verbose=True,
training=False)
######## example: conv2d_transpose ######## ######## example: conv2d_transpose ########
class Model(nn.Module): class Model(nn.Module):
def __init__(self): def __init__(self):
super(Model, self).__init__() super(Model, self).__init__()
...@@ -155,12 +156,12 @@ xb = torch.rand((2, 3, 4, 5)) ...@@ -155,12 +156,12 @@ xb = torch.rand((2, 3, 4, 5))
yp = model(xb) yp = model(xb)
idx += 1 idx += 1
print('index: ', idx) print('index: ', idx)
export_onnx_with_validation(model, (xb, ), 't' + str(idx), export_onnx_with_validation(
['x'], ['y'], model, (xb, ), 't' + str(idx), ['x'], ['y'], verbose=True, training=False)
verbose=True, training=False)
######## example: conv2d ######## ######## example: conv2d ########
class Model(nn.Module): class Model(nn.Module):
def __init__(self): def __init__(self):
super(Model, self).__init__() super(Model, self).__init__()
...@@ -181,10 +182,8 @@ xb = torch.rand((2, 3, 4, 5)) ...@@ -181,10 +182,8 @@ xb = torch.rand((2, 3, 4, 5))
yp = model(xb) yp = model(xb)
idx += 1 idx += 1
print('index: ', idx) print('index: ', idx)
export_onnx_with_validation(model, (xb, ), 't' + str(idx), export_onnx_with_validation(
['x'], ['y'], model, (xb, ), 't' + str(idx), ['x'], ['y'], verbose=True, training=False)
verbose=True, training=False)
######### example: conv1d ######## ######### example: conv1d ########
# #
...@@ -210,6 +209,7 @@ export_onnx_with_validation(model, (xb, ), 't' + str(idx), ...@@ -210,6 +209,7 @@ export_onnx_with_validation(model, (xb, ), 't' + str(idx),
######## example: empty ######## ######## example: empty ########
class Model(nn.Module): class Model(nn.Module):
def __init__(self): def __init__(self):
super(Model, self).__init__() super(Model, self).__init__()
...@@ -223,6 +223,5 @@ xb = torch.rand((2, 3)) ...@@ -223,6 +223,5 @@ xb = torch.rand((2, 3))
yp = model(xb) yp = model(xb)
idx += 1 idx += 1
print('index: ', idx) print('index: ', idx)
export_onnx_with_validation(model, (xb, ), 't' + str(idx), export_onnx_with_validation(
['y'], ['y'], model, (xb, ), 't' + str(idx), ['y'], ['y'], verbose=True, training=False)
verbose=True, training=False)
#! /usr/bin/env sh #! /usr/bin/env sh
get_url="proxychains4 aria2c -c -s8 -x8" get_url="aria2c -c -s8 -x8"
base_url="https://s3.amazonaws.com/download.onnx/models/opset_9/" base_url="https://s3.amazonaws.com/download.onnx/models/opset_9/"
flags="-de -o /tmp/export/" flags="-e -o /tmp/export/"
bvlc_alexnet() bvlc_alexnet()
{ {
...@@ -18,13 +18,13 @@ bvlc_alexnet() ...@@ -18,13 +18,13 @@ bvlc_alexnet()
do do
echo "converting $npz ..." echo "converting $npz ..."
python convert_data_npz_0.py "$npz" "data_0" "prob_1" python convert_data_npz_0.py "$npz" "data_0" "prob_1"
python -m onnx2paddle $flags "$fn_model" -t $npz python -m onnx2fluid $flags "$fn_model" -t $npz
done done
for pb_dir in $bn_tar/*/ for pb_dir in $bn_tar/*/
do do
echo "converting $pb_dir ..." echo "converting $pb_dir ..."
python convert_data_pb_0.py "$pb_dir" "data_0" "prob_1" python convert_data_pb_0.py "$pb_dir" "data_0" "prob_1"
python -m onnx2paddle $flags "$fn_model" -t echo $(dirname "$pb_dir/x").npz python -m onnx2fluid $flags "$fn_model" -t $(dirname "$pb_dir/x").npz
done done
} }
...@@ -42,7 +42,7 @@ bvlc_googlenet() ...@@ -42,7 +42,7 @@ bvlc_googlenet()
do do
echo "converting $pb_dir" echo "converting $pb_dir"
python convert_data_pb_0.py "$pb_dir" "data_0" "prob_1" python convert_data_pb_0.py "$pb_dir" "data_0" "prob_1"
python -m onnx2paddle $flags "$fn_model" -t $(dirname "$pb_dir/x").npz python -m onnx2fluid $flags "$fn_model" -t $(dirname "$pb_dir/x").npz
done done
} }
...@@ -60,7 +60,7 @@ bvlc_reference_caffenet() ...@@ -60,7 +60,7 @@ bvlc_reference_caffenet()
do do
echo "converting $pb_dir" echo "converting $pb_dir"
python convert_data_pb_0.py "$pb_dir" "data_0" "prob_1" python convert_data_pb_0.py "$pb_dir" "data_0" "prob_1"
python -m onnx2paddle $flags "$fn_model" -t $(dirname "$pb_dir/x").npz python -m onnx2fluid $flags "$fn_model" -t $(dirname "$pb_dir/x").npz
done done
} }
...@@ -77,8 +77,8 @@ bvlc_reference_rcnn_ilsvrc13() ...@@ -77,8 +77,8 @@ bvlc_reference_rcnn_ilsvrc13()
for pb_dir in $bn_tar/*/ for pb_dir in $bn_tar/*/
do do
echo "converting $pb_dir" echo "converting $pb_dir"
python convert_data_pb_0.py "$pb_dir" "data_0" "softmaxout_1" python convert_data_pb_0.py "$pb_dir" "data_0" "fc_rcnn_1"
python -m onnx2paddle $flags "$fn_model" -t $(dirname "$pb_dir/x").npz python -m onnx2fluid $flags "$fn_model" -t $(dirname "$pb_dir/x").npz
done done
} }
...@@ -96,14 +96,14 @@ inception_v1() ...@@ -96,14 +96,14 @@ inception_v1()
do do
echo "converting $npz ..." echo "converting $npz ..."
python convert_data_npz_0.py "$npz" "data_0" "prob_1" python convert_data_npz_0.py "$npz" "data_0" "prob_1"
python -m onnx2paddle $flags "$fn_model" -t $npz python -m onnx2fluid $flags "$fn_model" -t $npz
done done
for pb_dir in $bn_tar/*/ for pb_dir in $bn_tar/*/
do do
echo "converting $pb_dir" echo "converting $pb_dir"
python convert_data_pb_0.py "$pb_dir" "data_0" "prob_1" python convert_data_pb_0.py "$pb_dir" "data_0" "prob_1"
python -m onnx2paddle $flags "$fn_model" -t $(dirname "$pb_dir/x").npz python -m onnx2fluid $flags "$fn_model" -t $(dirname "$pb_dir/x").npz
done done
} }
...@@ -121,14 +121,14 @@ inception_v2() ...@@ -121,14 +121,14 @@ inception_v2()
do do
echo "converting $npz ..." echo "converting $npz ..."
python convert_data_npz_0.py "$npz" "data_0" "prob_1" python convert_data_npz_0.py "$npz" "data_0" "prob_1"
python -m onnx2paddle $flags "$fn_model" -t $npz python -m onnx2fluid $flags "$fn_model" -t $npz
done done
for pb_dir in $bn_tar/*/ for pb_dir in $bn_tar/*/
do do
echo "converting $pb_dir" echo "converting $pb_dir"
python convert_data_pb_0.py "$pb_dir" "data_0" "prob_1" python convert_data_pb_0.py "$pb_dir" "data_0" "prob_1"
python -m onnx2paddle $flags "$fn_model" -t $(dirname "$pb_dir/x").npz python -m onnx2fluid $flags "$fn_model" -t $(dirname "$pb_dir/x").npz
done done
} }
...@@ -146,14 +146,14 @@ resnet50() ...@@ -146,14 +146,14 @@ resnet50()
do do
echo "converting $npz ..." echo "converting $npz ..."
python convert_data_npz_0.py "$npz" "gpu_0/data_0" "gpu_0/softmaxout_1" python convert_data_npz_0.py "$npz" "gpu_0/data_0" "gpu_0/softmaxout_1"
python -m onnx2paddle $flags "$fn_model" -t $npz python -m onnx2fluid $flags "$fn_model" -t $npz
done done
for pb_dir in $bn_tar/*/ for pb_dir in $bn_tar/*/
do do
echo "converting $pb_dir" echo "converting $pb_dir"
python convert_data_pb_0.py "$pb_dir" "gpu_0/data_0" "gpu_0/softmaxout_1" python convert_data_pb_0.py "$pb_dir" "gpu_0/data_0" "gpu_0/softmaxout_1"
python -m onnx2paddle $flags "$fn_model" -t $(dirname "$pb_dir/x").npz python -m onnx2fluid $flags "$fn_model" -t $(dirname "$pb_dir/x").npz
done done
} }
...@@ -171,7 +171,7 @@ shufflenet() ...@@ -171,7 +171,7 @@ shufflenet()
do do
echo "converting $pb_dir" echo "converting $pb_dir"
python convert_data_pb_0.py "$pb_dir" "gpu_0/data_0" "gpu_0/softmaxout_1" python convert_data_pb_0.py "$pb_dir" "gpu_0/data_0" "gpu_0/softmaxout_1"
python -m onnx2paddle $flags "$fn_model" -t $(dirname "$pb_dir/x").npz python -m onnx2fluid $flags "$fn_model" -t $(dirname "$pb_dir/x").npz
done done
} }
...@@ -189,7 +189,7 @@ squeezenet() ...@@ -189,7 +189,7 @@ squeezenet()
do do
echo "converting $pb_dir" echo "converting $pb_dir"
python convert_data_pb_0.py "$pb_dir" "data_0" "softmaxout_1" python convert_data_pb_0.py "$pb_dir" "data_0" "softmaxout_1"
python -m onnx2paddle $flags "$fn_model" -t $(dirname "$pb_dir/x").npz python -m onnx2fluid $flags "$fn_model" -t $(dirname "$pb_dir/x").npz
done done
} }
...@@ -207,7 +207,7 @@ tiny_yolov2() ...@@ -207,7 +207,7 @@ tiny_yolov2()
do do
echo "converting $pb_dir" echo "converting $pb_dir"
python convert_data_pb_0.py "$pb_dir" "image" "grid" python convert_data_pb_0.py "$pb_dir" "image" "grid"
python -m onnx2paddle $flags "$fn_model" -t $(dirname "$pb_dir/x").npz -x python -m onnx2fluid $flags "$fn_model" -t $(dirname "$pb_dir/x").npz -x
done done
} }
...@@ -225,7 +225,7 @@ vgg19() ...@@ -225,7 +225,7 @@ vgg19()
do do
echo "converting $pb_dir" echo "converting $pb_dir"
python convert_data_pb_0.py "$pb_dir" "data_0" "prob_1" python convert_data_pb_0.py "$pb_dir" "data_0" "prob_1"
python -m onnx2paddle $flags "$fn_model" -t $(dirname "$pb_dir/x").npz python -m onnx2fluid $flags "$fn_model" -t $(dirname "$pb_dir/x").npz
done done
} }
...@@ -243,20 +243,20 @@ zfnet512() ...@@ -243,20 +243,20 @@ zfnet512()
do do
echo "converting $pb_dir" echo "converting $pb_dir"
python convert_data_pb_0.py "$pb_dir" "gpu_0/data_0" "gpu_0/softmax_1" python convert_data_pb_0.py "$pb_dir" "gpu_0/data_0" "gpu_0/softmax_1"
python -m onnx2paddle $flags "$fn_model" -t $(dirname "$pb_dir/x").npz python -m onnx2fluid $flags "$fn_model" -t $(dirname "$pb_dir/x").npz
done done
} }
bvlc_alexnet # data error bvlc_alexnet
bvlc_googlenet # desc error bvlc_googlenet
bvlc_reference_caffenet bvlc_reference_caffenet
bvlc_reference_rcnn_ilsvrc13 bvlc_reference_rcnn_ilsvrc13
inception_v1 ### inception_v1
inception_v2 ### inception_v2
resnet50 # data error resnet50
shufflenet ### shufflenet
squeezenet squeezenet
tiny_yolov2 # not supported tiny_yolov2 # not supported
vgg19 vgg19
zfnet512 # data error zfnet512
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
# #
################################################################################ ################################################################################
""" """
本文件允许模块包以python -m onnx2paddle方式直接执行。 本文件允许模块包以python -m onnx2fluid方式直接执行。
Authors: Macrobull Authors: Macrobull
Date: 2019/02/22 10:25:46 Date: 2019/02/22 10:25:46
...@@ -21,43 +21,67 @@ import argparse ...@@ -21,43 +21,67 @@ import argparse
import logging import logging
import sys import sys
parser = argparse.ArgumentParser(
parser = argparse.ArgumentParser(description='onnx2paddle', description='onnx2fluid',
formatter_class=argparse.ArgumentDefaultsHelpFormatter, formatter_class=argparse.ArgumentDefaultsHelpFormatter,
) )
parser.add_argument('model', nargs=1, parser.add_argument(
'model',
nargs=1,
help='path to model.onnx', help='path to model.onnx',
) )
parser.add_argument('--debug', '-d', action='store_true', parser.add_argument(
'--debug',
'-d',
action='store_true',
help='enable debug logging and checking', help='enable debug logging and checking',
) )
parser.add_argument('--output-dir', '-o', type=str, default='', parser.add_argument(
'--output_dir',
'-o',
type=str,
default='',
help='output directory', help='output directory',
) )
parser.add_argument('--test_data', '-t', type=str, default='', parser.add_argument(
'--test_data',
'-t',
type=str,
default='',
help='I/O golden data for validation, e.g. test.npy, test.npz', help='I/O golden data for validation, e.g. test.npy, test.npz',
) )
parser.add_argument('--embed_params', '-e', action='store_true', parser.add_argument(
help='try to embed parameters for trainable Paddle layers', '--embed_params',
) '-e',
parser.add_argument('--pedantic', action='store_true', default=True, action='store_true',
help='try to embed parameters for trainable Paddle fluid layers',
)
parser.add_argument(
'--pedantic',
action='store_true',
default=True,
help='accept and convert only standard ONNX opset', help='accept and convert only standard ONNX opset',
) )
parser.add_argument('--no-pedantic', '-x', action='store_false', parser.add_argument(
'--no-pedantic',
'-x',
action='store_false',
dest='pedantic', dest='pedantic',
help='process non-standard ONNX ops, this may lead to fails', help='process non-standard ONNX ops, this may lead to fails',
) )
parser.add_argument('--precision', '-p', type=int, default=4, parser.add_argument(
'--precision',
'-p',
type=int,
default=4,
help='assertion decimal for validation', help='assertion decimal for validation',
) )
args = parser.parse_args() args = parser.parse_args()
logging_format = '[%(levelname)8s]%(name)s::%(funcName)s:%(lineno)04d: %(message)s' logging_format = '[%(levelname)8s]%(name)s::%(funcName)s:%(lineno)04d: %(message)s'
logging_level = logging.DEBUG if args.debug else logging.INFO logging_level = logging.DEBUG if args.debug else logging.INFO
logging.basicConfig(format=logging_format, level=logging_level) logging.basicConfig(format=logging_format, level=logging_level)
try: try:
from . import cmdline from . import cmdline
except ImportError: except ImportError:
...@@ -66,5 +90,4 @@ except ImportError: ...@@ -66,5 +90,4 @@ except ImportError:
# imports # imports
main = cmdline.main main = cmdline.main
sys.exit(main(**args.__dict__)) sys.exit(main(**args.__dict__))
...@@ -21,7 +21,6 @@ import logging ...@@ -21,7 +21,6 @@ import logging
import shutil import shutil
import zipfile import zipfile
__all__ = [ __all__ = [
'main', 'main',
] ]
...@@ -42,7 +41,7 @@ def main(**kwargs): ...@@ -42,7 +41,7 @@ def main(**kwargs):
# imports # imports
convert = conversion.convert convert = conversion.convert
logger = logging.getLogger('onnx2paddle') logger = logging.getLogger('onnx2fluid')
debug = kwargs.get('debug', False) debug = kwargs.get('debug', False)
# prepare arguments # prepare arguments
...@@ -58,7 +57,9 @@ def main(**kwargs): ...@@ -58,7 +57,9 @@ def main(**kwargs):
onnx_opset_pedantic = kwargs.get('pedantic', True) onnx_opset_pedantic = kwargs.get('pedantic', True)
# convert # convert
convert(filename, save_dir, convert(
filename,
save_dir,
model_basename=model_basename, model_basename=model_basename,
model_func_name=model_func_name, model_func_name=model_func_name,
embed_params=embed_params, embed_params=embed_params,
...@@ -80,16 +81,18 @@ def main(**kwargs): ...@@ -80,16 +81,18 @@ def main(**kwargs):
# in fact fluid can not fully clear the context # in fact fluid can not fully clear the context
# continuous validation may be inaccurate # continuous validation may be inaccurate
precision = 10 ** -kwargs.get('precision', 4) precision = 10**-kwargs.get('precision', 4)
logger.info('starting validation on desc ...') logger.info('starting validation on desc ...')
passed &= validate(shutil.os.path.join(save_dir, '__model__'), passed &= validate(
shutil.os.path.join(save_dir, '__model__'),
golden_data_filename, golden_data_filename,
precision=precision, precision=precision,
) )
logger.info('starting validation on code ...') logger.info('starting validation on code ...')
passed &= validate(shutil.os.path.join(save_dir, model_basename), passed &= validate(
shutil.os.path.join(save_dir, model_basename),
golden_data_filename, golden_data_filename,
model_func_name=model_func_name, model_func_name=model_func_name,
precision=precision, precision=precision,
...@@ -112,20 +115,22 @@ def main(**kwargs): ...@@ -112,20 +115,22 @@ def main(**kwargs):
if __name__ == '__main__': if __name__ == '__main__':
logging.basicConfig( logging.basicConfig(
format='[%(levelname)8s]%(name)s::%(funcName)s:%(lineno)04d: %(message)s', format=
'[%(levelname)8s]%(name)s::%(funcName)s:%(lineno)04d: %(message)s',
level=logging.DEBUG, level=logging.DEBUG,
) )
# main(model=['../examples/t5.onnx'], # main(model=['../examples/t5.onnx'],
# output_dir='/tmp/export/', # output_dir='/tmp/export/',
# embed_params=False, # embed_params=False,
# pedantic=False, # pedantic=False,
# test_data='../examples/t5.npz', # test_data='../examples/t5.npz',
# debug=True) # debug=True)
main(model=['../examples/shufflenet/model.onnx'], main(
model=['../examples/inception_v2/model.onnx'],
output_dir='/tmp/export/', output_dir='/tmp/export/',
embed_params=True, embed_params=True,
pedantic=False, pedantic=False,
test_data='../examples/shufflenet/test_data_set_0.npz', test_data='../examples/inception_v2/test_data_set_2.npz',
debug=True) debug=True)
...@@ -12,19 +12,21 @@ from __future__ import division ...@@ -12,19 +12,21 @@ from __future__ import division
import logging import logging
import shutil import shutil
__all__ = [ __all__ = [
'convert', 'convert',
] ]
def convert(onnx_model_filename, save_dir, def convert(onnx_model_filename,
model_basename='model.py', model_func_name='inference', save_dir,
model_basename='model.py',
model_func_name='inference',
embed_params=False, embed_params=False,
onnx_opset_version=9, onnx_opset_pedantic=True, onnx_opset_version=9,
onnx_opset_pedantic=True,
debug=False): debug=False):
""" """
convert an ONNX model to Paddle Python code and desc pb convert an ONNX model to Paddle fluid Python code and desc pb
""" """
import onnx import onnx
...@@ -62,7 +64,8 @@ def convert(onnx_model_filename, save_dir, ...@@ -62,7 +64,8 @@ def convert(onnx_model_filename, save_dir,
if onnx_opset_pedantic: # WORKAROUND: RuntimeError: No Adapter For OP if onnx_opset_pedantic: # WORKAROUND: RuntimeError: No Adapter For OP
onnx_model = convert_version(onnx_model, onnx_opset_version) onnx_model = convert_version(onnx_model, onnx_opset_version)
else: # TODO: add new argument for this option else: # TODO: add new argument for this option
logger.warning('opset conversion skipped for onnx_opset_pedantic is OFF') logger.warning(
'opset conversion skipped for onnx_opset_pedantic is OFF')
onnx_model = polish_model(onnx_model) onnx_model = polish_model(onnx_model)
except ValidationError as e: except ValidationError as e:
if onnx_opset_pedantic: if onnx_opset_pedantic:
...@@ -90,13 +93,13 @@ def convert(onnx_model_filename, save_dir, ...@@ -90,13 +93,13 @@ def convert(onnx_model_filename, save_dir,
onnx.save(model, debug_model_filename + '.optimized_and_inffered.onnx') onnx.save(model, debug_model_filename + '.optimized_and_inffered.onnx')
# onnx.save(model, '/tmp/export/optimized_and_inffered.onnx') # onnx.save(model, '/tmp/export/optimized_and_inffered.onnx')
# I/O instances # I/O instances
onnx_graph = onnx_model.graph onnx_graph = onnx_model.graph
paddle_program = Program() fluid_program = Program()
paddle_writer = Writer() fluid_writer = Writer()
# model components # model components
# graph_name = onnx_graph.name # graph_name = onnx_graph.name
graph_inputs = [value.name for value in onnx_graph.input] graph_inputs = [value.name for value in onnx_graph.input]
graph_outputs = [value.name for value in onnx_graph.output] graph_outputs = [value.name for value in onnx_graph.output]
graph_params = [] graph_params = []
...@@ -107,29 +110,37 @@ def convert(onnx_model_filename, save_dir, ...@@ -107,29 +110,37 @@ def convert(onnx_model_filename, save_dir,
for name, weight in graph_weights(onnx_graph): for name, weight in graph_weights(onnx_graph):
value_info = graph_value_infos[name] value_info = graph_value_infos[name]
value_info['embeded_as'] = [] value_info['embeded_as'] = []
value_info['get_weight'] = lambda: weight.tolist() # lazy getter value_info['get_weight'] = (lambda w: lambda: w.tolist())(
weight) # lazy getter
logger.info('conversion started') logger.info('conversion started')
# op set conversion # op set conversion
# topo = 'backward' if embed_params else 'forward' # topo = 'backward' if embed_params else 'forward'
topo = 'forward' topo = 'forward'
for name, domain, op_type, inputs, outputs, attrs in graph_ops(onnx_graph, topo=topo): for name, domain, op_type, inputs, outputs, attrs in graph_ops(
onnx_graph, topo=topo):
logger.debug('translating op %s %s::%s ...', name, domain, op_type) logger.debug('translating op %s %s::%s ...', name, domain, op_type)
if domain == DEFAULT_OP_DOMAIN: if domain == DEFAULT_OP_DOMAIN:
domain = '' domain = ''
try: try:
paddle_writer.emit_op(paddle_program, name, domain, op_type, fluid_writer.emit_op(
inputs, outputs, attrs, fluid_program,
name,
domain,
op_type,
inputs,
outputs,
attrs,
graph_value_infos, graph_value_infos,
embed_params=embed_params, embed_params=embed_params,
) )
except BaseException as e: except BaseException as e:
logger.fatal('conversion failed for:\n\t%s -> %s::%s -> %s', logger.fatal('conversion failed for:\n\t%s -> %s::%s -> %s', inputs,
inputs, domain, op_type, outputs) domain, op_type, outputs)
raise e raise e
op_codes = paddle_program.codes op_codes = fluid_program.codes
paddle_program.codes = [] fluid_program.codes = []
logger.info('%d ops converted', len(paddle_program.op_descs)) logger.info('%d ops converted', len(fluid_program.op_descs))
# weight writer # weight writer
for name, weight in graph_weights(onnx_graph): for name, weight in graph_weights(onnx_graph):
...@@ -138,18 +149,24 @@ def convert(onnx_model_filename, save_dir, ...@@ -138,18 +149,24 @@ def convert(onnx_model_filename, save_dir,
var_names = value_info.get('embeded_as', []) var_names = value_info.get('embeded_as', [])
if var_names: if var_names:
if len(var_names) > 1: if len(var_names) > 1:
logger.info('weight %s is shared between ops, more disk space will be consumed', name) logger.info(
logger.debug('saving weight %s with size of %d, in %d bytes, as %s ...', 'weight %s is shared between ops, more disk space will be consumed',
name)
logger.debug(
'saving weight %s with size of %d, in %d bytes, as %s ...',
name, weight.size, weight.nbytes, var_names) name, weight.size, weight.nbytes, var_names)
for var_name in var_names: # multiple references for var_name in var_names: # multiple references
paddle_writer.write_weight(weight, shutil.os.path.join(save_dir, var_name)) fluid_writer.write_weight(
weight, shutil.os.path.join(save_dir, var_name))
else: else:
logger.debug('saving weight %s with size of %d, in %d bytes, to %s ...', logger.debug(
'saving weight %s with size of %d, in %d bytes, to %s ...',
name, weight.size, weight.nbytes, make_var_name(name)) name, weight.size, weight.nbytes, make_var_name(name))
paddle_writer.write_weight(weight, shutil.os.path.join(save_dir, make_var_name(name))) fluid_writer.write_weight(
paddle_writer.emit_param(paddle_program, name, value_info) weight, shutil.os.path.join(save_dir, make_var_name(name)))
param_codes = paddle_program.codes fluid_writer.emit_param(fluid_program, name, value_info)
paddle_program.codes = [] param_codes = fluid_program.codes
fluid_program.codes = []
logger.info('%d weights converted', len(graph_params)) logger.info('%d weights converted', len(graph_params))
# input writer # input writer
...@@ -159,9 +176,11 @@ def convert(onnx_model_filename, save_dir, ...@@ -159,9 +176,11 @@ def convert(onnx_model_filename, save_dir,
value_info = graph_value_infos[name] value_info = graph_value_infos[name]
assert value_info['external'] assert value_info['external']
external_inputs.append(name) external_inputs.append(name)
paddle_writer.emit_inputs(paddle_program, external_inputs, graph_value_infos, remove_batch=False) # TODO: fluid_writer.emit_inputs(
input_codes = paddle_program.codes fluid_program, external_inputs, graph_value_infos,
paddle_program.codes = [] remove_batch=False) # TODO:
input_codes = fluid_program.codes
fluid_program.codes = []
logger.info('%d inputs converted', len(external_inputs)) logger.info('%d inputs converted', len(external_inputs))
# output writer # output writer
...@@ -171,49 +190,93 @@ def convert(onnx_model_filename, save_dir, ...@@ -171,49 +190,93 @@ def convert(onnx_model_filename, save_dir,
value_info = graph_value_infos[name] value_info = graph_value_infos[name]
assert value_info['external'] assert value_info['external']
external_outputs.append(name) external_outputs.append(name)
paddle_writer.emit_outputs(paddle_program, external_outputs) fluid_writer.emit_outputs(fluid_program, external_outputs)
output_codes = [''] + paddle_program.codes # add an empty line output_codes = [''] + fluid_program.codes # add an empty line
paddle_program.codes = [] fluid_program.codes = []
logger.info('%d outputs converted', len(external_outputs)) logger.info('%d outputs converted', len(external_outputs))
# code generation # code generation
header_codes = fluid_writer.header_code(
model_func_name, 'From: {}'.format(onnx_model_filename))
code_filename = shutil.os.path.join(save_dir, model_basename) code_filename = shutil.os.path.join(save_dir, model_basename)
paddle_writer.write_code_file(code_filename, paddle_writer.header_code(model_func_name), fluid_writer.write_code_file(code_filename, header_codes, input_codes,
input_codes, param_codes, op_codes, output_codes) param_codes, op_codes, output_codes)
logger.info('code saved to %s, factory function: %s', code_filename, model_func_name) logger.info('code saved to %s, factory function: %s', code_filename,
model_func_name)
# desc generation # desc generation
desc_filename = shutil.os.path.join(save_dir, '__model__') desc_filename = shutil.os.path.join(save_dir, '__model__')
paddle_writer.write_desc_file(desc_filename, fluid_writer.write_desc_file(
op_descs=paddle_program.op_descs, desc_filename,
var_descs=paddle_program.var_descs, op_descs=fluid_program.op_descs,
var_descs=fluid_program.var_descs,
) )
logger.info('program saved to %s', desc_filename) logger.info('program saved to %s', desc_filename)
logger.info('conversion finished') logger.info('conversion finished')
# globals().update(locals())
# globals().update(locals())
if __name__ == '__main__': if __name__ == '__main__':
logging.basicConfig( import argparse
format='[%(levelname)8s]%(name)s::%(funcName)s:%(lineno)04d: %(message)s',
level=logging.DEBUG, parser = argparse.ArgumentParser(
description='onnx2fluid.convert',
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
'model',
nargs=1,
help='path to model.onnx',
)
parser.add_argument(
'--debug',
'-d',
action='store_true',
help='enable debug logging and checking',
) )
parser.add_argument(
'--output_dir',
'-o',
type=str,
default='',
help='output directory',
)
parser.add_argument(
'--embed_params',
'-e',
action='store_true',
help='try to embed parameters for trainable Paddle fluid layers',
)
parser.add_argument(
'--pedantic',
action='store_true',
default=True,
help='accept and convert only standard ONNX opset',
)
parser.add_argument(
'--no-pedantic',
'-x',
action='store_false',
dest='pedantic',
help='process non-standard ONNX ops, this may lead to fails',
)
args = parser.parse_args()
model_list = [ logging_format = '[%(levelname)8s]%(name)s::%(funcName)s:%(lineno)04d: %(message)s'
'../examples/t1.onnx', logging_level = logging.DEBUG if args.debug else logging.INFO
'../examples/t2.onnx', logging.basicConfig(format=logging_format, level=logging_level)
'../examples/t3.onnx',
'../examples/t4.onnx', debug = args.debug
'../examples/t5.onnx', model_filename = args.model[0]
'../examples/t6.onnx', save_dir = args.output_dir
# '../examples/t7.onnx', embed_params = args.embed_params
# '../examples/t8.onnx', pedantic = args.pedantic
]
convert(
for model in model_list: model_filename,
pathname, _ = shutil.os.path.splitext(model) save_dir,
convert(model, pathname, embed_params=embed_params,
onnx_opset_pedantic=False, debug=True) onnx_opset_pedantic=pedantic,
convert(model, pathname + '.embeded', debug=debug)
embed_params=True, onnx_opset_pedantic=False, debug=True)
# Generated by the protocol buffer compiler. DO NOT EDIT!
# source: framework.proto
import sys
_b = sys.version_info[0] < 3 and (lambda x: x) or (lambda x: x.encode('latin1'))
from google.protobuf.internal import enum_type_wrapper
from google.protobuf import descriptor as _descriptor
from google.protobuf import message as _message
from google.protobuf import reflection as _reflection
from google.protobuf import symbol_database as _symbol_database
from google.protobuf import descriptor_pb2
# @@protoc_insertion_point(imports)
_sym_db = _symbol_database.Default()
DESCRIPTOR = _descriptor.FileDescriptor(
name='framework.proto',
package='paddle.framework.proto',
syntax='proto2',
serialized_pb=_b(
'\n\x0f\x66ramework.proto\x12\x16paddle.framework.proto\"\x1d\n\x07Version\x12\x12\n\x07version\x18\x01 \x01(\x03:\x01\x30\"\xec\x03\n\x06OpDesc\x12\x0c\n\x04type\x18\x03 \x02(\t\x12\x32\n\x06inputs\x18\x01 \x03(\x0b\x32\".paddle.framework.proto.OpDesc.Var\x12\x33\n\x07outputs\x18\x02 \x03(\x0b\x32\".paddle.framework.proto.OpDesc.Var\x12\x32\n\x05\x61ttrs\x18\x04 \x03(\x0b\x32#.paddle.framework.proto.OpDesc.Attr\x12\x18\n\tis_target\x18\x05 \x01(\x08:\x05\x66\x61lse\x1a\xef\x01\n\x04\x41ttr\x12\x0c\n\x04name\x18\x01 \x02(\t\x12.\n\x04type\x18\x02 \x02(\x0e\x32 .paddle.framework.proto.AttrType\x12\t\n\x01i\x18\x03 \x01(\x05\x12\t\n\x01\x66\x18\x04 \x01(\x02\x12\t\n\x01s\x18\x05 \x01(\t\x12\x0c\n\x04ints\x18\x06 \x03(\x05\x12\x0e\n\x06\x66loats\x18\x07 \x03(\x02\x12\x0f\n\x07strings\x18\x08 \x03(\t\x12\t\n\x01\x62\x18\n \x01(\x08\x12\r\n\x05\x62ools\x18\x0b \x03(\x08\x12\x11\n\tblock_idx\x18\x0c \x01(\x05\x12\t\n\x01l\x18\r \x01(\x03\x12\x12\n\nblocks_idx\x18\x0e \x03(\x05\x12\r\n\x05longs\x18\x0f \x03(\x03\x1a+\n\x03Var\x12\x11\n\tparameter\x18\x01 \x02(\t\x12\x11\n\targuments\x18\x02 \x03(\t\"\xb3\x03\n\x07OpProto\x12\x0c\n\x04type\x18\x01 \x02(\t\x12\x33\n\x06inputs\x18\x02 \x03(\x0b\x32#.paddle.framework.proto.OpProto.Var\x12\x34\n\x07outputs\x18\x03 \x03(\x0b\x32#.paddle.framework.proto.OpProto.Var\x12\x33\n\x05\x61ttrs\x18\x04 \x03(\x0b\x32$.paddle.framework.proto.OpProto.Attr\x12\x0f\n\x07\x63omment\x18\x05 \x02(\t\x1ax\n\x03Var\x12\x0c\n\x04name\x18\x01 \x02(\t\x12\x0f\n\x07\x63omment\x18\x02 \x02(\t\x12\x19\n\nduplicable\x18\x03 \x01(\x08:\x05\x66\x61lse\x12\x1b\n\x0cintermediate\x18\x04 \x01(\x08:\x05\x66\x61lse\x12\x1a\n\x0b\x64ispensable\x18\x05 \x01(\x08:\x05\x66\x61lse\x1ao\n\x04\x41ttr\x12\x0c\n\x04name\x18\x01 \x02(\t\x12.\n\x04type\x18\x02 \x02(\x0e\x32 .paddle.framework.proto.AttrType\x12\x0f\n\x07\x63omment\x18\x03 \x02(\t\x12\x18\n\tgenerated\x18\x04 \x01(\x08:\x05\x66\x61lse\"\xda\x08\n\x07VarType\x12\x32\n\x04type\x18\x01 \x02(\x0e\x32$.paddle.framework.proto.VarType.Type\x12\x41\n\rselected_rows\x18\x02 \x01(\x0b\x32*.paddle.framework.proto.VarType.TensorDesc\x12\x41\n\nlod_tensor\x18\x03 \x01(\x0b\x32-.paddle.framework.proto.VarType.LoDTensorDesc\x12H\n\x0ctensor_array\x18\x04 \x01(\x0b\x32\x32.paddle.framework.proto.VarType.LoDTensorArrayDesc\x12:\n\x06reader\x18\x05 \x01(\x0b\x32*.paddle.framework.proto.VarType.ReaderDesc\x12\x34\n\x05tuple\x18\x07 \x01(\x0b\x32%.paddle.framework.proto.VarType.Tuple\x1aS\n\nTensorDesc\x12\x37\n\tdata_type\x18\x01 \x02(\x0e\x32$.paddle.framework.proto.VarType.Type\x12\x0c\n\x04\x64ims\x18\x02 \x03(\x03\x1a\x61\n\rLoDTensorDesc\x12:\n\x06tensor\x18\x01 \x02(\x0b\x32*.paddle.framework.proto.VarType.TensorDesc\x12\x14\n\tlod_level\x18\x02 \x01(\x05:\x01\x30\x1a\x66\n\x12LoDTensorArrayDesc\x12:\n\x06tensor\x18\x01 \x02(\x0b\x32*.paddle.framework.proto.VarType.TensorDesc\x12\x14\n\tlod_level\x18\x02 \x01(\x05:\x01\x30\x1aO\n\nReaderDesc\x12\x41\n\nlod_tensor\x18\x01 \x03(\x0b\x32-.paddle.framework.proto.VarType.LoDTensorDesc\x1a\x43\n\x05Tuple\x12:\n\x0c\x65lement_type\x18\x01 \x03(\x0e\x32$.paddle.framework.proto.VarType.Type\"\xa2\x02\n\x04Type\x12\x08\n\x04\x42OOL\x10\x00\x12\t\n\x05INT16\x10\x01\x12\t\n\x05INT32\x10\x02\x12\t\n\x05INT64\x10\x03\x12\x08\n\x04\x46P16\x10\x04\x12\x08\n\x04\x46P32\x10\x05\x12\x08\n\x04\x46P64\x10\x06\x12\n\n\x06SIZE_T\x10\x13\x12\t\n\x05UINT8\x10\x14\x12\x08\n\x04INT8\x10\x15\x12\x0e\n\nLOD_TENSOR\x10\x07\x12\x11\n\rSELECTED_ROWS\x10\x08\x12\x12\n\x0e\x46\x45\x45\x44_MINIBATCH\x10\t\x12\x0e\n\nFETCH_LIST\x10\n\x12\x0f\n\x0bSTEP_SCOPES\x10\x0b\x12\x12\n\x0eLOD_RANK_TABLE\x10\x0c\x12\x14\n\x10LOD_TENSOR_ARRAY\x10\r\x12\x0e\n\nPLACE_LIST\x10\x0e\x12\n\n\x06READER\x10\x0f\x12\x07\n\x03RAW\x10\x11\x12\t\n\x05TUPLE\x10\x12\"b\n\x07VarDesc\x12\x0c\n\x04name\x18\x01 \x02(\t\x12-\n\x04type\x18\x02 \x02(\x0b\x32\x1f.paddle.framework.proto.VarType\x12\x1a\n\x0bpersistable\x18\x03 \x01(\x08:\x05\x66\x61lse\"\xa7\x01\n\tBlockDesc\x12\x0b\n\x03idx\x18\x01 \x02(\x05\x12\x12\n\nparent_idx\x18\x02 \x02(\x05\x12-\n\x04vars\x18\x03 \x03(\x0b\x32\x1f.paddle.framework.proto.VarDesc\x12+\n\x03ops\x18\x04 \x03(\x0b\x32\x1e.paddle.framework.proto.OpDesc\x12\x1d\n\x11\x66orward_block_idx\x18\x05 \x01(\x05:\x02-1\"r\n\x0bProgramDesc\x12\x31\n\x06\x62locks\x18\x01 \x03(\x0b\x32!.paddle.framework.proto.BlockDesc\x12\x30\n\x07version\x18\x02 \x01(\x0b\x32\x1f.paddle.framework.proto.Version*\x94\x01\n\x08\x41ttrType\x12\x07\n\x03INT\x10\x00\x12\t\n\x05\x46LOAT\x10\x01\x12\n\n\x06STRING\x10\x02\x12\x08\n\x04INTS\x10\x03\x12\n\n\x06\x46LOATS\x10\x04\x12\x0b\n\x07STRINGS\x10\x05\x12\x0b\n\x07\x42OOLEAN\x10\x06\x12\x0c\n\x08\x42OOLEANS\x10\x07\x12\t\n\x05\x42LOCK\x10\x08\x12\x08\n\x04LONG\x10\t\x12\n\n\x06\x42LOCKS\x10\n\x12\t\n\x05LONGS\x10\x0b\x42\x02H\x03'
))
_sym_db.RegisterFileDescriptor(DESCRIPTOR)
_ATTRTYPE = _descriptor.EnumDescriptor(
name='AttrType',
full_name='paddle.framework.proto.AttrType',
filename=None,
file=DESCRIPTOR,
values=[
_descriptor.EnumValueDescriptor(
name='INT', index=0, number=0, options=None, type=None),
_descriptor.EnumValueDescriptor(
name='FLOAT', index=1, number=1, options=None, type=None),
_descriptor.EnumValueDescriptor(
name='STRING', index=2, number=2, options=None, type=None),
_descriptor.EnumValueDescriptor(
name='INTS', index=3, number=3, options=None, type=None),
_descriptor.EnumValueDescriptor(
name='FLOATS', index=4, number=4, options=None, type=None),
_descriptor.EnumValueDescriptor(
name='STRINGS', index=5, number=5, options=None, type=None),
_descriptor.EnumValueDescriptor(
name='BOOLEAN', index=6, number=6, options=None, type=None),
_descriptor.EnumValueDescriptor(
name='BOOLEANS', index=7, number=7, options=None, type=None),
_descriptor.EnumValueDescriptor(
name='BLOCK', index=8, number=8, options=None, type=None),
_descriptor.EnumValueDescriptor(
name='LONG', index=9, number=9, options=None, type=None),
_descriptor.EnumValueDescriptor(
name='BLOCKS', index=10, number=10, options=None, type=None),
_descriptor.EnumValueDescriptor(
name='LONGS', index=11, number=11, options=None, type=None),
],
containing_type=None,
options=None,
serialized_start=2511,
serialized_end=2659,
)
_sym_db.RegisterEnumDescriptor(_ATTRTYPE)
AttrType = enum_type_wrapper.EnumTypeWrapper(_ATTRTYPE)
INT = 0
FLOAT = 1
STRING = 2
INTS = 3
FLOATS = 4
STRINGS = 5
BOOLEAN = 6
BOOLEANS = 7
BLOCK = 8
LONG = 9
BLOCKS = 10
LONGS = 11
_VARTYPE_TYPE = _descriptor.EnumDescriptor(
name='Type',
full_name='paddle.framework.proto.VarType.Type',
filename=None,
file=DESCRIPTOR,
values=[
_descriptor.EnumValueDescriptor(
name='BOOL', index=0, number=0, options=None, type=None),
_descriptor.EnumValueDescriptor(
name='INT16', index=1, number=1, options=None, type=None),
_descriptor.EnumValueDescriptor(
name='INT32', index=2, number=2, options=None, type=None),
_descriptor.EnumValueDescriptor(
name='INT64', index=3, number=3, options=None, type=None),
_descriptor.EnumValueDescriptor(
name='FP16', index=4, number=4, options=None, type=None),
_descriptor.EnumValueDescriptor(
name='FP32', index=5, number=5, options=None, type=None),
_descriptor.EnumValueDescriptor(
name='FP64', index=6, number=6, options=None, type=None),
_descriptor.EnumValueDescriptor(
name='SIZE_T', index=7, number=19, options=None, type=None),
_descriptor.EnumValueDescriptor(
name='UINT8', index=8, number=20, options=None, type=None),
_descriptor.EnumValueDescriptor(
name='INT8', index=9, number=21, options=None, type=None),
_descriptor.EnumValueDescriptor(
name='LOD_TENSOR', index=10, number=7, options=None, type=None),
_descriptor.EnumValueDescriptor(
name='SELECTED_ROWS', index=11, number=8, options=None, type=None),
_descriptor.EnumValueDescriptor(
name='FEED_MINIBATCH', index=12, number=9, options=None, type=None),
_descriptor.EnumValueDescriptor(
name='FETCH_LIST', index=13, number=10, options=None, type=None),
_descriptor.EnumValueDescriptor(
name='STEP_SCOPES', index=14, number=11, options=None, type=None),
_descriptor.EnumValueDescriptor(
name='LOD_RANK_TABLE', index=15, number=12, options=None,
type=None),
_descriptor.EnumValueDescriptor(
name='LOD_TENSOR_ARRAY',
index=16,
number=13,
options=None,
type=None),
_descriptor.EnumValueDescriptor(
name='PLACE_LIST', index=17, number=14, options=None, type=None),
_descriptor.EnumValueDescriptor(
name='READER', index=18, number=15, options=None, type=None),
_descriptor.EnumValueDescriptor(
name='RAW', index=19, number=17, options=None, type=None),
_descriptor.EnumValueDescriptor(
name='TUPLE', index=20, number=18, options=None, type=None),
],
containing_type=None,
options=None,
serialized_start=1832,
serialized_end=2122,
)
_sym_db.RegisterEnumDescriptor(_VARTYPE_TYPE)
_VERSION = _descriptor.Descriptor(
name='Version',
full_name='paddle.framework.proto.Version',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='version',
full_name='paddle.framework.proto.Version.version',
index=0,
number=1,
type=3,
cpp_type=2,
label=1,
has_default_value=True,
default_value=0,
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
],
extensions=[],
nested_types=[],
enum_types=[],
options=None,
is_extendable=False,
syntax='proto2',
extension_ranges=[],
oneofs=[],
serialized_start=43,
serialized_end=72,
)
_OPDESC_ATTR = _descriptor.Descriptor(
name='Attr',
full_name='paddle.framework.proto.OpDesc.Attr',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='name',
full_name='paddle.framework.proto.OpDesc.Attr.name',
index=0,
number=1,
type=9,
cpp_type=9,
label=2,
has_default_value=False,
default_value=_b("").decode('utf-8'),
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='type',
full_name='paddle.framework.proto.OpDesc.Attr.type',
index=1,
number=2,
type=14,
cpp_type=8,
label=2,
has_default_value=False,
default_value=0,
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='i',
full_name='paddle.framework.proto.OpDesc.Attr.i',
index=2,
number=3,
type=5,
cpp_type=1,
label=1,
has_default_value=False,
default_value=0,
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='f',
full_name='paddle.framework.proto.OpDesc.Attr.f',
index=3,
number=4,
type=2,
cpp_type=6,
label=1,
has_default_value=False,
default_value=float(0),
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='s',
full_name='paddle.framework.proto.OpDesc.Attr.s',
index=4,
number=5,
type=9,
cpp_type=9,
label=1,
has_default_value=False,
default_value=_b("").decode('utf-8'),
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='ints',
full_name='paddle.framework.proto.OpDesc.Attr.ints',
index=5,
number=6,
type=5,
cpp_type=1,
label=3,
has_default_value=False,
default_value=[],
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='floats',
full_name='paddle.framework.proto.OpDesc.Attr.floats',
index=6,
number=7,
type=2,
cpp_type=6,
label=3,
has_default_value=False,
default_value=[],
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='strings',
full_name='paddle.framework.proto.OpDesc.Attr.strings',
index=7,
number=8,
type=9,
cpp_type=9,
label=3,
has_default_value=False,
default_value=[],
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='b',
full_name='paddle.framework.proto.OpDesc.Attr.b',
index=8,
number=10,
type=8,
cpp_type=7,
label=1,
has_default_value=False,
default_value=False,
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='bools',
full_name='paddle.framework.proto.OpDesc.Attr.bools',
index=9,
number=11,
type=8,
cpp_type=7,
label=3,
has_default_value=False,
default_value=[],
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='block_idx',
full_name='paddle.framework.proto.OpDesc.Attr.block_idx',
index=10,
number=12,
type=5,
cpp_type=1,
label=1,
has_default_value=False,
default_value=0,
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='l',
full_name='paddle.framework.proto.OpDesc.Attr.l',
index=11,
number=13,
type=3,
cpp_type=2,
label=1,
has_default_value=False,
default_value=0,
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='blocks_idx',
full_name='paddle.framework.proto.OpDesc.Attr.blocks_idx',
index=12,
number=14,
type=5,
cpp_type=1,
label=3,
has_default_value=False,
default_value=[],
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='longs',
full_name='paddle.framework.proto.OpDesc.Attr.longs',
index=13,
number=15,
type=3,
cpp_type=2,
label=3,
has_default_value=False,
default_value=[],
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
],
extensions=[],
nested_types=[],
enum_types=[],
options=None,
is_extendable=False,
syntax='proto2',
extension_ranges=[],
oneofs=[],
serialized_start=283,
serialized_end=522,
)
_OPDESC_VAR = _descriptor.Descriptor(
name='Var',
full_name='paddle.framework.proto.OpDesc.Var',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='parameter',
full_name='paddle.framework.proto.OpDesc.Var.parameter',
index=0,
number=1,
type=9,
cpp_type=9,
label=2,
has_default_value=False,
default_value=_b("").decode('utf-8'),
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='arguments',
full_name='paddle.framework.proto.OpDesc.Var.arguments',
index=1,
number=2,
type=9,
cpp_type=9,
label=3,
has_default_value=False,
default_value=[],
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
],
extensions=[],
nested_types=[],
enum_types=[],
options=None,
is_extendable=False,
syntax='proto2',
extension_ranges=[],
oneofs=[],
serialized_start=524,
serialized_end=567,
)
_OPDESC = _descriptor.Descriptor(
name='OpDesc',
full_name='paddle.framework.proto.OpDesc',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='type',
full_name='paddle.framework.proto.OpDesc.type',
index=0,
number=3,
type=9,
cpp_type=9,
label=2,
has_default_value=False,
default_value=_b("").decode('utf-8'),
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='inputs',
full_name='paddle.framework.proto.OpDesc.inputs',
index=1,
number=1,
type=11,
cpp_type=10,
label=3,
has_default_value=False,
default_value=[],
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='outputs',
full_name='paddle.framework.proto.OpDesc.outputs',
index=2,
number=2,
type=11,
cpp_type=10,
label=3,
has_default_value=False,
default_value=[],
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='attrs',
full_name='paddle.framework.proto.OpDesc.attrs',
index=3,
number=4,
type=11,
cpp_type=10,
label=3,
has_default_value=False,
default_value=[],
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='is_target',
full_name='paddle.framework.proto.OpDesc.is_target',
index=4,
number=5,
type=8,
cpp_type=7,
label=1,
has_default_value=True,
default_value=False,
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
],
extensions=[],
nested_types=[
_OPDESC_ATTR,
_OPDESC_VAR,
],
enum_types=[],
options=None,
is_extendable=False,
syntax='proto2',
extension_ranges=[],
oneofs=[],
serialized_start=75,
serialized_end=567,
)
_OPPROTO_VAR = _descriptor.Descriptor(
name='Var',
full_name='paddle.framework.proto.OpProto.Var',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='name',
full_name='paddle.framework.proto.OpProto.Var.name',
index=0,
number=1,
type=9,
cpp_type=9,
label=2,
has_default_value=False,
default_value=_b("").decode('utf-8'),
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='comment',
full_name='paddle.framework.proto.OpProto.Var.comment',
index=1,
number=2,
type=9,
cpp_type=9,
label=2,
has_default_value=False,
default_value=_b("").decode('utf-8'),
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='duplicable',
full_name='paddle.framework.proto.OpProto.Var.duplicable',
index=2,
number=3,
type=8,
cpp_type=7,
label=1,
has_default_value=True,
default_value=False,
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='intermediate',
full_name='paddle.framework.proto.OpProto.Var.intermediate',
index=3,
number=4,
type=8,
cpp_type=7,
label=1,
has_default_value=True,
default_value=False,
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='dispensable',
full_name='paddle.framework.proto.OpProto.Var.dispensable',
index=4,
number=5,
type=8,
cpp_type=7,
label=1,
has_default_value=True,
default_value=False,
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
],
extensions=[],
nested_types=[],
enum_types=[],
options=None,
is_extendable=False,
syntax='proto2',
extension_ranges=[],
oneofs=[],
serialized_start=772,
serialized_end=892,
)
_OPPROTO_ATTR = _descriptor.Descriptor(
name='Attr',
full_name='paddle.framework.proto.OpProto.Attr',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='name',
full_name='paddle.framework.proto.OpProto.Attr.name',
index=0,
number=1,
type=9,
cpp_type=9,
label=2,
has_default_value=False,
default_value=_b("").decode('utf-8'),
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='type',
full_name='paddle.framework.proto.OpProto.Attr.type',
index=1,
number=2,
type=14,
cpp_type=8,
label=2,
has_default_value=False,
default_value=0,
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='comment',
full_name='paddle.framework.proto.OpProto.Attr.comment',
index=2,
number=3,
type=9,
cpp_type=9,
label=2,
has_default_value=False,
default_value=_b("").decode('utf-8'),
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='generated',
full_name='paddle.framework.proto.OpProto.Attr.generated',
index=3,
number=4,
type=8,
cpp_type=7,
label=1,
has_default_value=True,
default_value=False,
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
],
extensions=[],
nested_types=[],
enum_types=[],
options=None,
is_extendable=False,
syntax='proto2',
extension_ranges=[],
oneofs=[],
serialized_start=894,
serialized_end=1005,
)
_OPPROTO = _descriptor.Descriptor(
name='OpProto',
full_name='paddle.framework.proto.OpProto',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='type',
full_name='paddle.framework.proto.OpProto.type',
index=0,
number=1,
type=9,
cpp_type=9,
label=2,
has_default_value=False,
default_value=_b("").decode('utf-8'),
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='inputs',
full_name='paddle.framework.proto.OpProto.inputs',
index=1,
number=2,
type=11,
cpp_type=10,
label=3,
has_default_value=False,
default_value=[],
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='outputs',
full_name='paddle.framework.proto.OpProto.outputs',
index=2,
number=3,
type=11,
cpp_type=10,
label=3,
has_default_value=False,
default_value=[],
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='attrs',
full_name='paddle.framework.proto.OpProto.attrs',
index=3,
number=4,
type=11,
cpp_type=10,
label=3,
has_default_value=False,
default_value=[],
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='comment',
full_name='paddle.framework.proto.OpProto.comment',
index=4,
number=5,
type=9,
cpp_type=9,
label=2,
has_default_value=False,
default_value=_b("").decode('utf-8'),
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
],
extensions=[],
nested_types=[
_OPPROTO_VAR,
_OPPROTO_ATTR,
],
enum_types=[],
options=None,
is_extendable=False,
syntax='proto2',
extension_ranges=[],
oneofs=[],
serialized_start=570,
serialized_end=1005,
)
_VARTYPE_TENSORDESC = _descriptor.Descriptor(
name='TensorDesc',
full_name='paddle.framework.proto.VarType.TensorDesc',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='data_type',
full_name='paddle.framework.proto.VarType.TensorDesc.data_type',
index=0,
number=1,
type=14,
cpp_type=8,
label=2,
has_default_value=False,
default_value=0,
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='dims',
full_name='paddle.framework.proto.VarType.TensorDesc.dims',
index=1,
number=2,
type=3,
cpp_type=2,
label=3,
has_default_value=False,
default_value=[],
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
],
extensions=[],
nested_types=[],
enum_types=[],
options=None,
is_extendable=False,
syntax='proto2',
extension_ranges=[],
oneofs=[],
serialized_start=1393,
serialized_end=1476,
)
_VARTYPE_LODTENSORDESC = _descriptor.Descriptor(
name='LoDTensorDesc',
full_name='paddle.framework.proto.VarType.LoDTensorDesc',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='tensor',
full_name='paddle.framework.proto.VarType.LoDTensorDesc.tensor',
index=0,
number=1,
type=11,
cpp_type=10,
label=2,
has_default_value=False,
default_value=None,
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='lod_level',
full_name='paddle.framework.proto.VarType.LoDTensorDesc.lod_level',
index=1,
number=2,
type=5,
cpp_type=1,
label=1,
has_default_value=True,
default_value=0,
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
],
extensions=[],
nested_types=[],
enum_types=[],
options=None,
is_extendable=False,
syntax='proto2',
extension_ranges=[],
oneofs=[],
serialized_start=1478,
serialized_end=1575,
)
_VARTYPE_LODTENSORARRAYDESC = _descriptor.Descriptor(
name='LoDTensorArrayDesc',
full_name='paddle.framework.proto.VarType.LoDTensorArrayDesc',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='tensor',
full_name='paddle.framework.proto.VarType.LoDTensorArrayDesc.tensor',
index=0,
number=1,
type=11,
cpp_type=10,
label=2,
has_default_value=False,
default_value=None,
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='lod_level',
full_name=
'paddle.framework.proto.VarType.LoDTensorArrayDesc.lod_level',
index=1,
number=2,
type=5,
cpp_type=1,
label=1,
has_default_value=True,
default_value=0,
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
],
extensions=[],
nested_types=[],
enum_types=[],
options=None,
is_extendable=False,
syntax='proto2',
extension_ranges=[],
oneofs=[],
serialized_start=1577,
serialized_end=1679,
)
_VARTYPE_READERDESC = _descriptor.Descriptor(
name='ReaderDesc',
full_name='paddle.framework.proto.VarType.ReaderDesc',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='lod_tensor',
full_name='paddle.framework.proto.VarType.ReaderDesc.lod_tensor',
index=0,
number=1,
type=11,
cpp_type=10,
label=3,
has_default_value=False,
default_value=[],
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
],
extensions=[],
nested_types=[],
enum_types=[],
options=None,
is_extendable=False,
syntax='proto2',
extension_ranges=[],
oneofs=[],
serialized_start=1681,
serialized_end=1760,
)
_VARTYPE_TUPLE = _descriptor.Descriptor(
name='Tuple',
full_name='paddle.framework.proto.VarType.Tuple',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='element_type',
full_name='paddle.framework.proto.VarType.Tuple.element_type',
index=0,
number=1,
type=14,
cpp_type=8,
label=3,
has_default_value=False,
default_value=[],
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
],
extensions=[],
nested_types=[],
enum_types=[],
options=None,
is_extendable=False,
syntax='proto2',
extension_ranges=[],
oneofs=[],
serialized_start=1762,
serialized_end=1829,
)
_VARTYPE = _descriptor.Descriptor(
name='VarType',
full_name='paddle.framework.proto.VarType',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='type',
full_name='paddle.framework.proto.VarType.type',
index=0,
number=1,
type=14,
cpp_type=8,
label=2,
has_default_value=False,
default_value=0,
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='selected_rows',
full_name='paddle.framework.proto.VarType.selected_rows',
index=1,
number=2,
type=11,
cpp_type=10,
label=1,
has_default_value=False,
default_value=None,
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='lod_tensor',
full_name='paddle.framework.proto.VarType.lod_tensor',
index=2,
number=3,
type=11,
cpp_type=10,
label=1,
has_default_value=False,
default_value=None,
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='tensor_array',
full_name='paddle.framework.proto.VarType.tensor_array',
index=3,
number=4,
type=11,
cpp_type=10,
label=1,
has_default_value=False,
default_value=None,
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='reader',
full_name='paddle.framework.proto.VarType.reader',
index=4,
number=5,
type=11,
cpp_type=10,
label=1,
has_default_value=False,
default_value=None,
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='tuple',
full_name='paddle.framework.proto.VarType.tuple',
index=5,
number=7,
type=11,
cpp_type=10,
label=1,
has_default_value=False,
default_value=None,
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
],
extensions=[],
nested_types=[
_VARTYPE_TENSORDESC,
_VARTYPE_LODTENSORDESC,
_VARTYPE_LODTENSORARRAYDESC,
_VARTYPE_READERDESC,
_VARTYPE_TUPLE,
],
enum_types=[
_VARTYPE_TYPE,
],
options=None,
is_extendable=False,
syntax='proto2',
extension_ranges=[],
oneofs=[],
serialized_start=1008,
serialized_end=2122,
)
_VARDESC = _descriptor.Descriptor(
name='VarDesc',
full_name='paddle.framework.proto.VarDesc',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='name',
full_name='paddle.framework.proto.VarDesc.name',
index=0,
number=1,
type=9,
cpp_type=9,
label=2,
has_default_value=False,
default_value=_b("").decode('utf-8'),
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='type',
full_name='paddle.framework.proto.VarDesc.type',
index=1,
number=2,
type=11,
cpp_type=10,
label=2,
has_default_value=False,
default_value=None,
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='persistable',
full_name='paddle.framework.proto.VarDesc.persistable',
index=2,
number=3,
type=8,
cpp_type=7,
label=1,
has_default_value=True,
default_value=False,
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
],
extensions=[],
nested_types=[],
enum_types=[],
options=None,
is_extendable=False,
syntax='proto2',
extension_ranges=[],
oneofs=[],
serialized_start=2124,
serialized_end=2222,
)
_BLOCKDESC = _descriptor.Descriptor(
name='BlockDesc',
full_name='paddle.framework.proto.BlockDesc',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='idx',
full_name='paddle.framework.proto.BlockDesc.idx',
index=0,
number=1,
type=5,
cpp_type=1,
label=2,
has_default_value=False,
default_value=0,
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='parent_idx',
full_name='paddle.framework.proto.BlockDesc.parent_idx',
index=1,
number=2,
type=5,
cpp_type=1,
label=2,
has_default_value=False,
default_value=0,
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='vars',
full_name='paddle.framework.proto.BlockDesc.vars',
index=2,
number=3,
type=11,
cpp_type=10,
label=3,
has_default_value=False,
default_value=[],
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='ops',
full_name='paddle.framework.proto.BlockDesc.ops',
index=3,
number=4,
type=11,
cpp_type=10,
label=3,
has_default_value=False,
default_value=[],
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='forward_block_idx',
full_name='paddle.framework.proto.BlockDesc.forward_block_idx',
index=4,
number=5,
type=5,
cpp_type=1,
label=1,
has_default_value=True,
default_value=-1,
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
],
extensions=[],
nested_types=[],
enum_types=[],
options=None,
is_extendable=False,
syntax='proto2',
extension_ranges=[],
oneofs=[],
serialized_start=2225,
serialized_end=2392,
)
_PROGRAMDESC = _descriptor.Descriptor(
name='ProgramDesc',
full_name='paddle.framework.proto.ProgramDesc',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='blocks',
full_name='paddle.framework.proto.ProgramDesc.blocks',
index=0,
number=1,
type=11,
cpp_type=10,
label=3,
has_default_value=False,
default_value=[],
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='version',
full_name='paddle.framework.proto.ProgramDesc.version',
index=1,
number=2,
type=11,
cpp_type=10,
label=1,
has_default_value=False,
default_value=None,
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
],
extensions=[],
nested_types=[],
enum_types=[],
options=None,
is_extendable=False,
syntax='proto2',
extension_ranges=[],
oneofs=[],
serialized_start=2394,
serialized_end=2508,
)
_OPDESC_ATTR.fields_by_name['type'].enum_type = _ATTRTYPE
_OPDESC_ATTR.containing_type = _OPDESC
_OPDESC_VAR.containing_type = _OPDESC
_OPDESC.fields_by_name['inputs'].message_type = _OPDESC_VAR
_OPDESC.fields_by_name['outputs'].message_type = _OPDESC_VAR
_OPDESC.fields_by_name['attrs'].message_type = _OPDESC_ATTR
_OPPROTO_VAR.containing_type = _OPPROTO
_OPPROTO_ATTR.fields_by_name['type'].enum_type = _ATTRTYPE
_OPPROTO_ATTR.containing_type = _OPPROTO
_OPPROTO.fields_by_name['inputs'].message_type = _OPPROTO_VAR
_OPPROTO.fields_by_name['outputs'].message_type = _OPPROTO_VAR
_OPPROTO.fields_by_name['attrs'].message_type = _OPPROTO_ATTR
_VARTYPE_TENSORDESC.fields_by_name['data_type'].enum_type = _VARTYPE_TYPE
_VARTYPE_TENSORDESC.containing_type = _VARTYPE
_VARTYPE_LODTENSORDESC.fields_by_name[
'tensor'].message_type = _VARTYPE_TENSORDESC
_VARTYPE_LODTENSORDESC.containing_type = _VARTYPE
_VARTYPE_LODTENSORARRAYDESC.fields_by_name[
'tensor'].message_type = _VARTYPE_TENSORDESC
_VARTYPE_LODTENSORARRAYDESC.containing_type = _VARTYPE
_VARTYPE_READERDESC.fields_by_name[
'lod_tensor'].message_type = _VARTYPE_LODTENSORDESC
_VARTYPE_READERDESC.containing_type = _VARTYPE
_VARTYPE_TUPLE.fields_by_name['element_type'].enum_type = _VARTYPE_TYPE
_VARTYPE_TUPLE.containing_type = _VARTYPE
_VARTYPE.fields_by_name['type'].enum_type = _VARTYPE_TYPE
_VARTYPE.fields_by_name['selected_rows'].message_type = _VARTYPE_TENSORDESC
_VARTYPE.fields_by_name['lod_tensor'].message_type = _VARTYPE_LODTENSORDESC
_VARTYPE.fields_by_name[
'tensor_array'].message_type = _VARTYPE_LODTENSORARRAYDESC
_VARTYPE.fields_by_name['reader'].message_type = _VARTYPE_READERDESC
_VARTYPE.fields_by_name['tuple'].message_type = _VARTYPE_TUPLE
_VARTYPE_TYPE.containing_type = _VARTYPE
_VARDESC.fields_by_name['type'].message_type = _VARTYPE
_BLOCKDESC.fields_by_name['vars'].message_type = _VARDESC
_BLOCKDESC.fields_by_name['ops'].message_type = _OPDESC
_PROGRAMDESC.fields_by_name['blocks'].message_type = _BLOCKDESC
_PROGRAMDESC.fields_by_name['version'].message_type = _VERSION
DESCRIPTOR.message_types_by_name['Version'] = _VERSION
DESCRIPTOR.message_types_by_name['OpDesc'] = _OPDESC
DESCRIPTOR.message_types_by_name['OpProto'] = _OPPROTO
DESCRIPTOR.message_types_by_name['VarType'] = _VARTYPE
DESCRIPTOR.message_types_by_name['VarDesc'] = _VARDESC
DESCRIPTOR.message_types_by_name['BlockDesc'] = _BLOCKDESC
DESCRIPTOR.message_types_by_name['ProgramDesc'] = _PROGRAMDESC
DESCRIPTOR.enum_types_by_name['AttrType'] = _ATTRTYPE
Version = _reflection.GeneratedProtocolMessageType(
'Version',
(_message.Message, ),
dict(
DESCRIPTOR=_VERSION,
__module__='framework_pb2'
# @@protoc_insertion_point(class_scope:paddle.framework.proto.Version)
))
_sym_db.RegisterMessage(Version)
OpDesc = _reflection.GeneratedProtocolMessageType(
'OpDesc',
(_message.Message, ),
dict(
Attr=_reflection.GeneratedProtocolMessageType(
'Attr',
(_message.Message, ),
dict(
DESCRIPTOR=_OPDESC_ATTR,
__module__='framework_pb2'
# @@protoc_insertion_point(class_scope:paddle.framework.proto.OpDesc.Attr)
)),
Var=_reflection.GeneratedProtocolMessageType(
'Var',
(_message.Message, ),
dict(
DESCRIPTOR=_OPDESC_VAR,
__module__='framework_pb2'
# @@protoc_insertion_point(class_scope:paddle.framework.proto.OpDesc.Var)
)),
DESCRIPTOR=_OPDESC,
__module__='framework_pb2'
# @@protoc_insertion_point(class_scope:paddle.framework.proto.OpDesc)
))
_sym_db.RegisterMessage(OpDesc)
_sym_db.RegisterMessage(OpDesc.Attr)
_sym_db.RegisterMessage(OpDesc.Var)
OpProto = _reflection.GeneratedProtocolMessageType(
'OpProto',
(_message.Message, ),
dict(
Var=_reflection.GeneratedProtocolMessageType(
'Var',
(_message.Message, ),
dict(
DESCRIPTOR=_OPPROTO_VAR,
__module__='framework_pb2'
# @@protoc_insertion_point(class_scope:paddle.framework.proto.OpProto.Var)
)),
Attr=_reflection.GeneratedProtocolMessageType(
'Attr',
(_message.Message, ),
dict(
DESCRIPTOR=_OPPROTO_ATTR,
__module__='framework_pb2'
# @@protoc_insertion_point(class_scope:paddle.framework.proto.OpProto.Attr)
)),
DESCRIPTOR=_OPPROTO,
__module__='framework_pb2'
# @@protoc_insertion_point(class_scope:paddle.framework.proto.OpProto)
))
_sym_db.RegisterMessage(OpProto)
_sym_db.RegisterMessage(OpProto.Var)
_sym_db.RegisterMessage(OpProto.Attr)
VarType = _reflection.GeneratedProtocolMessageType(
'VarType',
(_message.Message, ),
dict(
TensorDesc=_reflection.GeneratedProtocolMessageType(
'TensorDesc',
(_message.Message, ),
dict(
DESCRIPTOR=_VARTYPE_TENSORDESC,
__module__='framework_pb2'
# @@protoc_insertion_point(class_scope:paddle.framework.proto.VarType.TensorDesc)
)),
LoDTensorDesc=_reflection.GeneratedProtocolMessageType(
'LoDTensorDesc',
(_message.Message, ),
dict(
DESCRIPTOR=_VARTYPE_LODTENSORDESC,
__module__='framework_pb2'
# @@protoc_insertion_point(class_scope:paddle.framework.proto.VarType.LoDTensorDesc)
)),
LoDTensorArrayDesc=_reflection.GeneratedProtocolMessageType(
'LoDTensorArrayDesc',
(_message.Message, ),
dict(
DESCRIPTOR=_VARTYPE_LODTENSORARRAYDESC,
__module__='framework_pb2'
# @@protoc_insertion_point(class_scope:paddle.framework.proto.VarType.LoDTensorArrayDesc)
)),
ReaderDesc=_reflection.GeneratedProtocolMessageType(
'ReaderDesc',
(_message.Message, ),
dict(
DESCRIPTOR=_VARTYPE_READERDESC,
__module__='framework_pb2'
# @@protoc_insertion_point(class_scope:paddle.framework.proto.VarType.ReaderDesc)
)),
Tuple=_reflection.GeneratedProtocolMessageType(
'Tuple',
(_message.Message, ),
dict(
DESCRIPTOR=_VARTYPE_TUPLE,
__module__='framework_pb2'
# @@protoc_insertion_point(class_scope:paddle.framework.proto.VarType.Tuple)
)),
DESCRIPTOR=_VARTYPE,
__module__='framework_pb2'
# @@protoc_insertion_point(class_scope:paddle.framework.proto.VarType)
))
_sym_db.RegisterMessage(VarType)
_sym_db.RegisterMessage(VarType.TensorDesc)
_sym_db.RegisterMessage(VarType.LoDTensorDesc)
_sym_db.RegisterMessage(VarType.LoDTensorArrayDesc)
_sym_db.RegisterMessage(VarType.ReaderDesc)
_sym_db.RegisterMessage(VarType.Tuple)
VarDesc = _reflection.GeneratedProtocolMessageType(
'VarDesc',
(_message.Message, ),
dict(
DESCRIPTOR=_VARDESC,
__module__='framework_pb2'
# @@protoc_insertion_point(class_scope:paddle.framework.proto.VarDesc)
))
_sym_db.RegisterMessage(VarDesc)
BlockDesc = _reflection.GeneratedProtocolMessageType(
'BlockDesc',
(_message.Message, ),
dict(
DESCRIPTOR=_BLOCKDESC,
__module__='framework_pb2'
# @@protoc_insertion_point(class_scope:paddle.framework.proto.BlockDesc)
))
_sym_db.RegisterMessage(BlockDesc)
ProgramDesc = _reflection.GeneratedProtocolMessageType(
'ProgramDesc',
(_message.Message, ),
dict(
DESCRIPTOR=_PROGRAMDESC,
__module__='framework_pb2'
# @@protoc_insertion_point(class_scope:paddle.framework.proto.ProgramDesc)
))
_sym_db.RegisterMessage(ProgramDesc)
DESCRIPTOR.has_options = True
DESCRIPTOR._options = _descriptor._ParseOptions(descriptor_pb2.FileOptions(),
_b('H\003'))
# @@protoc_insertion_point(module_scope)
...@@ -18,28 +18,30 @@ from onnx.mapping import TENSOR_TYPE_TO_NP_TYPE ...@@ -18,28 +18,30 @@ from onnx.mapping import TENSOR_TYPE_TO_NP_TYPE
from onnx.numpy_helper import to_array from onnx.numpy_helper import to_array
from onnx.shape_inference import infer_shapes from onnx.shape_inference import infer_shapes
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
__all__ = [ __all__ = [
'print_pb_structure', 'print_pb_structure',
'build_value_refs', 'build_value_refs',
'node_attrs', 'node_topo', 'node_iter', 'node_attrs',
'node_topo',
'node_iter',
'tensor_shape', 'tensor_shape',
'graph_ops', 'graph_weights', 'graph_ops',
'graph_weights',
'inferred_model_value_info', 'inferred_model_value_info',
'optimize_model_skip_op_for_inference', 'optimize_model_skip_op_for_inference',
'optimize_model_strip_initializer', 'optimize_model_strip_initializer',
'optimize_model_cast', 'optimize_model_slice', 'optimize_model_cast',
'optimize_model_slice',
] ]
ONNX_INT_MAX = 2 ** 63 - 1 ONNX_INT_MAX = 2**63 - 1
DEFAULT_OP_DOMAIN = 'ai.onnx' DEFAULT_OP_DOMAIN = 'ai.onnx'
def print_pb_structure(message, def print_pb_structure(message, loop_iterative=False, depth=0):
loop_iterative=False, depth=0):
""" """
print pb fields in its structure print pb fields in its structure
""" """
...@@ -47,14 +49,17 @@ def print_pb_structure(message, ...@@ -47,14 +49,17 @@ def print_pb_structure(message,
if hasattr(message, 'DESCRIPTOR') and hasattr(message.DESCRIPTOR, 'fields'): if hasattr(message, 'DESCRIPTOR') and hasattr(message.DESCRIPTOR, 'fields'):
for field in message.DESCRIPTOR.fields: for field in message.DESCRIPTOR.fields:
print('\t' * depth + '-', field.name) print('\t' * depth + '-', field.name)
print_pb_structure(getattr(message, field.name), print_pb_structure(
loop_iterative=loop_iterative, depth=(depth + 1)) getattr(message, field.name),
loop_iterative=loop_iterative,
depth=(depth + 1))
if loop_iterative and hasattr(message, 'MergeFrom') and hasattr(message, '__len__'): if loop_iterative and hasattr(message, 'MergeFrom') and hasattr(
message, '__len__'):
for idx, item in enumerate(message): for idx, item in enumerate(message):
print('\t' * depth + '-', idx) print('\t' * depth + '-', idx)
print_pb_structure(item, print_pb_structure(
loop_iterative=loop_iterative, depth=(depth + 1)) item, loop_iterative=loop_iterative, depth=(depth + 1))
def build_value_refs(nodes): def build_value_refs(nodes):
...@@ -80,7 +85,8 @@ def get_attribute_value2(attr): ...@@ -80,7 +85,8 @@ def get_attribute_value2(attr):
if attr.type == onnx.AttributeProto.TENSOR: if attr.type == onnx.AttributeProto.TENSOR:
dtype = np.dtype(TENSOR_TYPE_TO_NP_TYPE[attr.t.data_type]) dtype = np.dtype(TENSOR_TYPE_TO_NP_TYPE[attr.t.data_type])
data = attr.t.raw_data data = attr.t.raw_data
value = np.frombuffer(data, dtype=dtype, count=(len(data) // dtype.itemsize)) value = np.frombuffer(
data, dtype=dtype, count=(len(data) // dtype.itemsize))
else: else:
value = get_attribute_value(attr) value = get_attribute_value(attr)
return value return value
...@@ -91,7 +97,8 @@ def node_attrs(node): ...@@ -91,7 +97,8 @@ def node_attrs(node):
convert ONNX node attributes to dict convert ONNX node attributes to dict
""" """
return {attr.name: get_attribute_value2(attr) for attr in node.attribute} # dict return {attr.name: get_attribute_value2(attr)
for attr in node.attribute} # dict
def tensor_shape(tensor): def tensor_shape(tensor):
...@@ -168,8 +175,7 @@ def node_topo(nodes, topo='default'): ...@@ -168,8 +175,7 @@ def node_topo(nodes, topo='default'):
raise ValueError('unkown given topo: {}'.format(topo)) raise ValueError('unkown given topo: {}'.format(topo))
def node_iter(nodes, def node_iter(nodes, indices=None):
indices=None):
""" """
generator for ONNX node graph with given indices generator for ONNX node graph with given indices
""" """
...@@ -194,8 +200,7 @@ def node_iter(nodes, ...@@ -194,8 +200,7 @@ def node_iter(nodes,
yield name, domain, op_type, inputs, outputs, attrs yield name, domain, op_type, inputs, outputs, attrs
def graph_ops(graph, def graph_ops(graph, topo='default'):
topo='default'):
""" """
generator for ONNX node graph with given topology generator for ONNX node graph with given topology
""" """
...@@ -244,7 +249,7 @@ def inferred_model_value_info(model): ...@@ -244,7 +249,7 @@ def inferred_model_value_info(model):
external=True, external=True,
) )
for item in graph.output: for item in graph.output:
# assert item.name not in value_info, 'bypass-model not supported' # assert item.name not in value_info, 'bypass-model not supported'
value_info[item.name] = dict( value_info[item.name] = dict(
dtype=TENSOR_TYPE_TO_NP_TYPE[item.type.tensor_type.elem_type], dtype=TENSOR_TYPE_TO_NP_TYPE[item.type.tensor_type.elem_type],
shape=tensor_shape(item), shape=tensor_shape(item),
...@@ -283,9 +288,7 @@ def skip_node_backward(nodes, src_input_name, dst_output_name, output_refs): ...@@ -283,9 +288,7 @@ def skip_node_backward(nodes, src_input_name, dst_output_name, output_refs):
return processed return processed
def optimize_model_skip_op_for_inference( def optimize_model_skip_op_for_inference(model, op_list=None):
model,
op_list=None):
""" """
skip ops can be bypassed for inference skip ops can be bypassed for inference
""" """
...@@ -297,21 +300,23 @@ def optimize_model_skip_op_for_inference( ...@@ -297,21 +300,23 @@ def optimize_model_skip_op_for_inference(
ret = type(model)() ret = type(model)()
ret.CopyFrom(model) ret.CopyFrom(model)
ret.graph.ClearField('value_info') # WORKAROUND: onnx do not drop old value_info ret.graph.ClearField(
'value_info') # WORKAROUND: onnx do not drop old value_info
ret_nodes = ret.graph.node ret_nodes = ret.graph.node
nodes_to_remove = [] nodes_to_remove = []
for node_idx, node in enumerate(nodes): for node_idx, node in enumerate(nodes):
if not(node.domain == DEFAULT_OP_DOMAIN or node.domain == ''): if not (node.domain == DEFAULT_OP_DOMAIN or node.domain == ''):
continue continue
op_type = node.op_type op_type = node.op_type
if not(op_type in op_list): if not (op_type in op_list):
continue continue
if op_type in ['Dropout']: if op_type in ['Dropout']:
input_name = node.input[0] input_name = node.input[0]
output_name = node.output[0] output_name = node.output[0]
elif not(len(node.input) == 1 and len(node.output) == 1): elif not (len(node.input) == 1 and len(node.output) == 1):
logger.warning('currently only 1-input-1-output op supported, skip required %d: %s', logger.warning(
'currently only 1-input-1-output op supported, skip required %d: %s',
node_idx, node.op_type) node_idx, node.op_type)
continue continue
else: else:
...@@ -319,16 +324,18 @@ def optimize_model_skip_op_for_inference( ...@@ -319,16 +324,18 @@ def optimize_model_skip_op_for_inference(
output_name = node.output[0] output_name = node.output[0]
if output_name in input_refs: if output_name in input_refs:
processed = skip_node_forward(ret_nodes, output_name, input_name, input_refs) processed = skip_node_forward(ret_nodes, output_name, input_name,
input_refs)
elif input_name in output_refs: elif input_name in output_refs:
processed = skip_node_backward(ret_nodes, input_name, output_name, output_refs) processed = skip_node_backward(ret_nodes, input_name, output_name,
output_refs)
else: else:
processed = -1 processed = -1
if processed > 0: if processed > 0:
nodes_to_remove.append(node_idx) nodes_to_remove.append(node_idx)
logger.debug('skip op %d: %s -> %s -> %s', logger.debug('skip op %d: %s -> %s -> %s', node_idx, input_name,
node_idx, input_name, node.op_type, output_name) node.op_type, output_name)
elif processed == 0: elif processed == 0:
logger.warning('weird, no node processed') logger.warning('weird, no node processed')
else: else:
...@@ -342,8 +349,7 @@ def optimize_model_skip_op_for_inference( ...@@ -342,8 +349,7 @@ def optimize_model_skip_op_for_inference(
return ret return ret
def optimize_model_strip_initializer(model, def optimize_model_strip_initializer(model, keep_input_only=True):
keep_input_only=True):
""" """
strip weights for inference strip weights for inference
""" """
...@@ -354,7 +360,8 @@ def optimize_model_strip_initializer(model, ...@@ -354,7 +360,8 @@ def optimize_model_strip_initializer(model,
ret = type(model)() ret = type(model)()
ret.CopyFrom(model) ret.CopyFrom(model)
ret.graph.ClearField('value_info') # WORKAROUND: onnx do not drop old value_info ret.graph.ClearField(
'value_info') # WORKAROUND: onnx do not drop old value_info
# strip initializers # strip initializers
ret.graph.ClearField('initializer') ret.graph.ClearField('initializer')
...@@ -366,8 +373,7 @@ def optimize_model_strip_initializer(model, ...@@ -366,8 +373,7 @@ def optimize_model_strip_initializer(model,
elif not keep_input_only and name in output_refs: elif not keep_input_only and name in output_refs:
ret_initializers.add().CopyFrom(initializer) ret_initializers.add().CopyFrom(initializer)
else: else:
logger.debug('initializer %s(%s[%d]) stripped', logger.debug('initializer %s(%s[%d]) stripped', name,
name,
TENSOR_TYPE_TO_NP_TYPE[initializer.data_type], TENSOR_TYPE_TO_NP_TYPE[initializer.data_type],
len(initializer.raw_data)) len(initializer.raw_data))
...@@ -379,8 +385,8 @@ def optimize_model_strip_initializer(model, ...@@ -379,8 +385,8 @@ def optimize_model_strip_initializer(model,
if name in input_refs or name in out_names: if name in input_refs or name in out_names:
ret_inputs.add().CopyFrom(item) ret_inputs.add().CopyFrom(item)
else: else:
logger.debug('input %s(%s%s) stripped', logger.debug(
name, 'input %s(%s%s) stripped', name,
TENSOR_TYPE_TO_NP_TYPE[item.type.tensor_type.elem_type], TENSOR_TYPE_TO_NP_TYPE[item.type.tensor_type.elem_type],
tensor_shape(item)) tensor_shape(item))
return ret return ret
...@@ -397,13 +403,14 @@ def optimize_model_cast(model): ...@@ -397,13 +403,14 @@ def optimize_model_cast(model):
ret = type(model)() ret = type(model)()
ret.CopyFrom(model) ret.CopyFrom(model)
ret.graph.ClearField('value_info') # WORKAROUND: onnx do not drop old value_info ret.graph.ClearField(
'value_info') # WORKAROUND: onnx do not drop old value_info
ret_nodes = ret.graph.node ret_nodes = ret.graph.node
nodes_to_remove = [] nodes_to_remove = []
for node_idx, node in enumerate(nodes): for node_idx, node in enumerate(nodes):
if not(node.domain == DEFAULT_OP_DOMAIN or node.domain == ''): if not (node.domain == DEFAULT_OP_DOMAIN or node.domain == ''):
continue continue
if not(node.op_type == 'Cast'): if not (node.op_type == 'Cast'):
continue continue
attrs = node_attrs(node) attrs = node_attrs(node)
output_dtype = TENSOR_TYPE_TO_NP_TYPE[attrs['to']] output_dtype = TENSOR_TYPE_TO_NP_TYPE[attrs['to']]
...@@ -417,21 +424,23 @@ def optimize_model_cast(model): ...@@ -417,21 +424,23 @@ def optimize_model_cast(model):
output_name = node.output[0] output_name = node.output[0]
if output_name in input_refs: if output_name in input_refs:
processed = skip_node_forward(ret_nodes, output_name, input_name, input_refs) processed = skip_node_forward(ret_nodes, output_name, input_name,
input_refs)
elif input_name in output_refs: elif input_name in output_refs:
processed = skip_node_backward(ret_nodes, input_name, output_name, output_refs) processed = skip_node_backward(ret_nodes, input_name, output_name,
output_refs)
else: else:
processed = -1 processed = -1
if processed > 0: if processed > 0:
nodes_to_remove.append(node_idx) nodes_to_remove.append(node_idx)
logger.debug('skip %s: %s -> %s Cast op', logger.debug('skip %s: %s -> %s Cast op', node.name, input_dtype,
node.name, input_dtype, output_dtype) output_dtype)
elif processed == 0: elif processed == 0:
logger.warning('weird, no node processed') logger.warning('weird, no node processed')
else: else:
logger.debug('keep standalone %s: %s -> %s Cast op', logger.debug('keep standalone %s: %s -> %s Cast op', node.name,
node.name, input_dtype, output_dtype) input_dtype, output_dtype)
nodes_to_remove.sort(reverse=True) nodes_to_remove.sort(reverse=True)
for node_idx in nodes_to_remove: for node_idx in nodes_to_remove:
...@@ -452,13 +461,14 @@ def optimize_model_slice(model): ...@@ -452,13 +461,14 @@ def optimize_model_slice(model):
chain = [] chain = []
while True: while True:
node = nodes[node_idx] node = nodes[node_idx]
if not(node.domain == DEFAULT_OP_DOMAIN or node.domain == ''): if not (node.domain == DEFAULT_OP_DOMAIN or node.domain == ''):
return chain return chain
if not node.op_type == 'Slice': if not node.op_type == 'Slice':
return chain return chain
chain.append(node_idx) chain.append(node_idx)
output_name = node.output[0] output_name = node.output[0]
if output_name not in input_refs or len(input_refs[output_name]) != 1: if output_name not in input_refs or len(
input_refs[output_name]) != 1:
return chain return chain
node_idx = list(input_refs[output_name])[0] node_idx = list(input_refs[output_name])[0]
...@@ -468,7 +478,8 @@ def optimize_model_slice(model): ...@@ -468,7 +478,8 @@ def optimize_model_slice(model):
for slice_node_idx in slice_chain: for slice_node_idx in slice_chain:
node = nodes[slice_node_idx] node = nodes[slice_node_idx]
attrs = node_attrs(node) attrs = node_attrs(node)
for axis, start, end in zip(attrs['axes'], attrs['starts'], attrs['ends']): for axis, start, end in zip(attrs['axes'], attrs['starts'],
attrs['ends']):
if start == 0 and end == ONNX_INT_MAX: if start == 0 and end == ONNX_INT_MAX:
continue continue
if axis in merged_slice: if axis in merged_slice:
...@@ -480,7 +491,8 @@ def optimize_model_slice(model): ...@@ -480,7 +491,8 @@ def optimize_model_slice(model):
ret = type(model)() ret = type(model)()
ret.CopyFrom(model) ret.CopyFrom(model)
ret.graph.ClearField('value_info') # WORKAROUND: onnx do not drop old value_info ret.graph.ClearField(
'value_info') # WORKAROUND: onnx do not drop old value_info
ret_nodes = ret.graph.node ret_nodes = ret.graph.node
nodes_to_remove = [] nodes_to_remove = []
for node_idx in range(len(nodes)): for node_idx in range(len(nodes)):
...@@ -502,40 +514,48 @@ def optimize_model_slice(model): ...@@ -502,40 +514,48 @@ def optimize_model_slice(model):
output_name = last_node.output[0] output_name = last_node.output[0]
processed = -1 processed = -1
if output_name in input_refs: # 0, [1...] if output_name in input_refs: # 0, [1...]
new_input_name = first_node.output[0] if len(merged_slice) > 0 else input_name new_input_name = first_node.output[0] if len(
processed = skip_node_forward(ret_nodes, output_name, new_input_name, input_refs) merged_slice) > 0 else input_name
processed = skip_node_forward(ret_nodes, output_name,
new_input_name, input_refs)
if processed > 0: if processed > 0:
if len(merged_slice) > 0: if len(merged_slice) > 0:
remain_idx = slice_chain[0] remain_idx = slice_chain[0]
remove_chain = slice_chain[1:] remove_chain = slice_chain[1:]
slice_node = ret_nodes[remain_idx] slice_node = ret_nodes[remain_idx]
for attr in slice_node.attribute: for attr in slice_node.attribute:
attr.CopyFrom(make_attribute(attr.name, attrs[attr.name])) attr.CopyFrom(
make_attribute(attr.name, attrs[attr.name]))
logger.debug('merged slice chain %s -> %s%s -> %s', logger.debug('merged slice chain %s -> %s%s -> %s',
input_name, remain_idx, remove_chain, output_name) input_name, remain_idx, remove_chain,
output_name)
else: else:
remove_chain = slice_chain remove_chain = slice_chain
if processed < 0 and input_name in output_refs: if processed < 0 and input_name in output_refs:
new_output_name = last_node.input[0] if len(merged_slice) > 0 else output_name new_output_name = last_node.input[0] if len(
processed = skip_node_backward(ret_nodes, input_name, new_output_name, output_refs) merged_slice) > 0 else output_name
processed = skip_node_backward(ret_nodes, input_name,
new_output_name, output_refs)
if processed > 0: if processed > 0:
if len(merged_slice) > 0: if len(merged_slice) > 0:
remain_idx = slice_chain[-1] remain_idx = slice_chain[-1]
remove_chain = slice_chain[:-1] remove_chain = slice_chain[:-1]
slice_node = ret_nodes[remain_idx] slice_node = ret_nodes[remain_idx]
for attr in slice_node.attribute: for attr in slice_node.attribute:
attr.CopyFrom(make_attribute(attr.name, attrs[attr.name])) attr.CopyFrom(
make_attribute(attr.name, attrs[attr.name]))
logger.debug('merged slice chain %s -> %s%s -> %s', logger.debug('merged slice chain %s -> %s%s -> %s',
input_name, remove_chain, remain_idx, output_name) input_name, remove_chain, remain_idx,
output_name)
else: else:
remove_chain = slice_chain remove_chain = slice_chain
if processed > 0: if processed > 0:
nodes_to_remove.extend(remove_chain) nodes_to_remove.extend(remove_chain)
if len(merged_slice) == 0: if len(merged_slice) == 0:
logger.debug('skip slice chain %s -> %s -> %s', logger.debug('skip slice chain %s -> %s -> %s', input_name,
input_name, slice_chain, output_name) slice_chain, output_name)
elif processed < 0: # NEVERFIX: not merge standalone slice chain elif processed < 0: # NEVERFIX: not merge standalone slice chain
logger.debug('keep standalone slice chain %s -> %s -> %s', logger.debug('keep standalone slice chain %s -> %s -> %s',
input_name, slice_chain, output_name) input_name, slice_chain, output_name)
...@@ -549,7 +569,8 @@ def optimize_model_slice(model): ...@@ -549,7 +569,8 @@ def optimize_model_slice(model):
if __name__ == '__main__': if __name__ == '__main__':
logging.basicConfig( logging.basicConfig(
format='[%(levelname)8s]%(name)s::%(funcName)s:%(lineno)04d: %(message)s', format=
'[%(levelname)8s]%(name)s::%(funcName)s:%(lineno)04d: %(message)s',
level=logging.DEBUG, level=logging.DEBUG,
) )
......
#!/usr/bin/env python3 #!/usr/bin/env python3
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
""" """
ONNX to Paddle symbolic translation ONNX to Paddle fluid symbolic translation
Created on Mon Feb 25 09:33:43 2019 Created on Mon Feb 25 09:33:43 2019
...@@ -18,20 +18,23 @@ from onnx.mapping import TENSOR_TYPE_TO_NP_TYPE ...@@ -18,20 +18,23 @@ from onnx.mapping import TENSOR_TYPE_TO_NP_TYPE
_logger = _logging.getLogger(__name__) _logger = _logging.getLogger(__name__)
ONNX_INT_MAX = 2**63 - 1
ONNX_INT_MAX = 2 ** 63 - 1 FLUID_INT_MAX = 2**31 - 1
DEFAULT_ONNX_OP_DOMAIN = '' DEFAULT_ONNX_OP_DOMAIN = ''
DEFAULT_PADDLE_OP_NAMESCOPE = '/' DEFAULT_FLUID_OP_NAMESCOPE = '/'
DEFAULT_OP_MAPPING_FIELD_VALUES = _dict() DEFAULT_OP_MAPPING_FIELD_VALUES = _dict()
DEFAULT_OP_MAPPING_FIELD_VALUES['PADDLE_OP'] = '' DEFAULT_OP_MAPPING_FIELD_VALUES['FLUID_OP'] = ''
DEFAULT_OP_MAPPING_FIELD_VALUES['PADDLE_INPUT_ARGS'] = None DEFAULT_OP_MAPPING_FIELD_VALUES['FLUID_INPUT_ARGS'] = None
DEFAULT_OP_MAPPING_FIELD_VALUES['PADDLE_OUTPUT_ARGS'] = None DEFAULT_OP_MAPPING_FIELD_VALUES['FLUID_OUTPUT_ARGS'] = None
DEFAULT_OP_MAPPING_FIELD_VALUES['ATTR_MAPPING'] = dict() # dict(onnx_attr_from=paddle_attr_to) DEFAULT_OP_MAPPING_FIELD_VALUES['ATTR_MAPPING'] = dict(
DEFAULT_OP_MAPPING_FIELD_VALUES['DEFAULTS'] = dict() # dict(paddle_attr=default) ) # dict(onnx_attr_from=fluid_attr_to)
DEFAULT_OP_MAPPING_FIELD_VALUES['INPUT_PERM'] = None # sampler: [idx_onnx_arg...] DEFAULT_OP_MAPPING_FIELD_VALUES['DEFAULTS'] = dict() # dict(fluid_attr=default)
DEFAULT_OP_MAPPING_FIELD_VALUES['OUTPUT_PERM'] = None # sampler: [idx_onnx_arg...] DEFAULT_OP_MAPPING_FIELD_VALUES[
'INPUT_PERM'] = None # sampler: [idx_onnx_arg...]
DEFAULT_OP_MAPPING_FIELD_VALUES[
'OUTPUT_PERM'] = None # sampler: [idx_onnx_arg...]
DEFAULT_OP_MAPPING_FIELD_VALUES['FILL_NAME_FIELD'] = True DEFAULT_OP_MAPPING_FIELD_VALUES['FILL_NAME_FIELD'] = True
DEFAULT_OP_MAPPING = { DEFAULT_OP_MAPPING = {
...@@ -60,7 +63,7 @@ DEFAULT_OP_MAPPING = { ...@@ -60,7 +63,7 @@ DEFAULT_OP_MAPPING = {
'Reciprocal': ['reciprocal', ['X'], ['Out']], 'Reciprocal': ['reciprocal', ['X'], ['Out']],
'Relu': ['relu', ['X'], ['Out']], 'Relu': ['relu', ['X'], ['Out']],
'Selu': ['selu', ['X'], ['Out'], dict(gamma='scale')], 'Selu': ['selu', ['X'], ['Out'], dict(gamma='scale')],
'Shape': ['shape', ['X'], ['Out']], # FIXME: out is int64 - int32 'Shape': ['shape', ['X'], ['Out']], # FIXME: out is int64 vs int32
'Shrink': ['softshrink', ['X'], ['Out'], dict(bias='', labmd='')], 'Shrink': ['softshrink', ['X'], ['Out'], dict(bias='', labmd='')],
'Sigmoid': ['sigmoid', ['X'], ['Out']], 'Sigmoid': ['sigmoid', ['X'], ['Out']],
'Sin': ['sin', ['X'], ['Out']], 'Sin': ['sin', ['X'], ['Out']],
...@@ -74,25 +77,24 @@ DEFAULT_OP_MAPPING = { ...@@ -74,25 +77,24 @@ DEFAULT_OP_MAPPING = {
'Transpose': ['transpose', ['X'], ['Out']], # FIXME: emit transpose2 'Transpose': ['transpose', ['X'], ['Out']], # FIXME: emit transpose2
'Unsqueeze': ['unsqueeze', ['X'], ['Out']], # attrs bypassed, FIXME: emit unsqueeze2 'Unsqueeze': ['unsqueeze', ['X'], ['Out']], # attrs bypassed, FIXME: emit unsqueeze2
## binary ops ## ## binary ops ##
# FIXME: axis=-1 in Paddle is broken, refer it in specialization 'Add': ['elementwise_add', ['X', 'Y'], ['Out'], dict(), dict(axis=-1)],
'Add': ['elementwise_add', ['X', 'Y'], ['Out'], dict(), dict(axis=0)], # 'AffineGrid': ['affine_grid', ['Theta'], ['Output'], dict(size='out_shape')],
# 'AffineGrid': ['affine_grid', ['Theta'], ['Output'], dict(size='out_shape')],
'And': ['logical_and', ['X', 'Y'], ['Out']], 'And': ['logical_and', ['X', 'Y'], ['Out']],
'Div': ['elementwise_div', ['X', 'Y'], ['Out'], dict(), dict(axis=0)], 'Div': ['elementwise_div', ['X', 'Y'], ['Out'], dict(), dict(axis=-1)],
'Equal': ['equal', ['X', 'Y'], ['Out'], dict(), dict(), None, None, False], 'Equal': ['equal', ['X', 'Y'], ['Out'], dict(), dict(), None, None, False],
'Greater': ['less_than', ['X', 'Y'], ['Out'], dict(), dict(), None, None, False], 'Greater': ['less_than', ['X', 'Y'], ['Out'], dict(), dict(), None, None, False],
'Less': ['less_than', ['X', 'Y'], ['Out'], dict(), dict(), None, None, False], 'Less': ['less_than', ['X', 'Y'], ['Out'], dict(), dict(), None, None, False],
'MatMul': ['matmul', ['X', 'Y'], ['Out']], # defaults excluded for transpose_x - transpose_X 'MatMul': ['matmul', ['X', 'Y'], ['Out']], # defaults excluded for transpose_x vs transpose_X
'Max': ['elementwise_max', ['X', 'Y'], ['Out'], dict(), dict(axis=0)], 'Max': ['elementwise_max', ['X', 'Y'], ['Out'], dict(), dict(axis=-1)],
'Min': ['elementwise_min', ['X', 'Y'], ['Out'], dict(), dict(axis=0)], 'Min': ['elementwise_min', ['X', 'Y'], ['Out'], dict(), dict(axis=-1)],
'Mul': ['elementwise_mul', ['X', 'Y'], ['Out'], dict(), dict(axis=0)], 'Mul': ['elementwise_mul', ['X', 'Y'], ['Out'], dict(), dict(axis=-1)],
'Not': ['logical_not', ['X', 'Y'], ['Out']], 'Not': ['logical_not', ['X', 'Y'], ['Out']],
'OneHot': # assuming values=[0, 1], axis=-1 and drop them 'OneHot': # assuming values=[0, 1], axis=-1 and drop them
['one_hot', ['Input', 'Depth'], ['Out'], dict(axis=''), dict(), ['one_hot', ['Input', 'Depth'], ['Out'], dict(axis=''), dict(),
[0, 1], None, False], [0, 1], None, False],
'Or': ['logical_or', ['X', 'Y'], ['Out']], 'Or': ['logical_or', ['X', 'Y'], ['Out']],
'Pow': ['elementwise_pow', ['X', 'Y'], ['Out'], dict(), dict(axis=0)], # TODO: pow for scalar exponent 'Pow': ['elementwise_pow', ['X', 'Y'], ['Out'], dict(), dict(axis=-1)], # TODO: pow for scalar exponent
'Sub': ['elementwise_sub', ['X', 'Y'], ['Out'], dict(), dict(axis=0)], 'Sub': ['elementwise_sub', ['X', 'Y'], ['Out'], dict(), dict(axis=-1)],
'Xor': ['logical_xor', ['X', 'Y'], ['Out']], 'Xor': ['logical_xor', ['X', 'Y'], ['Out']],
# reduce ops # reduce ops
'ReduceMax': ['reduce_max', ['X'], ['Out'], dict(axes='dim', keepdims='keep_dim')], 'ReduceMax': ['reduce_max', ['X'], ['Out'], dict(axes='dim', keepdims='keep_dim')],
...@@ -106,29 +108,34 @@ DEFAULT_OP_MAPPING = { ...@@ -106,29 +108,34 @@ DEFAULT_OP_MAPPING = {
} }
DEFAULT_IOA_CONSTRAINT = { DEFAULT_IOA_CONSTRAINT = {
'ArgMax': 'ArgMax': [
[(lambda i, o, a: a.get('keepdims', 1) == 1, 'only keepdims = 0 is supported'), (lambda i, o, a: a.get('keepdims', 1) == 1,
'only keepdims = 0 is supported'),
], ],
'ArgMin': 'ArgMin': [
[(lambda i, o, a: a.get('keepdims', 1) == 1, 'only keepdims = 0 is supported'), (lambda i, o, a: a.get('keepdims', 1) == 1,
'only keepdims = 0 is supported'),
], ],
'Gather': 'Gather': [
[(lambda i, o, a: a.get('axis', 0) == 0, 'only axis = 0 is supported'), (lambda i, o, a: a.get('axis', 0) == 0, 'only axis = 0 is supported'),
], ],
'Shrink': 'Shrink': [
[(lambda i, o, a: a.get('bias', 0) == a.get('lambd', 0.5), 'only SoftShrink with bias = lambd is supported'), (lambda i, o, a: a.get('bias', 0) == a.get('lambd', 0.5),
'only SoftShrink with bias = lambd is supported'),
], ],
# 'Softmax': # 'Softmax':
# [(lambda i, o, a: a.get('axis', 1) == -2, 'Paddle Softmax works on dim -2 only'), # [(lambda i, o, a: a.get('axis', 1) == -2, 'Paddle fluid Softmax works on dim -2 only'),
# ], # ],
'OneHot': 'OneHot': [
[(lambda i, o, a: a.get('axis', -1) == -1, 'only axis = -1 is supported'), (lambda i, o, a: a.get('axis', -1) == -1,
'only axis = -1 is supported'),
], ],
'Scatter': 'Scatter': [
[(lambda i, o, a: a.get('axis', 0) == 0, 'only axis = 0 is supported'), (lambda i, o, a: a.get('axis', 0) == 0, 'only axis = 0 is supported'),
], ],
'TopK': 'TopK': [
[(lambda i, o, a: a.get('axis', -1) == -1, 'only axis = -1 is supported'), (lambda i, o, a: a.get('axis', -1) == -1,
'only axis = -1 is supported'),
], ],
} }
...@@ -142,7 +149,7 @@ def _make_var_name(name): ...@@ -142,7 +149,7 @@ def _make_var_name(name):
return '_' return '_'
if name[0].isdigit(): if name[0].isdigit():
return 'var_' + name return 'var_' + name
for s in ' *?\/-:': for s in ' *?\\/-:':
name = name.replace(s, '_') name = name.replace(s, '_')
if name.startswith('_'): if name.startswith('_'):
name = 'var' + name name = 'var' + name
...@@ -188,82 +195,91 @@ def _shape_or_none(value_infos, val_name): ...@@ -188,82 +195,91 @@ def _shape_or_none(value_infos, val_name):
# return value_info.get('const_value', var_name) # return value_info.get('const_value', var_name)
def _default(prog, op_type, inputs, outputs, attrs, def _default(prog, op_type, inputs, outputs, attrs, *args, name='', **kwargs):
*args,
name='',
**kwargs):
info = DEFAULT_OP_MAPPING[op_type] info = DEFAULT_OP_MAPPING[op_type]
info.extend(list(DEFAULT_OP_MAPPING_FIELD_VALUES.values())[len(info):]) info.extend(list(DEFAULT_OP_MAPPING_FIELD_VALUES.values())[len(info):])
(paddle_op, (
paddle_input_args, paddle_output_args, fluid_op,
attr_mapping, default_attrs, fluid_input_args,
input_perm, output_perm, fluid_output_args,
attr_mapping,
default_attrs,
input_perm,
output_perm,
fill_name_field, fill_name_field,
) = info ) = info
if paddle_op in DEFAULT_IOA_CONSTRAINT: if fluid_op in DEFAULT_IOA_CONSTRAINT:
for predicate, message in DEFAULT_IOA_CONSTRAINT[paddle_op]: for predicate, message in DEFAULT_IOA_CONSTRAINT[fluid_op]:
assert predicate(inputs, outputs, attrs), message assert predicate(inputs, outputs, attrs), message
# bypass if key absent, drop if mapped key is '' or '_' # bypass if key absent, drop if mapped key is '' or '_'
mapped_attrs = {attr_mapping.get(key, key): value for key, value in attrs.items()} mapped_attrs = {
attr_mapping.get(key, key): value
for key, value in attrs.items()
}
if '' in mapped_attrs: if '' in mapped_attrs:
mapped_attrs.pop('') mapped_attrs.pop('')
if '_' in mapped_attrs: if '_' in mapped_attrs:
mapped_attrs.pop('_') mapped_attrs.pop('_')
paddle_attrs = default_attrs.copy() fluid_attrs = default_attrs.copy()
paddle_attrs.update(mapped_attrs) # as new attrs fluid_attrs.update(mapped_attrs) # as new attrs
val_inps = inputs if input_perm is None else map(lambda i: inputs[i], input_perm) val_inps = inputs if input_perm is None else map(lambda i: inputs[i],
val_outs = outputs if output_perm is None else map(lambda i: outputs[i], output_perm) input_perm)
val_outs = outputs if output_perm is None else map(lambda i: outputs[i],
output_perm)
var_inps = [_make_var_name(val) for val in val_inps] var_inps = [_make_var_name(val) for val in val_inps]
var_outs = [_make_var_name(val) for val in val_outs] var_outs = [_make_var_name(val) for val in val_outs]
arg_name = ', name={}'.format(repr(name)) if fill_name_field and name else '' arg_name = ', name={}'.format(
arg_attrs = [', {}={}'.format(key, value) for key, value in paddle_attrs.items()] repr(name)) if fill_name_field and name else ''
arg_attrs = [
prog.Code('{} = layers.{}({}{}{})' ', {}={}'.format(key, value) for key, value in fluid_attrs.items()
.format(', '.join(var_outs), ]
paddle_op,
prog.Code('{} = layers.{}({}{}{})'.format(
', '.join(var_outs),
fluid_op,
', '.join(var_inps), ', '.join(var_inps),
''.join(arg_attrs), ''.join(arg_attrs),
arg_name, arg_name,
)) ))
for val_out, var_out in zip(val_outs, var_outs): for var_out in var_outs:
prog.VarDesc(var_out) prog.VarDesc(var_out)
# dummy var_out # dummy var_out
num_vars = len(var_outs) num_vars = len(var_outs)
num_args = len(paddle_output_args) num_args = len(fluid_output_args)
if num_vars < num_args: if num_vars < num_args:
assert fill_name_field, 'name required to naming dummy output variable' assert fill_name_field, 'name required to naming dummy output variable'
for idx_out in range(num_vars, num_args): for idx_out in range(num_vars, num_args):
var_out = _make_var_name(name + '.' + paddle_output_args[idx_out].lower()) var_out = _make_var_name(name + '.' +
fluid_output_args[idx_out].lower())
var_outs.append(var_out) var_outs.append(var_out)
prog.VarDesc(var_out) prog.VarDesc(var_out)
prog.OpDesc(paddle_op, prog.OpDesc(fluid_op, (var_inps, *fluid_input_args),
(var_inps, *paddle_input_args), (var_outs, *fluid_output_args), fluid_attrs)
(var_outs, *paddle_output_args),
paddle_attrs)
def _assign(prog, attrs): def _assign(prog, attrs):
mapping = attrs['mapping'] # additional mapping = attrs['mapping'] # additional
paddle_op = 'assign' fluid_op = 'assign'
for val_dst, val_src in mapping.items(): for val_dst, val_src in mapping.items():
var_dst = _make_var_name(val_dst) var_dst = _make_var_name(val_dst)
var_src = _make_var_name(val_src) var_src = _make_var_name(val_src)
prog.Code('{} = {}'.format(var_dst, var_src)) prog.Code('{} = {}'.format(var_dst, var_src))
# prog.Code('{} = layers.{}({})' # prog.Code('{} = layers.{}({})'
# .format(var_dst, # .format(var_dst,
# paddle_op, # fluid_op,
# var_src, # var_src,
# )) # ))
prog.VarDesc(var_dst) prog.VarDesc(var_dst)
prog.OpDesc(paddle_op, prog.OpDesc(
fluid_op,
([var_src], 'X'), ([var_src], 'X'),
([var_dst], 'Out'), ([var_dst], 'Out'),
dict(), dict(),
...@@ -283,10 +299,13 @@ def _pad_if_asymmetric(prog, pads, val_name, value_infos): # pads: SSEE ...@@ -283,10 +299,13 @@ def _pad_if_asymmetric(prog, pads, val_name, value_infos): # pads: SSEE
return pads[:ndims], None return pads[:ndims], None
val_padded = val_name + '_padded' val_padded = val_name + '_padded'
prog.Op('', 'Pad', prog.Op(
'',
'Pad',
[val_name], [val_name],
[val_padded], # val [val_padded], # val
dict(mode='constant', dict(
mode='constant',
value=0., value=0.,
pads=pads, pads=pads,
), ),
...@@ -295,7 +314,13 @@ def _pad_if_asymmetric(prog, pads, val_name, value_infos): # pads: SSEE ...@@ -295,7 +314,13 @@ def _pad_if_asymmetric(prog, pads, val_name, value_infos): # pads: SSEE
) )
return [0] * ndims, val_padded return [0] * ndims, val_padded
def _adaptive_pool(prog, pool_type, inputs, outputs, attrs, value_infos,
def _adaptive_pool(prog,
pool_type,
inputs,
outputs,
attrs,
value_infos,
name=''): name=''):
# I/O # I/O
val_x, = inputs val_x, = inputs
...@@ -312,11 +337,12 @@ def _adaptive_pool(prog, pool_type, inputs, outputs, attrs, value_infos, ...@@ -312,11 +337,12 @@ def _adaptive_pool(prog, pool_type, inputs, outputs, attrs, value_infos,
pool_size = attrs['output_size'] # required pool_size = attrs['output_size'] # required
output_shape = _shape_or_none(value_infos, val_y) output_shape = _shape_or_none(value_infos, val_y)
if output_shape is not None: if output_shape is not None:
assert pool_size == output_shape[2:], 'pool_size unmatches shape of Y' # NC... assert pool_size == output_shape[
2:], 'pool_size unmatches shape of Y' # NC...
poolnd = len(pool_size) poolnd = len(pool_size)
assert 2 <= poolnd <= 3, 'only pool2d and pool3d supported' assert 2 <= poolnd <= 3, 'only pool2d and pool3d supported'
paddle_op = 'adaptive_pool{}d'.format(poolnd) fluid_op = 'adaptive_pool{}d'.format(poolnd)
name_attr = ', name={}'.format(repr(name)) if name else '' name_attr = ', name={}'.format(repr(name)) if name else ''
# generation # generation
...@@ -324,9 +350,10 @@ def _adaptive_pool(prog, pool_type, inputs, outputs, attrs, value_infos, ...@@ -324,9 +350,10 @@ def _adaptive_pool(prog, pool_type, inputs, outputs, attrs, value_infos,
', require_index={}' ', require_index={}'
', pool_size={}' ', pool_size={}'
', pool_type={}' ', pool_type={}'
'{})' '{})'.format(
.format(var_y, ', {}'.format(var_indices) if has_indices else '', var_y,
paddle_op, ', {}'.format(var_indices) if has_indices else '',
fluid_op,
var_x, var_x,
# attrs # attrs
has_indices, has_indices,
...@@ -334,14 +361,16 @@ def _adaptive_pool(prog, pool_type, inputs, outputs, attrs, value_infos, ...@@ -334,14 +361,16 @@ def _adaptive_pool(prog, pool_type, inputs, outputs, attrs, value_infos,
repr(pool_type), repr(pool_type),
name_attr, name_attr,
)) ))
paddle_op = 'pool{}d'.format(poolnd) fluid_op = 'pool{}d'.format(poolnd)
prog.VarDesc(var_y) prog.VarDesc(var_y)
if has_indices: if has_indices:
prog.VarDesc(var_indices) prog.VarDesc(var_indices)
prog.OpDesc(paddle_op, prog.OpDesc(
fluid_op,
([var_x], 'X'), ([var_x], 'X'),
([var_y] + ([var_indices] if has_indices else []), 'Out', 'Indices'), ([var_y] + ([var_indices] if has_indices else []), 'Out', 'Indices'),
dict(global_pooling=False, dict(
global_pooling=False,
adaptive=True, adaptive=True,
exclusive=True, exclusive=True,
require_index=has_indices, require_index=has_indices,
...@@ -351,8 +380,7 @@ def _adaptive_pool(prog, pool_type, inputs, outputs, attrs, value_infos, ...@@ -351,8 +380,7 @@ def _adaptive_pool(prog, pool_type, inputs, outputs, attrs, value_infos,
) )
def _global_pool(prog, pool_type, inputs, outputs, attrs, value_infos, def _global_pool(prog, pool_type, inputs, outputs, attrs, value_infos, name=''):
name=''):
# I/O # I/O
val_x, = inputs val_x, = inputs
val_y, = outputs val_y, = outputs
...@@ -369,33 +397,34 @@ def _global_pool(prog, pool_type, inputs, outputs, attrs, value_infos, ...@@ -369,33 +397,34 @@ def _global_pool(prog, pool_type, inputs, outputs, attrs, value_infos,
poolnd = len(output_shape) - 2 # NC... poolnd = len(output_shape) - 2 # NC...
assert 2 <= poolnd <= 3, 'only pool2d and pool3d supported' assert 2 <= poolnd <= 3, 'only pool2d and pool3d supported'
paddle_op = 'pool{}d'.format(poolnd) fluid_op = 'pool{}d'.format(poolnd)
name_attr = ', name={}'.format(repr(name)) if name else '' name_attr = ', name={}'.format(repr(name)) if name else ''
# generation # generation
prog.Code('{} = layers.{}({}, global_pooling=True' prog.Code('{} = layers.{}({}, global_pooling=True'
', pool_type={}' ', pool_type={}'
'{})' '{})'.format(
.format(var_y, var_y,
paddle_op, fluid_op,
var_x, var_x,
# attrs # attrs
repr(pool_type), repr(pool_type),
name_attr, name_attr,
)) ))
prog.VarDesc(var_y) prog.VarDesc(var_y)
prog.OpDesc(paddle_op, prog.OpDesc(
fluid_op,
([var_x], 'X'), ([var_x], 'X'),
([var_y], 'Out'), ([var_y], 'Out'),
dict(global_pooling=True, dict(
global_pooling=True,
adaptive=False, adaptive=False,
pooling_type=pool_type, pooling_type=pool_type,
), ),
) )
def _pool(prog, pool_type, inputs, outputs, attrs, value_infos, def _pool(prog, pool_type, inputs, outputs, attrs, value_infos, name=''):
name=''):
# I/O # I/O
val_x, = inputs val_x, = inputs
val_y, = outputs[:1] val_y, = outputs[:1]
...@@ -407,12 +436,14 @@ def _pool(prog, pool_type, inputs, outputs, attrs, value_infos, ...@@ -407,12 +436,14 @@ def _pool(prog, pool_type, inputs, outputs, attrs, value_infos,
var_indices = _make_var_name(val_indices) var_indices = _make_var_name(val_indices)
# interpretation # interpretation
assert attrs.get('auto_pad', 'NOTSET') == 'NOTSET', 'only auto_pad = NOTSET supported' # optional assert attrs.get(
'auto_pad',
'NOTSET') == 'NOTSET', 'only auto_pad = NOTSET supported' # optional
pool_size = attrs['kernel_shape'] # required pool_size = attrs['kernel_shape'] # required
poolnd = len(pool_size) poolnd = len(pool_size)
assert 2 <= poolnd <= 3, 'only pool2d and pool3d supported' assert 2 <= poolnd <= 3, 'only pool2d and pool3d supported'
paddle_op = 'pool{}d'.format(poolnd) fluid_op = 'pool{}d'.format(poolnd)
strides = attrs.get('strides', [1] * poolnd) # optional strides = attrs.get('strides', [1] * poolnd) # optional
pads = attrs.get('pads', [0] * len(pool_size * 2)) # optional pads = attrs.get('pads', [0] * len(pool_size * 2)) # optional
paddings, val_x_padded = _pad_if_asymmetric(prog, pads, val_x, value_infos) paddings, val_x_padded = _pad_if_asymmetric(prog, pads, val_x, value_infos)
...@@ -429,9 +460,10 @@ def _pool(prog, pool_type, inputs, outputs, attrs, value_infos, ...@@ -429,9 +460,10 @@ def _pool(prog, pool_type, inputs, outputs, attrs, value_infos,
', pool_stride={}' ', pool_stride={}'
', pool_padding={}' ', pool_padding={}'
', ceil_mode={}' ', ceil_mode={}'
'{})' '{})'.format(
.format(var_y, ', {}'.format(var_indices) if has_indices else '', var_y,
paddle_op, ', {}'.format(var_indices) if has_indices else '',
fluid_op,
var_x, var_x,
# attrs # attrs
pool_size, pool_size,
...@@ -444,23 +476,25 @@ def _pool(prog, pool_type, inputs, outputs, attrs, value_infos, ...@@ -444,23 +476,25 @@ def _pool(prog, pool_type, inputs, outputs, attrs, value_infos,
prog.VarDesc(var_y) prog.VarDesc(var_y)
if has_indices: if has_indices:
prog.VarDesc(var_indices) prog.VarDesc(var_indices)
prog.OpDesc(paddle_op, prog.OpDesc(
fluid_op,
([var_x], 'X'), ([var_x], 'X'),
([var_y] + ([var_indices] if has_indices else []), 'Out', 'Indices'), ([var_y] + ([var_indices] if has_indices else []), 'Out', 'Indices'),
dict(global_pooling=False, dict(
global_pooling=False,
adaptive=False, adaptive=False,
exclusive=True, exclusive=True,
require_index=has_indices, require_index=has_indices,
pooling_type=pool_type, pooling_type=pool_type,
ksize=pool_size, ksize=pool_size,
strides=strides, strides=strides,
pool_padding=paddings, paddings=paddings,
ceil_mode=ceil_mode, ceil_mode=ceil_mode,
), ),
) )
def _roi_pool(prog, paddle_op, inputs, outputs, attrs, value_infos, name): def _roi_pool(prog, fluid_op, inputs, outputs, attrs, value_infos, name):
# I/O # I/O
val_x, val_rois = inputs val_x, val_rois = inputs
val_y, = outputs val_y, = outputs
...@@ -469,7 +503,7 @@ def _roi_pool(prog, paddle_op, inputs, outputs, attrs, value_infos, name): ...@@ -469,7 +503,7 @@ def _roi_pool(prog, paddle_op, inputs, outputs, attrs, value_infos, name):
var_y = _make_var_name(val_y) var_y = _make_var_name(val_y)
# interpretation # interpretation
spatial_scale=attrs['spatial_scale'] # required spatial_scale = attrs['spatial_scale'] # required
pooled_height, pooled_width = attrs['pooled_shape'] # required pooled_height, pooled_width = attrs['pooled_shape'] # required
od_attrs = dict( od_attrs = dict(
spatial_scale=spatial_scale, spatial_scale=spatial_scale,
...@@ -477,7 +511,7 @@ def _roi_pool(prog, paddle_op, inputs, outputs, attrs, value_infos, name): ...@@ -477,7 +511,7 @@ def _roi_pool(prog, paddle_op, inputs, outputs, attrs, value_infos, name):
pooled_width=pooled_width, pooled_width=pooled_width,
) )
feature_attr = '' feature_attr = ''
is_max_pool = paddle_op == 'roi_pool' is_max_pool = fluid_op == 'roi_pool'
if 'sampling_ratio' in attrs: if 'sampling_ratio' in attrs:
sampling_ratio = attrs['sampling_ratio'] sampling_ratio = attrs['sampling_ratio']
od_attrs['sampling_ratio'] = sampling_ratio od_attrs['sampling_ratio'] = sampling_ratio
...@@ -492,10 +526,11 @@ def _roi_pool(prog, paddle_op, inputs, outputs, attrs, value_infos, name): ...@@ -492,10 +526,11 @@ def _roi_pool(prog, paddle_op, inputs, outputs, attrs, value_infos, name):
', spatial_scale={}' ', spatial_scale={}'
', pooled_height={}' ', pooled_height={}'
', pooled_width={}' ', pooled_width={}'
'{})' '{})'.format(
.format(var_y, var_y,
paddle_op, fluid_op,
val_x, var_rois, val_x,
var_rois,
# attrs # attrs
spatial_scale, spatial_scale,
pooled_height, pooled_height,
...@@ -506,7 +541,8 @@ def _roi_pool(prog, paddle_op, inputs, outputs, attrs, value_infos, name): ...@@ -506,7 +541,8 @@ def _roi_pool(prog, paddle_op, inputs, outputs, attrs, value_infos, name):
if is_max_pool: if is_max_pool:
var_argmax = _make_var_name(name + '.argmax') # implicit variable var_argmax = _make_var_name(name + '.argmax') # implicit variable
prog.VarDesc(var_argmax) prog.VarDesc(var_argmax)
prog.OpDesc(paddle_op, prog.OpDesc(
fluid_op,
([var_x, var_rois], 'X', 'Rois'), ([var_x, var_rois], 'X', 'Rois'),
([var_y] + ([var_argmax] if is_max_pool else []), 'Out', 'Argmax'), ([var_y] + ([var_argmax] if is_max_pool else []), 'Out', 'Argmax'),
od_attrs, od_attrs,
...@@ -514,7 +550,9 @@ def _roi_pool(prog, paddle_op, inputs, outputs, attrs, value_infos, name): ...@@ -514,7 +550,9 @@ def _roi_pool(prog, paddle_op, inputs, outputs, attrs, value_infos, name):
def _zeros_like(prog, val_ref, val_out, value_infos): def _zeros_like(prog, val_ref, val_out, value_infos):
prog.Op('', 'Sub', prog.Op(
'',
'Sub',
[val_ref, val_ref], [val_ref, val_ref],
[val_out], # val [val_out], # val
dict(axis=0), dict(axis=0),
...@@ -522,47 +560,54 @@ def _zeros_like(prog, val_ref, val_out, value_infos): ...@@ -522,47 +560,54 @@ def _zeros_like(prog, val_ref, val_out, value_infos):
) )
def AdaptiveAveragePool( def AdaptiveAveragePool(prog,
prog, inputs, outputs, attrs, value_infos, inputs,
outputs,
attrs,
value_infos,
name='', name='',
*args, **kwargs): *args,
**kwargs):
""" """
aten::adaptive_avg_poolnd aten::adaptive_avg_poolnd
""" """
return _adaptive_pool(prog, 'avg', inputs, outputs, attrs, value_infos, return _adaptive_pool(
name=name) prog, 'avg', inputs, outputs, attrs, value_infos, name=name)
def AdaptiveMaxPool( def AdaptiveMaxPool(prog,
prog, inputs, outputs, attrs, value_infos, inputs,
outputs,
attrs,
value_infos,
name='', name='',
*args, **kwargs): *args,
**kwargs):
""" """
aten::adaptive_max_poolnd aten::adaptive_max_poolnd
""" """
return _adaptive_pool(prog, 'max', inputs, outputs, attrs, value_infos, return _adaptive_pool(
name=name) prog, 'max', inputs, outputs, attrs, value_infos, name=name)
def AveragePool( def AveragePool(prog,
prog, inputs, outputs, attrs, value_infos, inputs,
outputs,
attrs,
value_infos,
name='', name='',
*args, **kwargs): *args,
**kwargs):
""" """
onnx::AveragePool-10: onnx::AveragePool-10:
""" """
return _pool(prog, 'avg', inputs, outputs, attrs, value_infos, return _pool(prog, 'avg', inputs, outputs, attrs, value_infos, name=name)
name=name)
def AffineGrid( def AffineGrid(prog, inputs, outputs, attrs, *args, name='', **kwargs):
prog, inputs, outputs, attrs,
*args,
name='',
**kwargs):
""" """
aten::affine_grid aten::affine_grid
""" """
...@@ -574,33 +619,39 @@ def AffineGrid( ...@@ -574,33 +619,39 @@ def AffineGrid(
var_grid = _make_var_name(val_grid) var_grid = _make_var_name(val_grid)
# interpretation # interpretation
paddle_op = 'affine_grid' fluid_op = 'affine_grid'
size = attrs['size'] # required size = attrs['size'] # required
name_attr = ', name={}'.format(repr(name)) if name else '' name_attr = ', name={}'.format(repr(name)) if name else ''
# generation # generation
prog.Code('{} = layers.{}({}' prog.Code('{} = layers.{}({}'
', out_shape={}' ', out_shape={}'
'{})' '{})'.format(
.format(var_grid, var_grid,
paddle_op, fluid_op,
var_theta, var_theta,
# attrs # attrs
size, size,
name_attr, name_attr,
)) ))
prog.VarDesc(var_grid) prog.VarDesc(var_grid)
prog.OpDesc(paddle_op, prog.OpDesc(
fluid_op,
([var_theta], 'Theta'), ([var_theta], 'Theta'),
([var_grid], 'Output'), ([var_grid], 'Output'),
dict(output_shape=size), # f**k you API dict(output_shape=size), # f**k you API
) )
def BatchNormalization( def BatchNormalization(prog,
prog, inputs, outputs, attrs, value_infos, inputs,
name='', embed_params=False, outputs,
*args, **kwargs): attrs,
value_infos,
name='',
embed_params=False,
*args,
**kwargs):
""" """
onnx::BatchNormalization-9: onnx::BatchNormalization-9:
""" """
...@@ -612,7 +663,7 @@ def BatchNormalization( ...@@ -612,7 +663,7 @@ def BatchNormalization(
var_y = _make_var_name(val_y) var_y = _make_var_name(val_y)
# interpretation # interpretation
paddle_op = 'batch_norm' fluid_op = 'batch_norm'
momentum = attrs.get('momentum', .9) # optional momentum = attrs.get('momentum', .9) # optional
epsilon = attrs.get('epsilon', 1e-5) # optional epsilon = attrs.get('epsilon', 1e-5) # optional
name_attr = ', name={}'.format(repr(name)) if name else '' name_attr = ', name={}'.format(repr(name)) if name else ''
...@@ -633,8 +684,9 @@ def BatchNormalization( ...@@ -633,8 +684,9 @@ def BatchNormalization(
var_mean = _make_var_name(val_mean) var_mean = _make_var_name(val_mean)
var_var = _make_var_name(val_var) var_var = _make_var_name(val_var)
param_attr = (', param_attr={}, bias_attr={}' param_attr = (', param_attr={}, bias_attr={}'
', moving_mean_name={}, moving_variance_name={}' ', moving_mean_name={}, moving_variance_name={}').format(
).format(repr(var_scale), repr(var_b), repr(var_mean), repr(var_var)) repr(var_scale), repr(var_b), repr(var_mean),
repr(var_var))
var_saved_mean = '{}.saved_mean'.format(name) # dropped var var_saved_mean = '{}.saved_mean'.format(name) # dropped var
var_saved_variance = '{}.saved_variance'.format(name) # dropped var var_saved_variance = '{}.saved_variance'.format(name) # dropped var
...@@ -642,24 +694,27 @@ def BatchNormalization( ...@@ -642,24 +694,27 @@ def BatchNormalization(
prog.Code('{} = layers.{}({}, is_test=True, data_layout="NCHW"' prog.Code('{} = layers.{}({}, is_test=True, data_layout="NCHW"'
', momentum={}' ', momentum={}'
', epsilon={}' ', epsilon={}'
'{}{})' '{}{})'.format(
.format(var_y, var_y,
paddle_op, fluid_op,
var_x, var_x,
# attrs # attrs
momentum, momentum,
epsilon, epsilon,
param_attr, name_attr, param_attr,
name_attr,
)) ))
prog.VarDesc(var_y) prog.VarDesc(var_y)
prog.VarDesc(var_saved_mean) prog.VarDesc(var_saved_mean)
prog.VarDesc(var_saved_variance) prog.VarDesc(var_saved_variance)
prog.OpDesc(paddle_op, prog.OpDesc(
([var_x, var_scale, var_b, var_mean, var_var], fluid_op,
'X', 'Scale', 'Bias', 'Mean', 'Variance'), ([var_x, var_scale, var_b, var_mean, var_var], 'X', 'Scale', 'Bias',
([var_y, var_mean, var_saved_mean, var_saved_variance, var_var], 'Mean', 'Variance'),
'Y', 'MeanOut', 'SavedMean', 'SavedVariance', 'VarianceOut'), ([var_y, var_mean, var_saved_mean, var_saved_variance, var_var], 'Y',
dict(is_test=1, 'MeanOut', 'SavedMean', 'SavedVariance', 'VarianceOut'),
dict(
is_test=1,
data_layout='NCHW', data_layout='NCHW',
use_global_stats=False, use_global_stats=False,
momentum=momentum, momentum=momentum,
...@@ -667,9 +722,7 @@ def BatchNormalization( ...@@ -667,9 +722,7 @@ def BatchNormalization(
) )
def Cast( def Cast(prog, inputs, outputs, attrs, value_infos, *args, **kwargs):
prog, inputs, outputs, attrs, value_infos,
*args, **kwargs):
""" """
onnx::Cast-9: onnx::Cast-9:
""" """
...@@ -688,33 +741,31 @@ def Cast( ...@@ -688,33 +741,31 @@ def Cast(
if output_dtype: if output_dtype:
assert dtype == output_dtype, 'dtype of to unmatches output' assert dtype == output_dtype, 'dtype of to unmatches output'
paddle_op = 'cast' fluid_op = 'cast'
# generation # generation
prog.Code('{} = layers.{}({}' prog.Code('{} = layers.{}({}'
', dtype={}' ', dtype={}'
')' ')'.format(
.format(var_output, var_output,
paddle_op, fluid_op,
var_input, var_input,
# attrs # attrs
repr(dtype.name), repr(dtype.name),
)) ))
prog.VarDesc(var_output) prog.VarDesc(var_output)
prog.OpDesc(paddle_op, prog.OpDesc(
fluid_op,
([var_input], 'X'), ([var_input], 'X'),
([var_output], 'Out'), ([var_output], 'Out'),
dict(in_dtype=prog.Dtype(_dtype(value_infos, val_input)), # holy, required dict(
in_dtype=prog.Dtype(_dtype(value_infos,
val_input)), # holy, required
out_dtype=prog.Dtype(dtype), out_dtype=prog.Dtype(dtype),
) ))
)
def Concat( def Concat(prog, inputs, outputs, attrs, *args, name='', **kwargs):
prog, inputs, outputs, attrs,
*args,
name='',
**kwargs):
""" """
onnx::Concat-4: onnx::Concat-4:
""" """
...@@ -725,32 +776,31 @@ def Concat( ...@@ -725,32 +776,31 @@ def Concat(
var_concat_result = _make_var_name(val_concat_result) var_concat_result = _make_var_name(val_concat_result)
# interpretation # interpretation
paddle_op = 'concat' fluid_op = 'concat'
axis = attrs['axis'] # required axis = attrs['axis'] # required
name_attr = ', name={}'.format(repr(name)) if name else '' name_attr = ', name={}'.format(repr(name)) if name else ''
# generation # generation
prog.Code('{} = layers.{}({}' prog.Code('{} = layers.{}({}'
', axis={}' ', axis={}'
'{})' '{})'.format(
.format(var_concat_result, var_concat_result,
paddle_op, fluid_op,
'[' + ', '.join(var_inps) + ']', '[' + ', '.join(var_inps) + ']',
# attrs # attrs
axis, axis,
name_attr, name_attr,
)) ))
prog.VarDesc(var_concat_result) prog.VarDesc(var_concat_result)
prog.OpDesc(paddle_op, prog.OpDesc(
fluid_op,
(var_inps, *(['X'] * len(var_inps))), (var_inps, *(['X'] * len(var_inps))),
([var_concat_result], 'Out'), ([var_concat_result], 'Out'),
dict(axis=axis), dict(axis=axis),
) )
def Constant( def Constant(prog, inputs, outputs, attrs, value_infos, *args, **kwargs):
prog, inputs, outputs, attrs, value_infos,
*args, **kwargs):
""" """
onnx::Constant-9: onnx::Constant-9:
""" """
...@@ -766,36 +816,40 @@ def Constant( ...@@ -766,36 +816,40 @@ def Constant(
output_dtype = _dtype_or_none(value_infos, val_output) output_dtype = _dtype_or_none(value_infos, val_output)
if output_dtype: if output_dtype:
assert dtype == output_dtype, 'tensor dtype unmatches storage dtype' assert dtype == output_dtype, 'tensor dtype unmatches storage dtype'
# dtype = np.dtype('float32') # force to float32 # dtype = np.dtype('float32') # force to float32
shape = attrs.get('shape', None) # additional, maybe var_name shape = attrs.get('shape', None) # additional, maybe var_name
if shape is None: if shape is None:
shape = _shape_or_none(value_infos, val_output) shape = _shape_or_none(value_infos, val_output)
if shape is None: if shape is None:
shape = list(value.shape) shape = list(value.shape)
_logger.warning('shape of %s not inferred, using value as 1-D tensor may lead to fails', val_output) _logger.warning(
'shape of %s not inferred, using value as 1-D tensor may lead to fails',
val_output)
# generation # generation
if value.size == 1: # scalar if value.size == 1: # scalar
paddle_op = 'fill_constant' fluid_op = 'fill_constant'
prog.Code('{} = layers.{}(shape={}, dtype={}, value={})' prog.Code('{} = layers.{}(shape={}, dtype={}, value={})'.format(
.format(var_output, var_output,
paddle_op, fluid_op,
# attrs # attrs
shape, repr(dtype.name), value[0], # shape can be list or var_name shape,
repr(dtype.name),
value[0], # shape can be list or var_name
)) ))
value_infos[val_output]['const_value'] = value[0] value_infos[val_output]['const_value'] = value[0]
prog.VarDesc(var_output) prog.VarDesc(var_output)
else: # list parameter -> const_value else: # list parameter -> const_value
prog.Code('{} = {}' prog.Code('{} = {}'.format(
.format(var_output, var_output,
value.tolist(), value.tolist(),
)) ))
value_infos[val_output]['const_value'] = value.tolist() value_infos[val_output]['const_value'] = value.tolist()
def ConstantOfShape( def ConstantOfShape(prog, inputs, outputs, attrs, value_infos, *args, **kwargs):
prog, inputs, outputs, attrs, value_infos,
*args, **kwargs):
""" """
onnx::ConstantOfShape-9: onnx::ConstantOfShape-9:
""" """
...@@ -815,10 +869,15 @@ def ConstantOfShape( ...@@ -815,10 +869,15 @@ def ConstantOfShape(
Constant(prog, [], outputs, attrs, value_infos) Constant(prog, [], outputs, attrs, value_infos)
def Conv( def Conv(prog,
prog, inputs, outputs, attrs, value_infos, inputs,
name='', embed_params=False, outputs,
*args, **kwargs): attrs,
value_infos,
name='',
embed_params=False,
*args,
**kwargs):
""" """
onnx::ConstantOfShape-1: onnx::ConstantOfShape-1:
""" """
...@@ -833,14 +892,17 @@ def Conv( ...@@ -833,14 +892,17 @@ def Conv(
val_b, = inputs[2:] val_b, = inputs[2:]
# interpretation # interpretation
assert attrs.get('auto_pad', 'NOTSET') == 'NOTSET', 'only auto_pad == NOTSET supported' # optional assert attrs.get(
'auto_pad',
'NOTSET') == 'NOTSET', 'only auto_pad == NOTSET supported' # optional
kernel_shape = _shape(value_infos, val_w)[2:] # OI... kernel_shape = _shape(value_infos, val_w)[2:] # OI...
assert kernel_shape == attrs['kernel_shape'], 'kernel_shape in attr unmatches value_info' # HW assert kernel_shape == attrs[
'kernel_shape'], 'kernel_shape in attr unmatches value_info' # HW
convnd = len(kernel_shape) convnd = len(kernel_shape)
assert 2 <= convnd <= 3, 'only conv2d and conv3d supported' assert 2 <= convnd <= 3, 'only conv2d and conv3d supported'
num_out_channels = _shape(value_infos, val_w)[0] # OI... num_out_channels = _shape(value_infos, val_w)[0] # OI...
paddle_op = 'conv{}d'.format(convnd) fluid_op = 'conv{}d'.format(convnd)
strides = attrs.get('strides', [1] * convnd) # optional strides = attrs.get('strides', [1] * convnd) # optional
pads = attrs.get('pads', [0] * convnd * 2) # optional pads = attrs.get('pads', [0] * convnd * 2) # optional
paddings, val_x_padded = _pad_if_asymmetric(prog, pads, val_x, value_infos) paddings, val_x_padded = _pad_if_asymmetric(prog, pads, val_x, value_infos)
...@@ -864,7 +926,8 @@ def Conv( ...@@ -864,7 +926,8 @@ def Conv(
var_w = _make_var_name(val_w) var_w = _make_var_name(val_w)
var_b = _make_var_name(val_b) if has_bias else False var_b = _make_var_name(val_b) if has_bias else False
param_attr = ', param_attr={}, bias_attr={}'.format( param_attr = ', param_attr={}, bias_attr={}'.format(
repr(var_w), repr(var_b) if var_b else False) repr(var_w),
repr(var_b) if var_b else False)
# generation # generation
prog.Code('{} = layers.{}({}' prog.Code('{} = layers.{}({}'
...@@ -874,9 +937,9 @@ def Conv( ...@@ -874,9 +937,9 @@ def Conv(
', padding={}' ', padding={}'
', dilation={}' ', dilation={}'
', groups={}' ', groups={}'
'{}{})' '{}{})'.format(
.format(var_y, var_y,
paddle_op, fluid_op,
var_x, var_x,
# attrs # attrs
num_out_channels, num_out_channels,
...@@ -885,13 +948,16 @@ def Conv( ...@@ -885,13 +948,16 @@ def Conv(
paddings, paddings,
dilations, dilations,
num_groups, num_groups,
param_attr, name_attr, param_attr,
name_attr,
)) ))
var_conv = _make_var_name(name + '.conv') # hidden variable var_conv = _make_var_name(name + '.conv') # hidden variable
prog.OpDesc(paddle_op, prog.OpDesc(
fluid_op,
([var_x, var_w], 'Input', 'Filter'), # , 'Bias', 'ResidualData' ([var_x, var_w], 'Input', 'Filter'), # , 'Bias', 'ResidualData'
([var_conv if has_bias else var_y], 'Output'), ([var_conv if has_bias else var_y], 'Output'),
dict(strides=strides, dict(
strides=strides,
paddings=paddings, paddings=paddings,
dilations=dilations, dilations=dilations,
groups=num_groups, groups=num_groups,
...@@ -899,7 +965,8 @@ def Conv( ...@@ -899,7 +965,8 @@ def Conv(
if has_bias: if has_bias:
prog.VarDesc(var_conv) prog.VarDesc(var_conv)
prog.IntermediateOp( prog.IntermediateOp(
'', 'Add', '',
'Add',
[var_conv, var_b], [var_conv, var_b],
[var_y], # var [var_y], # var
dict(axis=1), dict(axis=1),
...@@ -910,10 +977,15 @@ def Conv( ...@@ -910,10 +977,15 @@ def Conv(
prog.VarDesc(var_y) prog.VarDesc(var_y)
def ConvTranspose( def ConvTranspose(prog,
prog, inputs, outputs, attrs, value_infos, inputs,
name='', embed_params=False, outputs,
*args, **kwargs): attrs,
value_infos,
name='',
embed_params=False,
*args,
**kwargs):
""" """
onnx::ConvTranspose-1: onnx::ConvTranspose-1:
""" """
...@@ -928,15 +1000,20 @@ def ConvTranspose( ...@@ -928,15 +1000,20 @@ def ConvTranspose(
val_b, = inputs[2:] val_b, = inputs[2:]
# interpretation # interpretation
assert attrs.get('auto_pad', 'NOTSET') == 'NOTSET', 'only auto_pad == NOTSET supported' # optional assert attrs.get(
assert sum(attrs.get('output_padding', [])) == 0, 'only zero output_padding supported' # optional ? 'auto_pad',
'NOTSET') == 'NOTSET', 'only auto_pad == NOTSET supported' # optional
assert sum(
attrs.get('output_padding',
[])) == 0, 'only zero output_padding supported' # optional ?
kernel_shape = _shape(value_infos, val_w)[2:] # IO... kernel_shape = _shape(value_infos, val_w)[2:] # IO...
assert kernel_shape == attrs['kernel_shape'], 'kernel_shape in attr unmatches value_info' # HW assert kernel_shape == attrs[
'kernel_shape'], 'kernel_shape in attr unmatches value_info' # HW
convnd = len(kernel_shape) convnd = len(kernel_shape)
assert 2 <= convnd <= 3, 'only conv2d_transpose and conv3d_transpose supported' assert 2 <= convnd <= 3, 'only conv2d_transpose and conv3d_transpose supported'
num_out_channels = _shape(value_infos, val_w)[1] # IO... num_out_channels = _shape(value_infos, val_w)[1] # IO...
paddle_op = 'conv{}d_transpose'.format(convnd) fluid_op = 'conv{}d_transpose'.format(convnd)
strides = attrs.get('strides', [1] * convnd) # optional strides = attrs.get('strides', [1] * convnd) # optional
pads = attrs.get('pads', [0] * convnd * 2) # optional pads = attrs.get('pads', [0] * convnd * 2) # optional
paddings, val_x_padded = _pad_if_asymmetric(prog, pads, val_x, value_infos) paddings, val_x_padded = _pad_if_asymmetric(prog, pads, val_x, value_infos)
...@@ -960,20 +1037,21 @@ def ConvTranspose( ...@@ -960,20 +1037,21 @@ def ConvTranspose(
var_w = _make_var_name(val_w) var_w = _make_var_name(val_w)
var_b = _make_var_name(val_b) if has_bias else False var_b = _make_var_name(val_b) if has_bias else False
param_attr = ', param_attr={}, bias_attr={}'.format( param_attr = ', param_attr={}, bias_attr={}'.format(
repr(var_w), repr(var_b) if var_b else False) repr(var_w),
repr(var_b) if var_b else False)
# generation # generation
prog.Code('{} = layers.{}({}' prog.Code('{} = layers.{}({}'
', num_filters={}' ', num_filters={}'
# ', output_size={}' # ', output_size={}'
', filter_size={}' ', filter_size={}'
', padding={}' ', padding={}'
', stride={}' ', stride={}'
', dilation={}' ', dilation={}'
', groups={}' ', groups={}'
'{}{})' '{}{})'.format(
.format(var_y, var_y,
paddle_op, fluid_op,
var_x, var_x,
# attrs # attrs
num_out_channels, num_out_channels,
...@@ -982,13 +1060,16 @@ def ConvTranspose( ...@@ -982,13 +1060,16 @@ def ConvTranspose(
strides, strides,
dilations, dilations,
num_groups, num_groups,
param_attr, name_attr, param_attr,
name_attr,
)) ))
var_conv = _make_var_name(name + '.conv') # hidden variable var_conv = _make_var_name(name + '.conv') # hidden variable
prog.OpDesc(paddle_op, prog.OpDesc(
fluid_op,
([var_x, var_w], 'Input', 'Filter'), # , 'Bias', 'ResidualData' ([var_x, var_w], 'Input', 'Filter'), # , 'Bias', 'ResidualData'
([var_conv if has_bias else var_y], 'Output'), ([var_conv if has_bias else var_y], 'Output'),
dict(strides=strides, dict(
strides=strides,
paddings=paddings, paddings=paddings,
dilations=dilations, dilations=dilations,
# output_size=output_size, # output_size=output_size,
...@@ -997,7 +1078,8 @@ def ConvTranspose( ...@@ -997,7 +1078,8 @@ def ConvTranspose(
if has_bias: if has_bias:
prog.VarDesc(var_conv) prog.VarDesc(var_conv)
prog.IntermediateOp( prog.IntermediateOp(
'', 'Add', '',
'Add',
[var_conv, var_b], [var_conv, var_b],
[var_y], # var [var_y], # var
dict(axis=1), dict(axis=1),
...@@ -1025,14 +1107,12 @@ def ConvTranspose( ...@@ -1025,14 +1107,12 @@ def ConvTranspose(
# ) # )
def Gemm( def Gemm(prog, inputs, outputs, attrs, value_infos, name, *args, **kwargs):
prog, inputs, outputs, attrs, value_infos, name,
*args, **kwargs):
""" """
onnx::Gemm-9: onnx::Gemm-9:
""" """
# due to paddle fc don't support transposed weight, we use matmul + ew_add # due to fluid fc don't support transposed weight, we use matmul + ew_add
val_a, val_b, val_c = inputs val_a, val_b, val_c = inputs
val_y, = outputs val_y, = outputs
...@@ -1042,23 +1122,29 @@ def Gemm( ...@@ -1042,23 +1122,29 @@ def Gemm(
trans_b = bool(attrs.get('transB', 0)) # optional trans_b = bool(attrs.get('transB', 0)) # optional
val_mm = name + '_mm' # explicit variable val_mm = name + '_mm' # explicit variable
prog.Op('', 'MatMul', prog.Op(
'',
'MatMul',
[val_a, val_b], [val_a, val_b],
[val_mm], # val [val_mm], # val
dict(transpose_x=trans_a, dict(
transpose_x=trans_a,
transpose_y=trans_b, transpose_y=trans_b,
alpha=alpha, alpha=alpha,
), ),
value_infos=value_infos, value_infos=value_infos,
name=val_mm, name=val_mm,
) )
prog.op_descs[-1].attrs.extend(prog.OpDescAttrs(dict( prog.op_descs[-1].attrs.extend(
prog.OpDescAttrs(dict(
transpose_X=trans_a, transpose_X=trans_a,
transpose_Y=trans_b, transpose_Y=trans_b,
))) # f**k you API ))) # f**k you API
if beta != 0: if beta != 0:
if beta == 1.: # exactly if beta == 1.: # exactly
prog.Op('', 'Add', prog.Op(
'',
'Add',
[val_mm, val_c], [val_mm, val_c],
[val_y], # val [val_y], # val
dict(axis=1), dict(axis=1),
...@@ -1072,21 +1158,27 @@ def Gemm( ...@@ -1072,21 +1158,27 @@ def Gemm(
if vm_dtype is None: if vm_dtype is None:
vm_dtype = np.dtype('float32') vm_dtype = np.dtype('float32')
beta = np.dtype(vm_dtype).type(beta) beta = np.dtype(vm_dtype).type(beta)
prog.Op('', 'Constant', prog.Op(
'',
'Constant',
[], [],
[val_beta], # val [val_beta], # val
dict(value=beta), dict(value=beta),
value_infos=value_infos, value_infos=value_infos,
name=val_beta, name=val_beta,
) )
prog.Op('', 'Mul', prog.Op(
'',
'Mul',
[val_c, val_beta], [val_c, val_beta],
[val_vm], # val [val_vm], # val
dict(), dict(),
value_infos=value_infos, value_infos=value_infos,
name=(name + '_scale'), name=(name + '_scale'),
) )
prog.Op('', 'Add', prog.Op(
'',
'Add',
[val_mm, val_vm], [val_mm, val_vm],
[val_y], # val [val_y], # val
dict(axis=1), dict(axis=1),
...@@ -1094,28 +1186,36 @@ def Gemm( ...@@ -1094,28 +1186,36 @@ def Gemm(
) )
def GlobalAveragePool( def GlobalAveragePool(prog,
prog, inputs, outputs, attrs, value_infos, inputs,
outputs,
attrs,
value_infos,
name='', name='',
*args, **kwargs): *args,
**kwargs):
""" """
onnx::GlobalAveragePool-1: onnx::GlobalAveragePool-1:
""" """
return _global_pool(prog, 'avg', inputs, outputs, attrs, value_infos, return _global_pool(
name=name) prog, 'avg', inputs, outputs, attrs, value_infos, name=name)
def GlobalMaxPool( def GlobalMaxPool(prog,
prog, inputs, outputs, attrs, value_infos, inputs,
outputs,
attrs,
value_infos,
name='', name='',
*args, **kwargs): *args,
**kwargs):
""" """
onnx::GlobalMaxPool-1: onnx::GlobalMaxPool-1:
""" """
return _global_pool(prog, 'max', inputs, outputs, attrs, value_infos, return _global_pool(
name=name) prog, 'max', inputs, outputs, attrs, value_infos, name=name)
#def LRN( #def LRN(
...@@ -1132,7 +1232,7 @@ def GlobalMaxPool( ...@@ -1132,7 +1232,7 @@ def GlobalMaxPool(
# var_y = _make_var_name(val_y) # var_y = _make_var_name(val_y)
# #
# # interpretation # # interpretation
# paddle_op = 'lrn' # fluid_op = 'lrn'
# size = attrs['size'] # required # size = attrs['size'] # required
# alpha = attrs.get('alpha', 0.0001) # optional # alpha = attrs.get('alpha', 0.0001) # optional
# beta = attrs.get('beta', 0.75) # optional # beta = attrs.get('beta', 0.75) # optional
...@@ -1147,7 +1247,7 @@ def GlobalMaxPool( ...@@ -1147,7 +1247,7 @@ def GlobalMaxPool(
# ', beta={}' # ', beta={}'
# '{})' # '{})'
# .format(var_y, # .format(var_y,
# paddle_op, # fluid_op,
# var_x, # var_x,
# # attrs # # attrs
# size, # size,
...@@ -1159,7 +1259,7 @@ def GlobalMaxPool( ...@@ -1159,7 +1259,7 @@ def GlobalMaxPool(
# var_mid = name + '.mid' # hidden variable # var_mid = name + '.mid' # hidden variable
# prog.VarDesc(var_y) # prog.VarDesc(var_y)
# prog.VarDesc(var_mid) # prog.VarDesc(var_mid)
# prog.OpDesc(paddle_op, # prog.OpDesc(fluid_op,
# ([var_x], 'X'), # ([var_x], 'X'),
# ([var_y, var_mid], 'Out', 'MidOut'), # ([var_y, var_mid], 'Out', 'MidOut'),
# dict(n=size, # dict(n=size,
...@@ -1170,21 +1270,17 @@ def GlobalMaxPool( ...@@ -1170,21 +1270,17 @@ def GlobalMaxPool(
# ) # )
def MaxPool( def MaxPool(prog, inputs, outputs, attrs, value_infos, name='', *args,
prog, inputs, outputs, attrs, value_infos, **kwargs):
name='',
*args, **kwargs):
""" """
onnx::MaxPool-10: onnx::MaxPool-10:
""" """
return _pool(prog, 'max', inputs, outputs, attrs, value_infos, return _pool(prog, 'max', inputs, outputs, attrs, value_infos, name=name)
name=name)
def MaxRoiPool( def MaxRoiPool(prog, inputs, outputs, attrs, value_infos, name, *args,
prog, inputs, outputs, attrs, value_infos, name, **kwargs):
*args, **kwargs):
""" """
onnx::MaxRoiPool-1: onnx::MaxRoiPool-1:
""" """
...@@ -1192,9 +1288,7 @@ def MaxRoiPool( ...@@ -1192,9 +1288,7 @@ def MaxRoiPool(
_roi_pool(prog, 'roi_pool', inputs, outputs, attrs, value_infos, name) _roi_pool(prog, 'roi_pool', inputs, outputs, attrs, value_infos, name)
def RoiAlign( def RoiAlign(prog, inputs, outputs, attrs, value_infos, name, *args, **kwargs):
prog, inputs, outputs, attrs, value_infos, name,
*args, **kwargs):
""" """
caffe2::RoiAlign caffe2::RoiAlign
""" """
...@@ -1202,10 +1296,7 @@ def RoiAlign( ...@@ -1202,10 +1296,7 @@ def RoiAlign(
_roi_pool(prog, 'roi_align', inputs, outputs, attrs, value_infos, name) _roi_pool(prog, 'roi_align', inputs, outputs, attrs, value_infos, name)
def Pad( def Pad(prog, inputs, outputs, attrs, value_infos, name='', *args, **kwargs):
prog, inputs, outputs, attrs, value_infos,
name='',
*args, **kwargs):
""" """
onnx::Pad-2: onnx::Pad-2:
""" """
...@@ -1231,14 +1322,17 @@ def Pad( ...@@ -1231,14 +1322,17 @@ def Pad(
assume_pad2d |= output_shape and len(output_shape) == 4 # NCHW assume_pad2d |= output_shape and len(output_shape) == 4 # NCHW
od_attrs = dict(pad_value=value) od_attrs = dict(pad_value=value)
if assume_pad2d: if assume_pad2d:
paddle_op = 'pad2d' fluid_op = 'pad2d'
pad2d_attr = ', mode={}, data_format="NCHW"'.format(repr(mode)) pad2d_attr = ', mode={}, data_format="NCHW"'.format(repr(mode))
od_attrs['mode'] = mode od_attrs['mode'] = mode
od_attrs['data_format'] = "NCHW"
else: else:
assert mode == 'constant', 'mode {} is supported only in pad2d'.format(mode) assert mode == 'constant', 'mode {} is supported only in pad2d'.format(
paddle_op = 'pad' mode)
fluid_op = 'pad'
pad2d_attr = '' pad2d_attr = ''
paddings = np.array(pads).reshape((-1, 2)).transpose().flatten().tolist() # SSEE -> SESE paddings = np.array(pads).reshape(
(-1, 2)).transpose().flatten().tolist() # SSEE -> SESE
od_attrs['paddings'] = paddings od_attrs['paddings'] = paddings
name_attr = ', name={}'.format(repr(name)) if name else '' name_attr = ', name={}'.format(repr(name)) if name else ''
...@@ -1246,27 +1340,34 @@ def Pad( ...@@ -1246,27 +1340,34 @@ def Pad(
prog.Code('{} = layers.{}({}' prog.Code('{} = layers.{}({}'
', paddings={}' ', paddings={}'
', pad_value={}' ', pad_value={}'
'{}{})' '{}{})'.format(
.format(var_output, var_output,
paddle_op, fluid_op,
var_data, var_data,
# attrs # attrs
paddings, paddings,
value, value,
pad2d_attr, name_attr, pad2d_attr,
name_attr,
)) ))
prog.VarDesc(var_output) prog.VarDesc(var_output)
prog.OpDesc(paddle_op, prog.OpDesc(
fluid_op,
([var_data], 'X'), ([var_data], 'X'),
([var_output], 'Out'), ([var_output], 'Out'),
od_attrs, od_attrs,
) )
def PRelu( def PRelu(prog,
prog, inputs, outputs, attrs, value_infos, inputs,
name='', embed_params=False, outputs,
*args, **kwargs): attrs,
value_infos,
name='',
embed_params=False,
*args,
**kwargs):
""" """
onnx::PRelu-9: onnx::PRelu-9:
""" """
...@@ -1278,7 +1379,7 @@ def PRelu( ...@@ -1278,7 +1379,7 @@ def PRelu(
var_y = _make_var_name(val_y) var_y = _make_var_name(val_y)
# interpretation # interpretation
paddle_op = 'prelu' fluid_op = 'prelu'
name_attr = ', name={}'.format(repr(name)) if name else '' name_attr = ', name={}'.format(repr(name)) if name else ''
if embed_params: if embed_params:
assert name != '' assert name != ''
...@@ -1291,24 +1392,24 @@ def PRelu( ...@@ -1291,24 +1392,24 @@ def PRelu(
# generation # generation
prog.Code('{} = layers.{}({}, mode="all"' prog.Code('{} = layers.{}({}, mode="all"'
'{}{})' '{}{})'.format(
.format(var_y, var_y,
paddle_op, fluid_op,
var_x, var_x,
# attrs # attrs
param_attr, name_attr, param_attr,
name_attr,
)) ))
prog.VarDesc(var_y) prog.VarDesc(var_y)
prog.OpDesc(paddle_op, prog.OpDesc(
fluid_op,
([var_x], 'X'), ([var_x], 'X'),
([var_y], 'Out'), ([var_y], 'Out'),
dict(mode='all'), dict(mode='all'),
) )
def PsRoiPool( def PsRoiPool(prog, inputs, outputs, attrs, value_infos, name, *args, **kwargs):
prog, inputs, outputs, attrs, value_infos, name,
*args, **kwargs):
""" """
caffe2::PsRoiPool caffe2::PsRoiPool
""" """
...@@ -1316,9 +1417,7 @@ def PsRoiPool( ...@@ -1316,9 +1417,7 @@ def PsRoiPool(
_roi_pool(prog, 'psroi_pool', inputs, outputs, attrs, value_infos, name) _roi_pool(prog, 'psroi_pool', inputs, outputs, attrs, value_infos, name)
def Reshape( def Reshape(prog, inputs, outputs, attrs, value_infos, name, *args, **kwargs):
prog, inputs, outputs, attrs, value_infos, name,
*args, **kwargs):
""" """
onnx::Reshape-5: onnx::Reshape-5:
""" """
...@@ -1330,7 +1429,7 @@ def Reshape( ...@@ -1330,7 +1429,7 @@ def Reshape(
var_reshaped = _make_var_name(val_reshaped) var_reshaped = _make_var_name(val_reshaped)
# interpretation # interpretation
paddle_op = 'reshape' fluid_op = 'reshape'
is_const_shape = 'const_value' in value_infos[val_shape] is_const_shape = 'const_value' in value_infos[val_shape]
var_shape = _make_var_name(val_shape) # for code var_shape = _make_var_name(val_shape) # for code
if is_const_shape: if is_const_shape:
...@@ -1343,9 +1442,9 @@ def Reshape( ...@@ -1343,9 +1442,9 @@ def Reshape(
if is_const_shape: if is_const_shape:
prog.Code('{} = layers.{}({}' prog.Code('{} = layers.{}({}'
', shape={}' ', shape={}'
'{})' '{})'.format(
.format(var_reshaped, var_reshaped,
paddle_op, fluid_op,
var_data, var_data,
# attrs # attrs
var_shape, var_shape,
...@@ -1353,7 +1452,9 @@ def Reshape( ...@@ -1353,7 +1452,9 @@ def Reshape(
)) ))
else: else:
var_shape_int32 = var_shape + '_int32' var_shape_int32 = var_shape + '_int32'
prog.Op('', 'Cast', prog.Op(
'',
'Cast',
[var_shape], [var_shape],
[var_shape_int32], # var [var_shape_int32], # var
dict(to=np.dtype('int32')), dict(to=np.dtype('int32')),
...@@ -1363,36 +1464,36 @@ def Reshape( ...@@ -1363,36 +1464,36 @@ def Reshape(
prog.Code('{} = layers.{}({}' prog.Code('{} = layers.{}({}'
', shape={}' ', shape={}'
', actual_shape={}' ', actual_shape={}'
'{})' '{})'.format(
.format(var_reshaped, var_reshaped,
paddle_op, fluid_op,
var_data, var_data,
# attrs # attrs
shape, shape,
var_shape_int32, var_shape_int32,
name_attr, name_attr,
)) ))
paddle_op = 'reshape2' fluid_op = 'reshape2'
var_xshape = _make_var_name(name + '.xshape') var_xshape = _make_var_name(name + '.xshape')
prog.VarDesc(var_reshaped) prog.VarDesc(var_reshaped)
prog.VarDesc(var_xshape) prog.VarDesc(var_xshape)
if is_const_shape: if is_const_shape:
prog.OpDesc(paddle_op, prog.OpDesc(
fluid_op,
([var_data], 'X'), ([var_data], 'X'),
([var_reshaped, var_xshape], 'Out', 'XShape'), ([var_reshaped, var_xshape], 'Out', 'XShape'),
dict(shape=shape), dict(shape=shape),
) )
else: else:
prog.OpDesc(paddle_op, prog.OpDesc(
fluid_op,
([var_data, var_shape_int32], 'X', 'Shape'), ([var_data, var_shape_int32], 'X', 'Shape'),
([var_reshaped, var_xshape], 'Out', 'XShape'), ([var_reshaped, var_xshape], 'Out', 'XShape'),
dict(shape=shape), dict(shape=shape),
) )
def Slice( def Slice(prog, inputs, outputs, attrs, value_infos, *args, **kwargs):
prog, inputs, outputs, attrs, value_infos,
*args, **kwargs):
""" """
onnx::Slice-1:9 onnx::Slice-1:9
""" """
...@@ -1404,17 +1505,17 @@ def Slice( ...@@ -1404,17 +1505,17 @@ def Slice(
var_output = _make_var_name(val_output) var_output = _make_var_name(val_output)
# interpretation # interpretation
paddle_op = 'slice' fluid_op = 'slice'
axes = attrs['axes'] # required axes = attrs['axes'] # required
starts = attrs['starts'] # required starts = attrs['starts'] # required
ends = attrs['ends'] # required ends = attrs['ends'] # required
shape = _shape_or_none(value_infos, val_data) shape = _shape_or_none(value_infos, val_data)
if shape: if shape:
ndims = len(shape) # ndims = len(shape)
for idx, value in enumerate(axes): # for idx, value in enumerate(axes):
if value > ONNX_INT_MAX // 2: # if value > ONNX_INT_MAX // 2:
axes[idx] = ndims + value - ONNX_INT_MAX - 1 # axes[idx] = ndims + value - ONNX_INT_MAX - 1
# HINT: Paddle 1.3 Doc: '对于未知大小维度的末尾进行切片,则建议传入 INT_MAX' not works ? # FIXME: Paddle 1.3 Doc: '对于未知大小维度的末尾进行切片,则建议传入 INT_MAX' not works ?
for idx, value in enumerate(starts): for idx, value in enumerate(starts):
if value > ONNX_INT_MAX // 2: if value > ONNX_INT_MAX // 2:
value = value - ONNX_INT_MAX - 1 value = value - ONNX_INT_MAX - 1
...@@ -1429,9 +1530,9 @@ def Slice( ...@@ -1429,9 +1530,9 @@ def Slice(
', axes={}' ', axes={}'
', starts={}' ', starts={}'
', ends={}' ', ends={}'
')' ')'.format(
.format(var_output, var_output,
paddle_op, fluid_op,
var_data, var_data,
# attrs # attrs
axes, axes,
...@@ -1439,19 +1540,19 @@ def Slice( ...@@ -1439,19 +1540,19 @@ def Slice(
ends, ends,
)) ))
prog.VarDesc(var_output) prog.VarDesc(var_output)
prog.OpDesc(paddle_op, prog.OpDesc(
fluid_op,
([var_data], 'X'), ([var_data], 'X'),
([var_output], 'Out'), ([var_output], 'Out'),
dict(axes=axes, dict(
axes=axes,
starts=starts, starts=starts,
ends=ends, ends=ends,
), ),
) )
def Sum( def Sum(prog, inputs, outputs, *args, **kwargs):
prog, inputs, outputs,
*args, **kwargs):
""" """
onnx::Sum-8: onnx::Sum-8:
""" """
...@@ -1462,27 +1563,25 @@ def Sum( ...@@ -1462,27 +1563,25 @@ def Sum(
var_sum = _make_var_name(val_sum) var_sum = _make_var_name(val_sum)
# interpretation # interpretation
paddle_op = 'sums' fluid_op = 'sums'
# generation # generation
prog.Code('{} = layers.{}({})' prog.Code('{} = layers.{}({})'.format(
.format(var_sum, var_sum,
paddle_op, fluid_op,
'[' + ', '.join(var_inps) + ']', '[' + ', '.join(var_inps) + ']',
# attrs # attrs
)) ))
prog.VarDesc(var_sum) prog.VarDesc(var_sum)
prog.OpDesc(paddle_op, prog.OpDesc(
fluid_op,
(var_inps, *(['X'] * len(var_inps))), (var_inps, *(['X'] * len(var_inps))),
([var_sum], 'Out'), ([var_sum], 'Out'),
dict(), dict(),
) )
def Tile( def Tile(prog, inputs, outputs, attrs, value_infos, name='', *args, **kwargs):
prog, inputs, outputs, attrs, value_infos,
name='',
*args, **kwargs):
""" """
onnx::ConstantOfShape-6: onnx::ConstantOfShape-6:
""" """
...@@ -1494,7 +1593,7 @@ def Tile( ...@@ -1494,7 +1593,7 @@ def Tile(
var_output = _make_var_name(val_output) var_output = _make_var_name(val_output)
# interpretation # interpretation
paddle_op = 'expand' fluid_op = 'expand'
is_const_repeats = 'const_value' in value_infos[val_repeats] is_const_repeats = 'const_value' in value_infos[val_repeats]
if is_const_repeats: if is_const_repeats:
code_repeats = _make_var_name(val_repeats) # for code code_repeats = _make_var_name(val_repeats) # for code
...@@ -1507,16 +1606,17 @@ def Tile( ...@@ -1507,16 +1606,17 @@ def Tile(
# generation # generation
prog.Code('{} = layers.{}({}' prog.Code('{} = layers.{}({}'
', expand_times={}' ', expand_times={}'
'{})' '{})'.format(
.format(var_output, var_output,
paddle_op, fluid_op,
var_input, var_input,
# attrs # attrs
code_repeats, code_repeats,
name_attr, name_attr,
)) ))
prog.VarDesc(var_output) prog.VarDesc(var_output)
prog.OpDesc(paddle_op, prog.OpDesc(
fluid_op,
([var_input], 'X'), ([var_input], 'X'),
([var_output], 'Out'), ([var_output], 'Out'),
dict(expand_times=repeats), dict(expand_times=repeats),
...@@ -1537,29 +1637,25 @@ def Tile( ...@@ -1537,29 +1637,25 @@ def Tile(
# var_shape = _make_var_name(val_shape) # var_shape = _make_var_name(val_shape)
# #
# # interpretation # # interpretation
# paddle_op = 'shape' # fluid_op = 'shape'
## value_infos[val_shape]['remove_batch'] = False ## value_infos[val_shape]['remove_batch'] = False
# #
# # generation # # generation
# prog.Code('{} = layers.{}({})' # prog.Code('{} = layers.{}({})'
# .format(var_shape, # .format(var_shape,
# paddle_op, # fluid_op,
# var_data, # var_data,
# # attrs # # attrs
# )) # ))
# prog.VarDesc(var_shape) # , _value_info_or_none(value_infos, val_shape)) # prog.VarDesc(var_shape) # , _value_info_or_none(value_infos, val_shape))
# prog.OpDesc(paddle_op, # prog.OpDesc(fluid_op,
# ([var_data], 'X'), # ([var_data], 'X'),
# ([var_shape], 'Out'), # ([var_shape], 'Out'),
# dict(), # dict(),
# ) # )
def Split( def Split(prog, inputs, outputs, attrs, *args, name='', **kwargs):
prog, inputs, outputs, attrs,
*args,
name='',
**kwargs):
""" """
onnx::Split-2: onnx::Split-2:
""" """
...@@ -1570,7 +1666,7 @@ def Split( ...@@ -1570,7 +1666,7 @@ def Split(
var_input = _make_var_name(val_input) var_input = _make_var_name(val_input)
# interpretation # interpretation
paddle_op = 'split' fluid_op = 'split'
split = attrs['split'] # required split = attrs['split'] # required
axis = attrs.get('axis', 0) # optional axis = attrs.get('axis', 0) # optional
name_attr = ', name={}'.format(repr(name)) if name else '' name_attr = ', name={}'.format(repr(name)) if name else ''
...@@ -1578,21 +1674,23 @@ def Split( ...@@ -1578,21 +1674,23 @@ def Split(
# generation # generation
prog.Code('{} = layers.{}({}, {}' prog.Code('{} = layers.{}({}, {}'
', dim={}' ', dim={}'
'{})' '{})'.format(
.format(', '.join(var_outs), ', '.join(var_outs),
paddle_op, fluid_op,
var_input, var_input,
split, split,
# attrs # attrs
axis, axis,
name_attr, name_attr,
)) ))
for val_out, var_out in zip(outputs, var_outs): for var_out in var_outs:
prog.VarDesc(var_out) prog.VarDesc(var_out)
prog.OpDesc(paddle_op, prog.OpDesc(
fluid_op,
(var_input, 'X'), (var_input, 'X'),
([var_outs], *(['Out'] * len(var_outs))), ([var_outs], *(['Out'] * len(var_outs))),
dict(axis=axis, dict(
axis=axis,
sections=split, sections=split,
), ),
) )
...@@ -1600,7 +1698,8 @@ def Split( ...@@ -1600,7 +1698,8 @@ def Split(
if __name__ == '__main__': if __name__ == '__main__':
_logging.basicConfig( _logging.basicConfig(
format='[%(levelname)8s]%(name)s::%(funcName)s:%(lineno)04d: %(message)s', format=
'[%(levelname)8s]%(name)s::%(funcName)s:%(lineno)04d: %(message)s',
level=_logging.DEBUG, level=_logging.DEBUG,
) )
logger = _logging.getLogger('symbolic_test') logger = _logging.getLogger('symbolic_test')
...@@ -1608,7 +1707,10 @@ if __name__ == '__main__': ...@@ -1608,7 +1707,10 @@ if __name__ == '__main__':
from writer import Program from writer import Program
prog = Program() prog = Program()
AdaptiveAveragePool(prog, ['X'], ['Y'], AdaptiveAveragePool(
prog,
['X'],
['Y'],
dict(output_size=[3, 3]), dict(output_size=[3, 3]),
dict(Y=dict(shape=(2, 3, 3, 3), dtype=np.float32)), dict(Y=dict(shape=(2, 3, 3, 3), dtype=np.float32)),
name='AdaptiveAveragePool2d', name='AdaptiveAveragePool2d',
...@@ -1616,7 +1718,10 @@ if __name__ == '__main__': ...@@ -1616,7 +1718,10 @@ if __name__ == '__main__':
logger.info('AdaptiveAveragePool2d program:\n%s', prog) logger.info('AdaptiveAveragePool2d program:\n%s', prog)
prog = Program() prog = Program()
AdaptiveAveragePool(prog, ['X'], ['Y'], AdaptiveAveragePool(
prog,
['X'],
['Y'],
dict(output_size=[3, 3, 3]), dict(output_size=[3, 3, 3]),
dict(Y=dict(shape=(2, 3, 3, 3, 3), dtype=np.float32)), dict(Y=dict(shape=(2, 3, 3, 3, 3), dtype=np.float32)),
name='AdaptiveAveragePool3d', name='AdaptiveAveragePool3d',
...@@ -1624,18 +1729,26 @@ if __name__ == '__main__': ...@@ -1624,18 +1729,26 @@ if __name__ == '__main__':
logger.info('AdaptiveAveragePool3d program:\n%s', prog) logger.info('AdaptiveAveragePool3d program:\n%s', prog)
prog = Program() prog = Program()
AffineGrid(prog, ['Theta'], ['Grid'], AffineGrid(
prog,
['Theta'],
['Grid'],
dict(size=[2, 2, 8, 8]), dict(size=[2, 2, 8, 8]),
dict(Grid=dict(shape=(2, 8, 8, 2), dtype=np.float32)), dict(Grid=dict(shape=(2, 8, 8, 2), dtype=np.float32)),
) )
logger.info('AffineGrid program:\n%s', prog) logger.info('AffineGrid program:\n%s', prog)
prog = Program() prog = Program()
BatchNormalization(prog, ['X', 'scale', 'B', 'mean', 'var'], ['Y'], BatchNormalization(
dict(epsilon=1e-5, prog,
['X', 'scale', 'B', 'mean', 'var'],
['Y'],
dict(
epsilon=1e-5,
momentum=.9, momentum=.9,
), ),
dict(scale=dict(shape=(3, ), dtype=np.float32), dict(
scale=dict(shape=(3, ), dtype=np.float32),
B=dict(shape=(3, ), dtype=np.float32), B=dict(shape=(3, ), dtype=np.float32),
mean=dict(shape=(3, ), dtype=np.float32), mean=dict(shape=(3, ), dtype=np.float32),
var=dict(shape=(3, ), dtype=np.float32), var=dict(shape=(3, ), dtype=np.float32),
...@@ -1647,30 +1760,43 @@ if __name__ == '__main__': ...@@ -1647,30 +1760,43 @@ if __name__ == '__main__':
logger.info('BatchNormalization program:\n%s', prog) logger.info('BatchNormalization program:\n%s', prog)
prog = Program() prog = Program()
Cast(prog, ['input'], ['output'], Cast(
prog,
['input'],
['output'],
dict(to=2), # TensorProto.UINT8 dict(to=2), # TensorProto.UINT8
dict(input=dict(shape=(2, 3), dtype=np.float32), dict(
input=dict(shape=(2, 3), dtype=np.float32),
output=dict(shape=(2, 3), dtype=np.uint8)), output=dict(shape=(2, 3), dtype=np.uint8)),
) )
logger.info('Cast program:\n%s', prog) logger.info('Cast program:\n%s', prog)
prog = Program() prog = Program()
_default(prog, 'Clip', ['input'], ['output'], _default(
prog,
'Clip',
['input'],
['output'],
dict(min=-1., max=1.), dict(min=-1., max=1.),
dict(output=dict(shape=(2, 3), dtype=np.float32)), dict(output=dict(shape=(2, 3), dtype=np.float32)),
) )
logger.info('Clip program:\n%s', prog) logger.info('Clip program:\n%s', prog)
prog = Program() prog = Program()
Conv(prog, ['X', 'W'], ['Y'], Conv(
dict(auto_pad='NOTSET', prog,
['X', 'W'],
['Y'],
dict(
auto_pad='NOTSET',
dilations=[1, 1], dilations=[1, 1],
group=1, group=1,
kernel_shape=[3, 3], kernel_shape=[3, 3],
pads=[1, 1, 1, 1], pads=[1, 1, 1, 1],
strides=[1, 1], strides=[1, 1],
), ),
dict(W=dict(shape=(2, 3, 3, 3), dtype=np.float32), dict(
W=dict(shape=(2, 3, 3, 3), dtype=np.float32),
Y=dict(shape=(2, 2, 4, 6), dtype=np.float32), Y=dict(shape=(2, 2, 4, 6), dtype=np.float32),
), ),
name='ConvNoBias2d', name='ConvNoBias2d',
...@@ -1679,15 +1805,20 @@ if __name__ == '__main__': ...@@ -1679,15 +1805,20 @@ if __name__ == '__main__':
logger.info('ConvNoBias2d program:\n%s', prog) logger.info('ConvNoBias2d program:\n%s', prog)
prog = Program() prog = Program()
Conv(prog, ['X', 'W', 'B'], ['Y'], Conv(
dict(auto_pad='NOTSET', prog,
['X', 'W', 'B'],
['Y'],
dict(
auto_pad='NOTSET',
dilations=[1, 1], dilations=[1, 1],
group=1, group=1,
kernel_shape=[3, 3], kernel_shape=[3, 3],
pads=[1, 1, 1, 1], pads=[1, 1, 1, 1],
strides=[1, 1], strides=[1, 1],
), ),
dict(W=dict(shape=(2, 3, 3, 3), dtype=np.float32), dict(
W=dict(shape=(2, 3, 3, 3), dtype=np.float32),
B=dict(shape=(2), dtype=np.float32), B=dict(shape=(2), dtype=np.float32),
Y=dict(shape=(2, 2, 4, 6), dtype=np.float32), Y=dict(shape=(2, 2, 4, 6), dtype=np.float32),
), ),
...@@ -1697,17 +1828,22 @@ if __name__ == '__main__': ...@@ -1697,17 +1828,22 @@ if __name__ == '__main__':
logger.info('Conv2d program:\n%s', prog) logger.info('Conv2d program:\n%s', prog)
prog = Program() prog = Program()
ConvTranspose(prog, ['X', 'W', 'B'], ['Y'], ConvTranspose(
dict(auto_pad='NOTSET', prog,
['X', 'W', 'B'],
['Y'],
dict(
auto_pad='NOTSET',
dilations=[1, 1], dilations=[1, 1],
group=1, group=1,
kernel_shape=[3, 3], kernel_shape=[3, 3],
# output_padding=[1, 1, 1, 1], # output_padding=[1, 1, 1, 1],
# output_shape=[6, 8], # output_shape=[6, 8],
pads=[1, 1, 1, 1], pads=[1, 1, 1, 1],
strides=[1, 1], strides=[1, 1],
), ),
dict(W=dict(shape=(2, 3, 3, 3), dtype=np.float32), dict(
W=dict(shape=(2, 3, 3, 3), dtype=np.float32),
B=dict(shape=(2), dtype=np.float32), B=dict(shape=(2), dtype=np.float32),
Y=dict(shape=(2, 2, 6, 8), dtype=np.float32), Y=dict(shape=(2, 2, 6, 8), dtype=np.float32),
), ),
...@@ -1717,15 +1853,20 @@ if __name__ == '__main__': ...@@ -1717,15 +1853,20 @@ if __name__ == '__main__':
logger.info('ConvTransposed2d program:\n%s', prog) logger.info('ConvTransposed2d program:\n%s', prog)
prog = Program() prog = Program()
Conv(prog, ['X', 'W'], ['Y'], Conv(
dict(auto_pad='NOTSET', prog,
['X', 'W'],
['Y'],
dict(
auto_pad='NOTSET',
dilations=[1, 1, 1], dilations=[1, 1, 1],
group=1, group=1,
kernel_shape=[3, 3, 3], kernel_shape=[3, 3, 3],
pads=[1, 1, 1, 1, 1, 1], pads=[1, 1, 1, 1, 1, 1],
strides=[1, 1, 1], strides=[1, 1, 1],
), ),
dict(W=dict(shape=(2, 3, 3, 3, 3), dtype=np.float32), dict(
W=dict(shape=(2, 3, 3, 3, 3), dtype=np.float32),
Y=dict(shape=(2, 2, 4, 6, 8), dtype=np.float32), Y=dict(shape=(2, 2, 4, 6, 8), dtype=np.float32),
), ),
name='ConvNoBias3d', name='ConvNoBias3d',
...@@ -1734,15 +1875,20 @@ if __name__ == '__main__': ...@@ -1734,15 +1875,20 @@ if __name__ == '__main__':
logger.info('ConvNoBias3d program:\n%s', prog) logger.info('ConvNoBias3d program:\n%s', prog)
prog = Program() prog = Program()
Conv(prog, ['X', 'W', 'B'], ['Y'], Conv(
dict(auto_pad='NOTSET', prog,
['X', 'W', 'B'],
['Y'],
dict(
auto_pad='NOTSET',
dilations=[1, 1, 1], dilations=[1, 1, 1],
group=1, group=1,
kernel_shape=[3, 3, 3], kernel_shape=[3, 3, 3],
pads=[1, 1, 1, 1, 1, 1], pads=[1, 1, 1, 1, 1, 1],
strides=[1, 1, 1], strides=[1, 1, 1],
), ),
dict(W=dict(shape=(2, 3, 3, 3, 3), dtype=np.float32), dict(
W=dict(shape=(2, 3, 3, 3, 3), dtype=np.float32),
B=dict(shape=(2), dtype=np.float32), B=dict(shape=(2), dtype=np.float32),
Y=dict(shape=(2, 2, 4, 6, 8), dtype=np.float32), Y=dict(shape=(2, 2, 4, 6, 8), dtype=np.float32),
), ),
...@@ -1752,17 +1898,22 @@ if __name__ == '__main__': ...@@ -1752,17 +1898,22 @@ if __name__ == '__main__':
logger.info('Conv3d program:\n%s', prog) logger.info('Conv3d program:\n%s', prog)
prog = Program() prog = Program()
ConvTranspose(prog, ['X', 'W', 'B'], ['Y'], ConvTranspose(
dict(auto_pad='NOTSET', prog,
['X', 'W', 'B'],
['Y'],
dict(
auto_pad='NOTSET',
dilations=[1, 1, 1], dilations=[1, 1, 1],
group=1, group=1,
kernel_shape=[3, 3, 3], kernel_shape=[3, 3, 3],
# output_padding=[1, 1, 1, 1], # output_padding=[1, 1, 1, 1],
# output_shape=[6, 8], # output_shape=[6, 8],
pads=[1, 1, 1, 1, 1, 1], pads=[1, 1, 1, 1, 1, 1],
strides=[1, 1, 1], strides=[1, 1, 1],
), ),
dict(W=dict(shape=(2, 3, 3, 3, 3), dtype=np.float32), dict(
W=dict(shape=(2, 3, 3, 3, 3), dtype=np.float32),
B=dict(shape=(2), dtype=np.float32), B=dict(shape=(2), dtype=np.float32),
Y=dict(shape=(2, 2, 6, 8, 9), dtype=np.float32), Y=dict(shape=(2, 2, 6, 8, 9), dtype=np.float32),
), ),
...@@ -1772,20 +1923,29 @@ if __name__ == '__main__': ...@@ -1772,20 +1923,29 @@ if __name__ == '__main__':
logger.info('ConvTransposed3d program:\n%s', prog) logger.info('ConvTransposed3d program:\n%s', prog)
prog = Program() prog = Program()
_default(prog, 'Equal', ['A', 'B'], ['C'], _default(
prog,
'Equal',
['A', 'B'],
['C'],
dict(), dict(),
dict(C=dict(shape=(2, 3), dtype=np.bool)), dict(C=dict(shape=(2, 3), dtype=np.bool)),
) )
logger.info('Equal program:\n%s', prog) logger.info('Equal program:\n%s', prog)
prog = Program() prog = Program()
Gemm(prog, ['A', 'B', 'C'], ['Y'], Gemm(
dict(alpha=1., prog,
['A', 'B', 'C'],
['Y'],
dict(
alpha=1.,
beta=1., beta=1.,
transA=0, transA=0,
transB=1, transB=1,
), ),
dict(B=dict(shape=(8, 3), dtype=np.float32), dict(
B=dict(shape=(8, 3), dtype=np.float32),
Y=dict(shape=(2, 8), dtype=np.float32), Y=dict(shape=(2, 8), dtype=np.float32),
), ),
name='Gemm', name='Gemm',
...@@ -1793,34 +1953,48 @@ if __name__ == '__main__': ...@@ -1793,34 +1953,48 @@ if __name__ == '__main__':
logger.info('Gemm program:\n%s', prog) logger.info('Gemm program:\n%s', prog)
prog = Program() prog = Program()
_default(prog, 'Less', ['A', 'B'], ['C'], _default(
prog,
'Less',
['A', 'B'],
['C'],
dict(), dict(),
dict(C=dict(shape=(2, 3), dtype=np.bool)), dict(C=dict(shape=(2, 3), dtype=np.bool)),
) )
logger.info('Less program:\n%s', prog) logger.info('Less program:\n%s', prog)
prog = Program() prog = Program()
_default(prog, 'MatMul', ['A', 'B'], ['Y'], _default(
prog,
'MatMul', ['A', 'B'], ['Y'],
dict(), dict(),
dict(Y=dict(shape=(2, 8), dtype=np.float32)), dict(Y=dict(shape=(2, 8), dtype=np.float32)),
name='MatMul' name='MatMul')
)
logger.info('MatMul program:\n%s', prog) logger.info('MatMul program:\n%s', prog)
prog = Program() prog = Program()
_default(prog, 'OneHot', ['indices', 'depth', 'values'], ['output'], _default(
prog,
'OneHot',
['indices', 'depth', 'values'],
['output'],
dict(axis=-1), dict(axis=-1),
dict(output=dict(shape=(2, 8), dtype=np.float32)), dict(output=dict(shape=(2, 8), dtype=np.float32)),
) )
logger.info('OneHot program:\n%s', prog) logger.info('OneHot program:\n%s', prog)
prog = Program() prog = Program()
Pad(prog, ['data'], ['output'], Pad(
dict(mode='constant', prog,
['data'],
['output'],
dict(
mode='constant',
pads=[0, 1], pads=[0, 1],
value=0., value=0.,
), ),
dict(data=dict(shape=(2, 7), dtype=np.float32), dict(
data=dict(shape=(2, 7), dtype=np.float32),
output=dict(shape=(2, 8), dtype=np.float32), output=dict(shape=(2, 8), dtype=np.float32),
), ),
name='Pad', name='Pad',
...@@ -1828,12 +2002,17 @@ if __name__ == '__main__': ...@@ -1828,12 +2002,17 @@ if __name__ == '__main__':
logger.info('Pad program:\n%s', prog) logger.info('Pad program:\n%s', prog)
prog = Program() prog = Program()
Pad(prog, ['data'], ['output'], Pad(
dict(mode='reflect', prog,
['data'],
['output'],
dict(
mode='reflect',
pads=[0, 1, 2, 3], pads=[0, 1, 2, 3],
value=0., value=0.,
), ),
dict(data=dict(shape=(2, 3, 3, 3), dtype=np.float32), dict(
data=dict(shape=(2, 3, 3, 3), dtype=np.float32),
output=dict(shape=(2, 3, 5, 7), dtype=np.float32), output=dict(shape=(2, 3, 5, 7), dtype=np.float32),
), ),
name='Pad2d', name='Pad2d',
...@@ -1841,7 +2020,10 @@ if __name__ == '__main__': ...@@ -1841,7 +2020,10 @@ if __name__ == '__main__':
logger.info('Pad2d program:\n%s', prog) logger.info('Pad2d program:\n%s', prog)
prog = Program() prog = Program()
PRelu(prog, ['X', 'slope'], ['Y'], PRelu(
prog,
['X', 'slope'],
['Y'],
dict(), dict(),
dict(Y=dict(shape=(2, 3), dtype=np.float32)), dict(Y=dict(shape=(2, 3), dtype=np.float32)),
name='PRelu', name='PRelu',
...@@ -1849,11 +2031,11 @@ if __name__ == '__main__': ...@@ -1849,11 +2031,11 @@ if __name__ == '__main__':
logger.info('PRelu program:\n%s', prog) logger.info('PRelu program:\n%s', prog)
prog = Program() prog = Program()
Tile(prog, ['input', 'repeats'], ['output'], Tile(
prog, ['input', 'repeats'], ['output'],
dict(), dict(),
dict(repeats=dict(const_value=[1, 2]), dict(
output=dict(shape=(2, 2, 4), dtype=np.float32) repeats=dict(const_value=[1, 2]),
), output=dict(shape=(2, 2, 4), dtype=np.float32)),
name='Tile' name='Tile')
)
logger.info('Tile program:\n%s', prog) logger.info('Tile program:\n%s', prog)
...@@ -24,8 +24,7 @@ def _ensure_tuple(obj): ...@@ -24,8 +24,7 @@ def _ensure_tuple(obj):
return (obj, ) return (obj, )
def _flatten_list(obj, def _flatten_list(obj, out=None):
out=None):
assert isinstance(obj, list) assert isinstance(obj, list)
if out is None: if out is None:
out = type(obj)() out = type(obj)()
...@@ -37,8 +36,7 @@ def _flatten_list(obj, ...@@ -37,8 +36,7 @@ def _flatten_list(obj,
return out return out
def export_data(state_dict, def export_data(state_dict, prefix=''):
prefix=''):
""" """
export binary data with meta text for raw C++ inference engines export binary data with meta text for raw C++ inference engines
""" """
...@@ -65,10 +63,14 @@ def export_data(state_dict, ...@@ -65,10 +63,14 @@ def export_data(state_dict,
fp.close() fp.close()
def export_onnx_with_validation(model, inputs, export_basepath, def export_onnx_with_validation(model,
input_names=None, output_names=None, inputs,
export_basepath,
input_names=None,
output_names=None,
use_npz=True, use_npz=True,
*args, **kwargs): *args,
**kwargs):
""" """
export PyTorch model to ONNX model and export sample inputs and outputs in a Numpy file export PyTorch model to ONNX model and export sample inputs and outputs in a Numpy file
""" """
...@@ -96,10 +98,14 @@ def export_onnx_with_validation(model, inputs, export_basepath, ...@@ -96,10 +98,14 @@ def export_onnx_with_validation(model, inputs, export_basepath,
return ret return ret
torch_inputs = _ensure_tuple(inputs) # WORKAROUND: for torch.onnx torch_inputs = _ensure_tuple(inputs) # WORKAROUND: for torch.onnx
outputs = torch.onnx.export(model, torch_inputs, export_basepath + '.onnx', outputs = torch.onnx.export(
model,
torch_inputs,
export_basepath + '.onnx',
input_names=_flatten_list(input_names), input_names=_flatten_list(input_names),
output_names=_flatten_list(output_names), output_names=_flatten_list(output_names),
*args, **kwargs) *args,
**kwargs)
if outputs is None: # WORKAROUND: for torch.onnx if outputs is None: # WORKAROUND: for torch.onnx
outputs = model(*inputs) outputs = model(*inputs)
torch_outputs = _ensure_tuple(outputs) torch_outputs = _ensure_tuple(outputs)
......
...@@ -13,8 +13,7 @@ import os ...@@ -13,8 +13,7 @@ import os
import sys import sys
def _flatten_dict(obj, def _flatten_dict(obj, out=None):
out=None):
assert isinstance(obj, dict) assert isinstance(obj, dict)
if out is None: if out is None:
out = type(obj)() out = type(obj)()
...@@ -34,12 +33,13 @@ def _ensure_list(obj): ...@@ -34,12 +33,13 @@ def _ensure_list(obj):
return [obj] return [obj]
def validate(paddle_model_filename, golden_data_filename, def validate(fluid_model_filename,
golden_data_filename,
model_func_name='inference', model_func_name='inference',
precision=1e-4, precision=1e-4,
save_inference_model=False): save_inference_model=False):
""" """
inferece the converted Paddle model, validate with given golden data inferece the converted Paddle fluid model, validate with given golden data
""" """
import numpy as np import numpy as np
...@@ -52,17 +52,17 @@ def validate(paddle_model_filename, golden_data_filename, ...@@ -52,17 +52,17 @@ def validate(paddle_model_filename, golden_data_filename,
exe.run(fluid.default_startup_program()) exe.run(fluid.default_startup_program())
# load model # load model
paddle_model_dir, basename = os.path.split(paddle_model_filename) fluid_model_dir, basename = os.path.split(fluid_model_filename)
if basename == '__model__': # is desc model if basename == '__model__': # is desc model
logger.debug('using desc file %s', basename) logger.debug('using desc file %s', basename)
prog, in_names, var_outs = fluid.io.load_inference_model(paddle_model_dir, exe) prog, _, var_outs = fluid.io.load_inference_model(fluid_model_dir, exe)
out_names = var_outs # HINT: pass var if fetch ops already created out_names = var_outs # HINT: pass var if fetch ops already created
logger.info('model load passed') logger.info('model load passed')
elif basename.endswith('.py'): # is python code elif basename.endswith('.py'): # is python code
logger.debug('using python code file %s', basename) logger.debug('using python code file %s', basename)
module_name, _ = os.path.splitext(basename) module_name, _ = os.path.splitext(basename)
sys_path = sys.path.copy() sys_path = sys.path.copy()
sys.path.append(paddle_model_dir) sys.path.append(fluid_model_dir)
try: try:
module = importlib.import_module(module_name) module = importlib.import_module(module_name)
func = getattr(module, model_func_name) func = getattr(module, model_func_name)
...@@ -71,18 +71,21 @@ def validate(paddle_model_filename, golden_data_filename, ...@@ -71,18 +71,21 @@ def validate(paddle_model_filename, golden_data_filename,
module = importlib.import_module(module_name) module = importlib.import_module(module_name)
func = getattr(module, model_func_name) func = getattr(module, model_func_name)
sys.path = sys_path sys.path = sys_path
logger.debug('from %s imported %s: %s', module_name, model_func_name, func) logger.debug('from %s imported %s: %s', module_name, model_func_name,
func)
var_outs = func() var_outs = func()
var_outs = _ensure_list(var_outs) var_outs = _ensure_list(var_outs)
out_names = [var.name for var in var_outs] # HINT: pass string to create fetch ops out_names = [var.name for var in var_outs
] # HINT: pass string to create fetch ops
logger.info('import passed') logger.info('import passed')
prog = fluid.default_main_program() prog = fluid.default_main_program()
fluid.io.load_persistables(executor=exe, dirname=paddle_model_dir, main_program=prog) fluid.io.load_persistables(
executor=exe, dirname=fluid_model_dir, main_program=prog)
logger.info('weight load passed') logger.info('weight load passed')
else: else:
raise ValueError('unsupported Paddle model') raise ValueError('unsupported Paddle fluid model')
# load data # load data
logger.info('using golden data %s', golden_data_filename) logger.info('using golden data %s', golden_data_filename)
...@@ -100,10 +103,15 @@ def validate(paddle_model_filename, golden_data_filename, ...@@ -100,10 +103,15 @@ def validate(paddle_model_filename, golden_data_filename,
# DEBUG: reload test for python code # DEBUG: reload test for python code
if basename.endswith('.py') and save_inference_model: if basename.endswith('.py') and save_inference_model:
fluid.io.save_inference_model(paddle_model_dir, input_data.keys(), var_outs, exe, fluid.io.save_inference_model(
main_program=prog, export_for_deployment=True) fluid_model_dir,
input_data.keys(),
var_outs,
exe,
main_program=prog,
export_for_deployment=True)
logger.info('model re-save passed') logger.info('model re-save passed')
fluid.io.load_inference_model(paddle_model_dir, exe) fluid.io.load_inference_model(fluid_model_dir, exe)
logger.info('model re-load passed') logger.info('model re-load passed')
# execute # execute
...@@ -124,49 +132,54 @@ def validate(paddle_model_filename, golden_data_filename, ...@@ -124,49 +132,54 @@ def validate(paddle_model_filename, golden_data_filename,
else: else:
logger.info('accuracy not passed') logger.info('accuracy not passed')
# globals().update(locals()) # globals().update(locals())
return passed return passed
if __name__ == '__main__': if __name__ == '__main__':
logging.basicConfig( import argparse
format='[%(levelname)8s]%(name)s::%(funcName)s:%(lineno)04d: %(message)s',
level=logging.DEBUG,
)
logger = logging.getLogger('validation_test')
model_rc_list = [
'../examples/t{}/model.py',
'../examples/t{}/__model__',
'../examples/t{}.embeded/model.py',
'../examples/t{}.embeded/__model__',
]
import numpy as np
idx_model = np.random.randint(1, 7)
model = np.random.choice(model_rc_list).format(idx_model)
precision = 10 ** (np.random.rand() * -4 - 2)
debug = False
model = '/tmp/export/model.py'
# model = '../examples/t1/__model__'
# model = '../examples/t1.embeded/model.py'
# model = '../examples/t1.embeded/__model__'
debug = True
logger.info('args: %s %.6f', model, precision)
data_dir, dir_name = os.path.split(os.path.split(model)[0]) parser = argparse.ArgumentParser(
data_pathname = os.path.splitext(dir_name)[0] description='onnx2fluid.validate',
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
# proto debug test )
from framework_pb2 import ProgramDesc parser.add_argument(
pd = ProgramDesc() 'model',
pd.ParseFromString(open(os.path.join(data_dir, dir_name, '__model__'), 'rb').read()) nargs=1,
help='path to model.py or __model__',
# validate )
# validate(model, os.path.join(data_dir, data_pathname + '.npz'), parser.add_argument(
# precision=precision, save_inference_model=debug) '--debug',
validate(model, '../examples/bvlc_alexnet/test_data_0.npz', '-d',
precision=precision, save_inference_model=debug) action='store_true',
help='enable debug logging and checking',
)
parser.add_argument(
'--test_data',
'-t',
type=str,
help='I/O golden data for validation, e.g. test.npy, test.npz',
)
parser.add_argument(
'--precision',
'-p',
type=int,
default=4,
help='assertion decimal for validation',
)
args = parser.parse_args()
logging_format = '[%(levelname)8s]%(name)s::%(funcName)s:%(lineno)04d: %(message)s'
logging_level = logging.DEBUG if args.debug else logging.INFO
logging.basicConfig(format=logging_format, level=logging_level)
debug = args.debug
fluid_model_filename = args.model[0]
golden_data_filename = args.test_data
precision = args.precision
validate(
fluid_model_filename,
golden_data_filename,
precision=precision,
save_inference_model=debug)
...@@ -34,15 +34,13 @@ except ImportError: ...@@ -34,15 +34,13 @@ except ImportError:
logger.warning('importing paddle.fluid.proto.framework_pb2d failed,' logger.warning('importing paddle.fluid.proto.framework_pb2d failed,'
'using fallback framework_pb2') 'using fallback framework_pb2')
__all__ = [ __all__ = [
'Program', 'Program',
'Writer', 'Writer',
] ]
def _irepr(obj, def _irepr(obj, to='_'):
to='_'):
"""inline repr""" """inline repr"""
s = repr(obj) s = repr(obj)
...@@ -53,8 +51,7 @@ def _irepr(obj, ...@@ -53,8 +51,7 @@ def _irepr(obj,
return s return s
def _flatten_list(obj, def _flatten_list(obj, out=None):
out=None):
if out is None: if out is None:
out = type(obj)() out = type(obj)()
for item in obj: for item in obj:
...@@ -72,7 +69,7 @@ def make_attr_name(name): ...@@ -72,7 +69,7 @@ def make_attr_name(name):
if name == '': if name == '':
raise ValueError('name should not be empty') raise ValueError('name should not be empty')
for s in ' *?\/-:': # for s in ' *?\\/-:': #
name = name.replace(s, '_') name = name.replace(s, '_')
if not name.startswith('_'): if not name.startswith('_'):
name = '_' + name name = '_' + name
...@@ -168,11 +165,8 @@ class Program(object): ...@@ -168,11 +165,8 @@ class Program(object):
return ('Program(code mutable: {}) with:\n' return ('Program(code mutable: {}) with:\n'
'codes: {}\n' 'codes: {}\n'
'op_descs: {}\n' 'op_descs: {}\n'
'var_descs: {}\n').format( 'var_descs: {}\n').format(self.code_mutable, self.codes,
self.code_mutable, self.op_descs, self.var_descs)
self.codes,
self.op_descs,
self.var_descs)
def __repr__(self): def __repr__(self):
return self.__str__() return self.__str__()
...@@ -185,8 +179,11 @@ class Program(object): ...@@ -185,8 +179,11 @@ class Program(object):
if self.code_mutable: if self.code_mutable:
self.codes.append(code) self.codes.append(code)
def OpDesc(self, name, def OpDesc(self,
input_val_keys=None, output_val_keys=None, attrs=None): name,
input_val_keys=None,
output_val_keys=None,
attrs=None):
""" """
add OpDesc add OpDesc
""" """
...@@ -202,10 +199,15 @@ class Program(object): ...@@ -202,10 +199,15 @@ class Program(object):
self.op_descs.append(desc) self.op_descs.append(desc)
return desc return desc
def VarDesc(self, name, def VarDesc(self,
persistable=False, value_info=None, remove_batch=None): name,
persistable=False,
value_info=None,
remove_batch=None,
dummy_dtype='float32'):
""" """
add VarDesc add VarDesc,
dummy_dtype: WORKAROUND for Netron viewer
""" """
var_desc = framework_pb2.VarDesc() var_desc = framework_pb2.VarDesc()
...@@ -213,6 +215,10 @@ class Program(object): ...@@ -213,6 +215,10 @@ class Program(object):
var_desc.persistable = persistable var_desc.persistable = persistable
var_desc.type.type = framework_pb2.VarType.LOD_TENSOR var_desc.type.type = framework_pb2.VarType.LOD_TENSOR
# REMOVEIT: WORKAROUND: Netron: null.tensor error
tensor_desc = var_desc.type.lod_tensor.tensor
tensor_desc.data_type = self.Dtype(dummy_dtype) # required
if value_info and 'dtype' in value_info: if value_info and 'dtype' in value_info:
tensor_desc = var_desc.type.lod_tensor.tensor tensor_desc = var_desc.type.lod_tensor.tensor
tensor_desc.data_type = self.Dtype(value_info['dtype']) # required tensor_desc.data_type = self.Dtype(value_info['dtype']) # required
...@@ -220,7 +226,8 @@ class Program(object): ...@@ -220,7 +226,8 @@ class Program(object):
tensor_desc.dims.extend(value_info['shape']) tensor_desc.dims.extend(value_info['shape'])
if len(value_info['shape']) > 0: # skip scalars if len(value_info['shape']) > 0: # skip scalars
if remove_batch is None: if remove_batch is None:
remove_batch = value_info.get('remove_batch', not persistable) remove_batch = value_info.get('remove_batch',
not persistable)
if remove_batch: if remove_batch:
tensor_desc.dims[0] = -1 tensor_desc.dims[0] = -1
...@@ -240,8 +247,8 @@ class Program(object): ...@@ -240,8 +247,8 @@ class Program(object):
fn = getattr(symbolic, op_type) fn = getattr(symbolic, op_type)
fn(self, *args, **kwargs) fn(self, *args, **kwargs)
else: else:
raise ValueError('conversion for {}::{} not supported' raise ValueError('conversion for {}::{} not supported'.format(
.format(domain, op_type)) domain, op_type))
def IntermediateOp(self, domain, op_type, *args, **kwargs): def IntermediateOp(self, domain, op_type, *args, **kwargs):
""" """
...@@ -267,14 +274,15 @@ class Writer(object): ...@@ -267,14 +274,15 @@ class Writer(object):
CODE_INDENT = ' ' * 4 CODE_INDENT = ' ' * 4
@staticmethod @staticmethod
def header_code(func_name): def header_code(func_name, info=''):
""" """
Python header codes Python header codes
""" """
codes = list() codes = list()
codes.append('"""') codes.append('"""')
codes.append('This code is generated by onnx2paddle.') codes.append('This code is generated by onnx2fluid.')
codes.append('{}'.format(info))
codes.append('"""') codes.append('"""')
codes.append('') codes.append('')
codes.append('from __future__ import division') codes.append('from __future__ import division')
...@@ -287,16 +295,25 @@ class Writer(object): ...@@ -287,16 +295,25 @@ class Writer(object):
return codes return codes
@staticmethod @staticmethod
def emit_op(prog, name, domain, op_type, inputs, outputs, attrs, value_infos, *args, **kwargs): def emit_op(prog, name, domain, op_type, inputs, outputs, attrs,
value_infos, *args, **kwargs):
""" """
emit an ONNX op into program emit an ONNX op into program
""" """
prog.Code('# {}, {}::{}: {} -> {}, {}' prog.Code('# {}, {}::{}: {} -> {}, {}'.format(name, domain, op_type,
.format(name, domain, op_type, inputs, outputs, _irepr(attrs, to=', '))) inputs, outputs,
prog.Op(domain, op_type, inputs, outputs, attrs, _irepr(attrs, to=', ')))
value_infos=value_infos, name=name, prog.Op(
*args, **kwargs) domain,
op_type,
inputs,
outputs,
attrs,
value_infos=value_infos,
name=name,
*args,
**kwargs)
@staticmethod @staticmethod
def emit_param(prog, name, value_info): def emit_param(prog, name, value_info):
...@@ -315,16 +332,16 @@ class Writer(object): ...@@ -315,16 +332,16 @@ class Writer(object):
prog.Code('# parameter: {}'.format(name)) prog.Code('# parameter: {}'.format(name))
prog.Code('{} = ParamAttr(name={})' # , trainable=True prog.Code('{} = ParamAttr(name={})' # , trainable=True
.format(attr_name, repr(var_name))) .format(attr_name, repr(var_name)))
prog.Code('{} = layers.create_parameter(shape={}, dtype={}, name={}, attr={}' prog.Code(
'{} = layers.create_parameter(shape={}, dtype={}, name={}, attr={}'
', default_initializer=initializer.Constant(0))' #, is_bias={} ', default_initializer=initializer.Constant(0))' #, is_bias={}
.format(var_name, .format(var_name, value_info['shape'],
value_info['shape'], repr(value_info['dtype'].name), repr(value_info['dtype'].name), repr(name),
repr(name), attr_name)) #, value_info.get('is_bias', False))) attr_name)) #, value_info.get('is_bias', False)))
prog.VarDesc(var_name, persistable=True, value_info=value_info) prog.VarDesc(var_name, persistable=True, value_info=value_info)
@staticmethod @staticmethod
def emit_inputs(prog, names, value_infos, def emit_inputs(prog, names, value_infos, remove_batch=None):
remove_batch=None):
""" """
emit ONNX inputs into program emit ONNX inputs into program
""" """
...@@ -334,24 +351,30 @@ class Writer(object): ...@@ -334,24 +351,30 @@ class Writer(object):
value_info = value_infos[name] value_info = value_infos[name]
shape = value_info['shape'] shape = value_info['shape']
if remove_batch is None: if remove_batch is None:
remove_batch = value_info.get('remove_batch', True) # HINT: True by default ? remove_batch = value_info.get('remove_batch',
True) # HINT: True by default ?
if remove_batch: if remove_batch:
shape = shape[1:] shape = shape[1:]
prog.Code('# input: {}'.format(name)) prog.Code('# input: {}'.format(name))
prog.Code(('{} = layers.data(name={}, shape={}, dtype={}, ' prog.Code((
'{} = layers.data(name={}, shape={}, dtype={}, '
'append_batch_size={})' # , stop_gradient=True 'append_batch_size={})' # , stop_gradient=True
).format(var_name, repr(name), ).format(
var_name,
repr(name),
shape, shape,
repr(value_info['dtype'].name), repr(value_info['dtype'].name),
remove_batch, remove_batch,
)) ))
prog.OpDesc('feed', prog.OpDesc(
'feed',
(['feed'], 'X'), (['feed'], 'X'),
([var_name], 'Out'), ([var_name], 'Out'),
dict(col=idx), dict(col=idx),
) )
prog.VarDesc(var_name, value_info=value_info, remove_batch=remove_batch) prog.VarDesc(
var_name, value_info=value_info, remove_batch=remove_batch)
@staticmethod @staticmethod
def emit_outputs(prog, names): #, value_infos def emit_outputs(prog, names): #, value_infos
...@@ -364,7 +387,8 @@ class Writer(object): ...@@ -364,7 +387,8 @@ class Writer(object):
var_name = make_var_name(name) var_name = make_var_name(name)
code += var_name + ', ' code += var_name + ', '
prog.OpDesc('fetch', prog.OpDesc(
'fetch',
([var_name], 'X'), ([var_name], 'X'),
(['fetch'], 'Out'), (['fetch'], 'Out'),
dict(col=idx), dict(col=idx),
......
...@@ -2,14 +2,14 @@ ...@@ -2,14 +2,14 @@
# https://setuptools.readthedocs.io/en/latest/setuptools.html#configuring-setup-using-setup-cfg-files # https://setuptools.readthedocs.io/en/latest/setuptools.html#configuring-setup-using-setup-cfg-files
[metadata] [metadata]
# 项目名称,发布、安装时以此作为包名 # 项目名称,发布、安装时以此作为包名
name = onnx2paddle name = onnx2fluid
# 作者姓名和邮箱地址 # 作者姓名和邮箱地址
author = Macrobull author = Macrobull
# author_email = .Github@github.com # author_email = .Github@github.com
# 项目版本号,1.0以上版本才视为正式版 # 项目版本号,1.0以上版本才视为正式版
version = 0.1.0 version = 0.1.0
# 项目概要描述信息,一句话让用户明白项目概要,不支持中文 # 项目概要描述信息,一句话让用户明白项目概要,不支持中文
description = Inference model conversion from ONNX/PyTorch to Paddle description = Inference model conversion from ONNX/PyTorch to Paddle fluid
# 项目的详细描述内容和格式,包括readme和changelog等,通常使用md或rst等格式 # 项目的详细描述内容和格式,包括readme和changelog等,通常使用md或rst等格式
long_description = file: README.md, CHANGELOG.md long_description = file: README.md, CHANGELOG.md
long_description_content_type = text/markdown long_description_content_type = text/markdown
...@@ -25,7 +25,7 @@ classifier = ...@@ -25,7 +25,7 @@ classifier =
Programming Language :: Python :: 3.5 Programming Language :: Python :: 3.5
# 关键字,用于检索,方便用户搜索到你的项目 # 关键字,用于检索,方便用户搜索到你的项目
keywords = keywords =
onnx paddle onnx paddlepaddle
[options] [options]
# 包名称,find:表示自动寻找,可在options.packages.find中进行详细配置 # 包名称,find:表示自动寻找,可在options.packages.find中进行详细配置
...@@ -44,21 +44,21 @@ install_requires = ...@@ -44,21 +44,21 @@ install_requires =
# mock # mock
# 单测代码目录 # 单测代码目录
#test_suite = onnx2paddle.tests #test_suite = onnx2fluid.tests
# 自动添加被版本控制的数据文件 # 自动添加被版本控制的数据文件
include_package_data = True include_package_data = True
# 项目是纯py项目,可以直接执行zip源码包 # 项目是纯py项目,可以直接执行zip源码包
zip_safe = False zip_safe = False
# 可以通过以下配置将指定的函数变成命令行工具,允许用户直接执行 # 可以通过以下配置将指定的函数变成命令行工具,允许用户直接执行
#[options.entry_points] [options.entry_points]
#console_scripts = console_scripts =
# onnx2paddle = onnx2paddle.cmdline:main onnx2fluid = onnx2fluid.cmdline:main
# 可以通过以下配置向包中添加conf或data等非py文件,安装时会一同安装到site-packages目录下 # 可以通过以下配置向包中添加conf或data等非py文件,安装时会一同安装到site-packages目录下
# 仅支持文件,不支持目录,但可以使用通配 # 仅支持文件,不支持目录,但可以使用通配
#[options.package_data] #[options.package_data]
#onnx2paddle = #onnx2fluid =
# conf/* # conf/*
# data/* # data/*
......
...@@ -15,4 +15,3 @@ Date: 2019/02/22 10:25:46 ...@@ -15,4 +15,3 @@ Date: 2019/02/22 10:25:46
import setuptools import setuptools
setuptools.setup() setuptools.setup()
# Generated by the protocol buffer compiler. DO NOT EDIT!
# source: framework.proto
import sys
_b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1'))
from google.protobuf.internal import enum_type_wrapper
from google.protobuf import descriptor as _descriptor
from google.protobuf import message as _message
from google.protobuf import reflection as _reflection
from google.protobuf import symbol_database as _symbol_database
from google.protobuf import descriptor_pb2
# @@protoc_insertion_point(imports)
_sym_db = _symbol_database.Default()
DESCRIPTOR = _descriptor.FileDescriptor(
name='framework.proto',
package='paddle.framework.proto',
syntax='proto2',
serialized_pb=_b('\n\x0f\x66ramework.proto\x12\x16paddle.framework.proto\"\x1d\n\x07Version\x12\x12\n\x07version\x18\x01 \x01(\x03:\x01\x30\"\xec\x03\n\x06OpDesc\x12\x0c\n\x04type\x18\x03 \x02(\t\x12\x32\n\x06inputs\x18\x01 \x03(\x0b\x32\".paddle.framework.proto.OpDesc.Var\x12\x33\n\x07outputs\x18\x02 \x03(\x0b\x32\".paddle.framework.proto.OpDesc.Var\x12\x32\n\x05\x61ttrs\x18\x04 \x03(\x0b\x32#.paddle.framework.proto.OpDesc.Attr\x12\x18\n\tis_target\x18\x05 \x01(\x08:\x05\x66\x61lse\x1a\xef\x01\n\x04\x41ttr\x12\x0c\n\x04name\x18\x01 \x02(\t\x12.\n\x04type\x18\x02 \x02(\x0e\x32 .paddle.framework.proto.AttrType\x12\t\n\x01i\x18\x03 \x01(\x05\x12\t\n\x01\x66\x18\x04 \x01(\x02\x12\t\n\x01s\x18\x05 \x01(\t\x12\x0c\n\x04ints\x18\x06 \x03(\x05\x12\x0e\n\x06\x66loats\x18\x07 \x03(\x02\x12\x0f\n\x07strings\x18\x08 \x03(\t\x12\t\n\x01\x62\x18\n \x01(\x08\x12\r\n\x05\x62ools\x18\x0b \x03(\x08\x12\x11\n\tblock_idx\x18\x0c \x01(\x05\x12\t\n\x01l\x18\r \x01(\x03\x12\x12\n\nblocks_idx\x18\x0e \x03(\x05\x12\r\n\x05longs\x18\x0f \x03(\x03\x1a+\n\x03Var\x12\x11\n\tparameter\x18\x01 \x02(\t\x12\x11\n\targuments\x18\x02 \x03(\t\"\xb3\x03\n\x07OpProto\x12\x0c\n\x04type\x18\x01 \x02(\t\x12\x33\n\x06inputs\x18\x02 \x03(\x0b\x32#.paddle.framework.proto.OpProto.Var\x12\x34\n\x07outputs\x18\x03 \x03(\x0b\x32#.paddle.framework.proto.OpProto.Var\x12\x33\n\x05\x61ttrs\x18\x04 \x03(\x0b\x32$.paddle.framework.proto.OpProto.Attr\x12\x0f\n\x07\x63omment\x18\x05 \x02(\t\x1ax\n\x03Var\x12\x0c\n\x04name\x18\x01 \x02(\t\x12\x0f\n\x07\x63omment\x18\x02 \x02(\t\x12\x19\n\nduplicable\x18\x03 \x01(\x08:\x05\x66\x61lse\x12\x1b\n\x0cintermediate\x18\x04 \x01(\x08:\x05\x66\x61lse\x12\x1a\n\x0b\x64ispensable\x18\x05 \x01(\x08:\x05\x66\x61lse\x1ao\n\x04\x41ttr\x12\x0c\n\x04name\x18\x01 \x02(\t\x12.\n\x04type\x18\x02 \x02(\x0e\x32 .paddle.framework.proto.AttrType\x12\x0f\n\x07\x63omment\x18\x03 \x02(\t\x12\x18\n\tgenerated\x18\x04 \x01(\x08:\x05\x66\x61lse\"\xda\x08\n\x07VarType\x12\x32\n\x04type\x18\x01 \x02(\x0e\x32$.paddle.framework.proto.VarType.Type\x12\x41\n\rselected_rows\x18\x02 \x01(\x0b\x32*.paddle.framework.proto.VarType.TensorDesc\x12\x41\n\nlod_tensor\x18\x03 \x01(\x0b\x32-.paddle.framework.proto.VarType.LoDTensorDesc\x12H\n\x0ctensor_array\x18\x04 \x01(\x0b\x32\x32.paddle.framework.proto.VarType.LoDTensorArrayDesc\x12:\n\x06reader\x18\x05 \x01(\x0b\x32*.paddle.framework.proto.VarType.ReaderDesc\x12\x34\n\x05tuple\x18\x07 \x01(\x0b\x32%.paddle.framework.proto.VarType.Tuple\x1aS\n\nTensorDesc\x12\x37\n\tdata_type\x18\x01 \x02(\x0e\x32$.paddle.framework.proto.VarType.Type\x12\x0c\n\x04\x64ims\x18\x02 \x03(\x03\x1a\x61\n\rLoDTensorDesc\x12:\n\x06tensor\x18\x01 \x02(\x0b\x32*.paddle.framework.proto.VarType.TensorDesc\x12\x14\n\tlod_level\x18\x02 \x01(\x05:\x01\x30\x1a\x66\n\x12LoDTensorArrayDesc\x12:\n\x06tensor\x18\x01 \x02(\x0b\x32*.paddle.framework.proto.VarType.TensorDesc\x12\x14\n\tlod_level\x18\x02 \x01(\x05:\x01\x30\x1aO\n\nReaderDesc\x12\x41\n\nlod_tensor\x18\x01 \x03(\x0b\x32-.paddle.framework.proto.VarType.LoDTensorDesc\x1a\x43\n\x05Tuple\x12:\n\x0c\x65lement_type\x18\x01 \x03(\x0e\x32$.paddle.framework.proto.VarType.Type\"\xa2\x02\n\x04Type\x12\x08\n\x04\x42OOL\x10\x00\x12\t\n\x05INT16\x10\x01\x12\t\n\x05INT32\x10\x02\x12\t\n\x05INT64\x10\x03\x12\x08\n\x04\x46P16\x10\x04\x12\x08\n\x04\x46P32\x10\x05\x12\x08\n\x04\x46P64\x10\x06\x12\n\n\x06SIZE_T\x10\x13\x12\t\n\x05UINT8\x10\x14\x12\x08\n\x04INT8\x10\x15\x12\x0e\n\nLOD_TENSOR\x10\x07\x12\x11\n\rSELECTED_ROWS\x10\x08\x12\x12\n\x0e\x46\x45\x45\x44_MINIBATCH\x10\t\x12\x0e\n\nFETCH_LIST\x10\n\x12\x0f\n\x0bSTEP_SCOPES\x10\x0b\x12\x12\n\x0eLOD_RANK_TABLE\x10\x0c\x12\x14\n\x10LOD_TENSOR_ARRAY\x10\r\x12\x0e\n\nPLACE_LIST\x10\x0e\x12\n\n\x06READER\x10\x0f\x12\x07\n\x03RAW\x10\x11\x12\t\n\x05TUPLE\x10\x12\"b\n\x07VarDesc\x12\x0c\n\x04name\x18\x01 \x02(\t\x12-\n\x04type\x18\x02 \x02(\x0b\x32\x1f.paddle.framework.proto.VarType\x12\x1a\n\x0bpersistable\x18\x03 \x01(\x08:\x05\x66\x61lse\"\xa7\x01\n\tBlockDesc\x12\x0b\n\x03idx\x18\x01 \x02(\x05\x12\x12\n\nparent_idx\x18\x02 \x02(\x05\x12-\n\x04vars\x18\x03 \x03(\x0b\x32\x1f.paddle.framework.proto.VarDesc\x12+\n\x03ops\x18\x04 \x03(\x0b\x32\x1e.paddle.framework.proto.OpDesc\x12\x1d\n\x11\x66orward_block_idx\x18\x05 \x01(\x05:\x02-1\"r\n\x0bProgramDesc\x12\x31\n\x06\x62locks\x18\x01 \x03(\x0b\x32!.paddle.framework.proto.BlockDesc\x12\x30\n\x07version\x18\x02 \x01(\x0b\x32\x1f.paddle.framework.proto.Version*\x94\x01\n\x08\x41ttrType\x12\x07\n\x03INT\x10\x00\x12\t\n\x05\x46LOAT\x10\x01\x12\n\n\x06STRING\x10\x02\x12\x08\n\x04INTS\x10\x03\x12\n\n\x06\x46LOATS\x10\x04\x12\x0b\n\x07STRINGS\x10\x05\x12\x0b\n\x07\x42OOLEAN\x10\x06\x12\x0c\n\x08\x42OOLEANS\x10\x07\x12\t\n\x05\x42LOCK\x10\x08\x12\x08\n\x04LONG\x10\t\x12\n\n\x06\x42LOCKS\x10\n\x12\t\n\x05LONGS\x10\x0b\x42\x02H\x03')
)
_sym_db.RegisterFileDescriptor(DESCRIPTOR)
_ATTRTYPE = _descriptor.EnumDescriptor(
name='AttrType',
full_name='paddle.framework.proto.AttrType',
filename=None,
file=DESCRIPTOR,
values=[
_descriptor.EnumValueDescriptor(
name='INT', index=0, number=0,
options=None,
type=None),
_descriptor.EnumValueDescriptor(
name='FLOAT', index=1, number=1,
options=None,
type=None),
_descriptor.EnumValueDescriptor(
name='STRING', index=2, number=2,
options=None,
type=None),
_descriptor.EnumValueDescriptor(
name='INTS', index=3, number=3,
options=None,
type=None),
_descriptor.EnumValueDescriptor(
name='FLOATS', index=4, number=4,
options=None,
type=None),
_descriptor.EnumValueDescriptor(
name='STRINGS', index=5, number=5,
options=None,
type=None),
_descriptor.EnumValueDescriptor(
name='BOOLEAN', index=6, number=6,
options=None,
type=None),
_descriptor.EnumValueDescriptor(
name='BOOLEANS', index=7, number=7,
options=None,
type=None),
_descriptor.EnumValueDescriptor(
name='BLOCK', index=8, number=8,
options=None,
type=None),
_descriptor.EnumValueDescriptor(
name='LONG', index=9, number=9,
options=None,
type=None),
_descriptor.EnumValueDescriptor(
name='BLOCKS', index=10, number=10,
options=None,
type=None),
_descriptor.EnumValueDescriptor(
name='LONGS', index=11, number=11,
options=None,
type=None),
],
containing_type=None,
options=None,
serialized_start=2511,
serialized_end=2659,
)
_sym_db.RegisterEnumDescriptor(_ATTRTYPE)
AttrType = enum_type_wrapper.EnumTypeWrapper(_ATTRTYPE)
INT = 0
FLOAT = 1
STRING = 2
INTS = 3
FLOATS = 4
STRINGS = 5
BOOLEAN = 6
BOOLEANS = 7
BLOCK = 8
LONG = 9
BLOCKS = 10
LONGS = 11
_VARTYPE_TYPE = _descriptor.EnumDescriptor(
name='Type',
full_name='paddle.framework.proto.VarType.Type',
filename=None,
file=DESCRIPTOR,
values=[
_descriptor.EnumValueDescriptor(
name='BOOL', index=0, number=0,
options=None,
type=None),
_descriptor.EnumValueDescriptor(
name='INT16', index=1, number=1,
options=None,
type=None),
_descriptor.EnumValueDescriptor(
name='INT32', index=2, number=2,
options=None,
type=None),
_descriptor.EnumValueDescriptor(
name='INT64', index=3, number=3,
options=None,
type=None),
_descriptor.EnumValueDescriptor(
name='FP16', index=4, number=4,
options=None,
type=None),
_descriptor.EnumValueDescriptor(
name='FP32', index=5, number=5,
options=None,
type=None),
_descriptor.EnumValueDescriptor(
name='FP64', index=6, number=6,
options=None,
type=None),
_descriptor.EnumValueDescriptor(
name='SIZE_T', index=7, number=19,
options=None,
type=None),
_descriptor.EnumValueDescriptor(
name='UINT8', index=8, number=20,
options=None,
type=None),
_descriptor.EnumValueDescriptor(
name='INT8', index=9, number=21,
options=None,
type=None),
_descriptor.EnumValueDescriptor(
name='LOD_TENSOR', index=10, number=7,
options=None,
type=None),
_descriptor.EnumValueDescriptor(
name='SELECTED_ROWS', index=11, number=8,
options=None,
type=None),
_descriptor.EnumValueDescriptor(
name='FEED_MINIBATCH', index=12, number=9,
options=None,
type=None),
_descriptor.EnumValueDescriptor(
name='FETCH_LIST', index=13, number=10,
options=None,
type=None),
_descriptor.EnumValueDescriptor(
name='STEP_SCOPES', index=14, number=11,
options=None,
type=None),
_descriptor.EnumValueDescriptor(
name='LOD_RANK_TABLE', index=15, number=12,
options=None,
type=None),
_descriptor.EnumValueDescriptor(
name='LOD_TENSOR_ARRAY', index=16, number=13,
options=None,
type=None),
_descriptor.EnumValueDescriptor(
name='PLACE_LIST', index=17, number=14,
options=None,
type=None),
_descriptor.EnumValueDescriptor(
name='READER', index=18, number=15,
options=None,
type=None),
_descriptor.EnumValueDescriptor(
name='RAW', index=19, number=17,
options=None,
type=None),
_descriptor.EnumValueDescriptor(
name='TUPLE', index=20, number=18,
options=None,
type=None),
],
containing_type=None,
options=None,
serialized_start=1832,
serialized_end=2122,
)
_sym_db.RegisterEnumDescriptor(_VARTYPE_TYPE)
_VERSION = _descriptor.Descriptor(
name='Version',
full_name='paddle.framework.proto.Version',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='version', full_name='paddle.framework.proto.Version.version', index=0,
number=1, type=3, cpp_type=2, label=1,
has_default_value=True, default_value=0,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
],
extensions=[
],
nested_types=[],
enum_types=[
],
options=None,
is_extendable=False,
syntax='proto2',
extension_ranges=[],
oneofs=[
],
serialized_start=43,
serialized_end=72,
)
_OPDESC_ATTR = _descriptor.Descriptor(
name='Attr',
full_name='paddle.framework.proto.OpDesc.Attr',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='name', full_name='paddle.framework.proto.OpDesc.Attr.name', index=0,
number=1, type=9, cpp_type=9, label=2,
has_default_value=False, default_value=_b("").decode('utf-8'),
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='type', full_name='paddle.framework.proto.OpDesc.Attr.type', index=1,
number=2, type=14, cpp_type=8, label=2,
has_default_value=False, default_value=0,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='i', full_name='paddle.framework.proto.OpDesc.Attr.i', index=2,
number=3, type=5, cpp_type=1, label=1,
has_default_value=False, default_value=0,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='f', full_name='paddle.framework.proto.OpDesc.Attr.f', index=3,
number=4, type=2, cpp_type=6, label=1,
has_default_value=False, default_value=float(0),
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='s', full_name='paddle.framework.proto.OpDesc.Attr.s', index=4,
number=5, type=9, cpp_type=9, label=1,
has_default_value=False, default_value=_b("").decode('utf-8'),
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='ints', full_name='paddle.framework.proto.OpDesc.Attr.ints', index=5,
number=6, type=5, cpp_type=1, label=3,
has_default_value=False, default_value=[],
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='floats', full_name='paddle.framework.proto.OpDesc.Attr.floats', index=6,
number=7, type=2, cpp_type=6, label=3,
has_default_value=False, default_value=[],
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='strings', full_name='paddle.framework.proto.OpDesc.Attr.strings', index=7,
number=8, type=9, cpp_type=9, label=3,
has_default_value=False, default_value=[],
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='b', full_name='paddle.framework.proto.OpDesc.Attr.b', index=8,
number=10, type=8, cpp_type=7, label=1,
has_default_value=False, default_value=False,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='bools', full_name='paddle.framework.proto.OpDesc.Attr.bools', index=9,
number=11, type=8, cpp_type=7, label=3,
has_default_value=False, default_value=[],
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='block_idx', full_name='paddle.framework.proto.OpDesc.Attr.block_idx', index=10,
number=12, type=5, cpp_type=1, label=1,
has_default_value=False, default_value=0,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='l', full_name='paddle.framework.proto.OpDesc.Attr.l', index=11,
number=13, type=3, cpp_type=2, label=1,
has_default_value=False, default_value=0,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='blocks_idx', full_name='paddle.framework.proto.OpDesc.Attr.blocks_idx', index=12,
number=14, type=5, cpp_type=1, label=3,
has_default_value=False, default_value=[],
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='longs', full_name='paddle.framework.proto.OpDesc.Attr.longs', index=13,
number=15, type=3, cpp_type=2, label=3,
has_default_value=False, default_value=[],
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
],
extensions=[
],
nested_types=[],
enum_types=[
],
options=None,
is_extendable=False,
syntax='proto2',
extension_ranges=[],
oneofs=[
],
serialized_start=283,
serialized_end=522,
)
_OPDESC_VAR = _descriptor.Descriptor(
name='Var',
full_name='paddle.framework.proto.OpDesc.Var',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='parameter', full_name='paddle.framework.proto.OpDesc.Var.parameter', index=0,
number=1, type=9, cpp_type=9, label=2,
has_default_value=False, default_value=_b("").decode('utf-8'),
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='arguments', full_name='paddle.framework.proto.OpDesc.Var.arguments', index=1,
number=2, type=9, cpp_type=9, label=3,
has_default_value=False, default_value=[],
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
],
extensions=[
],
nested_types=[],
enum_types=[
],
options=None,
is_extendable=False,
syntax='proto2',
extension_ranges=[],
oneofs=[
],
serialized_start=524,
serialized_end=567,
)
_OPDESC = _descriptor.Descriptor(
name='OpDesc',
full_name='paddle.framework.proto.OpDesc',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='type', full_name='paddle.framework.proto.OpDesc.type', index=0,
number=3, type=9, cpp_type=9, label=2,
has_default_value=False, default_value=_b("").decode('utf-8'),
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='inputs', full_name='paddle.framework.proto.OpDesc.inputs', index=1,
number=1, type=11, cpp_type=10, label=3,
has_default_value=False, default_value=[],
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='outputs', full_name='paddle.framework.proto.OpDesc.outputs', index=2,
number=2, type=11, cpp_type=10, label=3,
has_default_value=False, default_value=[],
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='attrs', full_name='paddle.framework.proto.OpDesc.attrs', index=3,
number=4, type=11, cpp_type=10, label=3,
has_default_value=False, default_value=[],
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='is_target', full_name='paddle.framework.proto.OpDesc.is_target', index=4,
number=5, type=8, cpp_type=7, label=1,
has_default_value=True, default_value=False,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
],
extensions=[
],
nested_types=[_OPDESC_ATTR, _OPDESC_VAR, ],
enum_types=[
],
options=None,
is_extendable=False,
syntax='proto2',
extension_ranges=[],
oneofs=[
],
serialized_start=75,
serialized_end=567,
)
_OPPROTO_VAR = _descriptor.Descriptor(
name='Var',
full_name='paddle.framework.proto.OpProto.Var',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='name', full_name='paddle.framework.proto.OpProto.Var.name', index=0,
number=1, type=9, cpp_type=9, label=2,
has_default_value=False, default_value=_b("").decode('utf-8'),
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='comment', full_name='paddle.framework.proto.OpProto.Var.comment', index=1,
number=2, type=9, cpp_type=9, label=2,
has_default_value=False, default_value=_b("").decode('utf-8'),
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='duplicable', full_name='paddle.framework.proto.OpProto.Var.duplicable', index=2,
number=3, type=8, cpp_type=7, label=1,
has_default_value=True, default_value=False,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='intermediate', full_name='paddle.framework.proto.OpProto.Var.intermediate', index=3,
number=4, type=8, cpp_type=7, label=1,
has_default_value=True, default_value=False,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='dispensable', full_name='paddle.framework.proto.OpProto.Var.dispensable', index=4,
number=5, type=8, cpp_type=7, label=1,
has_default_value=True, default_value=False,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
],
extensions=[
],
nested_types=[],
enum_types=[
],
options=None,
is_extendable=False,
syntax='proto2',
extension_ranges=[],
oneofs=[
],
serialized_start=772,
serialized_end=892,
)
_OPPROTO_ATTR = _descriptor.Descriptor(
name='Attr',
full_name='paddle.framework.proto.OpProto.Attr',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='name', full_name='paddle.framework.proto.OpProto.Attr.name', index=0,
number=1, type=9, cpp_type=9, label=2,
has_default_value=False, default_value=_b("").decode('utf-8'),
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='type', full_name='paddle.framework.proto.OpProto.Attr.type', index=1,
number=2, type=14, cpp_type=8, label=2,
has_default_value=False, default_value=0,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='comment', full_name='paddle.framework.proto.OpProto.Attr.comment', index=2,
number=3, type=9, cpp_type=9, label=2,
has_default_value=False, default_value=_b("").decode('utf-8'),
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='generated', full_name='paddle.framework.proto.OpProto.Attr.generated', index=3,
number=4, type=8, cpp_type=7, label=1,
has_default_value=True, default_value=False,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
],
extensions=[
],
nested_types=[],
enum_types=[
],
options=None,
is_extendable=False,
syntax='proto2',
extension_ranges=[],
oneofs=[
],
serialized_start=894,
serialized_end=1005,
)
_OPPROTO = _descriptor.Descriptor(
name='OpProto',
full_name='paddle.framework.proto.OpProto',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='type', full_name='paddle.framework.proto.OpProto.type', index=0,
number=1, type=9, cpp_type=9, label=2,
has_default_value=False, default_value=_b("").decode('utf-8'),
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='inputs', full_name='paddle.framework.proto.OpProto.inputs', index=1,
number=2, type=11, cpp_type=10, label=3,
has_default_value=False, default_value=[],
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='outputs', full_name='paddle.framework.proto.OpProto.outputs', index=2,
number=3, type=11, cpp_type=10, label=3,
has_default_value=False, default_value=[],
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='attrs', full_name='paddle.framework.proto.OpProto.attrs', index=3,
number=4, type=11, cpp_type=10, label=3,
has_default_value=False, default_value=[],
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='comment', full_name='paddle.framework.proto.OpProto.comment', index=4,
number=5, type=9, cpp_type=9, label=2,
has_default_value=False, default_value=_b("").decode('utf-8'),
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
],
extensions=[
],
nested_types=[_OPPROTO_VAR, _OPPROTO_ATTR, ],
enum_types=[
],
options=None,
is_extendable=False,
syntax='proto2',
extension_ranges=[],
oneofs=[
],
serialized_start=570,
serialized_end=1005,
)
_VARTYPE_TENSORDESC = _descriptor.Descriptor(
name='TensorDesc',
full_name='paddle.framework.proto.VarType.TensorDesc',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='data_type', full_name='paddle.framework.proto.VarType.TensorDesc.data_type', index=0,
number=1, type=14, cpp_type=8, label=2,
has_default_value=False, default_value=0,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='dims', full_name='paddle.framework.proto.VarType.TensorDesc.dims', index=1,
number=2, type=3, cpp_type=2, label=3,
has_default_value=False, default_value=[],
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
],
extensions=[
],
nested_types=[],
enum_types=[
],
options=None,
is_extendable=False,
syntax='proto2',
extension_ranges=[],
oneofs=[
],
serialized_start=1393,
serialized_end=1476,
)
_VARTYPE_LODTENSORDESC = _descriptor.Descriptor(
name='LoDTensorDesc',
full_name='paddle.framework.proto.VarType.LoDTensorDesc',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='tensor', full_name='paddle.framework.proto.VarType.LoDTensorDesc.tensor', index=0,
number=1, type=11, cpp_type=10, label=2,
has_default_value=False, default_value=None,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='lod_level', full_name='paddle.framework.proto.VarType.LoDTensorDesc.lod_level', index=1,
number=2, type=5, cpp_type=1, label=1,
has_default_value=True, default_value=0,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
],
extensions=[
],
nested_types=[],
enum_types=[
],
options=None,
is_extendable=False,
syntax='proto2',
extension_ranges=[],
oneofs=[
],
serialized_start=1478,
serialized_end=1575,
)
_VARTYPE_LODTENSORARRAYDESC = _descriptor.Descriptor(
name='LoDTensorArrayDesc',
full_name='paddle.framework.proto.VarType.LoDTensorArrayDesc',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='tensor', full_name='paddle.framework.proto.VarType.LoDTensorArrayDesc.tensor', index=0,
number=1, type=11, cpp_type=10, label=2,
has_default_value=False, default_value=None,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='lod_level', full_name='paddle.framework.proto.VarType.LoDTensorArrayDesc.lod_level', index=1,
number=2, type=5, cpp_type=1, label=1,
has_default_value=True, default_value=0,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
],
extensions=[
],
nested_types=[],
enum_types=[
],
options=None,
is_extendable=False,
syntax='proto2',
extension_ranges=[],
oneofs=[
],
serialized_start=1577,
serialized_end=1679,
)
_VARTYPE_READERDESC = _descriptor.Descriptor(
name='ReaderDesc',
full_name='paddle.framework.proto.VarType.ReaderDesc',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='lod_tensor', full_name='paddle.framework.proto.VarType.ReaderDesc.lod_tensor', index=0,
number=1, type=11, cpp_type=10, label=3,
has_default_value=False, default_value=[],
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
],
extensions=[
],
nested_types=[],
enum_types=[
],
options=None,
is_extendable=False,
syntax='proto2',
extension_ranges=[],
oneofs=[
],
serialized_start=1681,
serialized_end=1760,
)
_VARTYPE_TUPLE = _descriptor.Descriptor(
name='Tuple',
full_name='paddle.framework.proto.VarType.Tuple',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='element_type', full_name='paddle.framework.proto.VarType.Tuple.element_type', index=0,
number=1, type=14, cpp_type=8, label=3,
has_default_value=False, default_value=[],
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
],
extensions=[
],
nested_types=[],
enum_types=[
],
options=None,
is_extendable=False,
syntax='proto2',
extension_ranges=[],
oneofs=[
],
serialized_start=1762,
serialized_end=1829,
)
_VARTYPE = _descriptor.Descriptor(
name='VarType',
full_name='paddle.framework.proto.VarType',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='type', full_name='paddle.framework.proto.VarType.type', index=0,
number=1, type=14, cpp_type=8, label=2,
has_default_value=False, default_value=0,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='selected_rows', full_name='paddle.framework.proto.VarType.selected_rows', index=1,
number=2, type=11, cpp_type=10, label=1,
has_default_value=False, default_value=None,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='lod_tensor', full_name='paddle.framework.proto.VarType.lod_tensor', index=2,
number=3, type=11, cpp_type=10, label=1,
has_default_value=False, default_value=None,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='tensor_array', full_name='paddle.framework.proto.VarType.tensor_array', index=3,
number=4, type=11, cpp_type=10, label=1,
has_default_value=False, default_value=None,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='reader', full_name='paddle.framework.proto.VarType.reader', index=4,
number=5, type=11, cpp_type=10, label=1,
has_default_value=False, default_value=None,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='tuple', full_name='paddle.framework.proto.VarType.tuple', index=5,
number=7, type=11, cpp_type=10, label=1,
has_default_value=False, default_value=None,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
],
extensions=[
],
nested_types=[_VARTYPE_TENSORDESC, _VARTYPE_LODTENSORDESC, _VARTYPE_LODTENSORARRAYDESC, _VARTYPE_READERDESC, _VARTYPE_TUPLE, ],
enum_types=[
_VARTYPE_TYPE,
],
options=None,
is_extendable=False,
syntax='proto2',
extension_ranges=[],
oneofs=[
],
serialized_start=1008,
serialized_end=2122,
)
_VARDESC = _descriptor.Descriptor(
name='VarDesc',
full_name='paddle.framework.proto.VarDesc',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='name', full_name='paddle.framework.proto.VarDesc.name', index=0,
number=1, type=9, cpp_type=9, label=2,
has_default_value=False, default_value=_b("").decode('utf-8'),
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='type', full_name='paddle.framework.proto.VarDesc.type', index=1,
number=2, type=11, cpp_type=10, label=2,
has_default_value=False, default_value=None,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='persistable', full_name='paddle.framework.proto.VarDesc.persistable', index=2,
number=3, type=8, cpp_type=7, label=1,
has_default_value=True, default_value=False,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
],
extensions=[
],
nested_types=[],
enum_types=[
],
options=None,
is_extendable=False,
syntax='proto2',
extension_ranges=[],
oneofs=[
],
serialized_start=2124,
serialized_end=2222,
)
_BLOCKDESC = _descriptor.Descriptor(
name='BlockDesc',
full_name='paddle.framework.proto.BlockDesc',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='idx', full_name='paddle.framework.proto.BlockDesc.idx', index=0,
number=1, type=5, cpp_type=1, label=2,
has_default_value=False, default_value=0,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='parent_idx', full_name='paddle.framework.proto.BlockDesc.parent_idx', index=1,
number=2, type=5, cpp_type=1, label=2,
has_default_value=False, default_value=0,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='vars', full_name='paddle.framework.proto.BlockDesc.vars', index=2,
number=3, type=11, cpp_type=10, label=3,
has_default_value=False, default_value=[],
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='ops', full_name='paddle.framework.proto.BlockDesc.ops', index=3,
number=4, type=11, cpp_type=10, label=3,
has_default_value=False, default_value=[],
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='forward_block_idx', full_name='paddle.framework.proto.BlockDesc.forward_block_idx', index=4,
number=5, type=5, cpp_type=1, label=1,
has_default_value=True, default_value=-1,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
],
extensions=[
],
nested_types=[],
enum_types=[
],
options=None,
is_extendable=False,
syntax='proto2',
extension_ranges=[],
oneofs=[
],
serialized_start=2225,
serialized_end=2392,
)
_PROGRAMDESC = _descriptor.Descriptor(
name='ProgramDesc',
full_name='paddle.framework.proto.ProgramDesc',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='blocks', full_name='paddle.framework.proto.ProgramDesc.blocks', index=0,
number=1, type=11, cpp_type=10, label=3,
has_default_value=False, default_value=[],
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='version', full_name='paddle.framework.proto.ProgramDesc.version', index=1,
number=2, type=11, cpp_type=10, label=1,
has_default_value=False, default_value=None,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
],
extensions=[
],
nested_types=[],
enum_types=[
],
options=None,
is_extendable=False,
syntax='proto2',
extension_ranges=[],
oneofs=[
],
serialized_start=2394,
serialized_end=2508,
)
_OPDESC_ATTR.fields_by_name['type'].enum_type = _ATTRTYPE
_OPDESC_ATTR.containing_type = _OPDESC
_OPDESC_VAR.containing_type = _OPDESC
_OPDESC.fields_by_name['inputs'].message_type = _OPDESC_VAR
_OPDESC.fields_by_name['outputs'].message_type = _OPDESC_VAR
_OPDESC.fields_by_name['attrs'].message_type = _OPDESC_ATTR
_OPPROTO_VAR.containing_type = _OPPROTO
_OPPROTO_ATTR.fields_by_name['type'].enum_type = _ATTRTYPE
_OPPROTO_ATTR.containing_type = _OPPROTO
_OPPROTO.fields_by_name['inputs'].message_type = _OPPROTO_VAR
_OPPROTO.fields_by_name['outputs'].message_type = _OPPROTO_VAR
_OPPROTO.fields_by_name['attrs'].message_type = _OPPROTO_ATTR
_VARTYPE_TENSORDESC.fields_by_name['data_type'].enum_type = _VARTYPE_TYPE
_VARTYPE_TENSORDESC.containing_type = _VARTYPE
_VARTYPE_LODTENSORDESC.fields_by_name['tensor'].message_type = _VARTYPE_TENSORDESC
_VARTYPE_LODTENSORDESC.containing_type = _VARTYPE
_VARTYPE_LODTENSORARRAYDESC.fields_by_name['tensor'].message_type = _VARTYPE_TENSORDESC
_VARTYPE_LODTENSORARRAYDESC.containing_type = _VARTYPE
_VARTYPE_READERDESC.fields_by_name['lod_tensor'].message_type = _VARTYPE_LODTENSORDESC
_VARTYPE_READERDESC.containing_type = _VARTYPE
_VARTYPE_TUPLE.fields_by_name['element_type'].enum_type = _VARTYPE_TYPE
_VARTYPE_TUPLE.containing_type = _VARTYPE
_VARTYPE.fields_by_name['type'].enum_type = _VARTYPE_TYPE
_VARTYPE.fields_by_name['selected_rows'].message_type = _VARTYPE_TENSORDESC
_VARTYPE.fields_by_name['lod_tensor'].message_type = _VARTYPE_LODTENSORDESC
_VARTYPE.fields_by_name['tensor_array'].message_type = _VARTYPE_LODTENSORARRAYDESC
_VARTYPE.fields_by_name['reader'].message_type = _VARTYPE_READERDESC
_VARTYPE.fields_by_name['tuple'].message_type = _VARTYPE_TUPLE
_VARTYPE_TYPE.containing_type = _VARTYPE
_VARDESC.fields_by_name['type'].message_type = _VARTYPE
_BLOCKDESC.fields_by_name['vars'].message_type = _VARDESC
_BLOCKDESC.fields_by_name['ops'].message_type = _OPDESC
_PROGRAMDESC.fields_by_name['blocks'].message_type = _BLOCKDESC
_PROGRAMDESC.fields_by_name['version'].message_type = _VERSION
DESCRIPTOR.message_types_by_name['Version'] = _VERSION
DESCRIPTOR.message_types_by_name['OpDesc'] = _OPDESC
DESCRIPTOR.message_types_by_name['OpProto'] = _OPPROTO
DESCRIPTOR.message_types_by_name['VarType'] = _VARTYPE
DESCRIPTOR.message_types_by_name['VarDesc'] = _VARDESC
DESCRIPTOR.message_types_by_name['BlockDesc'] = _BLOCKDESC
DESCRIPTOR.message_types_by_name['ProgramDesc'] = _PROGRAMDESC
DESCRIPTOR.enum_types_by_name['AttrType'] = _ATTRTYPE
Version = _reflection.GeneratedProtocolMessageType('Version', (_message.Message,), dict(
DESCRIPTOR = _VERSION,
__module__ = 'framework_pb2'
# @@protoc_insertion_point(class_scope:paddle.framework.proto.Version)
))
_sym_db.RegisterMessage(Version)
OpDesc = _reflection.GeneratedProtocolMessageType('OpDesc', (_message.Message,), dict(
Attr = _reflection.GeneratedProtocolMessageType('Attr', (_message.Message,), dict(
DESCRIPTOR = _OPDESC_ATTR,
__module__ = 'framework_pb2'
# @@protoc_insertion_point(class_scope:paddle.framework.proto.OpDesc.Attr)
))
,
Var = _reflection.GeneratedProtocolMessageType('Var', (_message.Message,), dict(
DESCRIPTOR = _OPDESC_VAR,
__module__ = 'framework_pb2'
# @@protoc_insertion_point(class_scope:paddle.framework.proto.OpDesc.Var)
))
,
DESCRIPTOR = _OPDESC,
__module__ = 'framework_pb2'
# @@protoc_insertion_point(class_scope:paddle.framework.proto.OpDesc)
))
_sym_db.RegisterMessage(OpDesc)
_sym_db.RegisterMessage(OpDesc.Attr)
_sym_db.RegisterMessage(OpDesc.Var)
OpProto = _reflection.GeneratedProtocolMessageType('OpProto', (_message.Message,), dict(
Var = _reflection.GeneratedProtocolMessageType('Var', (_message.Message,), dict(
DESCRIPTOR = _OPPROTO_VAR,
__module__ = 'framework_pb2'
# @@protoc_insertion_point(class_scope:paddle.framework.proto.OpProto.Var)
))
,
Attr = _reflection.GeneratedProtocolMessageType('Attr', (_message.Message,), dict(
DESCRIPTOR = _OPPROTO_ATTR,
__module__ = 'framework_pb2'
# @@protoc_insertion_point(class_scope:paddle.framework.proto.OpProto.Attr)
))
,
DESCRIPTOR = _OPPROTO,
__module__ = 'framework_pb2'
# @@protoc_insertion_point(class_scope:paddle.framework.proto.OpProto)
))
_sym_db.RegisterMessage(OpProto)
_sym_db.RegisterMessage(OpProto.Var)
_sym_db.RegisterMessage(OpProto.Attr)
VarType = _reflection.GeneratedProtocolMessageType('VarType', (_message.Message,), dict(
TensorDesc = _reflection.GeneratedProtocolMessageType('TensorDesc', (_message.Message,), dict(
DESCRIPTOR = _VARTYPE_TENSORDESC,
__module__ = 'framework_pb2'
# @@protoc_insertion_point(class_scope:paddle.framework.proto.VarType.TensorDesc)
))
,
LoDTensorDesc = _reflection.GeneratedProtocolMessageType('LoDTensorDesc', (_message.Message,), dict(
DESCRIPTOR = _VARTYPE_LODTENSORDESC,
__module__ = 'framework_pb2'
# @@protoc_insertion_point(class_scope:paddle.framework.proto.VarType.LoDTensorDesc)
))
,
LoDTensorArrayDesc = _reflection.GeneratedProtocolMessageType('LoDTensorArrayDesc', (_message.Message,), dict(
DESCRIPTOR = _VARTYPE_LODTENSORARRAYDESC,
__module__ = 'framework_pb2'
# @@protoc_insertion_point(class_scope:paddle.framework.proto.VarType.LoDTensorArrayDesc)
))
,
ReaderDesc = _reflection.GeneratedProtocolMessageType('ReaderDesc', (_message.Message,), dict(
DESCRIPTOR = _VARTYPE_READERDESC,
__module__ = 'framework_pb2'
# @@protoc_insertion_point(class_scope:paddle.framework.proto.VarType.ReaderDesc)
))
,
Tuple = _reflection.GeneratedProtocolMessageType('Tuple', (_message.Message,), dict(
DESCRIPTOR = _VARTYPE_TUPLE,
__module__ = 'framework_pb2'
# @@protoc_insertion_point(class_scope:paddle.framework.proto.VarType.Tuple)
))
,
DESCRIPTOR = _VARTYPE,
__module__ = 'framework_pb2'
# @@protoc_insertion_point(class_scope:paddle.framework.proto.VarType)
))
_sym_db.RegisterMessage(VarType)
_sym_db.RegisterMessage(VarType.TensorDesc)
_sym_db.RegisterMessage(VarType.LoDTensorDesc)
_sym_db.RegisterMessage(VarType.LoDTensorArrayDesc)
_sym_db.RegisterMessage(VarType.ReaderDesc)
_sym_db.RegisterMessage(VarType.Tuple)
VarDesc = _reflection.GeneratedProtocolMessageType('VarDesc', (_message.Message,), dict(
DESCRIPTOR = _VARDESC,
__module__ = 'framework_pb2'
# @@protoc_insertion_point(class_scope:paddle.framework.proto.VarDesc)
))
_sym_db.RegisterMessage(VarDesc)
BlockDesc = _reflection.GeneratedProtocolMessageType('BlockDesc', (_message.Message,), dict(
DESCRIPTOR = _BLOCKDESC,
__module__ = 'framework_pb2'
# @@protoc_insertion_point(class_scope:paddle.framework.proto.BlockDesc)
))
_sym_db.RegisterMessage(BlockDesc)
ProgramDesc = _reflection.GeneratedProtocolMessageType('ProgramDesc', (_message.Message,), dict(
DESCRIPTOR = _PROGRAMDESC,
__module__ = 'framework_pb2'
# @@protoc_insertion_point(class_scope:paddle.framework.proto.ProgramDesc)
))
_sym_db.RegisterMessage(ProgramDesc)
DESCRIPTOR.has_options = True
DESCRIPTOR._options = _descriptor._ParseOptions(descriptor_pb2.FileOptions(), _b('H\003'))
# @@protoc_insertion_point(module_scope)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册