提交 f0dede1f 编写于 作者: M Macrobull

rename onnx2paddle to onnx2fluid

上级 b483d12e
......@@ -57,3 +57,4 @@ coverage.xml
/examples/*.aria2
/examples/*.onnx
/examples/*.np?
**/.*
Onnx2paddle
Onnx2Fluid
===
Inference model conversion from ONNX/PyTorch to Paddle
Inference model conversion from ONNX/PyTorch to Paddle fluid
快速开始
---
......
......@@ -22,4 +22,4 @@ output_data = data['outputs']
inputs = Dict(zip(input_names, [input_data]))
outputs = Dict(zip(output_name, [output_data]))
np.savez(fn, inputs=inputs, outputs=outputs) # overwrite
np.savez(fn, inputs=inputs, outputs=outputs) # overwrite
......@@ -6,7 +6,7 @@ Created on Fri Mar 22 11:19:45 2019
@author: Macrobull
Not all ops in this file are supported by both Pytorch and ONNX
This only demostrates the conversion/validation workflow from Pytorch to ONNX to Paddle
This only demostrates the conversion/validation workflow from Pytorch to ONNX to Paddle fluid
"""
......@@ -16,12 +16,10 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
from onnx2paddle.torch_export_helper import export_onnx_with_validation
from onnx2fluid.torch_export_helper import export_onnx_with_validation
idx = 0
######### example: RNN ########
#
#class Model(nn.Module):
......@@ -44,7 +42,6 @@ idx = 0
# ['x'], ['y'],
# verbose=True, training=False)
######### example: random ########
#
#class Model(nn.Module):
......@@ -66,9 +63,9 @@ idx = 0
# ['x'], ['y'],
# verbose=True, training=False)
######## example: fc ########
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
......@@ -85,13 +82,12 @@ xb = torch.rand((2, 3))
yp = model(xb)
idx += 1
print('index: ', idx)
export_onnx_with_validation(model, (xb, ), 't' + str(idx),
['x'], ['y'],
verbose=True, training=False)
export_onnx_with_validation(
model, (xb, ), 't' + str(idx), ['x'], ['y'], verbose=True, training=False)
######## example: compare ########
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
......@@ -110,12 +106,15 @@ xb1 = torch.rand((2, 3))
ya, yb, yc = model(xb0, xb1)
idx += 1
print('index: ', idx)
export_onnx_with_validation(model, (xb0, xb1), 't' + str(idx),
['x0', 'x1'], ['ya', 'yb', 'yc'],
verbose=True, training=False)
export_onnx_with_validation(
model, (xb0, xb1),
't' + str(idx), ['x0', 'x1'], ['ya', 'yb', 'yc'],
verbose=True,
training=False)
######## example: affine_grid ########
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
......@@ -130,13 +129,15 @@ theta = torch.rand((2, 2, 3))
grid = model(theta)
idx += 1
print('index: ', idx)
export_onnx_with_validation(model, (theta, ), 't' + str(idx),
['theta'], ['grid'],
verbose=True, training=False)
export_onnx_with_validation(
model, (theta, ),
't' + str(idx), ['theta'], ['grid'],
verbose=True,
training=False)
######## example: conv2d_transpose ########
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
......@@ -155,12 +156,12 @@ xb = torch.rand((2, 3, 4, 5))
yp = model(xb)
idx += 1
print('index: ', idx)
export_onnx_with_validation(model, (xb, ), 't' + str(idx),
['x'], ['y'],
verbose=True, training=False)
export_onnx_with_validation(
model, (xb, ), 't' + str(idx), ['x'], ['y'], verbose=True, training=False)
######## example: conv2d ########
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
......@@ -181,10 +182,8 @@ xb = torch.rand((2, 3, 4, 5))
yp = model(xb)
idx += 1
print('index: ', idx)
export_onnx_with_validation(model, (xb, ), 't' + str(idx),
['x'], ['y'],
verbose=True, training=False)
export_onnx_with_validation(
model, (xb, ), 't' + str(idx), ['x'], ['y'], verbose=True, training=False)
######### example: conv1d ########
#
......@@ -210,6 +209,7 @@ export_onnx_with_validation(model, (xb, ), 't' + str(idx),
######## example: empty ########
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
......@@ -223,6 +223,5 @@ xb = torch.rand((2, 3))
yp = model(xb)
idx += 1
print('index: ', idx)
export_onnx_with_validation(model, (xb, ), 't' + str(idx),
['y'], ['y'],
verbose=True, training=False)
export_onnx_with_validation(
model, (xb, ), 't' + str(idx), ['y'], ['y'], verbose=True, training=False)
#! /usr/bin/env sh
get_url="proxychains4 aria2c -c -s8 -x8"
get_url="aria2c -c -s8 -x8"
base_url="https://s3.amazonaws.com/download.onnx/models/opset_9/"
flags="-de -o /tmp/export/"
flags="-e -o /tmp/export/"
bvlc_alexnet()
{
......@@ -18,13 +18,13 @@ bvlc_alexnet()
do
echo "converting $npz ..."
python convert_data_npz_0.py "$npz" "data_0" "prob_1"
python -m onnx2paddle $flags "$fn_model" -t $npz
python -m onnx2fluid $flags "$fn_model" -t $npz
done
for pb_dir in $bn_tar/*/
do
echo "converting $pb_dir ..."
python convert_data_pb_0.py "$pb_dir" "data_0" "prob_1"
python -m onnx2paddle $flags "$fn_model" -t echo $(dirname "$pb_dir/x").npz
python -m onnx2fluid $flags "$fn_model" -t $(dirname "$pb_dir/x").npz
done
}
......@@ -42,7 +42,7 @@ bvlc_googlenet()
do
echo "converting $pb_dir"
python convert_data_pb_0.py "$pb_dir" "data_0" "prob_1"
python -m onnx2paddle $flags "$fn_model" -t $(dirname "$pb_dir/x").npz
python -m onnx2fluid $flags "$fn_model" -t $(dirname "$pb_dir/x").npz
done
}
......@@ -60,7 +60,7 @@ bvlc_reference_caffenet()
do
echo "converting $pb_dir"
python convert_data_pb_0.py "$pb_dir" "data_0" "prob_1"
python -m onnx2paddle $flags "$fn_model" -t $(dirname "$pb_dir/x").npz
python -m onnx2fluid $flags "$fn_model" -t $(dirname "$pb_dir/x").npz
done
}
......@@ -69,7 +69,7 @@ bvlc_reference_rcnn_ilsvrc13()
bn_tar="bvlc_reference_rcnn_ilsvrc13"
fn_tar="$bn_tar.tar.gz"
fn_model="$bn_tar/model.onnx"
$get_url "$base_url$fn_tar"
echo "extracting ..."
tar xf "$fn_tar"
......@@ -77,8 +77,8 @@ bvlc_reference_rcnn_ilsvrc13()
for pb_dir in $bn_tar/*/
do
echo "converting $pb_dir"
python convert_data_pb_0.py "$pb_dir" "data_0" "softmaxout_1"
python -m onnx2paddle $flags "$fn_model" -t $(dirname "$pb_dir/x").npz
python convert_data_pb_0.py "$pb_dir" "data_0" "fc_rcnn_1"
python -m onnx2fluid $flags "$fn_model" -t $(dirname "$pb_dir/x").npz
done
}
......@@ -87,7 +87,7 @@ inception_v1()
bn_tar="inception_v1"
fn_tar="$bn_tar.tar.gz"
fn_model="$bn_tar/model.onnx"
$get_url "$base_url$fn_tar"
echo "extracting ..."
tar xf "$fn_tar"
......@@ -96,14 +96,14 @@ inception_v1()
do
echo "converting $npz ..."
python convert_data_npz_0.py "$npz" "data_0" "prob_1"
python -m onnx2paddle $flags "$fn_model" -t $npz
python -m onnx2fluid $flags "$fn_model" -t $npz
done
for pb_dir in $bn_tar/*/
do
echo "converting $pb_dir"
python convert_data_pb_0.py "$pb_dir" "data_0" "prob_1"
python -m onnx2paddle $flags "$fn_model" -t $(dirname "$pb_dir/x").npz
python -m onnx2fluid $flags "$fn_model" -t $(dirname "$pb_dir/x").npz
done
}
......@@ -112,7 +112,7 @@ inception_v2()
bn_tar="inception_v2"
fn_tar="$bn_tar.tar.gz"
fn_model="$bn_tar/model.onnx"
$get_url "$base_url$fn_tar"
echo "extracting ..."
tar xf "$fn_tar"
......@@ -121,14 +121,14 @@ inception_v2()
do
echo "converting $npz ..."
python convert_data_npz_0.py "$npz" "data_0" "prob_1"
python -m onnx2paddle $flags "$fn_model" -t $npz
python -m onnx2fluid $flags "$fn_model" -t $npz
done
for pb_dir in $bn_tar/*/
do
echo "converting $pb_dir"
python convert_data_pb_0.py "$pb_dir" "data_0" "prob_1"
python -m onnx2paddle $flags "$fn_model" -t $(dirname "$pb_dir/x").npz
python -m onnx2fluid $flags "$fn_model" -t $(dirname "$pb_dir/x").npz
done
}
......@@ -137,7 +137,7 @@ resnet50()
bn_tar="resnet50"
fn_tar="$bn_tar.tar.gz"
fn_model="$bn_tar/model.onnx"
$get_url "$base_url$fn_tar"
echo "extracting ..."
tar xf "$fn_tar"
......@@ -146,14 +146,14 @@ resnet50()
do
echo "converting $npz ..."
python convert_data_npz_0.py "$npz" "gpu_0/data_0" "gpu_0/softmaxout_1"
python -m onnx2paddle $flags "$fn_model" -t $npz
python -m onnx2fluid $flags "$fn_model" -t $npz
done
for pb_dir in $bn_tar/*/
do
echo "converting $pb_dir"
python convert_data_pb_0.py "$pb_dir" "gpu_0/data_0" "gpu_0/softmaxout_1"
python -m onnx2paddle $flags "$fn_model" -t $(dirname "$pb_dir/x").npz
python -m onnx2fluid $flags "$fn_model" -t $(dirname "$pb_dir/x").npz
done
}
......@@ -162,7 +162,7 @@ shufflenet()
bn_tar="shufflenet"
fn_tar="$bn_tar.tar.gz"
fn_model="$bn_tar/model.onnx"
$get_url "$base_url$fn_tar"
echo "extracting ..."
tar xf "$fn_tar"
......@@ -171,7 +171,7 @@ shufflenet()
do
echo "converting $pb_dir"
python convert_data_pb_0.py "$pb_dir" "gpu_0/data_0" "gpu_0/softmaxout_1"
python -m onnx2paddle $flags "$fn_model" -t $(dirname "$pb_dir/x").npz
python -m onnx2fluid $flags "$fn_model" -t $(dirname "$pb_dir/x").npz
done
}
......@@ -180,7 +180,7 @@ squeezenet()
bn_tar="squeezenet"
fn_tar="$bn_tar.tar.gz"
fn_model="$bn_tar/model.onnx"
$get_url "$base_url$fn_tar"
echo "extracting ..."
tar xf "$fn_tar"
......@@ -189,7 +189,7 @@ squeezenet()
do
echo "converting $pb_dir"
python convert_data_pb_0.py "$pb_dir" "data_0" "softmaxout_1"
python -m onnx2paddle $flags "$fn_model" -t $(dirname "$pb_dir/x").npz
python -m onnx2fluid $flags "$fn_model" -t $(dirname "$pb_dir/x").npz
done
}
......@@ -198,7 +198,7 @@ tiny_yolov2()
bn_tar="tiny_yolov2"
fn_tar="$bn_tar.tar.gz"
fn_model="$bn_tar/model.onnx"
$get_url "https://onnxzoo.blob.core.windows.net/models/opset_8/tiny_yolov2/$fn_tar"
echo "extracting ..."
tar xf "$fn_tar"
......@@ -207,7 +207,7 @@ tiny_yolov2()
do
echo "converting $pb_dir"
python convert_data_pb_0.py "$pb_dir" "image" "grid"
python -m onnx2paddle $flags "$fn_model" -t $(dirname "$pb_dir/x").npz -x
python -m onnx2fluid $flags "$fn_model" -t $(dirname "$pb_dir/x").npz -x
done
}
......@@ -216,7 +216,7 @@ vgg19()
bn_tar="vgg19"
fn_tar="$bn_tar.tar.gz"
fn_model="$bn_tar/model.onnx"
$get_url "$base_url$fn_tar"
echo "extracting ..."
tar xf "$fn_tar"
......@@ -225,7 +225,7 @@ vgg19()
do
echo "converting $pb_dir"
python convert_data_pb_0.py "$pb_dir" "data_0" "prob_1"
python -m onnx2paddle $flags "$fn_model" -t $(dirname "$pb_dir/x").npz
python -m onnx2fluid $flags "$fn_model" -t $(dirname "$pb_dir/x").npz
done
}
......@@ -234,7 +234,7 @@ zfnet512()
bn_tar="zfnet512"
fn_tar="$bn_tar.tar.gz"
fn_model="$bn_tar/model.onnx"
$get_url "$base_url$fn_tar"
echo "extracting ..."
tar xf "$fn_tar"
......@@ -243,20 +243,20 @@ zfnet512()
do
echo "converting $pb_dir"
python convert_data_pb_0.py "$pb_dir" "gpu_0/data_0" "gpu_0/softmax_1"
python -m onnx2paddle $flags "$fn_model" -t $(dirname "$pb_dir/x").npz
python -m onnx2fluid $flags "$fn_model" -t $(dirname "$pb_dir/x").npz
done
}
bvlc_alexnet # data error
bvlc_googlenet # desc error
bvlc_alexnet
bvlc_googlenet
bvlc_reference_caffenet
bvlc_reference_rcnn_ilsvrc13
inception_v1 ###
inception_v2 ###
resnet50 # data error
shufflenet ###
inception_v1
inception_v2
resnet50
shufflenet
squeezenet
tiny_yolov2 # not supported
vgg19
zfnet512 # data error
zfnet512
......@@ -5,7 +5,7 @@
#
################################################################################
"""
本文件允许模块包以python -m onnx2paddle方式直接执行。
本文件允许模块包以python -m onnx2fluid方式直接执行。
Authors: Macrobull
Date: 2019/02/22 10:25:46
......@@ -21,43 +21,67 @@ import argparse
import logging
import sys
parser = argparse.ArgumentParser(description='onnx2paddle',
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument('model', nargs=1,
help='path to model.onnx',
)
parser.add_argument('--debug', '-d', action='store_true',
help='enable debug logging and checking',
)
parser.add_argument('--output-dir', '-o', type=str, default='',
help='output directory',
)
parser.add_argument('--test_data', '-t', type=str, default='',
help='I/O golden data for validation, e.g. test.npy, test.npz',
)
parser.add_argument('--embed_params', '-e', action='store_true',
help='try to embed parameters for trainable Paddle layers',
)
parser.add_argument('--pedantic', action='store_true', default=True,
help='accept and convert only standard ONNX opset',
)
parser.add_argument('--no-pedantic', '-x', action='store_false',
dest='pedantic',
help='process non-standard ONNX ops, this may lead to fails',
)
parser.add_argument('--precision', '-p', type=int, default=4,
help='assertion decimal for validation',
)
parser = argparse.ArgumentParser(
description='onnx2fluid',
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
'model',
nargs=1,
help='path to model.onnx',
)
parser.add_argument(
'--debug',
'-d',
action='store_true',
help='enable debug logging and checking',
)
parser.add_argument(
'--output_dir',
'-o',
type=str,
default='',
help='output directory',
)
parser.add_argument(
'--test_data',
'-t',
type=str,
default='',
help='I/O golden data for validation, e.g. test.npy, test.npz',
)
parser.add_argument(
'--embed_params',
'-e',
action='store_true',
help='try to embed parameters for trainable Paddle fluid layers',
)
parser.add_argument(
'--pedantic',
action='store_true',
default=True,
help='accept and convert only standard ONNX opset',
)
parser.add_argument(
'--no-pedantic',
'-x',
action='store_false',
dest='pedantic',
help='process non-standard ONNX ops, this may lead to fails',
)
parser.add_argument(
'--precision',
'-p',
type=int,
default=4,
help='assertion decimal for validation',
)
args = parser.parse_args()
logging_format = '[%(levelname)8s]%(name)s::%(funcName)s:%(lineno)04d: %(message)s'
logging_level = logging.DEBUG if args.debug else logging.INFO
logging.basicConfig(format=logging_format, level=logging_level)
try:
from . import cmdline
except ImportError:
......@@ -66,5 +90,4 @@ except ImportError:
# imports
main = cmdline.main
sys.exit(main(**args.__dict__))
......@@ -21,7 +21,6 @@ import logging
import shutil
import zipfile
__all__ = [
'main',
]
......@@ -42,7 +41,7 @@ def main(**kwargs):
# imports
convert = conversion.convert
logger = logging.getLogger('onnx2paddle')
logger = logging.getLogger('onnx2fluid')
debug = kwargs.get('debug', False)
# prepare arguments
......@@ -58,13 +57,15 @@ def main(**kwargs):
onnx_opset_pedantic = kwargs.get('pedantic', True)
# convert
convert(filename, save_dir,
model_basename=model_basename,
model_func_name=model_func_name,
embed_params=embed_params,
onnx_opset_version=onnx_opset_version,
onnx_opset_pedantic=onnx_opset_pedantic,
debug=debug)
convert(
filename,
save_dir,
model_basename=model_basename,
model_func_name=model_func_name,
embed_params=embed_params,
onnx_opset_version=onnx_opset_version,
onnx_opset_pedantic=onnx_opset_pedantic,
debug=debug)
# validate
passed = True
......@@ -80,21 +81,23 @@ def main(**kwargs):
# in fact fluid can not fully clear the context
# continuous validation may be inaccurate
precision = 10 ** -kwargs.get('precision', 4)
precision = 10**-kwargs.get('precision', 4)
logger.info('starting validation on desc ...')
passed &= validate(shutil.os.path.join(save_dir, '__model__'),
golden_data_filename,
precision=precision,
)
passed &= validate(
shutil.os.path.join(save_dir, '__model__'),
golden_data_filename,
precision=precision,
)
logger.info('starting validation on code ...')
passed &= validate(shutil.os.path.join(save_dir, model_basename),
golden_data_filename,
model_func_name=model_func_name,
precision=precision,
save_inference_model=debug, # this overwrite desc file for test
)
passed &= validate(
shutil.os.path.join(save_dir, model_basename),
golden_data_filename,
model_func_name=model_func_name,
precision=precision,
save_inference_model=debug, # this overwrite desc file for test
)
if not passed:
logger.error('validation failed, exit')
......@@ -112,20 +115,22 @@ def main(**kwargs):
if __name__ == '__main__':
logging.basicConfig(
format='[%(levelname)8s]%(name)s::%(funcName)s:%(lineno)04d: %(message)s',
level=logging.DEBUG,
)
# main(model=['../examples/t5.onnx'],
# output_dir='/tmp/export/',
# embed_params=False,
# pedantic=False,
# test_data='../examples/t5.npz',
# debug=True)
main(model=['../examples/shufflenet/model.onnx'],
output_dir='/tmp/export/',
embed_params=True,
pedantic=False,
test_data='../examples/shufflenet/test_data_set_0.npz',
debug=True)
format=
'[%(levelname)8s]%(name)s::%(funcName)s:%(lineno)04d: %(message)s',
level=logging.DEBUG,
)
# main(model=['../examples/t5.onnx'],
# output_dir='/tmp/export/',
# embed_params=False,
# pedantic=False,
# test_data='../examples/t5.npz',
# debug=True)
main(
model=['../examples/inception_v2/model.onnx'],
output_dir='/tmp/export/',
embed_params=True,
pedantic=False,
test_data='../examples/inception_v2/test_data_set_2.npz',
debug=True)
......@@ -12,19 +12,21 @@ from __future__ import division
import logging
import shutil
__all__ = [
'convert',
]
def convert(onnx_model_filename, save_dir,
model_basename='model.py', model_func_name='inference',
def convert(onnx_model_filename,
save_dir,
model_basename='model.py',
model_func_name='inference',
embed_params=False,
onnx_opset_version=9, onnx_opset_pedantic=True,
onnx_opset_version=9,
onnx_opset_pedantic=True,
debug=False):
"""
convert an ONNX model to Paddle Python code and desc pb
convert an ONNX model to Paddle fluid Python code and desc pb
"""
import onnx
......@@ -59,10 +61,11 @@ def convert(onnx_model_filename, save_dir,
logger.info('checking model ...')
check_model(onnx_model)
logger.debug('using opset version: %d', onnx_opset_version)
if onnx_opset_pedantic: # WORKAROUND: RuntimeError: No Adapter For OP
if onnx_opset_pedantic: # WORKAROUND: RuntimeError: No Adapter For OP
onnx_model = convert_version(onnx_model, onnx_opset_version)
else: # TODO: add new argument for this option
logger.warning('opset conversion skipped for onnx_opset_pedantic is OFF')
else: # TODO: add new argument for this option
logger.warning(
'opset conversion skipped for onnx_opset_pedantic is OFF')
onnx_model = polish_model(onnx_model)
except ValidationError as e:
if onnx_opset_pedantic:
......@@ -90,13 +93,13 @@ def convert(onnx_model_filename, save_dir,
onnx.save(model, debug_model_filename + '.optimized_and_inffered.onnx')
# onnx.save(model, '/tmp/export/optimized_and_inffered.onnx')
# I/O instances
# I/O instances
onnx_graph = onnx_model.graph
paddle_program = Program()
paddle_writer = Writer()
fluid_program = Program()
fluid_writer = Writer()
# model components
# graph_name = onnx_graph.name
# graph_name = onnx_graph.name
graph_inputs = [value.name for value in onnx_graph.input]
graph_outputs = [value.name for value in onnx_graph.output]
graph_params = []
......@@ -107,29 +110,37 @@ def convert(onnx_model_filename, save_dir,
for name, weight in graph_weights(onnx_graph):
value_info = graph_value_infos[name]
value_info['embeded_as'] = []
value_info['get_weight'] = lambda: weight.tolist() # lazy getter
value_info['get_weight'] = (lambda w: lambda: w.tolist())(
weight) # lazy getter
logger.info('conversion started')
# op set conversion
# topo = 'backward' if embed_params else 'forward'
# topo = 'backward' if embed_params else 'forward'
topo = 'forward'
for name, domain, op_type, inputs, outputs, attrs in graph_ops(onnx_graph, topo=topo):
for name, domain, op_type, inputs, outputs, attrs in graph_ops(
onnx_graph, topo=topo):
logger.debug('translating op %s %s::%s ...', name, domain, op_type)
if domain == DEFAULT_OP_DOMAIN:
domain = ''
try:
paddle_writer.emit_op(paddle_program, name, domain, op_type,
inputs, outputs, attrs,
graph_value_infos,
embed_params=embed_params,
)
fluid_writer.emit_op(
fluid_program,
name,
domain,
op_type,
inputs,
outputs,
attrs,
graph_value_infos,
embed_params=embed_params,
)
except BaseException as e:
logger.fatal('conversion failed for:\n\t%s -> %s::%s -> %s',
inputs, domain, op_type, outputs)
logger.fatal('conversion failed for:\n\t%s -> %s::%s -> %s', inputs,
domain, op_type, outputs)
raise e
op_codes = paddle_program.codes
paddle_program.codes = []
logger.info('%d ops converted', len(paddle_program.op_descs))
op_codes = fluid_program.codes
fluid_program.codes = []
logger.info('%d ops converted', len(fluid_program.op_descs))
# weight writer
for name, weight in graph_weights(onnx_graph):
......@@ -138,18 +149,24 @@ def convert(onnx_model_filename, save_dir,
var_names = value_info.get('embeded_as', [])
if var_names:
if len(var_names) > 1:
logger.info('weight %s is shared between ops, more disk space will be consumed', name)
logger.debug('saving weight %s with size of %d, in %d bytes, as %s ...',
name, weight.size, weight.nbytes, var_names)
for var_name in var_names: # multiple references
paddle_writer.write_weight(weight, shutil.os.path.join(save_dir, var_name))
logger.info(
'weight %s is shared between ops, more disk space will be consumed',
name)
logger.debug(
'saving weight %s with size of %d, in %d bytes, as %s ...',
name, weight.size, weight.nbytes, var_names)
for var_name in var_names: # multiple references
fluid_writer.write_weight(
weight, shutil.os.path.join(save_dir, var_name))
else:
logger.debug('saving weight %s with size of %d, in %d bytes, to %s ...',
name, weight.size, weight.nbytes, make_var_name(name))
paddle_writer.write_weight(weight, shutil.os.path.join(save_dir, make_var_name(name)))
paddle_writer.emit_param(paddle_program, name, value_info)
param_codes = paddle_program.codes
paddle_program.codes = []
logger.debug(
'saving weight %s with size of %d, in %d bytes, to %s ...',
name, weight.size, weight.nbytes, make_var_name(name))
fluid_writer.write_weight(
weight, shutil.os.path.join(save_dir, make_var_name(name)))
fluid_writer.emit_param(fluid_program, name, value_info)
param_codes = fluid_program.codes
fluid_program.codes = []
logger.info('%d weights converted', len(graph_params))
# input writer
......@@ -159,9 +176,11 @@ def convert(onnx_model_filename, save_dir,
value_info = graph_value_infos[name]
assert value_info['external']
external_inputs.append(name)
paddle_writer.emit_inputs(paddle_program, external_inputs, graph_value_infos, remove_batch=False) # TODO:
input_codes = paddle_program.codes
paddle_program.codes = []
fluid_writer.emit_inputs(
fluid_program, external_inputs, graph_value_infos,
remove_batch=False) # TODO:
input_codes = fluid_program.codes
fluid_program.codes = []
logger.info('%d inputs converted', len(external_inputs))
# output writer
......@@ -171,49 +190,93 @@ def convert(onnx_model_filename, save_dir,
value_info = graph_value_infos[name]
assert value_info['external']
external_outputs.append(name)
paddle_writer.emit_outputs(paddle_program, external_outputs)
output_codes = [''] + paddle_program.codes # add an empty line
paddle_program.codes = []
fluid_writer.emit_outputs(fluid_program, external_outputs)
output_codes = [''] + fluid_program.codes # add an empty line
fluid_program.codes = []
logger.info('%d outputs converted', len(external_outputs))
# code generation
header_codes = fluid_writer.header_code(
model_func_name, 'From: {}'.format(onnx_model_filename))
code_filename = shutil.os.path.join(save_dir, model_basename)
paddle_writer.write_code_file(code_filename, paddle_writer.header_code(model_func_name),
input_codes, param_codes, op_codes, output_codes)
logger.info('code saved to %s, factory function: %s', code_filename, model_func_name)
fluid_writer.write_code_file(code_filename, header_codes, input_codes,
param_codes, op_codes, output_codes)
logger.info('code saved to %s, factory function: %s', code_filename,
model_func_name)
# desc generation
desc_filename = shutil.os.path.join(save_dir, '__model__')
paddle_writer.write_desc_file(desc_filename,
op_descs=paddle_program.op_descs,
var_descs=paddle_program.var_descs,
)
fluid_writer.write_desc_file(
desc_filename,
op_descs=fluid_program.op_descs,
var_descs=fluid_program.var_descs,
)
logger.info('program saved to %s', desc_filename)
logger.info('conversion finished')
# globals().update(locals())
# globals().update(locals())
if __name__ == '__main__':
logging.basicConfig(
format='[%(levelname)8s]%(name)s::%(funcName)s:%(lineno)04d: %(message)s',
level=logging.DEBUG,
)
import argparse
parser = argparse.ArgumentParser(
description='onnx2fluid.convert',
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
'model',
nargs=1,
help='path to model.onnx',
)
parser.add_argument(
'--debug',
'-d',
action='store_true',
help='enable debug logging and checking',
)
parser.add_argument(
'--output_dir',
'-o',
type=str,
default='',
help='output directory',
)
parser.add_argument(
'--embed_params',
'-e',
action='store_true',
help='try to embed parameters for trainable Paddle fluid layers',
)
parser.add_argument(
'--pedantic',
action='store_true',
default=True,
help='accept and convert only standard ONNX opset',
)
parser.add_argument(
'--no-pedantic',
'-x',
action='store_false',
dest='pedantic',
help='process non-standard ONNX ops, this may lead to fails',
)
args = parser.parse_args()
logging_format = '[%(levelname)8s]%(name)s::%(funcName)s:%(lineno)04d: %(message)s'
logging_level = logging.DEBUG if args.debug else logging.INFO
logging.basicConfig(format=logging_format, level=logging_level)
debug = args.debug
model_filename = args.model[0]
save_dir = args.output_dir
embed_params = args.embed_params
pedantic = args.pedantic
model_list = [
'../examples/t1.onnx',
'../examples/t2.onnx',
'../examples/t3.onnx',
'../examples/t4.onnx',
'../examples/t5.onnx',
'../examples/t6.onnx',
# '../examples/t7.onnx',
# '../examples/t8.onnx',
]
for model in model_list:
pathname, _ = shutil.os.path.splitext(model)
convert(model, pathname,
onnx_opset_pedantic=False, debug=True)
convert(model, pathname + '.embeded',
embed_params=True, onnx_opset_pedantic=False, debug=True)
convert(
model_filename,
save_dir,
embed_params=embed_params,
onnx_opset_pedantic=pedantic,
debug=debug)
# Generated by the protocol buffer compiler. DO NOT EDIT!
# source: framework.proto
import sys
_b = sys.version_info[0] < 3 and (lambda x: x) or (lambda x: x.encode('latin1'))
from google.protobuf.internal import enum_type_wrapper
from google.protobuf import descriptor as _descriptor
from google.protobuf import message as _message
from google.protobuf import reflection as _reflection
from google.protobuf import symbol_database as _symbol_database
from google.protobuf import descriptor_pb2
# @@protoc_insertion_point(imports)
_sym_db = _symbol_database.Default()
DESCRIPTOR = _descriptor.FileDescriptor(
name='framework.proto',
package='paddle.framework.proto',
syntax='proto2',
serialized_pb=_b(
'\n\x0f\x66ramework.proto\x12\x16paddle.framework.proto\"\x1d\n\x07Version\x12\x12\n\x07version\x18\x01 \x01(\x03:\x01\x30\"\xec\x03\n\x06OpDesc\x12\x0c\n\x04type\x18\x03 \x02(\t\x12\x32\n\x06inputs\x18\x01 \x03(\x0b\x32\".paddle.framework.proto.OpDesc.Var\x12\x33\n\x07outputs\x18\x02 \x03(\x0b\x32\".paddle.framework.proto.OpDesc.Var\x12\x32\n\x05\x61ttrs\x18\x04 \x03(\x0b\x32#.paddle.framework.proto.OpDesc.Attr\x12\x18\n\tis_target\x18\x05 \x01(\x08:\x05\x66\x61lse\x1a\xef\x01\n\x04\x41ttr\x12\x0c\n\x04name\x18\x01 \x02(\t\x12.\n\x04type\x18\x02 \x02(\x0e\x32 .paddle.framework.proto.AttrType\x12\t\n\x01i\x18\x03 \x01(\x05\x12\t\n\x01\x66\x18\x04 \x01(\x02\x12\t\n\x01s\x18\x05 \x01(\t\x12\x0c\n\x04ints\x18\x06 \x03(\x05\x12\x0e\n\x06\x66loats\x18\x07 \x03(\x02\x12\x0f\n\x07strings\x18\x08 \x03(\t\x12\t\n\x01\x62\x18\n \x01(\x08\x12\r\n\x05\x62ools\x18\x0b \x03(\x08\x12\x11\n\tblock_idx\x18\x0c \x01(\x05\x12\t\n\x01l\x18\r \x01(\x03\x12\x12\n\nblocks_idx\x18\x0e \x03(\x05\x12\r\n\x05longs\x18\x0f \x03(\x03\x1a+\n\x03Var\x12\x11\n\tparameter\x18\x01 \x02(\t\x12\x11\n\targuments\x18\x02 \x03(\t\"\xb3\x03\n\x07OpProto\x12\x0c\n\x04type\x18\x01 \x02(\t\x12\x33\n\x06inputs\x18\x02 \x03(\x0b\x32#.paddle.framework.proto.OpProto.Var\x12\x34\n\x07outputs\x18\x03 \x03(\x0b\x32#.paddle.framework.proto.OpProto.Var\x12\x33\n\x05\x61ttrs\x18\x04 \x03(\x0b\x32$.paddle.framework.proto.OpProto.Attr\x12\x0f\n\x07\x63omment\x18\x05 \x02(\t\x1ax\n\x03Var\x12\x0c\n\x04name\x18\x01 \x02(\t\x12\x0f\n\x07\x63omment\x18\x02 \x02(\t\x12\x19\n\nduplicable\x18\x03 \x01(\x08:\x05\x66\x61lse\x12\x1b\n\x0cintermediate\x18\x04 \x01(\x08:\x05\x66\x61lse\x12\x1a\n\x0b\x64ispensable\x18\x05 \x01(\x08:\x05\x66\x61lse\x1ao\n\x04\x41ttr\x12\x0c\n\x04name\x18\x01 \x02(\t\x12.\n\x04type\x18\x02 \x02(\x0e\x32 .paddle.framework.proto.AttrType\x12\x0f\n\x07\x63omment\x18\x03 \x02(\t\x12\x18\n\tgenerated\x18\x04 \x01(\x08:\x05\x66\x61lse\"\xda\x08\n\x07VarType\x12\x32\n\x04type\x18\x01 \x02(\x0e\x32$.paddle.framework.proto.VarType.Type\x12\x41\n\rselected_rows\x18\x02 \x01(\x0b\x32*.paddle.framework.proto.VarType.TensorDesc\x12\x41\n\nlod_tensor\x18\x03 \x01(\x0b\x32-.paddle.framework.proto.VarType.LoDTensorDesc\x12H\n\x0ctensor_array\x18\x04 \x01(\x0b\x32\x32.paddle.framework.proto.VarType.LoDTensorArrayDesc\x12:\n\x06reader\x18\x05 \x01(\x0b\x32*.paddle.framework.proto.VarType.ReaderDesc\x12\x34\n\x05tuple\x18\x07 \x01(\x0b\x32%.paddle.framework.proto.VarType.Tuple\x1aS\n\nTensorDesc\x12\x37\n\tdata_type\x18\x01 \x02(\x0e\x32$.paddle.framework.proto.VarType.Type\x12\x0c\n\x04\x64ims\x18\x02 \x03(\x03\x1a\x61\n\rLoDTensorDesc\x12:\n\x06tensor\x18\x01 \x02(\x0b\x32*.paddle.framework.proto.VarType.TensorDesc\x12\x14\n\tlod_level\x18\x02 \x01(\x05:\x01\x30\x1a\x66\n\x12LoDTensorArrayDesc\x12:\n\x06tensor\x18\x01 \x02(\x0b\x32*.paddle.framework.proto.VarType.TensorDesc\x12\x14\n\tlod_level\x18\x02 \x01(\x05:\x01\x30\x1aO\n\nReaderDesc\x12\x41\n\nlod_tensor\x18\x01 \x03(\x0b\x32-.paddle.framework.proto.VarType.LoDTensorDesc\x1a\x43\n\x05Tuple\x12:\n\x0c\x65lement_type\x18\x01 \x03(\x0e\x32$.paddle.framework.proto.VarType.Type\"\xa2\x02\n\x04Type\x12\x08\n\x04\x42OOL\x10\x00\x12\t\n\x05INT16\x10\x01\x12\t\n\x05INT32\x10\x02\x12\t\n\x05INT64\x10\x03\x12\x08\n\x04\x46P16\x10\x04\x12\x08\n\x04\x46P32\x10\x05\x12\x08\n\x04\x46P64\x10\x06\x12\n\n\x06SIZE_T\x10\x13\x12\t\n\x05UINT8\x10\x14\x12\x08\n\x04INT8\x10\x15\x12\x0e\n\nLOD_TENSOR\x10\x07\x12\x11\n\rSELECTED_ROWS\x10\x08\x12\x12\n\x0e\x46\x45\x45\x44_MINIBATCH\x10\t\x12\x0e\n\nFETCH_LIST\x10\n\x12\x0f\n\x0bSTEP_SCOPES\x10\x0b\x12\x12\n\x0eLOD_RANK_TABLE\x10\x0c\x12\x14\n\x10LOD_TENSOR_ARRAY\x10\r\x12\x0e\n\nPLACE_LIST\x10\x0e\x12\n\n\x06READER\x10\x0f\x12\x07\n\x03RAW\x10\x11\x12\t\n\x05TUPLE\x10\x12\"b\n\x07VarDesc\x12\x0c\n\x04name\x18\x01 \x02(\t\x12-\n\x04type\x18\x02 \x02(\x0b\x32\x1f.paddle.framework.proto.VarType\x12\x1a\n\x0bpersistable\x18\x03 \x01(\x08:\x05\x66\x61lse\"\xa7\x01\n\tBlockDesc\x12\x0b\n\x03idx\x18\x01 \x02(\x05\x12\x12\n\nparent_idx\x18\x02 \x02(\x05\x12-\n\x04vars\x18\x03 \x03(\x0b\x32\x1f.paddle.framework.proto.VarDesc\x12+\n\x03ops\x18\x04 \x03(\x0b\x32\x1e.paddle.framework.proto.OpDesc\x12\x1d\n\x11\x66orward_block_idx\x18\x05 \x01(\x05:\x02-1\"r\n\x0bProgramDesc\x12\x31\n\x06\x62locks\x18\x01 \x03(\x0b\x32!.paddle.framework.proto.BlockDesc\x12\x30\n\x07version\x18\x02 \x01(\x0b\x32\x1f.paddle.framework.proto.Version*\x94\x01\n\x08\x41ttrType\x12\x07\n\x03INT\x10\x00\x12\t\n\x05\x46LOAT\x10\x01\x12\n\n\x06STRING\x10\x02\x12\x08\n\x04INTS\x10\x03\x12\n\n\x06\x46LOATS\x10\x04\x12\x0b\n\x07STRINGS\x10\x05\x12\x0b\n\x07\x42OOLEAN\x10\x06\x12\x0c\n\x08\x42OOLEANS\x10\x07\x12\t\n\x05\x42LOCK\x10\x08\x12\x08\n\x04LONG\x10\t\x12\n\n\x06\x42LOCKS\x10\n\x12\t\n\x05LONGS\x10\x0b\x42\x02H\x03'
))
_sym_db.RegisterFileDescriptor(DESCRIPTOR)
_ATTRTYPE = _descriptor.EnumDescriptor(
name='AttrType',
full_name='paddle.framework.proto.AttrType',
filename=None,
file=DESCRIPTOR,
values=[
_descriptor.EnumValueDescriptor(
name='INT', index=0, number=0, options=None, type=None),
_descriptor.EnumValueDescriptor(
name='FLOAT', index=1, number=1, options=None, type=None),
_descriptor.EnumValueDescriptor(
name='STRING', index=2, number=2, options=None, type=None),
_descriptor.EnumValueDescriptor(
name='INTS', index=3, number=3, options=None, type=None),
_descriptor.EnumValueDescriptor(
name='FLOATS', index=4, number=4, options=None, type=None),
_descriptor.EnumValueDescriptor(
name='STRINGS', index=5, number=5, options=None, type=None),
_descriptor.EnumValueDescriptor(
name='BOOLEAN', index=6, number=6, options=None, type=None),
_descriptor.EnumValueDescriptor(
name='BOOLEANS', index=7, number=7, options=None, type=None),
_descriptor.EnumValueDescriptor(
name='BLOCK', index=8, number=8, options=None, type=None),
_descriptor.EnumValueDescriptor(
name='LONG', index=9, number=9, options=None, type=None),
_descriptor.EnumValueDescriptor(
name='BLOCKS', index=10, number=10, options=None, type=None),
_descriptor.EnumValueDescriptor(
name='LONGS', index=11, number=11, options=None, type=None),
],
containing_type=None,
options=None,
serialized_start=2511,
serialized_end=2659,
)
_sym_db.RegisterEnumDescriptor(_ATTRTYPE)
AttrType = enum_type_wrapper.EnumTypeWrapper(_ATTRTYPE)
INT = 0
FLOAT = 1
STRING = 2
INTS = 3
FLOATS = 4
STRINGS = 5
BOOLEAN = 6
BOOLEANS = 7
BLOCK = 8
LONG = 9
BLOCKS = 10
LONGS = 11
_VARTYPE_TYPE = _descriptor.EnumDescriptor(
name='Type',
full_name='paddle.framework.proto.VarType.Type',
filename=None,
file=DESCRIPTOR,
values=[
_descriptor.EnumValueDescriptor(
name='BOOL', index=0, number=0, options=None, type=None),
_descriptor.EnumValueDescriptor(
name='INT16', index=1, number=1, options=None, type=None),
_descriptor.EnumValueDescriptor(
name='INT32', index=2, number=2, options=None, type=None),
_descriptor.EnumValueDescriptor(
name='INT64', index=3, number=3, options=None, type=None),
_descriptor.EnumValueDescriptor(
name='FP16', index=4, number=4, options=None, type=None),
_descriptor.EnumValueDescriptor(
name='FP32', index=5, number=5, options=None, type=None),
_descriptor.EnumValueDescriptor(
name='FP64', index=6, number=6, options=None, type=None),
_descriptor.EnumValueDescriptor(
name='SIZE_T', index=7, number=19, options=None, type=None),
_descriptor.EnumValueDescriptor(
name='UINT8', index=8, number=20, options=None, type=None),
_descriptor.EnumValueDescriptor(
name='INT8', index=9, number=21, options=None, type=None),
_descriptor.EnumValueDescriptor(
name='LOD_TENSOR', index=10, number=7, options=None, type=None),
_descriptor.EnumValueDescriptor(
name='SELECTED_ROWS', index=11, number=8, options=None, type=None),
_descriptor.EnumValueDescriptor(
name='FEED_MINIBATCH', index=12, number=9, options=None, type=None),
_descriptor.EnumValueDescriptor(
name='FETCH_LIST', index=13, number=10, options=None, type=None),
_descriptor.EnumValueDescriptor(
name='STEP_SCOPES', index=14, number=11, options=None, type=None),
_descriptor.EnumValueDescriptor(
name='LOD_RANK_TABLE', index=15, number=12, options=None,
type=None),
_descriptor.EnumValueDescriptor(
name='LOD_TENSOR_ARRAY',
index=16,
number=13,
options=None,
type=None),
_descriptor.EnumValueDescriptor(
name='PLACE_LIST', index=17, number=14, options=None, type=None),
_descriptor.EnumValueDescriptor(
name='READER', index=18, number=15, options=None, type=None),
_descriptor.EnumValueDescriptor(
name='RAW', index=19, number=17, options=None, type=None),
_descriptor.EnumValueDescriptor(
name='TUPLE', index=20, number=18, options=None, type=None),
],
containing_type=None,
options=None,
serialized_start=1832,
serialized_end=2122,
)
_sym_db.RegisterEnumDescriptor(_VARTYPE_TYPE)
_VERSION = _descriptor.Descriptor(
name='Version',
full_name='paddle.framework.proto.Version',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='version',
full_name='paddle.framework.proto.Version.version',
index=0,
number=1,
type=3,
cpp_type=2,
label=1,
has_default_value=True,
default_value=0,
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
],
extensions=[],
nested_types=[],
enum_types=[],
options=None,
is_extendable=False,
syntax='proto2',
extension_ranges=[],
oneofs=[],
serialized_start=43,
serialized_end=72,
)
_OPDESC_ATTR = _descriptor.Descriptor(
name='Attr',
full_name='paddle.framework.proto.OpDesc.Attr',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='name',
full_name='paddle.framework.proto.OpDesc.Attr.name',
index=0,
number=1,
type=9,
cpp_type=9,
label=2,
has_default_value=False,
default_value=_b("").decode('utf-8'),
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='type',
full_name='paddle.framework.proto.OpDesc.Attr.type',
index=1,
number=2,
type=14,
cpp_type=8,
label=2,
has_default_value=False,
default_value=0,
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='i',
full_name='paddle.framework.proto.OpDesc.Attr.i',
index=2,
number=3,
type=5,
cpp_type=1,
label=1,
has_default_value=False,
default_value=0,
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='f',
full_name='paddle.framework.proto.OpDesc.Attr.f',
index=3,
number=4,
type=2,
cpp_type=6,
label=1,
has_default_value=False,
default_value=float(0),
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='s',
full_name='paddle.framework.proto.OpDesc.Attr.s',
index=4,
number=5,
type=9,
cpp_type=9,
label=1,
has_default_value=False,
default_value=_b("").decode('utf-8'),
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='ints',
full_name='paddle.framework.proto.OpDesc.Attr.ints',
index=5,
number=6,
type=5,
cpp_type=1,
label=3,
has_default_value=False,
default_value=[],
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='floats',
full_name='paddle.framework.proto.OpDesc.Attr.floats',
index=6,
number=7,
type=2,
cpp_type=6,
label=3,
has_default_value=False,
default_value=[],
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='strings',
full_name='paddle.framework.proto.OpDesc.Attr.strings',
index=7,
number=8,
type=9,
cpp_type=9,
label=3,
has_default_value=False,
default_value=[],
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='b',
full_name='paddle.framework.proto.OpDesc.Attr.b',
index=8,
number=10,
type=8,
cpp_type=7,
label=1,
has_default_value=False,
default_value=False,
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='bools',
full_name='paddle.framework.proto.OpDesc.Attr.bools',
index=9,
number=11,
type=8,
cpp_type=7,
label=3,
has_default_value=False,
default_value=[],
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='block_idx',
full_name='paddle.framework.proto.OpDesc.Attr.block_idx',
index=10,
number=12,
type=5,
cpp_type=1,
label=1,
has_default_value=False,
default_value=0,
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='l',
full_name='paddle.framework.proto.OpDesc.Attr.l',
index=11,
number=13,
type=3,
cpp_type=2,
label=1,
has_default_value=False,
default_value=0,
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='blocks_idx',
full_name='paddle.framework.proto.OpDesc.Attr.blocks_idx',
index=12,
number=14,
type=5,
cpp_type=1,
label=3,
has_default_value=False,
default_value=[],
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='longs',
full_name='paddle.framework.proto.OpDesc.Attr.longs',
index=13,
number=15,
type=3,
cpp_type=2,
label=3,
has_default_value=False,
default_value=[],
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
],
extensions=[],
nested_types=[],
enum_types=[],
options=None,
is_extendable=False,
syntax='proto2',
extension_ranges=[],
oneofs=[],
serialized_start=283,
serialized_end=522,
)
_OPDESC_VAR = _descriptor.Descriptor(
name='Var',
full_name='paddle.framework.proto.OpDesc.Var',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='parameter',
full_name='paddle.framework.proto.OpDesc.Var.parameter',
index=0,
number=1,
type=9,
cpp_type=9,
label=2,
has_default_value=False,
default_value=_b("").decode('utf-8'),
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='arguments',
full_name='paddle.framework.proto.OpDesc.Var.arguments',
index=1,
number=2,
type=9,
cpp_type=9,
label=3,
has_default_value=False,
default_value=[],
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
],
extensions=[],
nested_types=[],
enum_types=[],
options=None,
is_extendable=False,
syntax='proto2',
extension_ranges=[],
oneofs=[],
serialized_start=524,
serialized_end=567,
)
_OPDESC = _descriptor.Descriptor(
name='OpDesc',
full_name='paddle.framework.proto.OpDesc',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='type',
full_name='paddle.framework.proto.OpDesc.type',
index=0,
number=3,
type=9,
cpp_type=9,
label=2,
has_default_value=False,
default_value=_b("").decode('utf-8'),
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='inputs',
full_name='paddle.framework.proto.OpDesc.inputs',
index=1,
number=1,
type=11,
cpp_type=10,
label=3,
has_default_value=False,
default_value=[],
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='outputs',
full_name='paddle.framework.proto.OpDesc.outputs',
index=2,
number=2,
type=11,
cpp_type=10,
label=3,
has_default_value=False,
default_value=[],
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='attrs',
full_name='paddle.framework.proto.OpDesc.attrs',
index=3,
number=4,
type=11,
cpp_type=10,
label=3,
has_default_value=False,
default_value=[],
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='is_target',
full_name='paddle.framework.proto.OpDesc.is_target',
index=4,
number=5,
type=8,
cpp_type=7,
label=1,
has_default_value=True,
default_value=False,
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
],
extensions=[],
nested_types=[
_OPDESC_ATTR,
_OPDESC_VAR,
],
enum_types=[],
options=None,
is_extendable=False,
syntax='proto2',
extension_ranges=[],
oneofs=[],
serialized_start=75,
serialized_end=567,
)
_OPPROTO_VAR = _descriptor.Descriptor(
name='Var',
full_name='paddle.framework.proto.OpProto.Var',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='name',
full_name='paddle.framework.proto.OpProto.Var.name',
index=0,
number=1,
type=9,
cpp_type=9,
label=2,
has_default_value=False,
default_value=_b("").decode('utf-8'),
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='comment',
full_name='paddle.framework.proto.OpProto.Var.comment',
index=1,
number=2,
type=9,
cpp_type=9,
label=2,
has_default_value=False,
default_value=_b("").decode('utf-8'),
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='duplicable',
full_name='paddle.framework.proto.OpProto.Var.duplicable',
index=2,
number=3,
type=8,
cpp_type=7,
label=1,
has_default_value=True,
default_value=False,
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='intermediate',
full_name='paddle.framework.proto.OpProto.Var.intermediate',
index=3,
number=4,
type=8,
cpp_type=7,
label=1,
has_default_value=True,
default_value=False,
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='dispensable',
full_name='paddle.framework.proto.OpProto.Var.dispensable',
index=4,
number=5,
type=8,
cpp_type=7,
label=1,
has_default_value=True,
default_value=False,
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
],
extensions=[],
nested_types=[],
enum_types=[],
options=None,
is_extendable=False,
syntax='proto2',
extension_ranges=[],
oneofs=[],
serialized_start=772,
serialized_end=892,
)
_OPPROTO_ATTR = _descriptor.Descriptor(
name='Attr',
full_name='paddle.framework.proto.OpProto.Attr',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='name',
full_name='paddle.framework.proto.OpProto.Attr.name',
index=0,
number=1,
type=9,
cpp_type=9,
label=2,
has_default_value=False,
default_value=_b("").decode('utf-8'),
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='type',
full_name='paddle.framework.proto.OpProto.Attr.type',
index=1,
number=2,
type=14,
cpp_type=8,
label=2,
has_default_value=False,
default_value=0,
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='comment',
full_name='paddle.framework.proto.OpProto.Attr.comment',
index=2,
number=3,
type=9,
cpp_type=9,
label=2,
has_default_value=False,
default_value=_b("").decode('utf-8'),
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='generated',
full_name='paddle.framework.proto.OpProto.Attr.generated',
index=3,
number=4,
type=8,
cpp_type=7,
label=1,
has_default_value=True,
default_value=False,
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
],
extensions=[],
nested_types=[],
enum_types=[],
options=None,
is_extendable=False,
syntax='proto2',
extension_ranges=[],
oneofs=[],
serialized_start=894,
serialized_end=1005,
)
_OPPROTO = _descriptor.Descriptor(
name='OpProto',
full_name='paddle.framework.proto.OpProto',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='type',
full_name='paddle.framework.proto.OpProto.type',
index=0,
number=1,
type=9,
cpp_type=9,
label=2,
has_default_value=False,
default_value=_b("").decode('utf-8'),
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='inputs',
full_name='paddle.framework.proto.OpProto.inputs',
index=1,
number=2,
type=11,
cpp_type=10,
label=3,
has_default_value=False,
default_value=[],
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='outputs',
full_name='paddle.framework.proto.OpProto.outputs',
index=2,
number=3,
type=11,
cpp_type=10,
label=3,
has_default_value=False,
default_value=[],
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='attrs',
full_name='paddle.framework.proto.OpProto.attrs',
index=3,
number=4,
type=11,
cpp_type=10,
label=3,
has_default_value=False,
default_value=[],
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='comment',
full_name='paddle.framework.proto.OpProto.comment',
index=4,
number=5,
type=9,
cpp_type=9,
label=2,
has_default_value=False,
default_value=_b("").decode('utf-8'),
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
],
extensions=[],
nested_types=[
_OPPROTO_VAR,
_OPPROTO_ATTR,
],
enum_types=[],
options=None,
is_extendable=False,
syntax='proto2',
extension_ranges=[],
oneofs=[],
serialized_start=570,
serialized_end=1005,
)
_VARTYPE_TENSORDESC = _descriptor.Descriptor(
name='TensorDesc',
full_name='paddle.framework.proto.VarType.TensorDesc',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='data_type',
full_name='paddle.framework.proto.VarType.TensorDesc.data_type',
index=0,
number=1,
type=14,
cpp_type=8,
label=2,
has_default_value=False,
default_value=0,
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='dims',
full_name='paddle.framework.proto.VarType.TensorDesc.dims',
index=1,
number=2,
type=3,
cpp_type=2,
label=3,
has_default_value=False,
default_value=[],
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
],
extensions=[],
nested_types=[],
enum_types=[],
options=None,
is_extendable=False,
syntax='proto2',
extension_ranges=[],
oneofs=[],
serialized_start=1393,
serialized_end=1476,
)
_VARTYPE_LODTENSORDESC = _descriptor.Descriptor(
name='LoDTensorDesc',
full_name='paddle.framework.proto.VarType.LoDTensorDesc',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='tensor',
full_name='paddle.framework.proto.VarType.LoDTensorDesc.tensor',
index=0,
number=1,
type=11,
cpp_type=10,
label=2,
has_default_value=False,
default_value=None,
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='lod_level',
full_name='paddle.framework.proto.VarType.LoDTensorDesc.lod_level',
index=1,
number=2,
type=5,
cpp_type=1,
label=1,
has_default_value=True,
default_value=0,
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
],
extensions=[],
nested_types=[],
enum_types=[],
options=None,
is_extendable=False,
syntax='proto2',
extension_ranges=[],
oneofs=[],
serialized_start=1478,
serialized_end=1575,
)
_VARTYPE_LODTENSORARRAYDESC = _descriptor.Descriptor(
name='LoDTensorArrayDesc',
full_name='paddle.framework.proto.VarType.LoDTensorArrayDesc',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='tensor',
full_name='paddle.framework.proto.VarType.LoDTensorArrayDesc.tensor',
index=0,
number=1,
type=11,
cpp_type=10,
label=2,
has_default_value=False,
default_value=None,
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='lod_level',
full_name=
'paddle.framework.proto.VarType.LoDTensorArrayDesc.lod_level',
index=1,
number=2,
type=5,
cpp_type=1,
label=1,
has_default_value=True,
default_value=0,
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
],
extensions=[],
nested_types=[],
enum_types=[],
options=None,
is_extendable=False,
syntax='proto2',
extension_ranges=[],
oneofs=[],
serialized_start=1577,
serialized_end=1679,
)
_VARTYPE_READERDESC = _descriptor.Descriptor(
name='ReaderDesc',
full_name='paddle.framework.proto.VarType.ReaderDesc',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='lod_tensor',
full_name='paddle.framework.proto.VarType.ReaderDesc.lod_tensor',
index=0,
number=1,
type=11,
cpp_type=10,
label=3,
has_default_value=False,
default_value=[],
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
],
extensions=[],
nested_types=[],
enum_types=[],
options=None,
is_extendable=False,
syntax='proto2',
extension_ranges=[],
oneofs=[],
serialized_start=1681,
serialized_end=1760,
)
_VARTYPE_TUPLE = _descriptor.Descriptor(
name='Tuple',
full_name='paddle.framework.proto.VarType.Tuple',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='element_type',
full_name='paddle.framework.proto.VarType.Tuple.element_type',
index=0,
number=1,
type=14,
cpp_type=8,
label=3,
has_default_value=False,
default_value=[],
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
],
extensions=[],
nested_types=[],
enum_types=[],
options=None,
is_extendable=False,
syntax='proto2',
extension_ranges=[],
oneofs=[],
serialized_start=1762,
serialized_end=1829,
)
_VARTYPE = _descriptor.Descriptor(
name='VarType',
full_name='paddle.framework.proto.VarType',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='type',
full_name='paddle.framework.proto.VarType.type',
index=0,
number=1,
type=14,
cpp_type=8,
label=2,
has_default_value=False,
default_value=0,
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='selected_rows',
full_name='paddle.framework.proto.VarType.selected_rows',
index=1,
number=2,
type=11,
cpp_type=10,
label=1,
has_default_value=False,
default_value=None,
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='lod_tensor',
full_name='paddle.framework.proto.VarType.lod_tensor',
index=2,
number=3,
type=11,
cpp_type=10,
label=1,
has_default_value=False,
default_value=None,
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='tensor_array',
full_name='paddle.framework.proto.VarType.tensor_array',
index=3,
number=4,
type=11,
cpp_type=10,
label=1,
has_default_value=False,
default_value=None,
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='reader',
full_name='paddle.framework.proto.VarType.reader',
index=4,
number=5,
type=11,
cpp_type=10,
label=1,
has_default_value=False,
default_value=None,
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='tuple',
full_name='paddle.framework.proto.VarType.tuple',
index=5,
number=7,
type=11,
cpp_type=10,
label=1,
has_default_value=False,
default_value=None,
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
],
extensions=[],
nested_types=[
_VARTYPE_TENSORDESC,
_VARTYPE_LODTENSORDESC,
_VARTYPE_LODTENSORARRAYDESC,
_VARTYPE_READERDESC,
_VARTYPE_TUPLE,
],
enum_types=[
_VARTYPE_TYPE,
],
options=None,
is_extendable=False,
syntax='proto2',
extension_ranges=[],
oneofs=[],
serialized_start=1008,
serialized_end=2122,
)
_VARDESC = _descriptor.Descriptor(
name='VarDesc',
full_name='paddle.framework.proto.VarDesc',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='name',
full_name='paddle.framework.proto.VarDesc.name',
index=0,
number=1,
type=9,
cpp_type=9,
label=2,
has_default_value=False,
default_value=_b("").decode('utf-8'),
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='type',
full_name='paddle.framework.proto.VarDesc.type',
index=1,
number=2,
type=11,
cpp_type=10,
label=2,
has_default_value=False,
default_value=None,
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='persistable',
full_name='paddle.framework.proto.VarDesc.persistable',
index=2,
number=3,
type=8,
cpp_type=7,
label=1,
has_default_value=True,
default_value=False,
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
],
extensions=[],
nested_types=[],
enum_types=[],
options=None,
is_extendable=False,
syntax='proto2',
extension_ranges=[],
oneofs=[],
serialized_start=2124,
serialized_end=2222,
)
_BLOCKDESC = _descriptor.Descriptor(
name='BlockDesc',
full_name='paddle.framework.proto.BlockDesc',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='idx',
full_name='paddle.framework.proto.BlockDesc.idx',
index=0,
number=1,
type=5,
cpp_type=1,
label=2,
has_default_value=False,
default_value=0,
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='parent_idx',
full_name='paddle.framework.proto.BlockDesc.parent_idx',
index=1,
number=2,
type=5,
cpp_type=1,
label=2,
has_default_value=False,
default_value=0,
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='vars',
full_name='paddle.framework.proto.BlockDesc.vars',
index=2,
number=3,
type=11,
cpp_type=10,
label=3,
has_default_value=False,
default_value=[],
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='ops',
full_name='paddle.framework.proto.BlockDesc.ops',
index=3,
number=4,
type=11,
cpp_type=10,
label=3,
has_default_value=False,
default_value=[],
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='forward_block_idx',
full_name='paddle.framework.proto.BlockDesc.forward_block_idx',
index=4,
number=5,
type=5,
cpp_type=1,
label=1,
has_default_value=True,
default_value=-1,
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
],
extensions=[],
nested_types=[],
enum_types=[],
options=None,
is_extendable=False,
syntax='proto2',
extension_ranges=[],
oneofs=[],
serialized_start=2225,
serialized_end=2392,
)
_PROGRAMDESC = _descriptor.Descriptor(
name='ProgramDesc',
full_name='paddle.framework.proto.ProgramDesc',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='blocks',
full_name='paddle.framework.proto.ProgramDesc.blocks',
index=0,
number=1,
type=11,
cpp_type=10,
label=3,
has_default_value=False,
default_value=[],
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='version',
full_name='paddle.framework.proto.ProgramDesc.version',
index=1,
number=2,
type=11,
cpp_type=10,
label=1,
has_default_value=False,
default_value=None,
message_type=None,
enum_type=None,
containing_type=None,
is_extension=False,
extension_scope=None,
options=None),
],
extensions=[],
nested_types=[],
enum_types=[],
options=None,
is_extendable=False,
syntax='proto2',
extension_ranges=[],
oneofs=[],
serialized_start=2394,
serialized_end=2508,
)
_OPDESC_ATTR.fields_by_name['type'].enum_type = _ATTRTYPE
_OPDESC_ATTR.containing_type = _OPDESC
_OPDESC_VAR.containing_type = _OPDESC
_OPDESC.fields_by_name['inputs'].message_type = _OPDESC_VAR
_OPDESC.fields_by_name['outputs'].message_type = _OPDESC_VAR
_OPDESC.fields_by_name['attrs'].message_type = _OPDESC_ATTR
_OPPROTO_VAR.containing_type = _OPPROTO
_OPPROTO_ATTR.fields_by_name['type'].enum_type = _ATTRTYPE
_OPPROTO_ATTR.containing_type = _OPPROTO
_OPPROTO.fields_by_name['inputs'].message_type = _OPPROTO_VAR
_OPPROTO.fields_by_name['outputs'].message_type = _OPPROTO_VAR
_OPPROTO.fields_by_name['attrs'].message_type = _OPPROTO_ATTR
_VARTYPE_TENSORDESC.fields_by_name['data_type'].enum_type = _VARTYPE_TYPE
_VARTYPE_TENSORDESC.containing_type = _VARTYPE
_VARTYPE_LODTENSORDESC.fields_by_name[
'tensor'].message_type = _VARTYPE_TENSORDESC
_VARTYPE_LODTENSORDESC.containing_type = _VARTYPE
_VARTYPE_LODTENSORARRAYDESC.fields_by_name[
'tensor'].message_type = _VARTYPE_TENSORDESC
_VARTYPE_LODTENSORARRAYDESC.containing_type = _VARTYPE
_VARTYPE_READERDESC.fields_by_name[
'lod_tensor'].message_type = _VARTYPE_LODTENSORDESC
_VARTYPE_READERDESC.containing_type = _VARTYPE
_VARTYPE_TUPLE.fields_by_name['element_type'].enum_type = _VARTYPE_TYPE
_VARTYPE_TUPLE.containing_type = _VARTYPE
_VARTYPE.fields_by_name['type'].enum_type = _VARTYPE_TYPE
_VARTYPE.fields_by_name['selected_rows'].message_type = _VARTYPE_TENSORDESC
_VARTYPE.fields_by_name['lod_tensor'].message_type = _VARTYPE_LODTENSORDESC
_VARTYPE.fields_by_name[
'tensor_array'].message_type = _VARTYPE_LODTENSORARRAYDESC
_VARTYPE.fields_by_name['reader'].message_type = _VARTYPE_READERDESC
_VARTYPE.fields_by_name['tuple'].message_type = _VARTYPE_TUPLE
_VARTYPE_TYPE.containing_type = _VARTYPE
_VARDESC.fields_by_name['type'].message_type = _VARTYPE
_BLOCKDESC.fields_by_name['vars'].message_type = _VARDESC
_BLOCKDESC.fields_by_name['ops'].message_type = _OPDESC
_PROGRAMDESC.fields_by_name['blocks'].message_type = _BLOCKDESC
_PROGRAMDESC.fields_by_name['version'].message_type = _VERSION
DESCRIPTOR.message_types_by_name['Version'] = _VERSION
DESCRIPTOR.message_types_by_name['OpDesc'] = _OPDESC
DESCRIPTOR.message_types_by_name['OpProto'] = _OPPROTO
DESCRIPTOR.message_types_by_name['VarType'] = _VARTYPE
DESCRIPTOR.message_types_by_name['VarDesc'] = _VARDESC
DESCRIPTOR.message_types_by_name['BlockDesc'] = _BLOCKDESC
DESCRIPTOR.message_types_by_name['ProgramDesc'] = _PROGRAMDESC
DESCRIPTOR.enum_types_by_name['AttrType'] = _ATTRTYPE
Version = _reflection.GeneratedProtocolMessageType(
'Version',
(_message.Message, ),
dict(
DESCRIPTOR=_VERSION,
__module__='framework_pb2'
# @@protoc_insertion_point(class_scope:paddle.framework.proto.Version)
))
_sym_db.RegisterMessage(Version)
OpDesc = _reflection.GeneratedProtocolMessageType(
'OpDesc',
(_message.Message, ),
dict(
Attr=_reflection.GeneratedProtocolMessageType(
'Attr',
(_message.Message, ),
dict(
DESCRIPTOR=_OPDESC_ATTR,
__module__='framework_pb2'
# @@protoc_insertion_point(class_scope:paddle.framework.proto.OpDesc.Attr)
)),
Var=_reflection.GeneratedProtocolMessageType(
'Var',
(_message.Message, ),
dict(
DESCRIPTOR=_OPDESC_VAR,
__module__='framework_pb2'
# @@protoc_insertion_point(class_scope:paddle.framework.proto.OpDesc.Var)
)),
DESCRIPTOR=_OPDESC,
__module__='framework_pb2'
# @@protoc_insertion_point(class_scope:paddle.framework.proto.OpDesc)
))
_sym_db.RegisterMessage(OpDesc)
_sym_db.RegisterMessage(OpDesc.Attr)
_sym_db.RegisterMessage(OpDesc.Var)
OpProto = _reflection.GeneratedProtocolMessageType(
'OpProto',
(_message.Message, ),
dict(
Var=_reflection.GeneratedProtocolMessageType(
'Var',
(_message.Message, ),
dict(
DESCRIPTOR=_OPPROTO_VAR,
__module__='framework_pb2'
# @@protoc_insertion_point(class_scope:paddle.framework.proto.OpProto.Var)
)),
Attr=_reflection.GeneratedProtocolMessageType(
'Attr',
(_message.Message, ),
dict(
DESCRIPTOR=_OPPROTO_ATTR,
__module__='framework_pb2'
# @@protoc_insertion_point(class_scope:paddle.framework.proto.OpProto.Attr)
)),
DESCRIPTOR=_OPPROTO,
__module__='framework_pb2'
# @@protoc_insertion_point(class_scope:paddle.framework.proto.OpProto)
))
_sym_db.RegisterMessage(OpProto)
_sym_db.RegisterMessage(OpProto.Var)
_sym_db.RegisterMessage(OpProto.Attr)
VarType = _reflection.GeneratedProtocolMessageType(
'VarType',
(_message.Message, ),
dict(
TensorDesc=_reflection.GeneratedProtocolMessageType(
'TensorDesc',
(_message.Message, ),
dict(
DESCRIPTOR=_VARTYPE_TENSORDESC,
__module__='framework_pb2'
# @@protoc_insertion_point(class_scope:paddle.framework.proto.VarType.TensorDesc)
)),
LoDTensorDesc=_reflection.GeneratedProtocolMessageType(
'LoDTensorDesc',
(_message.Message, ),
dict(
DESCRIPTOR=_VARTYPE_LODTENSORDESC,
__module__='framework_pb2'
# @@protoc_insertion_point(class_scope:paddle.framework.proto.VarType.LoDTensorDesc)
)),
LoDTensorArrayDesc=_reflection.GeneratedProtocolMessageType(
'LoDTensorArrayDesc',
(_message.Message, ),
dict(
DESCRIPTOR=_VARTYPE_LODTENSORARRAYDESC,
__module__='framework_pb2'
# @@protoc_insertion_point(class_scope:paddle.framework.proto.VarType.LoDTensorArrayDesc)
)),
ReaderDesc=_reflection.GeneratedProtocolMessageType(
'ReaderDesc',
(_message.Message, ),
dict(
DESCRIPTOR=_VARTYPE_READERDESC,
__module__='framework_pb2'
# @@protoc_insertion_point(class_scope:paddle.framework.proto.VarType.ReaderDesc)
)),
Tuple=_reflection.GeneratedProtocolMessageType(
'Tuple',
(_message.Message, ),
dict(
DESCRIPTOR=_VARTYPE_TUPLE,
__module__='framework_pb2'
# @@protoc_insertion_point(class_scope:paddle.framework.proto.VarType.Tuple)
)),
DESCRIPTOR=_VARTYPE,
__module__='framework_pb2'
# @@protoc_insertion_point(class_scope:paddle.framework.proto.VarType)
))
_sym_db.RegisterMessage(VarType)
_sym_db.RegisterMessage(VarType.TensorDesc)
_sym_db.RegisterMessage(VarType.LoDTensorDesc)
_sym_db.RegisterMessage(VarType.LoDTensorArrayDesc)
_sym_db.RegisterMessage(VarType.ReaderDesc)
_sym_db.RegisterMessage(VarType.Tuple)
VarDesc = _reflection.GeneratedProtocolMessageType(
'VarDesc',
(_message.Message, ),
dict(
DESCRIPTOR=_VARDESC,
__module__='framework_pb2'
# @@protoc_insertion_point(class_scope:paddle.framework.proto.VarDesc)
))
_sym_db.RegisterMessage(VarDesc)
BlockDesc = _reflection.GeneratedProtocolMessageType(
'BlockDesc',
(_message.Message, ),
dict(
DESCRIPTOR=_BLOCKDESC,
__module__='framework_pb2'
# @@protoc_insertion_point(class_scope:paddle.framework.proto.BlockDesc)
))
_sym_db.RegisterMessage(BlockDesc)
ProgramDesc = _reflection.GeneratedProtocolMessageType(
'ProgramDesc',
(_message.Message, ),
dict(
DESCRIPTOR=_PROGRAMDESC,
__module__='framework_pb2'
# @@protoc_insertion_point(class_scope:paddle.framework.proto.ProgramDesc)
))
_sym_db.RegisterMessage(ProgramDesc)
DESCRIPTOR.has_options = True
DESCRIPTOR._options = _descriptor._ParseOptions(descriptor_pb2.FileOptions(),
_b('H\003'))
# @@protoc_insertion_point(module_scope)
......@@ -12,34 +12,36 @@ import logging
import numpy as np
import onnx
from collections import OrderedDict as Dict # as default dict
from collections import OrderedDict as Dict # as default dict
from onnx.helper import get_attribute_value, make_attribute
from onnx.mapping import TENSOR_TYPE_TO_NP_TYPE
from onnx.numpy_helper import to_array
from onnx.shape_inference import infer_shapes
logger = logging.getLogger(__name__)
__all__ = [
'print_pb_structure',
'build_value_refs',
'node_attrs', 'node_topo', 'node_iter',
'node_attrs',
'node_topo',
'node_iter',
'tensor_shape',
'graph_ops', 'graph_weights',
'graph_ops',
'graph_weights',
'inferred_model_value_info',
'optimize_model_skip_op_for_inference',
'optimize_model_strip_initializer',
'optimize_model_cast', 'optimize_model_slice',
'optimize_model_cast',
'optimize_model_slice',
]
ONNX_INT_MAX = 2 ** 63 - 1
ONNX_INT_MAX = 2**63 - 1
DEFAULT_OP_DOMAIN = 'ai.onnx'
def print_pb_structure(message,
loop_iterative=False, depth=0):
def print_pb_structure(message, loop_iterative=False, depth=0):
"""
print pb fields in its structure
"""
......@@ -47,14 +49,17 @@ def print_pb_structure(message,
if hasattr(message, 'DESCRIPTOR') and hasattr(message.DESCRIPTOR, 'fields'):
for field in message.DESCRIPTOR.fields:
print('\t' * depth + '-', field.name)
print_pb_structure(getattr(message, field.name),
loop_iterative=loop_iterative, depth=(depth + 1))
print_pb_structure(
getattr(message, field.name),
loop_iterative=loop_iterative,
depth=(depth + 1))
if loop_iterative and hasattr(message, 'MergeFrom') and hasattr(message, '__len__'):
if loop_iterative and hasattr(message, 'MergeFrom') and hasattr(
message, '__len__'):
for idx, item in enumerate(message):
print('\t' * depth + '-', idx)
print_pb_structure(item,
loop_iterative=loop_iterative, depth=(depth + 1))
print_pb_structure(
item, loop_iterative=loop_iterative, depth=(depth + 1))
def build_value_refs(nodes):
......@@ -80,7 +85,8 @@ def get_attribute_value2(attr):
if attr.type == onnx.AttributeProto.TENSOR:
dtype = np.dtype(TENSOR_TYPE_TO_NP_TYPE[attr.t.data_type])
data = attr.t.raw_data
value = np.frombuffer(data, dtype=dtype, count=(len(data) // dtype.itemsize))
value = np.frombuffer(
data, dtype=dtype, count=(len(data) // dtype.itemsize))
else:
value = get_attribute_value(attr)
return value
......@@ -91,7 +97,8 @@ def node_attrs(node):
convert ONNX node attributes to dict
"""
return {attr.name: get_attribute_value2(attr) for attr in node.attribute} # dict
return {attr.name: get_attribute_value2(attr)
for attr in node.attribute} # dict
def tensor_shape(tensor):
......@@ -137,7 +144,7 @@ def node_topo(nodes, topo='default'):
for next_idx in input_refs[val_name]:
node_in_degrees[next_idx] -= 1
if node_in_degrees[next_idx] == 0:
queue.insert(0, next_idx) # make it lazy
queue.insert(0, next_idx) # make it lazy
return node_topo
if topo == 'backward':
......@@ -162,14 +169,13 @@ def node_topo(nodes, topo='default'):
for next_idx in output_refs[val_name]:
node_out_degrees[next_idx] -= 1
if node_out_degrees[next_idx] == 0:
queue.insert(0, next_idx) # make it lazy
queue.insert(0, next_idx) # make it lazy
return node_topo
raise ValueError('unkown given topo: {}'.format(topo))
def node_iter(nodes,
indices=None):
def node_iter(nodes, indices=None):
"""
generator for ONNX node graph with given indices
"""
......@@ -194,8 +200,7 @@ def node_iter(nodes,
yield name, domain, op_type, inputs, outputs, attrs
def graph_ops(graph,
topo='default'):
def graph_ops(graph, topo='default'):
"""
generator for ONNX node graph with given topology
"""
......@@ -232,24 +237,24 @@ def inferred_model_value_info(model):
value_info = Dict()
for item in graph.value_info:
value_info[item.name] = dict(
dtype=TENSOR_TYPE_TO_NP_TYPE[item.type.tensor_type.elem_type],
shape=tensor_shape(item),
external=False,
)
dtype=TENSOR_TYPE_TO_NP_TYPE[item.type.tensor_type.elem_type],
shape=tensor_shape(item),
external=False,
)
for item in graph.input:
assert item.name not in value_info
value_info[item.name] = dict(
dtype=TENSOR_TYPE_TO_NP_TYPE[item.type.tensor_type.elem_type],
shape=tensor_shape(item),
external=True,
)
dtype=TENSOR_TYPE_TO_NP_TYPE[item.type.tensor_type.elem_type],
shape=tensor_shape(item),
external=True,
)
for item in graph.output:
# assert item.name not in value_info, 'bypass-model not supported'
# assert item.name not in value_info, 'bypass-model not supported'
value_info[item.name] = dict(
dtype=TENSOR_TYPE_TO_NP_TYPE[item.type.tensor_type.elem_type],
shape=tensor_shape(item),
external=True,
)
dtype=TENSOR_TYPE_TO_NP_TYPE[item.type.tensor_type.elem_type],
shape=tensor_shape(item),
external=True,
)
return value_info
......@@ -283,9 +288,7 @@ def skip_node_backward(nodes, src_input_name, dst_output_name, output_refs):
return processed
def optimize_model_skip_op_for_inference(
model,
op_list=None):
def optimize_model_skip_op_for_inference(model, op_list=None):
"""
skip ops can be bypassed for inference
"""
......@@ -297,38 +300,42 @@ def optimize_model_skip_op_for_inference(
ret = type(model)()
ret.CopyFrom(model)
ret.graph.ClearField('value_info') # WORKAROUND: onnx do not drop old value_info
ret.graph.ClearField(
'value_info') # WORKAROUND: onnx do not drop old value_info
ret_nodes = ret.graph.node
nodes_to_remove = []
for node_idx, node in enumerate(nodes):
if not(node.domain == DEFAULT_OP_DOMAIN or node.domain == ''):
if not (node.domain == DEFAULT_OP_DOMAIN or node.domain == ''):
continue
op_type = node.op_type
if not(op_type in op_list):
if not (op_type in op_list):
continue
if op_type in ['Dropout']:
input_name = node.input[0]
output_name = node.output[0]
elif not(len(node.input) == 1 and len(node.output) == 1):
logger.warning('currently only 1-input-1-output op supported, skip required %d: %s',
node_idx, node.op_type)
elif not (len(node.input) == 1 and len(node.output) == 1):
logger.warning(
'currently only 1-input-1-output op supported, skip required %d: %s',
node_idx, node.op_type)
continue
else:
input_name = node.input[0]
output_name = node.output[0]
if output_name in input_refs:
processed = skip_node_forward(ret_nodes, output_name, input_name, input_refs)
processed = skip_node_forward(ret_nodes, output_name, input_name,
input_refs)
elif input_name in output_refs:
processed = skip_node_backward(ret_nodes, input_name, output_name, output_refs)
processed = skip_node_backward(ret_nodes, input_name, output_name,
output_refs)
else:
processed = -1
if processed > 0:
nodes_to_remove.append(node_idx)
logger.debug('skip op %d: %s -> %s -> %s',
node_idx, input_name, node.op_type, output_name)
logger.debug('skip op %d: %s -> %s -> %s', node_idx, input_name,
node.op_type, output_name)
elif processed == 0:
logger.warning('weird, no node processed')
else:
......@@ -342,8 +349,7 @@ def optimize_model_skip_op_for_inference(
return ret
def optimize_model_strip_initializer(model,
keep_input_only=True):
def optimize_model_strip_initializer(model, keep_input_only=True):
"""
strip weights for inference
"""
......@@ -354,7 +360,8 @@ def optimize_model_strip_initializer(model,
ret = type(model)()
ret.CopyFrom(model)
ret.graph.ClearField('value_info') # WORKAROUND: onnx do not drop old value_info
ret.graph.ClearField(
'value_info') # WORKAROUND: onnx do not drop old value_info
# strip initializers
ret.graph.ClearField('initializer')
......@@ -366,8 +373,7 @@ def optimize_model_strip_initializer(model,
elif not keep_input_only and name in output_refs:
ret_initializers.add().CopyFrom(initializer)
else:
logger.debug('initializer %s(%s[%d]) stripped',
name,
logger.debug('initializer %s(%s[%d]) stripped', name,
TENSOR_TYPE_TO_NP_TYPE[initializer.data_type],
len(initializer.raw_data))
......@@ -379,10 +385,10 @@ def optimize_model_strip_initializer(model,
if name in input_refs or name in out_names:
ret_inputs.add().CopyFrom(item)
else:
logger.debug('input %s(%s%s) stripped',
name,
TENSOR_TYPE_TO_NP_TYPE[item.type.tensor_type.elem_type],
tensor_shape(item))
logger.debug(
'input %s(%s%s) stripped', name,
TENSOR_TYPE_TO_NP_TYPE[item.type.tensor_type.elem_type],
tensor_shape(item))
return ret
......@@ -397,18 +403,19 @@ def optimize_model_cast(model):
ret = type(model)()
ret.CopyFrom(model)
ret.graph.ClearField('value_info') # WORKAROUND: onnx do not drop old value_info
ret.graph.ClearField(
'value_info') # WORKAROUND: onnx do not drop old value_info
ret_nodes = ret.graph.node
nodes_to_remove = []
for node_idx, node in enumerate(nodes):
if not(node.domain == DEFAULT_OP_DOMAIN or node.domain == ''):
if not (node.domain == DEFAULT_OP_DOMAIN or node.domain == ''):
continue
if not(node.op_type == 'Cast'):
if not (node.op_type == 'Cast'):
continue
attrs = node_attrs(node)
output_dtype = TENSOR_TYPE_TO_NP_TYPE[attrs['to']]
input_name = node.input[0]
info = value_info.get('input_name', None) # relax for un-inferrable
info = value_info.get('input_name', None) # relax for un-inferrable
if info is None:
continue
input_dtype = info.get('dtype', None)
......@@ -417,21 +424,23 @@ def optimize_model_cast(model):
output_name = node.output[0]
if output_name in input_refs:
processed = skip_node_forward(ret_nodes, output_name, input_name, input_refs)
processed = skip_node_forward(ret_nodes, output_name, input_name,
input_refs)
elif input_name in output_refs:
processed = skip_node_backward(ret_nodes, input_name, output_name, output_refs)
processed = skip_node_backward(ret_nodes, input_name, output_name,
output_refs)
else:
processed = -1
if processed > 0:
nodes_to_remove.append(node_idx)
logger.debug('skip %s: %s -> %s Cast op',
node.name, input_dtype, output_dtype)
logger.debug('skip %s: %s -> %s Cast op', node.name, input_dtype,
output_dtype)
elif processed == 0:
logger.warning('weird, no node processed')
else:
logger.debug('keep standalone %s: %s -> %s Cast op',
node.name, input_dtype, output_dtype)
logger.debug('keep standalone %s: %s -> %s Cast op', node.name,
input_dtype, output_dtype)
nodes_to_remove.sort(reverse=True)
for node_idx in nodes_to_remove:
......@@ -452,13 +461,14 @@ def optimize_model_slice(model):
chain = []
while True:
node = nodes[node_idx]
if not(node.domain == DEFAULT_OP_DOMAIN or node.domain == ''):
if not (node.domain == DEFAULT_OP_DOMAIN or node.domain == ''):
return chain
if not node.op_type == 'Slice':
return chain
chain.append(node_idx)
output_name = node.output[0]
if output_name not in input_refs or len(input_refs[output_name]) != 1:
if output_name not in input_refs or len(
input_refs[output_name]) != 1:
return chain
node_idx = list(input_refs[output_name])[0]
......@@ -468,7 +478,8 @@ def optimize_model_slice(model):
for slice_node_idx in slice_chain:
node = nodes[slice_node_idx]
attrs = node_attrs(node)
for axis, start, end in zip(attrs['axes'], attrs['starts'], attrs['ends']):
for axis, start, end in zip(attrs['axes'], attrs['starts'],
attrs['ends']):
if start == 0 and end == ONNX_INT_MAX:
continue
if axis in merged_slice:
......@@ -480,7 +491,8 @@ def optimize_model_slice(model):
ret = type(model)()
ret.CopyFrom(model)
ret.graph.ClearField('value_info') # WORKAROUND: onnx do not drop old value_info
ret.graph.ClearField(
'value_info') # WORKAROUND: onnx do not drop old value_info
ret_nodes = ret.graph.node
nodes_to_remove = []
for node_idx in range(len(nodes)):
......@@ -488,7 +500,7 @@ def optimize_model_slice(model):
if len(slice_chain) == 0:
continue
merged_slice = _merge_slice(slice_chain)
if len(merged_slice) > 0 and len(slice_chain) == 1: # no need to merge
if len(merged_slice) > 0 and len(slice_chain) == 1: # no need to merge
continue
attrs = dict(axes=[], starts=[], ends=[])
......@@ -501,42 +513,50 @@ def optimize_model_slice(model):
input_name = first_node.input[0]
output_name = last_node.output[0]
processed = -1
if output_name in input_refs: # 0, [1...]
new_input_name = first_node.output[0] if len(merged_slice) > 0 else input_name
processed = skip_node_forward(ret_nodes, output_name, new_input_name, input_refs)
if output_name in input_refs: # 0, [1...]
new_input_name = first_node.output[0] if len(
merged_slice) > 0 else input_name
processed = skip_node_forward(ret_nodes, output_name,
new_input_name, input_refs)
if processed > 0:
if len(merged_slice) > 0:
remain_idx = slice_chain[0]
remove_chain = slice_chain[1:]
slice_node = ret_nodes[remain_idx]
for attr in slice_node.attribute:
attr.CopyFrom(make_attribute(attr.name, attrs[attr.name]))
attr.CopyFrom(
make_attribute(attr.name, attrs[attr.name]))
logger.debug('merged slice chain %s -> %s%s -> %s',
input_name, remain_idx, remove_chain, output_name)
input_name, remain_idx, remove_chain,
output_name)
else:
remove_chain = slice_chain
if processed < 0 and input_name in output_refs:
new_output_name = last_node.input[0] if len(merged_slice) > 0 else output_name
processed = skip_node_backward(ret_nodes, input_name, new_output_name, output_refs)
new_output_name = last_node.input[0] if len(
merged_slice) > 0 else output_name
processed = skip_node_backward(ret_nodes, input_name,
new_output_name, output_refs)
if processed > 0:
if len(merged_slice) > 0:
remain_idx = slice_chain[-1]
remove_chain = slice_chain[:-1]
slice_node = ret_nodes[remain_idx]
for attr in slice_node.attribute:
attr.CopyFrom(make_attribute(attr.name, attrs[attr.name]))
attr.CopyFrom(
make_attribute(attr.name, attrs[attr.name]))
logger.debug('merged slice chain %s -> %s%s -> %s',
input_name, remove_chain, remain_idx, output_name)
input_name, remove_chain, remain_idx,
output_name)
else:
remove_chain = slice_chain
if processed > 0:
nodes_to_remove.extend(remove_chain)
if len(merged_slice) == 0:
logger.debug('skip slice chain %s -> %s -> %s',
input_name, slice_chain, output_name)
elif processed < 0: # NEVERFIX: not merge standalone slice chain
logger.debug('skip slice chain %s -> %s -> %s', input_name,
slice_chain, output_name)
elif processed < 0: # NEVERFIX: not merge standalone slice chain
logger.debug('keep standalone slice chain %s -> %s -> %s',
input_name, slice_chain, output_name)
......@@ -549,9 +569,10 @@ def optimize_model_slice(model):
if __name__ == '__main__':
logging.basicConfig(
format='[%(levelname)8s]%(name)s::%(funcName)s:%(lineno)04d: %(message)s',
level=logging.DEBUG,
)
format=
'[%(levelname)8s]%(name)s::%(funcName)s:%(lineno)04d: %(message)s',
level=logging.DEBUG,
)
from onnx.checker import check_model
from onnx.utils import polish_model
......
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
ONNX to Paddle symbolic translation
ONNX to Paddle fluid symbolic translation
Created on Mon Feb 25 09:33:43 2019
......@@ -18,20 +18,23 @@ from onnx.mapping import TENSOR_TYPE_TO_NP_TYPE
_logger = _logging.getLogger(__name__)
ONNX_INT_MAX = 2 ** 63 - 1
ONNX_INT_MAX = 2**63 - 1
FLUID_INT_MAX = 2**31 - 1
DEFAULT_ONNX_OP_DOMAIN = ''
DEFAULT_PADDLE_OP_NAMESCOPE = '/'
DEFAULT_FLUID_OP_NAMESCOPE = '/'
DEFAULT_OP_MAPPING_FIELD_VALUES = _dict()
DEFAULT_OP_MAPPING_FIELD_VALUES['PADDLE_OP'] = ''
DEFAULT_OP_MAPPING_FIELD_VALUES['PADDLE_INPUT_ARGS'] = None
DEFAULT_OP_MAPPING_FIELD_VALUES['PADDLE_OUTPUT_ARGS'] = None
DEFAULT_OP_MAPPING_FIELD_VALUES['ATTR_MAPPING'] = dict() # dict(onnx_attr_from=paddle_attr_to)
DEFAULT_OP_MAPPING_FIELD_VALUES['DEFAULTS'] = dict() # dict(paddle_attr=default)
DEFAULT_OP_MAPPING_FIELD_VALUES['INPUT_PERM'] = None # sampler: [idx_onnx_arg...]
DEFAULT_OP_MAPPING_FIELD_VALUES['OUTPUT_PERM'] = None # sampler: [idx_onnx_arg...]
DEFAULT_OP_MAPPING_FIELD_VALUES['FLUID_OP'] = ''
DEFAULT_OP_MAPPING_FIELD_VALUES['FLUID_INPUT_ARGS'] = None
DEFAULT_OP_MAPPING_FIELD_VALUES['FLUID_OUTPUT_ARGS'] = None
DEFAULT_OP_MAPPING_FIELD_VALUES['ATTR_MAPPING'] = dict(
) # dict(onnx_attr_from=fluid_attr_to)
DEFAULT_OP_MAPPING_FIELD_VALUES['DEFAULTS'] = dict() # dict(fluid_attr=default)
DEFAULT_OP_MAPPING_FIELD_VALUES[
'INPUT_PERM'] = None # sampler: [idx_onnx_arg...]
DEFAULT_OP_MAPPING_FIELD_VALUES[
'OUTPUT_PERM'] = None # sampler: [idx_onnx_arg...]
DEFAULT_OP_MAPPING_FIELD_VALUES['FILL_NAME_FIELD'] = True
DEFAULT_OP_MAPPING = {
......@@ -60,7 +63,7 @@ DEFAULT_OP_MAPPING = {
'Reciprocal': ['reciprocal', ['X'], ['Out']],
'Relu': ['relu', ['X'], ['Out']],
'Selu': ['selu', ['X'], ['Out'], dict(gamma='scale')],
'Shape': ['shape', ['X'], ['Out']], # FIXME: out is int64 - int32
'Shape': ['shape', ['X'], ['Out']], # FIXME: out is int64 vs int32
'Shrink': ['softshrink', ['X'], ['Out'], dict(bias='', labmd='')],
'Sigmoid': ['sigmoid', ['X'], ['Out']],
'Sin': ['sin', ['X'], ['Out']],
......@@ -74,25 +77,24 @@ DEFAULT_OP_MAPPING = {
'Transpose': ['transpose', ['X'], ['Out']], # FIXME: emit transpose2
'Unsqueeze': ['unsqueeze', ['X'], ['Out']], # attrs bypassed, FIXME: emit unsqueeze2
## binary ops ##
# FIXME: axis=-1 in Paddle is broken, refer it in specialization
'Add': ['elementwise_add', ['X', 'Y'], ['Out'], dict(), dict(axis=0)],
# 'AffineGrid': ['affine_grid', ['Theta'], ['Output'], dict(size='out_shape')],
'Add': ['elementwise_add', ['X', 'Y'], ['Out'], dict(), dict(axis=-1)],
# 'AffineGrid': ['affine_grid', ['Theta'], ['Output'], dict(size='out_shape')],
'And': ['logical_and', ['X', 'Y'], ['Out']],
'Div': ['elementwise_div', ['X', 'Y'], ['Out'], dict(), dict(axis=0)],
'Div': ['elementwise_div', ['X', 'Y'], ['Out'], dict(), dict(axis=-1)],
'Equal': ['equal', ['X', 'Y'], ['Out'], dict(), dict(), None, None, False],
'Greater': ['less_than', ['X', 'Y'], ['Out'], dict(), dict(), None, None, False],
'Less': ['less_than', ['X', 'Y'], ['Out'], dict(), dict(), None, None, False],
'MatMul': ['matmul', ['X', 'Y'], ['Out']], # defaults excluded for transpose_x - transpose_X
'Max': ['elementwise_max', ['X', 'Y'], ['Out'], dict(), dict(axis=0)],
'Min': ['elementwise_min', ['X', 'Y'], ['Out'], dict(), dict(axis=0)],
'Mul': ['elementwise_mul', ['X', 'Y'], ['Out'], dict(), dict(axis=0)],
'MatMul': ['matmul', ['X', 'Y'], ['Out']], # defaults excluded for transpose_x vs transpose_X
'Max': ['elementwise_max', ['X', 'Y'], ['Out'], dict(), dict(axis=-1)],
'Min': ['elementwise_min', ['X', 'Y'], ['Out'], dict(), dict(axis=-1)],
'Mul': ['elementwise_mul', ['X', 'Y'], ['Out'], dict(), dict(axis=-1)],
'Not': ['logical_not', ['X', 'Y'], ['Out']],
'OneHot': # assuming values=[0, 1], axis=-1 and drop them
['one_hot', ['Input', 'Depth'], ['Out'], dict(axis=''), dict(),
[0, 1], None, False],
'Or': ['logical_or', ['X', 'Y'], ['Out']],
'Pow': ['elementwise_pow', ['X', 'Y'], ['Out'], dict(), dict(axis=0)], # TODO: pow for scalar exponent
'Sub': ['elementwise_sub', ['X', 'Y'], ['Out'], dict(), dict(axis=0)],
'Pow': ['elementwise_pow', ['X', 'Y'], ['Out'], dict(), dict(axis=-1)], # TODO: pow for scalar exponent
'Sub': ['elementwise_sub', ['X', 'Y'], ['Out'], dict(), dict(axis=-1)],
'Xor': ['logical_xor', ['X', 'Y'], ['Out']],
# reduce ops
'ReduceMax': ['reduce_max', ['X'], ['Out'], dict(axes='dim', keepdims='keep_dim')],
......@@ -106,30 +108,35 @@ DEFAULT_OP_MAPPING = {
}
DEFAULT_IOA_CONSTRAINT = {
'ArgMax':
[(lambda i, o, a: a.get('keepdims', 1) == 1, 'only keepdims = 0 is supported'),
],
'ArgMin':
[(lambda i, o, a: a.get('keepdims', 1) == 1, 'only keepdims = 0 is supported'),
],
'Gather':
[(lambda i, o, a: a.get('axis', 0) == 0, 'only axis = 0 is supported'),
],
'Shrink':
[(lambda i, o, a: a.get('bias', 0) == a.get('lambd', 0.5), 'only SoftShrink with bias = lambd is supported'),
],
# 'Softmax':
# [(lambda i, o, a: a.get('axis', 1) == -2, 'Paddle Softmax works on dim -2 only'),
# ],
'OneHot':
[(lambda i, o, a: a.get('axis', -1) == -1, 'only axis = -1 is supported'),
],
'Scatter':
[(lambda i, o, a: a.get('axis', 0) == 0, 'only axis = 0 is supported'),
],
'TopK':
[(lambda i, o, a: a.get('axis', -1) == -1, 'only axis = -1 is supported'),
],
'ArgMax': [
(lambda i, o, a: a.get('keepdims', 1) == 1,
'only keepdims = 0 is supported'),
],
'ArgMin': [
(lambda i, o, a: a.get('keepdims', 1) == 1,
'only keepdims = 0 is supported'),
],
'Gather': [
(lambda i, o, a: a.get('axis', 0) == 0, 'only axis = 0 is supported'),
],
'Shrink': [
(lambda i, o, a: a.get('bias', 0) == a.get('lambd', 0.5),
'only SoftShrink with bias = lambd is supported'),
],
# 'Softmax':
# [(lambda i, o, a: a.get('axis', 1) == -2, 'Paddle fluid Softmax works on dim -2 only'),
# ],
'OneHot': [
(lambda i, o, a: a.get('axis', -1) == -1,
'only axis = -1 is supported'),
],
'Scatter': [
(lambda i, o, a: a.get('axis', 0) == 0, 'only axis = 0 is supported'),
],
'TopK': [
(lambda i, o, a: a.get('axis', -1) == -1,
'only axis = -1 is supported'),
],
}
......@@ -142,7 +149,7 @@ def _make_var_name(name):
return '_'
if name[0].isdigit():
return 'var_' + name
for s in ' *?\/-:':
for s in ' *?\\/-:':
name = name.replace(s, '_')
if name.startswith('_'):
name = 'var' + name
......@@ -188,89 +195,98 @@ def _shape_or_none(value_infos, val_name):
# return value_info.get('const_value', var_name)
def _default(prog, op_type, inputs, outputs, attrs,
*args,
name='',
**kwargs):
def _default(prog, op_type, inputs, outputs, attrs, *args, name='', **kwargs):
info = DEFAULT_OP_MAPPING[op_type]
info.extend(list(DEFAULT_OP_MAPPING_FIELD_VALUES.values())[len(info):])
(paddle_op,
paddle_input_args, paddle_output_args,
attr_mapping, default_attrs,
input_perm, output_perm,
fill_name_field,
) = info
if paddle_op in DEFAULT_IOA_CONSTRAINT:
for predicate, message in DEFAULT_IOA_CONSTRAINT[paddle_op]:
assert predicate(inputs, outputs, attrs), message
(
fluid_op,
fluid_input_args,
fluid_output_args,
attr_mapping,
default_attrs,
input_perm,
output_perm,
fill_name_field,
) = info
if fluid_op in DEFAULT_IOA_CONSTRAINT:
for predicate, message in DEFAULT_IOA_CONSTRAINT[fluid_op]:
assert predicate(inputs, outputs, attrs), message
# bypass if key absent, drop if mapped key is '' or '_'
mapped_attrs = {attr_mapping.get(key, key): value for key, value in attrs.items()}
mapped_attrs = {
attr_mapping.get(key, key): value
for key, value in attrs.items()
}
if '' in mapped_attrs:
mapped_attrs.pop('')
if '_' in mapped_attrs:
mapped_attrs.pop('_')
paddle_attrs = default_attrs.copy()
paddle_attrs.update(mapped_attrs) # as new attrs
fluid_attrs = default_attrs.copy()
fluid_attrs.update(mapped_attrs) # as new attrs
val_inps = inputs if input_perm is None else map(lambda i: inputs[i], input_perm)
val_outs = outputs if output_perm is None else map(lambda i: outputs[i], output_perm)
val_inps = inputs if input_perm is None else map(lambda i: inputs[i],
input_perm)
val_outs = outputs if output_perm is None else map(lambda i: outputs[i],
output_perm)
var_inps = [_make_var_name(val) for val in val_inps]
var_outs = [_make_var_name(val) for val in val_outs]
arg_name = ', name={}'.format(repr(name)) if fill_name_field and name else ''
arg_attrs = [', {}={}'.format(key, value) for key, value in paddle_attrs.items()]
prog.Code('{} = layers.{}({}{}{})'
.format(', '.join(var_outs),
paddle_op,
', '.join(var_inps),
''.join(arg_attrs),
arg_name,
))
for val_out, var_out in zip(val_outs, var_outs):
arg_name = ', name={}'.format(
repr(name)) if fill_name_field and name else ''
arg_attrs = [
', {}={}'.format(key, value) for key, value in fluid_attrs.items()
]
prog.Code('{} = layers.{}({}{}{})'.format(
', '.join(var_outs),
fluid_op,
', '.join(var_inps),
''.join(arg_attrs),
arg_name,
))
for var_out in var_outs:
prog.VarDesc(var_out)
# dummy var_out
num_vars = len(var_outs)
num_args = len(paddle_output_args)
num_args = len(fluid_output_args)
if num_vars < num_args:
assert fill_name_field, 'name required to naming dummy output variable'
for idx_out in range(num_vars, num_args):
var_out = _make_var_name(name + '.' + paddle_output_args[idx_out].lower())
var_out = _make_var_name(name + '.' +
fluid_output_args[idx_out].lower())
var_outs.append(var_out)
prog.VarDesc(var_out)
prog.OpDesc(paddle_op,
(var_inps, *paddle_input_args),
(var_outs, *paddle_output_args),
paddle_attrs)
prog.OpDesc(fluid_op, (var_inps, *fluid_input_args),
(var_outs, *fluid_output_args), fluid_attrs)
def _assign(prog, attrs):
mapping = attrs['mapping'] # additional
paddle_op = 'assign'
mapping = attrs['mapping'] # additional
fluid_op = 'assign'
for val_dst, val_src in mapping.items():
var_dst = _make_var_name(val_dst)
var_src = _make_var_name(val_src)
prog.Code('{} = {}'.format(var_dst, var_src))
# prog.Code('{} = layers.{}({})'
# .format(var_dst,
# paddle_op,
# var_src,
# ))
# prog.Code('{} = layers.{}({})'
# .format(var_dst,
# fluid_op,
# var_src,
# ))
prog.VarDesc(var_dst)
prog.OpDesc(paddle_op,
([var_src], 'X'),
([var_dst], 'Out'),
dict(),
)
prog.OpDesc(
fluid_op,
([var_src], 'X'),
([var_dst], 'Out'),
dict(),
)
def _pad_if_asymmetric(prog, pads, val_name, value_infos): # pads: SSEE
def _pad_if_asymmetric(prog, pads, val_name, value_infos): # pads: SSEE
assert len(pads) & 1 == 0
ndims = len(pads) // 2
symmetric = True
......@@ -283,19 +299,28 @@ def _pad_if_asymmetric(prog, pads, val_name, value_infos): # pads: SSEE
return pads[:ndims], None
val_padded = val_name + '_padded'
prog.Op('', 'Pad',
[val_name],
[val_padded], # val
dict(mode='constant',
value=0.,
pads=pads,
),
value_infos=value_infos,
name=val_padded,
)
prog.Op(
'',
'Pad',
[val_name],
[val_padded], # val
dict(
mode='constant',
value=0.,
pads=pads,
),
value_infos=value_infos,
name=val_padded,
)
return [0] * ndims, val_padded
def _adaptive_pool(prog, pool_type, inputs, outputs, attrs, value_infos,
def _adaptive_pool(prog,
pool_type,
inputs,
outputs,
attrs,
value_infos,
name=''):
# I/O
val_x, = inputs
......@@ -309,14 +334,15 @@ def _adaptive_pool(prog, pool_type, inputs, outputs, attrs, value_infos,
var_indices = _make_var_name(val_indices)
# interpretation
pool_size = attrs['output_size'] # required
pool_size = attrs['output_size'] # required
output_shape = _shape_or_none(value_infos, val_y)
if output_shape is not None:
assert pool_size == output_shape[2:], 'pool_size unmatches shape of Y' # NC...
assert pool_size == output_shape[
2:], 'pool_size unmatches shape of Y' # NC...
poolnd = len(pool_size)
assert 2 <= poolnd <= 3, 'only pool2d and pool3d supported'
paddle_op = 'adaptive_pool{}d'.format(poolnd)
fluid_op = 'adaptive_pool{}d'.format(poolnd)
name_attr = ', name={}'.format(repr(name)) if name else ''
# generation
......@@ -324,35 +350,37 @@ def _adaptive_pool(prog, pool_type, inputs, outputs, attrs, value_infos,
', require_index={}'
', pool_size={}'
', pool_type={}'
'{})'
.format(var_y, ', {}'.format(var_indices) if has_indices else '',
paddle_op,
var_x,
# attrs
has_indices,
pool_size,
repr(pool_type),
name_attr,
))
paddle_op = 'pool{}d'.format(poolnd)
'{})'.format(
var_y,
', {}'.format(var_indices) if has_indices else '',
fluid_op,
var_x,
# attrs
has_indices,
pool_size,
repr(pool_type),
name_attr,
))
fluid_op = 'pool{}d'.format(poolnd)
prog.VarDesc(var_y)
if has_indices:
prog.VarDesc(var_indices)
prog.OpDesc(paddle_op,
([var_x], 'X'),
([var_y] + ([var_indices] if has_indices else []), 'Out', 'Indices'),
dict(global_pooling=False,
adaptive=True,
exclusive=True,
require_index=has_indices,
pooling_type=pool_type,
ksize=pool_size,
),
)
def _global_pool(prog, pool_type, inputs, outputs, attrs, value_infos,
name=''):
prog.OpDesc(
fluid_op,
([var_x], 'X'),
([var_y] + ([var_indices] if has_indices else []), 'Out', 'Indices'),
dict(
global_pooling=False,
adaptive=True,
exclusive=True,
require_index=has_indices,
pooling_type=pool_type,
ksize=pool_size,
),
)
def _global_pool(prog, pool_type, inputs, outputs, attrs, value_infos, name=''):
# I/O
val_x, = inputs
val_y, = outputs
......@@ -362,40 +390,41 @@ def _global_pool(prog, pool_type, inputs, outputs, attrs, value_infos,
# interpretation
input_shape = _shape_or_none(value_infos, val_x)
output_shape = _shape_or_none(value_infos, val_y)
assert input_shape is not None or output_shape is not None, 'poolnd not inferred' # NC...
assert input_shape is not None or output_shape is not None, 'poolnd not inferred' # NC...
if input_shape:
poolnd = len(input_shape) - 2 # NC...
poolnd = len(input_shape) - 2 # NC...
elif output_shape:
poolnd = len(output_shape) - 2 # NC...
poolnd = len(output_shape) - 2 # NC...
assert 2 <= poolnd <= 3, 'only pool2d and pool3d supported'
paddle_op = 'pool{}d'.format(poolnd)
fluid_op = 'pool{}d'.format(poolnd)
name_attr = ', name={}'.format(repr(name)) if name else ''
# generation
prog.Code('{} = layers.{}({}, global_pooling=True'
', pool_type={}'
'{})'
.format(var_y,
paddle_op,
var_x,
# attrs
repr(pool_type),
name_attr,
))
'{})'.format(
var_y,
fluid_op,
var_x,
# attrs
repr(pool_type),
name_attr,
))
prog.VarDesc(var_y)
prog.OpDesc(paddle_op,
([var_x], 'X'),
([var_y], 'Out'),
dict(global_pooling=True,
adaptive=False,
pooling_type=pool_type,
),
)
def _pool(prog, pool_type, inputs, outputs, attrs, value_infos,
name=''):
prog.OpDesc(
fluid_op,
([var_x], 'X'),
([var_y], 'Out'),
dict(
global_pooling=True,
adaptive=False,
pooling_type=pool_type,
),
)
def _pool(prog, pool_type, inputs, outputs, attrs, value_infos, name=''):
# I/O
val_x, = inputs
val_y, = outputs[:1]
......@@ -407,18 +436,20 @@ def _pool(prog, pool_type, inputs, outputs, attrs, value_infos,
var_indices = _make_var_name(val_indices)
# interpretation
assert attrs.get('auto_pad', 'NOTSET') == 'NOTSET', 'only auto_pad = NOTSET supported' # optional
pool_size = attrs['kernel_shape'] # required
assert attrs.get(
'auto_pad',
'NOTSET') == 'NOTSET', 'only auto_pad = NOTSET supported' # optional
pool_size = attrs['kernel_shape'] # required
poolnd = len(pool_size)
assert 2 <= poolnd <= 3, 'only pool2d and pool3d supported'
paddle_op = 'pool{}d'.format(poolnd)
strides = attrs.get('strides', [1] * poolnd) # optional
pads = attrs.get('pads', [0] * len(pool_size * 2)) # optional
fluid_op = 'pool{}d'.format(poolnd)
strides = attrs.get('strides', [1] * poolnd) # optional
pads = attrs.get('pads', [0] * len(pool_size * 2)) # optional
paddings, val_x_padded = _pad_if_asymmetric(prog, pads, val_x, value_infos)
if val_x_padded:
val_x = val_x_padded
ceil_mode = bool(attrs.get('ceil_mode', 0)) # optional
ceil_mode = bool(attrs.get('ceil_mode', 0)) # optional
var_x = _make_var_name(val_x)
name_attr = ', name={}'.format(repr(name)) if name else ''
......@@ -429,38 +460,41 @@ def _pool(prog, pool_type, inputs, outputs, attrs, value_infos,
', pool_stride={}'
', pool_padding={}'
', ceil_mode={}'
'{})'
.format(var_y, ', {}'.format(var_indices) if has_indices else '',
paddle_op,
var_x,
# attrs
pool_size,
repr(pool_type),
strides,
paddings,
ceil_mode,
name_attr,
))
'{})'.format(
var_y,
', {}'.format(var_indices) if has_indices else '',
fluid_op,
var_x,
# attrs
pool_size,
repr(pool_type),
strides,
paddings,
ceil_mode,
name_attr,
))
prog.VarDesc(var_y)
if has_indices:
prog.VarDesc(var_indices)
prog.OpDesc(paddle_op,
([var_x], 'X'),
([var_y] + ([var_indices] if has_indices else []), 'Out', 'Indices'),
dict(global_pooling=False,
adaptive=False,
exclusive=True,
require_index=has_indices,
pooling_type=pool_type,
ksize=pool_size,
strides=strides,
pool_padding=paddings,
ceil_mode=ceil_mode,
),
)
def _roi_pool(prog, paddle_op, inputs, outputs, attrs, value_infos, name):
prog.OpDesc(
fluid_op,
([var_x], 'X'),
([var_y] + ([var_indices] if has_indices else []), 'Out', 'Indices'),
dict(
global_pooling=False,
adaptive=False,
exclusive=True,
require_index=has_indices,
pooling_type=pool_type,
ksize=pool_size,
strides=strides,
paddings=paddings,
ceil_mode=ceil_mode,
),
)
def _roi_pool(prog, fluid_op, inputs, outputs, attrs, value_infos, name):
# I/O
val_x, val_rois = inputs
val_y, = outputs
......@@ -469,15 +503,15 @@ def _roi_pool(prog, paddle_op, inputs, outputs, attrs, value_infos, name):
var_y = _make_var_name(val_y)
# interpretation
spatial_scale=attrs['spatial_scale'] # required
pooled_height, pooled_width = attrs['pooled_shape'] # required
spatial_scale = attrs['spatial_scale'] # required
pooled_height, pooled_width = attrs['pooled_shape'] # required
od_attrs = dict(
spatial_scale=spatial_scale,
pooled_height=pooled_height,
pooled_width=pooled_width,
spatial_scale=spatial_scale,
pooled_height=pooled_height,
pooled_width=pooled_width,
)
feature_attr = ''
is_max_pool = paddle_op == 'roi_pool'
is_max_pool = fluid_op == 'roi_pool'
if 'sampling_ratio' in attrs:
sampling_ratio = attrs['sampling_ratio']
od_attrs['sampling_ratio'] = sampling_ratio
......@@ -492,77 +526,88 @@ def _roi_pool(prog, paddle_op, inputs, outputs, attrs, value_infos, name):
', spatial_scale={}'
', pooled_height={}'
', pooled_width={}'
'{})'
.format(var_y,
paddle_op,
val_x, var_rois,
# attrs
spatial_scale,
pooled_height,
pooled_width,
feature_attr,
))
'{})'.format(
var_y,
fluid_op,
val_x,
var_rois,
# attrs
spatial_scale,
pooled_height,
pooled_width,
feature_attr,
))
prog.VarDesc(var_y)
if is_max_pool:
var_argmax = _make_var_name(name + '.argmax') # implicit variable
var_argmax = _make_var_name(name + '.argmax') # implicit variable
prog.VarDesc(var_argmax)
prog.OpDesc(paddle_op,
([var_x, var_rois], 'X', 'Rois'),
([var_y] + ([var_argmax] if is_max_pool else []), 'Out', 'Argmax'),
od_attrs,
)
prog.OpDesc(
fluid_op,
([var_x, var_rois], 'X', 'Rois'),
([var_y] + ([var_argmax] if is_max_pool else []), 'Out', 'Argmax'),
od_attrs,
)
def _zeros_like(prog, val_ref, val_out, value_infos):
prog.Op('', 'Sub',
[val_ref, val_ref],
[val_out], # val
dict(axis=0),
value_infos,
)
prog.Op(
'',
'Sub',
[val_ref, val_ref],
[val_out], # val
dict(axis=0),
value_infos,
)
def AdaptiveAveragePool(
prog, inputs, outputs, attrs, value_infos,
name='',
*args, **kwargs):
def AdaptiveAveragePool(prog,
inputs,
outputs,
attrs,
value_infos,
name='',
*args,
**kwargs):
"""
aten::adaptive_avg_poolnd
"""
return _adaptive_pool(prog, 'avg', inputs, outputs, attrs, value_infos,
name=name)
return _adaptive_pool(
prog, 'avg', inputs, outputs, attrs, value_infos, name=name)
def AdaptiveMaxPool(
prog, inputs, outputs, attrs, value_infos,
name='',
*args, **kwargs):
def AdaptiveMaxPool(prog,
inputs,
outputs,
attrs,
value_infos,
name='',
*args,
**kwargs):
"""
aten::adaptive_max_poolnd
"""
return _adaptive_pool(prog, 'max', inputs, outputs, attrs, value_infos,
name=name)
return _adaptive_pool(
prog, 'max', inputs, outputs, attrs, value_infos, name=name)
def AveragePool(
prog, inputs, outputs, attrs, value_infos,
name='',
*args, **kwargs):
def AveragePool(prog,
inputs,
outputs,
attrs,
value_infos,
name='',
*args,
**kwargs):
"""
onnx::AveragePool-10:
"""
return _pool(prog, 'avg', inputs, outputs, attrs, value_infos,
name=name)
return _pool(prog, 'avg', inputs, outputs, attrs, value_infos, name=name)
def AffineGrid(
prog, inputs, outputs, attrs,
*args,
name='',
**kwargs):
def AffineGrid(prog, inputs, outputs, attrs, *args, name='', **kwargs):
"""
aten::affine_grid
"""
......@@ -574,33 +619,39 @@ def AffineGrid(
var_grid = _make_var_name(val_grid)
# interpretation
paddle_op = 'affine_grid'
size = attrs['size'] # required
fluid_op = 'affine_grid'
size = attrs['size'] # required
name_attr = ', name={}'.format(repr(name)) if name else ''
# generation
prog.Code('{} = layers.{}({}'
', out_shape={}'
'{})'
.format(var_grid,
paddle_op,
var_theta,
# attrs
size,
name_attr,
))
'{})'.format(
var_grid,
fluid_op,
var_theta,
# attrs
size,
name_attr,
))
prog.VarDesc(var_grid)
prog.OpDesc(paddle_op,
([var_theta], 'Theta'),
([var_grid], 'Output'),
dict(output_shape=size), # f**k you API
)
prog.OpDesc(
fluid_op,
([var_theta], 'Theta'),
([var_grid], 'Output'),
dict(output_shape=size), # f**k you API
)
def BatchNormalization(
prog, inputs, outputs, attrs, value_infos,
name='', embed_params=False,
*args, **kwargs):
def BatchNormalization(prog,
inputs,
outputs,
attrs,
value_infos,
name='',
embed_params=False,
*args,
**kwargs):
"""
onnx::BatchNormalization-9:
"""
......@@ -612,9 +663,9 @@ def BatchNormalization(
var_y = _make_var_name(val_y)
# interpretation
paddle_op = 'batch_norm'
momentum = attrs.get('momentum', .9) # optional
epsilon = attrs.get('epsilon', 1e-5) # optional
fluid_op = 'batch_norm'
momentum = attrs.get('momentum', .9) # optional
epsilon = attrs.get('epsilon', 1e-5) # optional
name_attr = ', name={}'.format(repr(name)) if name else ''
if embed_params:
assert name != ''
......@@ -633,43 +684,45 @@ def BatchNormalization(
var_mean = _make_var_name(val_mean)
var_var = _make_var_name(val_var)
param_attr = (', param_attr={}, bias_attr={}'
', moving_mean_name={}, moving_variance_name={}'
).format(repr(var_scale), repr(var_b), repr(var_mean), repr(var_var))
var_saved_mean = '{}.saved_mean'.format(name) # dropped var
var_saved_variance = '{}.saved_variance'.format(name) # dropped var
', moving_mean_name={}, moving_variance_name={}').format(
repr(var_scale), repr(var_b), repr(var_mean),
repr(var_var))
var_saved_mean = '{}.saved_mean'.format(name) # dropped var
var_saved_variance = '{}.saved_variance'.format(name) # dropped var
# generationvalue_infos
prog.Code('{} = layers.{}({}, is_test=True, data_layout="NCHW"'
', momentum={}'
', epsilon={}'
'{}{})'
.format(var_y,
paddle_op,
var_x,
# attrs
momentum,
epsilon,
param_attr, name_attr,
))
'{}{})'.format(
var_y,
fluid_op,
var_x,
# attrs
momentum,
epsilon,
param_attr,
name_attr,
))
prog.VarDesc(var_y)
prog.VarDesc(var_saved_mean)
prog.VarDesc(var_saved_variance)
prog.OpDesc(paddle_op,
([var_x, var_scale, var_b, var_mean, var_var],
'X', 'Scale', 'Bias', 'Mean', 'Variance'),
([var_y, var_mean, var_saved_mean, var_saved_variance, var_var],
'Y', 'MeanOut', 'SavedMean', 'SavedVariance', 'VarianceOut'),
dict(is_test=1,
data_layout='NCHW',
use_global_stats=False,
momentum=momentum,
epsilon=epsilon),
)
def Cast(
prog, inputs, outputs, attrs, value_infos,
*args, **kwargs):
prog.OpDesc(
fluid_op,
([var_x, var_scale, var_b, var_mean, var_var], 'X', 'Scale', 'Bias',
'Mean', 'Variance'),
([var_y, var_mean, var_saved_mean, var_saved_variance, var_var], 'Y',
'MeanOut', 'SavedMean', 'SavedVariance', 'VarianceOut'),
dict(
is_test=1,
data_layout='NCHW',
use_global_stats=False,
momentum=momentum,
epsilon=epsilon),
)
def Cast(prog, inputs, outputs, attrs, value_infos, *args, **kwargs):
"""
onnx::Cast-9:
"""
......@@ -683,38 +736,36 @@ def Cast(
# interpretation
dtype = attrs['to']
if not isinstance(dtype, np.dtype):
dtype = TENSOR_TYPE_TO_NP_TYPE[dtype] # required
dtype = TENSOR_TYPE_TO_NP_TYPE[dtype] # required
output_dtype = _dtype_or_none(value_infos, val_output)
if output_dtype:
assert dtype == output_dtype, 'dtype of to unmatches output'
paddle_op = 'cast'
fluid_op = 'cast'
# generation
prog.Code('{} = layers.{}({}'
', dtype={}'
')'
.format(var_output,
paddle_op,
var_input,
# attrs
repr(dtype.name),
))
')'.format(
var_output,
fluid_op,
var_input,
# attrs
repr(dtype.name),
))
prog.VarDesc(var_output)
prog.OpDesc(paddle_op,
([var_input], 'X'),
([var_output], 'Out'),
dict(in_dtype=prog.Dtype(_dtype(value_infos, val_input)), # holy, required
out_dtype=prog.Dtype(dtype),
)
)
def Concat(
prog, inputs, outputs, attrs,
*args,
name='',
**kwargs):
prog.OpDesc(
fluid_op,
([var_input], 'X'),
([var_output], 'Out'),
dict(
in_dtype=prog.Dtype(_dtype(value_infos,
val_input)), # holy, required
out_dtype=prog.Dtype(dtype),
))
def Concat(prog, inputs, outputs, attrs, *args, name='', **kwargs):
"""
onnx::Concat-4:
"""
......@@ -725,32 +776,31 @@ def Concat(
var_concat_result = _make_var_name(val_concat_result)
# interpretation
paddle_op = 'concat'
axis = attrs['axis'] # required
fluid_op = 'concat'
axis = attrs['axis'] # required
name_attr = ', name={}'.format(repr(name)) if name else ''
# generation
prog.Code('{} = layers.{}({}'
', axis={}'
'{})'
.format(var_concat_result,
paddle_op,
'[' + ', '.join(var_inps) + ']',
# attrs
axis,
name_attr,
))
'{})'.format(
var_concat_result,
fluid_op,
'[' + ', '.join(var_inps) + ']',
# attrs
axis,
name_attr,
))
prog.VarDesc(var_concat_result)
prog.OpDesc(paddle_op,
(var_inps, *(['X'] * len(var_inps))),
([var_concat_result], 'Out'),
dict(axis=axis),
)
prog.OpDesc(
fluid_op,
(var_inps, *(['X'] * len(var_inps))),
([var_concat_result], 'Out'),
dict(axis=axis),
)
def Constant(
prog, inputs, outputs, attrs, value_infos,
*args, **kwargs):
def Constant(prog, inputs, outputs, attrs, value_infos, *args, **kwargs):
"""
onnx::Constant-9:
"""
......@@ -761,41 +811,45 @@ def Constant(
var_output = _make_var_name(val_output)
# interpretation
value = attrs['value'] # required
value = attrs['value'] # required
dtype = np.dtype(value.dtype)
output_dtype = _dtype_or_none(value_infos, val_output)
if output_dtype:
assert dtype == output_dtype, 'tensor dtype unmatches storage dtype'
# dtype = np.dtype('float32') # force to float32
shape = attrs.get('shape', None) # additional, maybe var_name
shape = attrs.get('shape', None) # additional, maybe var_name
if shape is None:
shape = _shape_or_none(value_infos, val_output)
if shape is None:
shape = list(value.shape)
_logger.warning('shape of %s not inferred, using value as 1-D tensor may lead to fails', val_output)
_logger.warning(
'shape of %s not inferred, using value as 1-D tensor may lead to fails',
val_output)
# generation
if value.size == 1: # scalar
paddle_op = 'fill_constant'
prog.Code('{} = layers.{}(shape={}, dtype={}, value={})'
.format(var_output,
paddle_op,
# attrs
shape, repr(dtype.name), value[0], # shape can be list or var_name
))
if value.size == 1: # scalar
fluid_op = 'fill_constant'
prog.Code('{} = layers.{}(shape={}, dtype={}, value={})'.format(
var_output,
fluid_op,
# attrs
shape,
repr(dtype.name),
value[0], # shape can be list or var_name
))
value_infos[val_output]['const_value'] = value[0]
prog.VarDesc(var_output)
else: # list parameter -> const_value
prog.Code('{} = {}'
.format(var_output,
value.tolist(),
))
else: # list parameter -> const_value
prog.Code('{} = {}'.format(
var_output,
value.tolist(),
))
value_infos[val_output]['const_value'] = value.tolist()
def ConstantOfShape(
prog, inputs, outputs, attrs, value_infos,
*args, **kwargs):
def ConstantOfShape(prog, inputs, outputs, attrs, value_infos, *args, **kwargs):
"""
onnx::ConstantOfShape-9:
"""
......@@ -810,15 +864,20 @@ def ConstantOfShape(
shape = value_infos[val_input]['get_weight']()
dtype = attrs['value'].dtype
attrs = attrs.copy()
attrs.update(dict(shape=shape, dtype=dtype)) # pass var_name
attrs.update(dict(shape=shape, dtype=dtype)) # pass var_name
Constant(prog, [], outputs, attrs, value_infos)
def Conv(
prog, inputs, outputs, attrs, value_infos,
name='', embed_params=False,
*args, **kwargs):
def Conv(prog,
inputs,
outputs,
attrs,
value_infos,
name='',
embed_params=False,
*args,
**kwargs):
"""
onnx::ConstantOfShape-1:
"""
......@@ -833,21 +892,24 @@ def Conv(
val_b, = inputs[2:]
# interpretation
assert attrs.get('auto_pad', 'NOTSET') == 'NOTSET', 'only auto_pad == NOTSET supported' # optional
kernel_shape = _shape(value_infos, val_w)[2:] # OI...
assert kernel_shape == attrs['kernel_shape'], 'kernel_shape in attr unmatches value_info' # HW
assert attrs.get(
'auto_pad',
'NOTSET') == 'NOTSET', 'only auto_pad == NOTSET supported' # optional
kernel_shape = _shape(value_infos, val_w)[2:] # OI...
assert kernel_shape == attrs[
'kernel_shape'], 'kernel_shape in attr unmatches value_info' # HW
convnd = len(kernel_shape)
assert 2 <= convnd <= 3, 'only conv2d and conv3d supported'
num_out_channels = _shape(value_infos, val_w)[0] # OI...
num_out_channels = _shape(value_infos, val_w)[0] # OI...
paddle_op = 'conv{}d'.format(convnd)
strides = attrs.get('strides', [1] * convnd) # optional
pads = attrs.get('pads', [0] * convnd * 2) # optional
fluid_op = 'conv{}d'.format(convnd)
strides = attrs.get('strides', [1] * convnd) # optional
pads = attrs.get('pads', [0] * convnd * 2) # optional
paddings, val_x_padded = _pad_if_asymmetric(prog, pads, val_x, value_infos)
if val_x_padded:
val_x = val_x_padded
dilations = attrs.get('dilations', [1] * convnd) # optional
num_groups = attrs.get('group', 1) # optional
dilations = attrs.get('dilations', [1] * convnd) # optional
num_groups = attrs.get('group', 1) # optional
var_x = _make_var_name(val_x)
name_attr = ', name={}'.format(repr(name)) if name else ''
if embed_params:
......@@ -864,7 +926,8 @@ def Conv(
var_w = _make_var_name(val_w)
var_b = _make_var_name(val_b) if has_bias else False
param_attr = ', param_attr={}, bias_attr={}'.format(
repr(var_w), repr(var_b) if var_b else False)
repr(var_w),
repr(var_b) if var_b else False)
# generation
prog.Code('{} = layers.{}({}'
......@@ -874,46 +937,55 @@ def Conv(
', padding={}'
', dilation={}'
', groups={}'
'{}{})'
.format(var_y,
paddle_op,
var_x,
# attrs
num_out_channels,
kernel_shape,
strides,
paddings,
dilations,
num_groups,
param_attr, name_attr,
))
var_conv = _make_var_name(name + '.conv') # hidden variable
prog.OpDesc(paddle_op,
([var_x, var_w], 'Input', 'Filter'), # , 'Bias', 'ResidualData'
([var_conv if has_bias else var_y], 'Output'),
dict(strides=strides,
paddings=paddings,
dilations=dilations,
groups=num_groups,
))
'{}{})'.format(
var_y,
fluid_op,
var_x,
# attrs
num_out_channels,
kernel_shape,
strides,
paddings,
dilations,
num_groups,
param_attr,
name_attr,
))
var_conv = _make_var_name(name + '.conv') # hidden variable
prog.OpDesc(
fluid_op,
([var_x, var_w], 'Input', 'Filter'), # , 'Bias', 'ResidualData'
([var_conv if has_bias else var_y], 'Output'),
dict(
strides=strides,
paddings=paddings,
dilations=dilations,
groups=num_groups,
))
if has_bias:
prog.VarDesc(var_conv)
prog.IntermediateOp(
'', 'Add',
[var_conv, var_b],
[var_y], # var
dict(axis=1),
value_infos=value_infos,
name=(name + '.bias'),
)
'',
'Add',
[var_conv, var_b],
[var_y], # var
dict(axis=1),
value_infos=value_infos,
name=(name + '.bias'),
)
else:
prog.VarDesc(var_y)
def ConvTranspose(
prog, inputs, outputs, attrs, value_infos,
name='', embed_params=False,
*args, **kwargs):
def ConvTranspose(prog,
inputs,
outputs,
attrs,
value_infos,
name='',
embed_params=False,
*args,
**kwargs):
"""
onnx::ConvTranspose-1:
"""
......@@ -928,22 +1000,27 @@ def ConvTranspose(
val_b, = inputs[2:]
# interpretation
assert attrs.get('auto_pad', 'NOTSET') == 'NOTSET', 'only auto_pad == NOTSET supported' # optional
assert sum(attrs.get('output_padding', [])) == 0, 'only zero output_padding supported' # optional ?
kernel_shape = _shape(value_infos, val_w)[2:] # IO...
assert kernel_shape == attrs['kernel_shape'], 'kernel_shape in attr unmatches value_info' # HW
assert attrs.get(
'auto_pad',
'NOTSET') == 'NOTSET', 'only auto_pad == NOTSET supported' # optional
assert sum(
attrs.get('output_padding',
[])) == 0, 'only zero output_padding supported' # optional ?
kernel_shape = _shape(value_infos, val_w)[2:] # IO...
assert kernel_shape == attrs[
'kernel_shape'], 'kernel_shape in attr unmatches value_info' # HW
convnd = len(kernel_shape)
assert 2 <= convnd <= 3, 'only conv2d_transpose and conv3d_transpose supported'
num_out_channels = _shape(value_infos, val_w)[1] # IO...
num_out_channels = _shape(value_infos, val_w)[1] # IO...
paddle_op = 'conv{}d_transpose'.format(convnd)
strides = attrs.get('strides', [1] * convnd) # optional
pads = attrs.get('pads', [0] * convnd * 2) # optional
fluid_op = 'conv{}d_transpose'.format(convnd)
strides = attrs.get('strides', [1] * convnd) # optional
pads = attrs.get('pads', [0] * convnd * 2) # optional
paddings, val_x_padded = _pad_if_asymmetric(prog, pads, val_x, value_infos)
if val_x_padded:
val_x = val_x_padded
dilations = attrs.get('dilations', [1] * convnd) # optional
num_groups = attrs.get('group', 1) # optional
dilations = attrs.get('dilations', [1] * convnd) # optional
num_groups = attrs.get('group', 1) # optional
var_x = _make_var_name(val_x)
name_attr = ', name={}'.format(repr(name)) if name else ''
if embed_params:
......@@ -960,50 +1037,55 @@ def ConvTranspose(
var_w = _make_var_name(val_w)
var_b = _make_var_name(val_b) if has_bias else False
param_attr = ', param_attr={}, bias_attr={}'.format(
repr(var_w), repr(var_b) if var_b else False)
repr(var_w),
repr(var_b) if var_b else False)
# generation
prog.Code('{} = layers.{}({}'
', num_filters={}'
# ', output_size={}'
# ', output_size={}'
', filter_size={}'
', padding={}'
', stride={}'
', dilation={}'
', groups={}'
'{}{})'
.format(var_y,
paddle_op,
var_x,
# attrs
num_out_channels,
kernel_shape,
paddings,
strides,
dilations,
num_groups,
param_attr, name_attr,
))
var_conv = _make_var_name(name + '.conv') # hidden variable
prog.OpDesc(paddle_op,
([var_x, var_w], 'Input', 'Filter'), # , 'Bias', 'ResidualData'
([var_conv if has_bias else var_y], 'Output'),
dict(strides=strides,
paddings=paddings,
dilations=dilations,
# output_size=output_size,
groups=num_groups,
))
'{}{})'.format(
var_y,
fluid_op,
var_x,
# attrs
num_out_channels,
kernel_shape,
paddings,
strides,
dilations,
num_groups,
param_attr,
name_attr,
))
var_conv = _make_var_name(name + '.conv') # hidden variable
prog.OpDesc(
fluid_op,
([var_x, var_w], 'Input', 'Filter'), # , 'Bias', 'ResidualData'
([var_conv if has_bias else var_y], 'Output'),
dict(
strides=strides,
paddings=paddings,
dilations=dilations,
# output_size=output_size,
groups=num_groups,
))
if has_bias:
prog.VarDesc(var_conv)
prog.IntermediateOp(
'', 'Add',
[var_conv, var_b],
[var_y], # var
dict(axis=1),
value_infos=value_infos,
name=(name + '.bias'),
)
'',
'Add',
[var_conv, var_b],
[var_y], # var
dict(axis=1),
value_infos=value_infos,
name=(name + '.bias'),
)
else:
prog.VarDesc(var_y)
......@@ -1025,97 +1107,115 @@ def ConvTranspose(
# )
def Gemm(
prog, inputs, outputs, attrs, value_infos, name,
*args, **kwargs):
def Gemm(prog, inputs, outputs, attrs, value_infos, name, *args, **kwargs):
"""
onnx::Gemm-9:
"""
# due to paddle fc don't support transposed weight, we use matmul + ew_add
# due to fluid fc don't support transposed weight, we use matmul + ew_add
val_a, val_b, val_c = inputs
val_y, = outputs
alpha = attrs.get('alpha', 1.) # optional
beta = attrs.get('beta', 1.) # optional
trans_a = bool(attrs.get('transA', 0)) # optional
trans_b = bool(attrs.get('transB', 0)) # optional
val_mm = name + '_mm' # explicit variable
prog.Op('', 'MatMul',
[val_a, val_b],
[val_mm], # val
dict(transpose_x=trans_a,
transpose_y=trans_b,
alpha=alpha,
),
value_infos=value_infos,
name=val_mm,
)
prog.op_descs[-1].attrs.extend(prog.OpDescAttrs(dict(
alpha = attrs.get('alpha', 1.) # optional
beta = attrs.get('beta', 1.) # optional
trans_a = bool(attrs.get('transA', 0)) # optional
trans_b = bool(attrs.get('transB', 0)) # optional
val_mm = name + '_mm' # explicit variable
prog.Op(
'',
'MatMul',
[val_a, val_b],
[val_mm], # val
dict(
transpose_x=trans_a,
transpose_y=trans_b,
alpha=alpha,
),
value_infos=value_infos,
name=val_mm,
)
prog.op_descs[-1].attrs.extend(
prog.OpDescAttrs(dict(
transpose_X=trans_a,
transpose_Y=trans_b,
))) # f**k you API
))) # f**k you API
if beta != 0:
if beta == 1.: # exactly
prog.Op('', 'Add',
[val_mm, val_c],
[val_y], # val
dict(axis=1),
value_infos=value_infos,
name=(name + '_beta'),
)
if beta == 1.: # exactly
prog.Op(
'',
'Add',
[val_mm, val_c],
[val_y], # val
dict(axis=1),
value_infos=value_infos,
name=(name + '_beta'),
)
else:
val_beta = name + '_beta' # explicit variable
val_vm = name + '_vm' # explicit variable
val_beta = name + '_beta' # explicit variable
val_vm = name + '_vm' # explicit variable
vm_dtype = _dtype_or_none(value_infos, val_c)
if vm_dtype is None:
vm_dtype = np.dtype('float32')
beta = np.dtype(vm_dtype).type(beta)
prog.Op('', 'Constant',
[],
[val_beta], # val
dict(value=beta),
value_infos=value_infos,
name=val_beta,
)
prog.Op('', 'Mul',
[val_c, val_beta],
[val_vm], # val
dict(),
value_infos=value_infos,
name=(name + '_scale'),
)
prog.Op('', 'Add',
[val_mm, val_vm],
[val_y], # val
dict(axis=1),
name=(name + '_bias'),
)
def GlobalAveragePool(
prog, inputs, outputs, attrs, value_infos,
name='',
*args, **kwargs):
prog.Op(
'',
'Constant',
[],
[val_beta], # val
dict(value=beta),
value_infos=value_infos,
name=val_beta,
)
prog.Op(
'',
'Mul',
[val_c, val_beta],
[val_vm], # val
dict(),
value_infos=value_infos,
name=(name + '_scale'),
)
prog.Op(
'',
'Add',
[val_mm, val_vm],
[val_y], # val
dict(axis=1),
name=(name + '_bias'),
)
def GlobalAveragePool(prog,
inputs,
outputs,
attrs,
value_infos,
name='',
*args,
**kwargs):
"""
onnx::GlobalAveragePool-1:
"""
return _global_pool(prog, 'avg', inputs, outputs, attrs, value_infos,
name=name)
return _global_pool(
prog, 'avg', inputs, outputs, attrs, value_infos, name=name)
def GlobalMaxPool(
prog, inputs, outputs, attrs, value_infos,
name='',
*args, **kwargs):
def GlobalMaxPool(prog,
inputs,
outputs,
attrs,
value_infos,
name='',
*args,
**kwargs):
"""
onnx::GlobalMaxPool-1:
"""
return _global_pool(prog, 'max', inputs, outputs, attrs, value_infos,
name=name)
return _global_pool(
prog, 'max', inputs, outputs, attrs, value_infos, name=name)
#def LRN(
......@@ -1132,7 +1232,7 @@ def GlobalMaxPool(
# var_y = _make_var_name(val_y)
#
# # interpretation
# paddle_op = 'lrn'
# fluid_op = 'lrn'
# size = attrs['size'] # required
# alpha = attrs.get('alpha', 0.0001) # optional
# beta = attrs.get('beta', 0.75) # optional
......@@ -1147,7 +1247,7 @@ def GlobalMaxPool(
# ', beta={}'
# '{})'
# .format(var_y,
# paddle_op,
# fluid_op,
# var_x,
# # attrs
# size,
......@@ -1159,7 +1259,7 @@ def GlobalMaxPool(
# var_mid = name + '.mid' # hidden variable
# prog.VarDesc(var_y)
# prog.VarDesc(var_mid)
# prog.OpDesc(paddle_op,
# prog.OpDesc(fluid_op,
# ([var_x], 'X'),
# ([var_y, var_mid], 'Out', 'MidOut'),
# dict(n=size,
......@@ -1170,21 +1270,17 @@ def GlobalMaxPool(
# )
def MaxPool(
prog, inputs, outputs, attrs, value_infos,
name='',
*args, **kwargs):
def MaxPool(prog, inputs, outputs, attrs, value_infos, name='', *args,
**kwargs):
"""
onnx::MaxPool-10:
"""
return _pool(prog, 'max', inputs, outputs, attrs, value_infos,
name=name)
return _pool(prog, 'max', inputs, outputs, attrs, value_infos, name=name)
def MaxRoiPool(
prog, inputs, outputs, attrs, value_infos, name,
*args, **kwargs):
def MaxRoiPool(prog, inputs, outputs, attrs, value_infos, name, *args,
**kwargs):
"""
onnx::MaxRoiPool-1:
"""
......@@ -1192,9 +1288,7 @@ def MaxRoiPool(
_roi_pool(prog, 'roi_pool', inputs, outputs, attrs, value_infos, name)
def RoiAlign(
prog, inputs, outputs, attrs, value_infos, name,
*args, **kwargs):
def RoiAlign(prog, inputs, outputs, attrs, value_infos, name, *args, **kwargs):
"""
caffe2::RoiAlign
"""
......@@ -1202,10 +1296,7 @@ def RoiAlign(
_roi_pool(prog, 'roi_align', inputs, outputs, attrs, value_infos, name)
def Pad(
prog, inputs, outputs, attrs, value_infos,
name='',
*args, **kwargs):
def Pad(prog, inputs, outputs, attrs, value_infos, name='', *args, **kwargs):
"""
onnx::Pad-2:
"""
......@@ -1217,28 +1308,31 @@ def Pad(
var_output = _make_var_name(val_output)
# interpretation
pads = attrs['pads'] # required
mode = attrs.get('mode', 'constant') # optional
value = attrs.get('value', 0.) # optional
pads = attrs['pads'] # required
mode = attrs.get('mode', 'constant') # optional
value = attrs.get('value', 0.) # optional
data_shape = _shape_or_none(value_infos, val_data)
output_shape = _shape_or_none(value_infos, val_output)
assume_pad2d = False
if len(pads) == 4:
assume_pad2d |= mode != 'constant'
if data_shape:
assume_pad2d |= data_shape and len(data_shape) == 4 # NCHW
assume_pad2d |= data_shape and len(data_shape) == 4 # NCHW
if output_shape:
assume_pad2d |= output_shape and len(output_shape) == 4 # NCHW
assume_pad2d |= output_shape and len(output_shape) == 4 # NCHW
od_attrs = dict(pad_value=value)
if assume_pad2d:
paddle_op = 'pad2d'
fluid_op = 'pad2d'
pad2d_attr = ', mode={}, data_format="NCHW"'.format(repr(mode))
od_attrs['mode'] = mode
od_attrs['data_format'] = "NCHW"
else:
assert mode == 'constant', 'mode {} is supported only in pad2d'.format(mode)
paddle_op = 'pad'
assert mode == 'constant', 'mode {} is supported only in pad2d'.format(
mode)
fluid_op = 'pad'
pad2d_attr = ''
paddings = np.array(pads).reshape((-1, 2)).transpose().flatten().tolist() # SSEE -> SESE
paddings = np.array(pads).reshape(
(-1, 2)).transpose().flatten().tolist() # SSEE -> SESE
od_attrs['paddings'] = paddings
name_attr = ', name={}'.format(repr(name)) if name else ''
......@@ -1246,27 +1340,34 @@ def Pad(
prog.Code('{} = layers.{}({}'
', paddings={}'
', pad_value={}'
'{}{})'
.format(var_output,
paddle_op,
var_data,
# attrs
paddings,
value,
pad2d_attr, name_attr,
))
'{}{})'.format(
var_output,
fluid_op,
var_data,
# attrs
paddings,
value,
pad2d_attr,
name_attr,
))
prog.VarDesc(var_output)
prog.OpDesc(paddle_op,
([var_data], 'X'),
([var_output], 'Out'),
od_attrs,
)
prog.OpDesc(
fluid_op,
([var_data], 'X'),
([var_output], 'Out'),
od_attrs,
)
def PRelu(
prog, inputs, outputs, attrs, value_infos,
name='', embed_params=False,
*args, **kwargs):
def PRelu(prog,
inputs,
outputs,
attrs,
value_infos,
name='',
embed_params=False,
*args,
**kwargs):
"""
onnx::PRelu-9:
"""
......@@ -1278,7 +1379,7 @@ def PRelu(
var_y = _make_var_name(val_y)
# interpretation
paddle_op = 'prelu'
fluid_op = 'prelu'
name_attr = ', name={}'.format(repr(name)) if name else ''
if embed_params:
assert name != ''
......@@ -1291,24 +1392,24 @@ def PRelu(
# generation
prog.Code('{} = layers.{}({}, mode="all"'
'{}{})'
.format(var_y,
paddle_op,
var_x,
# attrs
param_attr, name_attr,
))
'{}{})'.format(
var_y,
fluid_op,
var_x,
# attrs
param_attr,
name_attr,
))
prog.VarDesc(var_y)
prog.OpDesc(paddle_op,
([var_x], 'X'),
([var_y], 'Out'),
dict(mode='all'),
)
prog.OpDesc(
fluid_op,
([var_x], 'X'),
([var_y], 'Out'),
dict(mode='all'),
)
def PsRoiPool(
prog, inputs, outputs, attrs, value_infos, name,
*args, **kwargs):
def PsRoiPool(prog, inputs, outputs, attrs, value_infos, name, *args, **kwargs):
"""
caffe2::PsRoiPool
"""
......@@ -1316,9 +1417,7 @@ def PsRoiPool(
_roi_pool(prog, 'psroi_pool', inputs, outputs, attrs, value_infos, name)
def Reshape(
prog, inputs, outputs, attrs, value_infos, name,
*args, **kwargs):
def Reshape(prog, inputs, outputs, attrs, value_infos, name, *args, **kwargs):
"""
onnx::Reshape-5:
"""
......@@ -1330,69 +1429,71 @@ def Reshape(
var_reshaped = _make_var_name(val_reshaped)
# interpretation
paddle_op = 'reshape'
fluid_op = 'reshape'
is_const_shape = 'const_value' in value_infos[val_shape]
var_shape = _make_var_name(val_shape) # for code
var_shape = _make_var_name(val_shape) # for code
if is_const_shape:
shape = value_infos[val_shape]['const_value'] # for desc
shape = value_infos[val_shape]['const_value'] # for desc
else:
shape = value_infos[val_shape]['get_weight']() # for desc
shape = value_infos[val_shape]['get_weight']() # for desc
name_attr = ', name={}'.format(repr(name)) if name else ''
# generation
if is_const_shape:
prog.Code('{} = layers.{}({}'
', shape={}'
'{})'
.format(var_reshaped,
paddle_op,
var_data,
# attrs
var_shape,
name_attr,
))
'{})'.format(
var_reshaped,
fluid_op,
var_data,
# attrs
var_shape,
name_attr,
))
else:
var_shape_int32 = var_shape + '_int32'
prog.Op('', 'Cast',
[var_shape],
[var_shape_int32], # var
dict(to=np.dtype('int32')),
value_infos=value_infos,
name=(name + '_cast'),
)
prog.Op(
'',
'Cast',
[var_shape],
[var_shape_int32], # var
dict(to=np.dtype('int32')),
value_infos=value_infos,
name=(name + '_cast'),
)
prog.Code('{} = layers.{}({}'
', shape={}'
', actual_shape={}'
'{})'
.format(var_reshaped,
paddle_op,
var_data,
# attrs
shape,
var_shape_int32,
name_attr,
))
paddle_op = 'reshape2'
'{})'.format(
var_reshaped,
fluid_op,
var_data,
# attrs
shape,
var_shape_int32,
name_attr,
))
fluid_op = 'reshape2'
var_xshape = _make_var_name(name + '.xshape')
prog.VarDesc(var_reshaped)
prog.VarDesc(var_xshape)
if is_const_shape:
prog.OpDesc(paddle_op,
([var_data], 'X'),
([var_reshaped, var_xshape], 'Out', 'XShape'),
dict(shape=shape),
)
prog.OpDesc(
fluid_op,
([var_data], 'X'),
([var_reshaped, var_xshape], 'Out', 'XShape'),
dict(shape=shape),
)
else:
prog.OpDesc(paddle_op,
([var_data, var_shape_int32], 'X', 'Shape'),
([var_reshaped, var_xshape], 'Out', 'XShape'),
dict(shape=shape),
)
prog.OpDesc(
fluid_op,
([var_data, var_shape_int32], 'X', 'Shape'),
([var_reshaped, var_xshape], 'Out', 'XShape'),
dict(shape=shape),
)
def Slice(
prog, inputs, outputs, attrs, value_infos,
*args, **kwargs):
def Slice(prog, inputs, outputs, attrs, value_infos, *args, **kwargs):
"""
onnx::Slice-1:9
"""
......@@ -1404,17 +1505,17 @@ def Slice(
var_output = _make_var_name(val_output)
# interpretation
paddle_op = 'slice'
axes = attrs['axes'] # required
starts = attrs['starts'] # required
ends = attrs['ends'] # required
fluid_op = 'slice'
axes = attrs['axes'] # required
starts = attrs['starts'] # required
ends = attrs['ends'] # required
shape = _shape_or_none(value_infos, val_data)
if shape:
ndims = len(shape)
for idx, value in enumerate(axes):
if value > ONNX_INT_MAX // 2:
axes[idx] = ndims + value - ONNX_INT_MAX - 1
# HINT: Paddle 1.3 Doc: '对于未知大小维度的末尾进行切片,则建议传入 INT_MAX' not works ?
# ndims = len(shape)
# for idx, value in enumerate(axes):
# if value > ONNX_INT_MAX // 2:
# axes[idx] = ndims + value - ONNX_INT_MAX - 1
# FIXME: Paddle 1.3 Doc: '对于未知大小维度的末尾进行切片,则建议传入 INT_MAX' not works ?
for idx, value in enumerate(starts):
if value > ONNX_INT_MAX // 2:
value = value - ONNX_INT_MAX - 1
......@@ -1429,29 +1530,29 @@ def Slice(
', axes={}'
', starts={}'
', ends={}'
')'
.format(var_output,
paddle_op,
var_data,
# attrs
axes,
starts,
ends,
))
')'.format(
var_output,
fluid_op,
var_data,
# attrs
axes,
starts,
ends,
))
prog.VarDesc(var_output)
prog.OpDesc(paddle_op,
([var_data], 'X'),
([var_output], 'Out'),
dict(axes=axes,
starts=starts,
ends=ends,
),
)
def Sum(
prog, inputs, outputs,
*args, **kwargs):
prog.OpDesc(
fluid_op,
([var_data], 'X'),
([var_output], 'Out'),
dict(
axes=axes,
starts=starts,
ends=ends,
),
)
def Sum(prog, inputs, outputs, *args, **kwargs):
"""
onnx::Sum-8:
"""
......@@ -1462,27 +1563,25 @@ def Sum(
var_sum = _make_var_name(val_sum)
# interpretation
paddle_op = 'sums'
fluid_op = 'sums'
# generation
prog.Code('{} = layers.{}({})'
.format(var_sum,
paddle_op,
'[' + ', '.join(var_inps) + ']',
# attrs
))
prog.Code('{} = layers.{}({})'.format(
var_sum,
fluid_op,
'[' + ', '.join(var_inps) + ']',
# attrs
))
prog.VarDesc(var_sum)
prog.OpDesc(paddle_op,
(var_inps, *(['X'] * len(var_inps))),
([var_sum], 'Out'),
dict(),
)
prog.OpDesc(
fluid_op,
(var_inps, *(['X'] * len(var_inps))),
([var_sum], 'Out'),
dict(),
)
def Tile(
prog, inputs, outputs, attrs, value_infos,
name='',
*args, **kwargs):
def Tile(prog, inputs, outputs, attrs, value_infos, name='', *args, **kwargs):
"""
onnx::ConstantOfShape-6:
"""
......@@ -1494,33 +1593,34 @@ def Tile(
var_output = _make_var_name(val_output)
# interpretation
paddle_op = 'expand'
fluid_op = 'expand'
is_const_repeats = 'const_value' in value_infos[val_repeats]
if is_const_repeats:
code_repeats = _make_var_name(val_repeats) # for code
repeats = value_infos[val_repeats]['const_value'] # for desc
code_repeats = _make_var_name(val_repeats) # for code
repeats = value_infos[val_repeats]['const_value'] # for desc
else:
repeats = value_infos[val_input]['get_weight']() # for desc
code_repeats = repeats # for code
repeats = value_infos[val_input]['get_weight']() # for desc
code_repeats = repeats # for code
name_attr = ', name={}'.format(repr(name)) if name else ''
# generation
prog.Code('{} = layers.{}({}'
', expand_times={}'
'{})'
.format(var_output,
paddle_op,
var_input,
# attrs
code_repeats,
name_attr,
))
'{})'.format(
var_output,
fluid_op,
var_input,
# attrs
code_repeats,
name_attr,
))
prog.VarDesc(var_output)
prog.OpDesc(paddle_op,
([var_input], 'X'),
([var_output], 'Out'),
dict(expand_times=repeats),
)
prog.OpDesc(
fluid_op,
([var_input], 'X'),
([var_output], 'Out'),
dict(expand_times=repeats),
)
#def Shape(
......@@ -1537,29 +1637,25 @@ def Tile(
# var_shape = _make_var_name(val_shape)
#
# # interpretation
# paddle_op = 'shape'
# fluid_op = 'shape'
## value_infos[val_shape]['remove_batch'] = False
#
# # generation
# prog.Code('{} = layers.{}({})'
# .format(var_shape,
# paddle_op,
# fluid_op,
# var_data,
# # attrs
# ))
# prog.VarDesc(var_shape) # , _value_info_or_none(value_infos, val_shape))
# prog.OpDesc(paddle_op,
# prog.OpDesc(fluid_op,
# ([var_data], 'X'),
# ([var_shape], 'Out'),
# dict(),
# )
def Split(
prog, inputs, outputs, attrs,
*args,
name='',
**kwargs):
def Split(prog, inputs, outputs, attrs, *args, name='', **kwargs):
"""
onnx::Split-2:
"""
......@@ -1570,290 +1666,376 @@ def Split(
var_input = _make_var_name(val_input)
# interpretation
paddle_op = 'split'
split = attrs['split'] # required
axis = attrs.get('axis', 0) # optional
fluid_op = 'split'
split = attrs['split'] # required
axis = attrs.get('axis', 0) # optional
name_attr = ', name={}'.format(repr(name)) if name else ''
# generation
prog.Code('{} = layers.{}({}, {}'
', dim={}'
'{})'
.format(', '.join(var_outs),
paddle_op,
var_input,
split,
# attrs
axis,
name_attr,
))
for val_out, var_out in zip(outputs, var_outs):
'{})'.format(
', '.join(var_outs),
fluid_op,
var_input,
split,
# attrs
axis,
name_attr,
))
for var_out in var_outs:
prog.VarDesc(var_out)
prog.OpDesc(paddle_op,
(var_input, 'X'),
([var_outs], *(['Out'] * len(var_outs))),
dict(axis=axis,
sections=split,
),
)
prog.OpDesc(
fluid_op,
(var_input, 'X'),
([var_outs], *(['Out'] * len(var_outs))),
dict(
axis=axis,
sections=split,
),
)
if __name__ == '__main__':
_logging.basicConfig(
format='[%(levelname)8s]%(name)s::%(funcName)s:%(lineno)04d: %(message)s',
level=_logging.DEBUG,
)
format=
'[%(levelname)8s]%(name)s::%(funcName)s:%(lineno)04d: %(message)s',
level=_logging.DEBUG,
)
logger = _logging.getLogger('symbolic_test')
from writer import Program
prog = Program()
AdaptiveAveragePool(prog, ['X'], ['Y'],
dict(output_size=[3, 3]),
dict(Y=dict(shape=(2, 3, 3, 3), dtype=np.float32)),
name='AdaptiveAveragePool2d',
)
AdaptiveAveragePool(
prog,
['X'],
['Y'],
dict(output_size=[3, 3]),
dict(Y=dict(shape=(2, 3, 3, 3), dtype=np.float32)),
name='AdaptiveAveragePool2d',
)
logger.info('AdaptiveAveragePool2d program:\n%s', prog)
prog = Program()
AdaptiveAveragePool(prog, ['X'], ['Y'],
dict(output_size=[3, 3, 3]),
dict(Y=dict(shape=(2, 3, 3, 3, 3), dtype=np.float32)),
name='AdaptiveAveragePool3d',
)
AdaptiveAveragePool(
prog,
['X'],
['Y'],
dict(output_size=[3, 3, 3]),
dict(Y=dict(shape=(2, 3, 3, 3, 3), dtype=np.float32)),
name='AdaptiveAveragePool3d',
)
logger.info('AdaptiveAveragePool3d program:\n%s', prog)
prog = Program()
AffineGrid(prog, ['Theta'], ['Grid'],
dict(size=[2, 2, 8, 8]),
dict(Grid=dict(shape=(2, 8, 8, 2), dtype=np.float32)),
)
AffineGrid(
prog,
['Theta'],
['Grid'],
dict(size=[2, 2, 8, 8]),
dict(Grid=dict(shape=(2, 8, 8, 2), dtype=np.float32)),
)
logger.info('AffineGrid program:\n%s', prog)
prog = Program()
BatchNormalization(prog, ['X', 'scale', 'B', 'mean', 'var'], ['Y'],
dict(epsilon=1e-5,
momentum=.9,
),
dict(scale=dict(shape=(3, ), dtype=np.float32),
B=dict(shape=(3, ), dtype=np.float32),
mean=dict(shape=(3, ), dtype=np.float32),
var=dict(shape=(3, ), dtype=np.float32),
Y=dict(shape=(2, 3), dtype=np.float32),
),
name='BatchNormalization',
embed_params=True,
)
BatchNormalization(
prog,
['X', 'scale', 'B', 'mean', 'var'],
['Y'],
dict(
epsilon=1e-5,
momentum=.9,
),
dict(
scale=dict(shape=(3, ), dtype=np.float32),
B=dict(shape=(3, ), dtype=np.float32),
mean=dict(shape=(3, ), dtype=np.float32),
var=dict(shape=(3, ), dtype=np.float32),
Y=dict(shape=(2, 3), dtype=np.float32),
),
name='BatchNormalization',
embed_params=True,
)
logger.info('BatchNormalization program:\n%s', prog)
prog = Program()
Cast(prog, ['input'], ['output'],
dict(to=2), # TensorProto.UINT8
dict(input=dict(shape=(2, 3), dtype=np.float32),
output=dict(shape=(2, 3), dtype=np.uint8)),
)
Cast(
prog,
['input'],
['output'],
dict(to=2), # TensorProto.UINT8
dict(
input=dict(shape=(2, 3), dtype=np.float32),
output=dict(shape=(2, 3), dtype=np.uint8)),
)
logger.info('Cast program:\n%s', prog)
prog = Program()
_default(prog, 'Clip', ['input'], ['output'],
dict(min=-1., max=1.),
dict(output=dict(shape=(2, 3), dtype=np.float32)),
)
_default(
prog,
'Clip',
['input'],
['output'],
dict(min=-1., max=1.),
dict(output=dict(shape=(2, 3), dtype=np.float32)),
)
logger.info('Clip program:\n%s', prog)
prog = Program()
Conv(prog, ['X', 'W'], ['Y'],
dict(auto_pad='NOTSET',
dilations=[1, 1],
group=1,
kernel_shape=[3, 3],
pads=[1, 1, 1, 1],
strides=[1, 1],
),
dict(W=dict(shape=(2, 3, 3, 3), dtype=np.float32),
Y=dict(shape=(2, 2, 4, 6), dtype=np.float32),
),
name='ConvNoBias2d',
embed_params=True,
)
Conv(
prog,
['X', 'W'],
['Y'],
dict(
auto_pad='NOTSET',
dilations=[1, 1],
group=1,
kernel_shape=[3, 3],
pads=[1, 1, 1, 1],
strides=[1, 1],
),
dict(
W=dict(shape=(2, 3, 3, 3), dtype=np.float32),
Y=dict(shape=(2, 2, 4, 6), dtype=np.float32),
),
name='ConvNoBias2d',
embed_params=True,
)
logger.info('ConvNoBias2d program:\n%s', prog)
prog = Program()
Conv(prog, ['X', 'W', 'B'], ['Y'],
dict(auto_pad='NOTSET',
dilations=[1, 1],
group=1,
kernel_shape=[3, 3],
pads=[1, 1, 1, 1],
strides=[1, 1],
),
dict(W=dict(shape=(2, 3, 3, 3), dtype=np.float32),
B=dict(shape=(2), dtype=np.float32),
Y=dict(shape=(2, 2, 4, 6), dtype=np.float32),
),
name='Conv2d',
embed_params=True,
)
Conv(
prog,
['X', 'W', 'B'],
['Y'],
dict(
auto_pad='NOTSET',
dilations=[1, 1],
group=1,
kernel_shape=[3, 3],
pads=[1, 1, 1, 1],
strides=[1, 1],
),
dict(
W=dict(shape=(2, 3, 3, 3), dtype=np.float32),
B=dict(shape=(2), dtype=np.float32),
Y=dict(shape=(2, 2, 4, 6), dtype=np.float32),
),
name='Conv2d',
embed_params=True,
)
logger.info('Conv2d program:\n%s', prog)
prog = Program()
ConvTranspose(prog, ['X', 'W', 'B'], ['Y'],
dict(auto_pad='NOTSET',
dilations=[1, 1],
group=1,
kernel_shape=[3, 3],
# output_padding=[1, 1, 1, 1],
# output_shape=[6, 8],
pads=[1, 1, 1, 1],
strides=[1, 1],
),
dict(W=dict(shape=(2, 3, 3, 3), dtype=np.float32),
B=dict(shape=(2), dtype=np.float32),
Y=dict(shape=(2, 2, 6, 8), dtype=np.float32),
),
name='ConvTransposed2d',
embed_params=True,
)
ConvTranspose(
prog,
['X', 'W', 'B'],
['Y'],
dict(
auto_pad='NOTSET',
dilations=[1, 1],
group=1,
kernel_shape=[3, 3],
# output_padding=[1, 1, 1, 1],
# output_shape=[6, 8],
pads=[1, 1, 1, 1],
strides=[1, 1],
),
dict(
W=dict(shape=(2, 3, 3, 3), dtype=np.float32),
B=dict(shape=(2), dtype=np.float32),
Y=dict(shape=(2, 2, 6, 8), dtype=np.float32),
),
name='ConvTransposed2d',
embed_params=True,
)
logger.info('ConvTransposed2d program:\n%s', prog)
prog = Program()
Conv(prog, ['X', 'W'], ['Y'],
dict(auto_pad='NOTSET',
dilations=[1, 1, 1],
group=1,
kernel_shape=[3, 3, 3],
pads=[1, 1, 1, 1, 1, 1],
strides=[1, 1, 1],
),
dict(W=dict(shape=(2, 3, 3, 3, 3), dtype=np.float32),
Y=dict(shape=(2, 2, 4, 6, 8), dtype=np.float32),
),
name='ConvNoBias3d',
embed_params=True,
)
Conv(
prog,
['X', 'W'],
['Y'],
dict(
auto_pad='NOTSET',
dilations=[1, 1, 1],
group=1,
kernel_shape=[3, 3, 3],
pads=[1, 1, 1, 1, 1, 1],
strides=[1, 1, 1],
),
dict(
W=dict(shape=(2, 3, 3, 3, 3), dtype=np.float32),
Y=dict(shape=(2, 2, 4, 6, 8), dtype=np.float32),
),
name='ConvNoBias3d',
embed_params=True,
)
logger.info('ConvNoBias3d program:\n%s', prog)
prog = Program()
Conv(prog, ['X', 'W', 'B'], ['Y'],
dict(auto_pad='NOTSET',
dilations=[1, 1, 1],
group=1,
kernel_shape=[3, 3, 3],
pads=[1, 1, 1, 1, 1, 1],
strides=[1, 1, 1],
),
dict(W=dict(shape=(2, 3, 3, 3, 3), dtype=np.float32),
B=dict(shape=(2), dtype=np.float32),
Y=dict(shape=(2, 2, 4, 6, 8), dtype=np.float32),
),
name='Conv3d',
embed_params=True,
)
Conv(
prog,
['X', 'W', 'B'],
['Y'],
dict(
auto_pad='NOTSET',
dilations=[1, 1, 1],
group=1,
kernel_shape=[3, 3, 3],
pads=[1, 1, 1, 1, 1, 1],
strides=[1, 1, 1],
),
dict(
W=dict(shape=(2, 3, 3, 3, 3), dtype=np.float32),
B=dict(shape=(2), dtype=np.float32),
Y=dict(shape=(2, 2, 4, 6, 8), dtype=np.float32),
),
name='Conv3d',
embed_params=True,
)
logger.info('Conv3d program:\n%s', prog)
prog = Program()
ConvTranspose(prog, ['X', 'W', 'B'], ['Y'],
dict(auto_pad='NOTSET',
dilations=[1, 1, 1],
group=1,
kernel_shape=[3, 3, 3],
# output_padding=[1, 1, 1, 1],
# output_shape=[6, 8],
pads=[1, 1, 1, 1, 1, 1],
strides=[1, 1, 1],
),
dict(W=dict(shape=(2, 3, 3, 3, 3), dtype=np.float32),
B=dict(shape=(2), dtype=np.float32),
Y=dict(shape=(2, 2, 6, 8, 9), dtype=np.float32),
),
name='ConvTransposed3d',
embed_params=True,
)
ConvTranspose(
prog,
['X', 'W', 'B'],
['Y'],
dict(
auto_pad='NOTSET',
dilations=[1, 1, 1],
group=1,
kernel_shape=[3, 3, 3],
# output_padding=[1, 1, 1, 1],
# output_shape=[6, 8],
pads=[1, 1, 1, 1, 1, 1],
strides=[1, 1, 1],
),
dict(
W=dict(shape=(2, 3, 3, 3, 3), dtype=np.float32),
B=dict(shape=(2), dtype=np.float32),
Y=dict(shape=(2, 2, 6, 8, 9), dtype=np.float32),
),
name='ConvTransposed3d',
embed_params=True,
)
logger.info('ConvTransposed3d program:\n%s', prog)
prog = Program()
_default(prog, 'Equal', ['A', 'B'], ['C'],
dict(),
dict(C=dict(shape=(2, 3), dtype=np.bool)),
)
_default(
prog,
'Equal',
['A', 'B'],
['C'],
dict(),
dict(C=dict(shape=(2, 3), dtype=np.bool)),
)
logger.info('Equal program:\n%s', prog)
prog = Program()
Gemm(prog, ['A', 'B', 'C'], ['Y'],
dict(alpha=1.,
beta=1.,
transA=0,
transB=1,
),
dict(B=dict(shape=(8, 3), dtype=np.float32),
Y=dict(shape=(2, 8), dtype=np.float32),
),
name='Gemm',
)
Gemm(
prog,
['A', 'B', 'C'],
['Y'],
dict(
alpha=1.,
beta=1.,
transA=0,
transB=1,
),
dict(
B=dict(shape=(8, 3), dtype=np.float32),
Y=dict(shape=(2, 8), dtype=np.float32),
),
name='Gemm',
)
logger.info('Gemm program:\n%s', prog)
prog = Program()
_default(prog, 'Less', ['A', 'B'], ['C'],
dict(),
dict(C=dict(shape=(2, 3), dtype=np.bool)),
)
_default(
prog,
'Less',
['A', 'B'],
['C'],
dict(),
dict(C=dict(shape=(2, 3), dtype=np.bool)),
)
logger.info('Less program:\n%s', prog)
prog = Program()
_default(prog, 'MatMul', ['A', 'B'], ['Y'],
dict(),
dict(Y=dict(shape=(2, 8), dtype=np.float32)),
name='MatMul'
)
_default(
prog,
'MatMul', ['A', 'B'], ['Y'],
dict(),
dict(Y=dict(shape=(2, 8), dtype=np.float32)),
name='MatMul')
logger.info('MatMul program:\n%s', prog)
prog = Program()
_default(prog, 'OneHot', ['indices', 'depth', 'values'], ['output'],
dict(axis=-1),
dict(output=dict(shape=(2, 8), dtype=np.float32)),
)
_default(
prog,
'OneHot',
['indices', 'depth', 'values'],
['output'],
dict(axis=-1),
dict(output=dict(shape=(2, 8), dtype=np.float32)),
)
logger.info('OneHot program:\n%s', prog)
prog = Program()
Pad(prog, ['data'], ['output'],
dict(mode='constant',
pads=[0, 1],
value=0.,
),
dict(data=dict(shape=(2, 7), dtype=np.float32),
output=dict(shape=(2, 8), dtype=np.float32),
),
name='Pad',
)
Pad(
prog,
['data'],
['output'],
dict(
mode='constant',
pads=[0, 1],
value=0.,
),
dict(
data=dict(shape=(2, 7), dtype=np.float32),
output=dict(shape=(2, 8), dtype=np.float32),
),
name='Pad',
)
logger.info('Pad program:\n%s', prog)
prog = Program()
Pad(prog, ['data'], ['output'],
dict(mode='reflect',
pads=[0, 1, 2, 3],
value=0.,
),
dict(data=dict(shape=(2, 3, 3, 3), dtype=np.float32),
output=dict(shape=(2, 3, 5, 7), dtype=np.float32),
),
name='Pad2d',
)
Pad(
prog,
['data'],
['output'],
dict(
mode='reflect',
pads=[0, 1, 2, 3],
value=0.,
),
dict(
data=dict(shape=(2, 3, 3, 3), dtype=np.float32),
output=dict(shape=(2, 3, 5, 7), dtype=np.float32),
),
name='Pad2d',
)
logger.info('Pad2d program:\n%s', prog)
prog = Program()
PRelu(prog, ['X', 'slope'], ['Y'],
dict(),
dict(Y=dict(shape=(2, 3), dtype=np.float32)),
name='PRelu',
)
PRelu(
prog,
['X', 'slope'],
['Y'],
dict(),
dict(Y=dict(shape=(2, 3), dtype=np.float32)),
name='PRelu',
)
logger.info('PRelu program:\n%s', prog)
prog = Program()
Tile(prog, ['input', 'repeats'], ['output'],
dict(),
dict(repeats=dict(const_value=[1, 2]),
output=dict(shape=(2, 2, 4), dtype=np.float32)
),
name='Tile'
)
Tile(
prog, ['input', 'repeats'], ['output'],
dict(),
dict(
repeats=dict(const_value=[1, 2]),
output=dict(shape=(2, 2, 4), dtype=np.float32)),
name='Tile')
logger.info('Tile program:\n%s', prog)
......@@ -24,8 +24,7 @@ def _ensure_tuple(obj):
return (obj, )
def _flatten_list(obj,
out=None):
def _flatten_list(obj, out=None):
assert isinstance(obj, list)
if out is None:
out = type(obj)()
......@@ -37,8 +36,7 @@ def _flatten_list(obj,
return out
def export_data(state_dict,
prefix=''):
def export_data(state_dict, prefix=''):
"""
export binary data with meta text for raw C++ inference engines
"""
......@@ -65,10 +63,14 @@ def export_data(state_dict,
fp.close()
def export_onnx_with_validation(model, inputs, export_basepath,
input_names=None, output_names=None,
def export_onnx_with_validation(model,
inputs,
export_basepath,
input_names=None,
output_names=None,
use_npz=True,
*args, **kwargs):
*args,
**kwargs):
"""
export PyTorch model to ONNX model and export sample inputs and outputs in a Numpy file
"""
......@@ -95,12 +97,16 @@ def export_onnx_with_validation(model, inputs, export_basepath,
ret[key] = value
return ret
torch_inputs = _ensure_tuple(inputs) # WORKAROUND: for torch.onnx
outputs = torch.onnx.export(model, torch_inputs, export_basepath + '.onnx',
input_names=_flatten_list(input_names),
output_names=_flatten_list(output_names),
*args, **kwargs)
if outputs is None: # WORKAROUND: for torch.onnx
torch_inputs = _ensure_tuple(inputs) # WORKAROUND: for torch.onnx
outputs = torch.onnx.export(
model,
torch_inputs,
export_basepath + '.onnx',
input_names=_flatten_list(input_names),
output_names=_flatten_list(output_names),
*args,
**kwargs)
if outputs is None: # WORKAROUND: for torch.onnx
outputs = model(*inputs)
torch_outputs = _ensure_tuple(outputs)
......
......@@ -13,8 +13,7 @@ import os
import sys
def _flatten_dict(obj,
out=None):
def _flatten_dict(obj, out=None):
assert isinstance(obj, dict)
if out is None:
out = type(obj)()
......@@ -34,12 +33,13 @@ def _ensure_list(obj):
return [obj]
def validate(paddle_model_filename, golden_data_filename,
def validate(fluid_model_filename,
golden_data_filename,
model_func_name='inference',
precision=1e-4,
save_inference_model=False):
"""
inferece the converted Paddle model, validate with given golden data
inferece the converted Paddle fluid model, validate with given golden data
"""
import numpy as np
......@@ -52,17 +52,17 @@ def validate(paddle_model_filename, golden_data_filename,
exe.run(fluid.default_startup_program())
# load model
paddle_model_dir, basename = os.path.split(paddle_model_filename)
if basename == '__model__': # is desc model
fluid_model_dir, basename = os.path.split(fluid_model_filename)
if basename == '__model__': # is desc model
logger.debug('using desc file %s', basename)
prog, in_names, var_outs = fluid.io.load_inference_model(paddle_model_dir, exe)
out_names = var_outs # HINT: pass var if fetch ops already created
prog, _, var_outs = fluid.io.load_inference_model(fluid_model_dir, exe)
out_names = var_outs # HINT: pass var if fetch ops already created
logger.info('model load passed')
elif basename.endswith('.py'): # is python code
elif basename.endswith('.py'): # is python code
logger.debug('using python code file %s', basename)
module_name, _ = os.path.splitext(basename)
sys_path = sys.path.copy()
sys.path.append(paddle_model_dir)
sys.path.append(fluid_model_dir)
try:
module = importlib.import_module(module_name)
func = getattr(module, model_func_name)
......@@ -71,18 +71,21 @@ def validate(paddle_model_filename, golden_data_filename,
module = importlib.import_module(module_name)
func = getattr(module, model_func_name)
sys.path = sys_path
logger.debug('from %s imported %s: %s', module_name, model_func_name, func)
logger.debug('from %s imported %s: %s', module_name, model_func_name,
func)
var_outs = func()
var_outs = _ensure_list(var_outs)
out_names = [var.name for var in var_outs] # HINT: pass string to create fetch ops
out_names = [var.name for var in var_outs
] # HINT: pass string to create fetch ops
logger.info('import passed')
prog = fluid.default_main_program()
fluid.io.load_persistables(executor=exe, dirname=paddle_model_dir, main_program=prog)
fluid.io.load_persistables(
executor=exe, dirname=fluid_model_dir, main_program=prog)
logger.info('weight load passed')
else:
raise ValueError('unsupported Paddle model')
raise ValueError('unsupported Paddle fluid model')
# load data
logger.info('using golden data %s', golden_data_filename)
......@@ -100,10 +103,15 @@ def validate(paddle_model_filename, golden_data_filename,
# DEBUG: reload test for python code
if basename.endswith('.py') and save_inference_model:
fluid.io.save_inference_model(paddle_model_dir, input_data.keys(), var_outs, exe,
main_program=prog, export_for_deployment=True)
fluid.io.save_inference_model(
fluid_model_dir,
input_data.keys(),
var_outs,
exe,
main_program=prog,
export_for_deployment=True)
logger.info('model re-save passed')
fluid.io.load_inference_model(paddle_model_dir, exe)
fluid.io.load_inference_model(fluid_model_dir, exe)
logger.info('model re-load passed')
# execute
......@@ -124,49 +132,54 @@ def validate(paddle_model_filename, golden_data_filename,
else:
logger.info('accuracy not passed')
# globals().update(locals())
return passed
if __name__ == '__main__':
logging.basicConfig(
format='[%(levelname)8s]%(name)s::%(funcName)s:%(lineno)04d: %(message)s',
level=logging.DEBUG,
)
logger = logging.getLogger('validation_test')
model_rc_list = [
'../examples/t{}/model.py',
'../examples/t{}/__model__',
'../examples/t{}.embeded/model.py',
'../examples/t{}.embeded/__model__',
]
import numpy as np
idx_model = np.random.randint(1, 7)
model = np.random.choice(model_rc_list).format(idx_model)
precision = 10 ** (np.random.rand() * -4 - 2)
debug = False
model = '/tmp/export/model.py'
# model = '../examples/t1/__model__'
# model = '../examples/t1.embeded/model.py'
# model = '../examples/t1.embeded/__model__'
debug = True
logger.info('args: %s %.6f', model, precision)
data_dir, dir_name = os.path.split(os.path.split(model)[0])
data_pathname = os.path.splitext(dir_name)[0]
# proto debug test
from framework_pb2 import ProgramDesc
pd = ProgramDesc()
pd.ParseFromString(open(os.path.join(data_dir, dir_name, '__model__'), 'rb').read())
# validate
# validate(model, os.path.join(data_dir, data_pathname + '.npz'),
# precision=precision, save_inference_model=debug)
validate(model, '../examples/bvlc_alexnet/test_data_0.npz',
precision=precision, save_inference_model=debug)
import argparse
parser = argparse.ArgumentParser(
description='onnx2fluid.validate',
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
'model',
nargs=1,
help='path to model.py or __model__',
)
parser.add_argument(
'--debug',
'-d',
action='store_true',
help='enable debug logging and checking',
)
parser.add_argument(
'--test_data',
'-t',
type=str,
help='I/O golden data for validation, e.g. test.npy, test.npz',
)
parser.add_argument(
'--precision',
'-p',
type=int,
default=4,
help='assertion decimal for validation',
)
args = parser.parse_args()
logging_format = '[%(levelname)8s]%(name)s::%(funcName)s:%(lineno)04d: %(message)s'
logging_level = logging.DEBUG if args.debug else logging.INFO
logging.basicConfig(format=logging_format, level=logging_level)
debug = args.debug
fluid_model_filename = args.model[0]
golden_data_filename = args.test_data
precision = args.precision
validate(
fluid_model_filename,
golden_data_filename,
precision=precision,
save_inference_model=debug)
......@@ -34,15 +34,13 @@ except ImportError:
logger.warning('importing paddle.fluid.proto.framework_pb2d failed,'
'using fallback framework_pb2')
__all__ = [
'Program',
'Writer',
'Program',
'Writer',
]
def _irepr(obj,
to='_'):
def _irepr(obj, to='_'):
"""inline repr"""
s = repr(obj)
......@@ -53,8 +51,7 @@ def _irepr(obj,
return s
def _flatten_list(obj,
out=None):
def _flatten_list(obj, out=None):
if out is None:
out = type(obj)()
for item in obj:
......@@ -72,7 +69,7 @@ def make_attr_name(name):
if name == '':
raise ValueError('name should not be empty')
for s in ' *?\/-:': #
for s in ' *?\\/-:': #
name = name.replace(s, '_')
if not name.startswith('_'):
name = '_' + name
......@@ -85,15 +82,15 @@ class Program(object):
"""
DTYPE_TO_FRAMEWORK_DTYPE = {
'bool': framework_pb2.VarType.BOOL,
'int8': framework_pb2.VarType.INT8,
'uint8': framework_pb2.VarType.UINT8,
'int16': framework_pb2.VarType.INT16,
'int32': framework_pb2.VarType.INT32,
'int64': framework_pb2.VarType.INT64,
'float16': framework_pb2.VarType.FP16,
'float32': framework_pb2.VarType.FP32,
'float64': framework_pb2.VarType.FP64
'bool': framework_pb2.VarType.BOOL,
'int8': framework_pb2.VarType.INT8,
'uint8': framework_pb2.VarType.UINT8,
'int16': framework_pb2.VarType.INT16,
'int32': framework_pb2.VarType.INT32,
'int64': framework_pb2.VarType.INT64,
'float16': framework_pb2.VarType.FP16,
'float32': framework_pb2.VarType.FP32,
'float64': framework_pb2.VarType.FP64
}
@staticmethod
......@@ -116,7 +113,7 @@ class Program(object):
od_var = framework_pb2.OpDesc.Var()
od_var.parameter = key
if idx < len(vals):
od_var.arguments.append(vals[idx]) #
od_var.arguments.append(vals[idx]) #
od_vars.append(od_var)
return od_vars
......@@ -130,10 +127,10 @@ class Program(object):
for key, value in attrs.items():
od_attr = framework_pb2.OpDesc.Attr()
od_attr.name = key
if isinstance(value, bool): # bool.mro() = [bool, int, object]
if isinstance(value, bool): # bool.mro() = [bool, int, object]
od_attr.type = framework_pb2.BOOLEAN
od_attr.b = value
elif isinstance(value, int): # only cast to int32
elif isinstance(value, int): # only cast to int32
od_attr.type = framework_pb2.INT
od_attr.i = value
elif isinstance(value, float):
......@@ -143,10 +140,10 @@ class Program(object):
od_attr.type = framework_pb2.STRING
od_attr.s = value
elif isinstance(value, list) and len(value) > 0:
if isinstance(value, bool): # bool.mro() = [bool, int, object]
if isinstance(value, bool): # bool.mro() = [bool, int, object]
od_attr.type = framework_pb2.BOOLEANS
od_attr.bools.extend(value)
elif isinstance(value[0], int): # only cast to int32 list
elif isinstance(value[0], int): # only cast to int32 list
od_attr.type = framework_pb2.INTS
od_attr.ints.extend(value)
elif isinstance(value[0], float):
......@@ -168,11 +165,8 @@ class Program(object):
return ('Program(code mutable: {}) with:\n'
'codes: {}\n'
'op_descs: {}\n'
'var_descs: {}\n').format(
self.code_mutable,
self.codes,
self.op_descs,
self.var_descs)
'var_descs: {}\n').format(self.code_mutable, self.codes,
self.op_descs, self.var_descs)
def __repr__(self):
return self.__str__()
......@@ -185,8 +179,11 @@ class Program(object):
if self.code_mutable:
self.codes.append(code)
def OpDesc(self, name,
input_val_keys=None, output_val_keys=None, attrs=None):
def OpDesc(self,
name,
input_val_keys=None,
output_val_keys=None,
attrs=None):
"""
add OpDesc
"""
......@@ -202,10 +199,15 @@ class Program(object):
self.op_descs.append(desc)
return desc
def VarDesc(self, name,
persistable=False, value_info=None, remove_batch=None):
def VarDesc(self,
name,
persistable=False,
value_info=None,
remove_batch=None,
dummy_dtype='float32'):
"""
add VarDesc
add VarDesc,
dummy_dtype: WORKAROUND for Netron viewer
"""
var_desc = framework_pb2.VarDesc()
......@@ -213,14 +215,19 @@ class Program(object):
var_desc.persistable = persistable
var_desc.type.type = framework_pb2.VarType.LOD_TENSOR
# REMOVEIT: WORKAROUND: Netron: null.tensor error
tensor_desc = var_desc.type.lod_tensor.tensor
tensor_desc.data_type = self.Dtype(dummy_dtype) # required
if value_info and 'dtype' in value_info:
tensor_desc = var_desc.type.lod_tensor.tensor
tensor_desc.data_type = self.Dtype(value_info['dtype']) # required
tensor_desc.data_type = self.Dtype(value_info['dtype']) # required
if 'shape' in value_info:
tensor_desc.dims.extend(value_info['shape'])
if len(value_info['shape']) > 0: # skip scalars
if len(value_info['shape']) > 0: # skip scalars
if remove_batch is None:
remove_batch = value_info.get('remove_batch', not persistable)
remove_batch = value_info.get('remove_batch',
not persistable)
if remove_batch:
tensor_desc.dims[0] = -1
......@@ -231,7 +238,7 @@ class Program(object):
convert an ONNX op and add it to program
"""
if domain != '': # TODO: symbolic file routing by domain
if domain != '': # TODO: symbolic file routing by domain
raise ValueError('only default domain supported')
if op_type in symbolic.DEFAULT_OP_MAPPING:
......@@ -240,8 +247,8 @@ class Program(object):
fn = getattr(symbolic, op_type)
fn(self, *args, **kwargs)
else:
raise ValueError('conversion for {}::{} not supported'
.format(domain, op_type))
raise ValueError('conversion for {}::{} not supported'.format(
domain, op_type))
def IntermediateOp(self, domain, op_type, *args, **kwargs):
"""
......@@ -267,14 +274,15 @@ class Writer(object):
CODE_INDENT = ' ' * 4
@staticmethod
def header_code(func_name):
def header_code(func_name, info=''):
"""
Python header codes
"""
codes = list()
codes.append('"""')
codes.append('This code is generated by onnx2paddle.')
codes.append('This code is generated by onnx2fluid.')
codes.append('{}'.format(info))
codes.append('"""')
codes.append('')
codes.append('from __future__ import division')
......@@ -287,16 +295,25 @@ class Writer(object):
return codes
@staticmethod
def emit_op(prog, name, domain, op_type, inputs, outputs, attrs, value_infos, *args, **kwargs):
def emit_op(prog, name, domain, op_type, inputs, outputs, attrs,
value_infos, *args, **kwargs):
"""
emit an ONNX op into program
"""
prog.Code('# {}, {}::{}: {} -> {}, {}'
.format(name, domain, op_type, inputs, outputs, _irepr(attrs, to=', ')))
prog.Op(domain, op_type, inputs, outputs, attrs,
value_infos=value_infos, name=name,
*args, **kwargs)
prog.Code('# {}, {}::{}: {} -> {}, {}'.format(name, domain, op_type,
inputs, outputs,
_irepr(attrs, to=', ')))
prog.Op(
domain,
op_type,
inputs,
outputs,
attrs,
value_infos=value_infos,
name=name,
*args,
**kwargs)
@staticmethod
def emit_param(prog, name, value_info):
......@@ -313,18 +330,18 @@ class Writer(object):
var_name = make_var_name(name)
attr_name = make_attr_name(name)
prog.Code('# parameter: {}'.format(name))
prog.Code('{} = ParamAttr(name={})' # , trainable=True
prog.Code('{} = ParamAttr(name={})' # , trainable=True
.format(attr_name, repr(var_name)))
prog.Code('{} = layers.create_parameter(shape={}, dtype={}, name={}, attr={}'
', default_initializer=initializer.Constant(0))' #, is_bias={}
.format(var_name,
value_info['shape'], repr(value_info['dtype'].name),
repr(name), attr_name)) #, value_info.get('is_bias', False)))
prog.Code(
'{} = layers.create_parameter(shape={}, dtype={}, name={}, attr={}'
', default_initializer=initializer.Constant(0))' #, is_bias={}
.format(var_name, value_info['shape'],
repr(value_info['dtype'].name), repr(name),
attr_name)) #, value_info.get('is_bias', False)))
prog.VarDesc(var_name, persistable=True, value_info=value_info)
@staticmethod
def emit_inputs(prog, names, value_infos,
remove_batch=None):
def emit_inputs(prog, names, value_infos, remove_batch=None):
"""
emit ONNX inputs into program
"""
......@@ -334,27 +351,33 @@ class Writer(object):
value_info = value_infos[name]
shape = value_info['shape']
if remove_batch is None:
remove_batch = value_info.get('remove_batch', True) # HINT: True by default ?
remove_batch = value_info.get('remove_batch',
True) # HINT: True by default ?
if remove_batch:
shape = shape[1:]
prog.Code('# input: {}'.format(name))
prog.Code(('{} = layers.data(name={}, shape={}, dtype={}, '
'append_batch_size={})' # , stop_gradient=True
).format(var_name, repr(name),
shape,
repr(value_info['dtype'].name),
remove_batch,
))
prog.OpDesc('feed',
(['feed'], 'X'),
([var_name], 'Out'),
dict(col=idx),
)
prog.VarDesc(var_name, value_info=value_info, remove_batch=remove_batch)
prog.Code((
'{} = layers.data(name={}, shape={}, dtype={}, '
'append_batch_size={})' # , stop_gradient=True
).format(
var_name,
repr(name),
shape,
repr(value_info['dtype'].name),
remove_batch,
))
prog.OpDesc(
'feed',
(['feed'], 'X'),
([var_name], 'Out'),
dict(col=idx),
)
prog.VarDesc(
var_name, value_info=value_info, remove_batch=remove_batch)
@staticmethod
def emit_outputs(prog, names): #, value_infos
def emit_outputs(prog, names): #, value_infos
"""
emit ONNX outputs into program
"""
......@@ -364,11 +387,12 @@ class Writer(object):
var_name = make_var_name(name)
code += var_name + ', '
prog.OpDesc('fetch',
([var_name], 'X'),
(['fetch'], 'Out'),
dict(col=idx),
)
prog.OpDesc(
'fetch',
([var_name], 'X'),
(['fetch'], 'Out'),
dict(col=idx),
)
# var is emitted over ops
prog.Code(code)
......@@ -396,9 +420,9 @@ class Writer(object):
tensor_desc.dims.extend(weight.shape)
fp = open(filename, 'wb')
np.array([0], dtype=np.int32).tofile(fp) # version
np.array([0], dtype=np.int64).tofile(fp) # LOD level
np.array([0], dtype=np.int32).tofile(fp) # tensor version
np.array([0], dtype=np.int32).tofile(fp) # version
np.array([0], dtype=np.int64).tofile(fp) # LOD level
np.array([0], dtype=np.int32).tofile(fp) # tensor version
np.array([tensor_desc.ByteSize()], dtype=np.int32).tofile(fp)
fp.write(tensor_desc.SerializeToString())
weight.tofile(fp)
......@@ -463,4 +487,4 @@ class Writer(object):
fp = open(filename, 'wb')
fp.write(prog_desc.SerializeToString())
fp.close()
logger.debug('saved descs to %s', filename)
\ No newline at end of file
logger.debug('saved descs to %s', filename)
-e .
onnx>=1.4.0
paddlepaddle
\ No newline at end of file
paddlepaddle
......@@ -2,14 +2,14 @@
# https://setuptools.readthedocs.io/en/latest/setuptools.html#configuring-setup-using-setup-cfg-files
[metadata]
# 项目名称,发布、安装时以此作为包名
name = onnx2paddle
name = onnx2fluid
# 作者姓名和邮箱地址
author = Macrobull
# author_email = .Github@github.com
# 项目版本号,1.0以上版本才视为正式版
version = 0.1.0
# 项目概要描述信息,一句话让用户明白项目概要,不支持中文
description = Inference model conversion from ONNX/PyTorch to Paddle
description = Inference model conversion from ONNX/PyTorch to Paddle fluid
# 项目的详细描述内容和格式,包括readme和changelog等,通常使用md或rst等格式
long_description = file: README.md, CHANGELOG.md
long_description_content_type = text/markdown
......@@ -25,7 +25,7 @@ classifier =
Programming Language :: Python :: 3.5
# 关键字,用于检索,方便用户搜索到你的项目
keywords =
onnx paddle
onnx paddlepaddle
[options]
# 包名称,find:表示自动寻找,可在options.packages.find中进行详细配置
......@@ -44,21 +44,21 @@ install_requires =
# mock
# 单测代码目录
#test_suite = onnx2paddle.tests
#test_suite = onnx2fluid.tests
# 自动添加被版本控制的数据文件
include_package_data = True
# 项目是纯py项目,可以直接执行zip源码包
zip_safe = False
# 可以通过以下配置将指定的函数变成命令行工具,允许用户直接执行
#[options.entry_points]
#console_scripts =
# onnx2paddle = onnx2paddle.cmdline:main
[options.entry_points]
console_scripts =
onnx2fluid = onnx2fluid.cmdline:main
# 可以通过以下配置向包中添加conf或data等非py文件,安装时会一同安装到site-packages目录下
# 仅支持文件,不支持目录,但可以使用通配
#[options.package_data]
#onnx2paddle =
#onnx2fluid =
# conf/*
# data/*
......
......@@ -15,4 +15,3 @@ Date: 2019/02/22 10:25:46
import setuptools
setuptools.setup()
# Generated by the protocol buffer compiler. DO NOT EDIT!
# source: framework.proto
import sys
_b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1'))
from google.protobuf.internal import enum_type_wrapper
from google.protobuf import descriptor as _descriptor
from google.protobuf import message as _message
from google.protobuf import reflection as _reflection
from google.protobuf import symbol_database as _symbol_database
from google.protobuf import descriptor_pb2
# @@protoc_insertion_point(imports)
_sym_db = _symbol_database.Default()
DESCRIPTOR = _descriptor.FileDescriptor(
name='framework.proto',
package='paddle.framework.proto',
syntax='proto2',
serialized_pb=_b('\n\x0f\x66ramework.proto\x12\x16paddle.framework.proto\"\x1d\n\x07Version\x12\x12\n\x07version\x18\x01 \x01(\x03:\x01\x30\"\xec\x03\n\x06OpDesc\x12\x0c\n\x04type\x18\x03 \x02(\t\x12\x32\n\x06inputs\x18\x01 \x03(\x0b\x32\".paddle.framework.proto.OpDesc.Var\x12\x33\n\x07outputs\x18\x02 \x03(\x0b\x32\".paddle.framework.proto.OpDesc.Var\x12\x32\n\x05\x61ttrs\x18\x04 \x03(\x0b\x32#.paddle.framework.proto.OpDesc.Attr\x12\x18\n\tis_target\x18\x05 \x01(\x08:\x05\x66\x61lse\x1a\xef\x01\n\x04\x41ttr\x12\x0c\n\x04name\x18\x01 \x02(\t\x12.\n\x04type\x18\x02 \x02(\x0e\x32 .paddle.framework.proto.AttrType\x12\t\n\x01i\x18\x03 \x01(\x05\x12\t\n\x01\x66\x18\x04 \x01(\x02\x12\t\n\x01s\x18\x05 \x01(\t\x12\x0c\n\x04ints\x18\x06 \x03(\x05\x12\x0e\n\x06\x66loats\x18\x07 \x03(\x02\x12\x0f\n\x07strings\x18\x08 \x03(\t\x12\t\n\x01\x62\x18\n \x01(\x08\x12\r\n\x05\x62ools\x18\x0b \x03(\x08\x12\x11\n\tblock_idx\x18\x0c \x01(\x05\x12\t\n\x01l\x18\r \x01(\x03\x12\x12\n\nblocks_idx\x18\x0e \x03(\x05\x12\r\n\x05longs\x18\x0f \x03(\x03\x1a+\n\x03Var\x12\x11\n\tparameter\x18\x01 \x02(\t\x12\x11\n\targuments\x18\x02 \x03(\t\"\xb3\x03\n\x07OpProto\x12\x0c\n\x04type\x18\x01 \x02(\t\x12\x33\n\x06inputs\x18\x02 \x03(\x0b\x32#.paddle.framework.proto.OpProto.Var\x12\x34\n\x07outputs\x18\x03 \x03(\x0b\x32#.paddle.framework.proto.OpProto.Var\x12\x33\n\x05\x61ttrs\x18\x04 \x03(\x0b\x32$.paddle.framework.proto.OpProto.Attr\x12\x0f\n\x07\x63omment\x18\x05 \x02(\t\x1ax\n\x03Var\x12\x0c\n\x04name\x18\x01 \x02(\t\x12\x0f\n\x07\x63omment\x18\x02 \x02(\t\x12\x19\n\nduplicable\x18\x03 \x01(\x08:\x05\x66\x61lse\x12\x1b\n\x0cintermediate\x18\x04 \x01(\x08:\x05\x66\x61lse\x12\x1a\n\x0b\x64ispensable\x18\x05 \x01(\x08:\x05\x66\x61lse\x1ao\n\x04\x41ttr\x12\x0c\n\x04name\x18\x01 \x02(\t\x12.\n\x04type\x18\x02 \x02(\x0e\x32 .paddle.framework.proto.AttrType\x12\x0f\n\x07\x63omment\x18\x03 \x02(\t\x12\x18\n\tgenerated\x18\x04 \x01(\x08:\x05\x66\x61lse\"\xda\x08\n\x07VarType\x12\x32\n\x04type\x18\x01 \x02(\x0e\x32$.paddle.framework.proto.VarType.Type\x12\x41\n\rselected_rows\x18\x02 \x01(\x0b\x32*.paddle.framework.proto.VarType.TensorDesc\x12\x41\n\nlod_tensor\x18\x03 \x01(\x0b\x32-.paddle.framework.proto.VarType.LoDTensorDesc\x12H\n\x0ctensor_array\x18\x04 \x01(\x0b\x32\x32.paddle.framework.proto.VarType.LoDTensorArrayDesc\x12:\n\x06reader\x18\x05 \x01(\x0b\x32*.paddle.framework.proto.VarType.ReaderDesc\x12\x34\n\x05tuple\x18\x07 \x01(\x0b\x32%.paddle.framework.proto.VarType.Tuple\x1aS\n\nTensorDesc\x12\x37\n\tdata_type\x18\x01 \x02(\x0e\x32$.paddle.framework.proto.VarType.Type\x12\x0c\n\x04\x64ims\x18\x02 \x03(\x03\x1a\x61\n\rLoDTensorDesc\x12:\n\x06tensor\x18\x01 \x02(\x0b\x32*.paddle.framework.proto.VarType.TensorDesc\x12\x14\n\tlod_level\x18\x02 \x01(\x05:\x01\x30\x1a\x66\n\x12LoDTensorArrayDesc\x12:\n\x06tensor\x18\x01 \x02(\x0b\x32*.paddle.framework.proto.VarType.TensorDesc\x12\x14\n\tlod_level\x18\x02 \x01(\x05:\x01\x30\x1aO\n\nReaderDesc\x12\x41\n\nlod_tensor\x18\x01 \x03(\x0b\x32-.paddle.framework.proto.VarType.LoDTensorDesc\x1a\x43\n\x05Tuple\x12:\n\x0c\x65lement_type\x18\x01 \x03(\x0e\x32$.paddle.framework.proto.VarType.Type\"\xa2\x02\n\x04Type\x12\x08\n\x04\x42OOL\x10\x00\x12\t\n\x05INT16\x10\x01\x12\t\n\x05INT32\x10\x02\x12\t\n\x05INT64\x10\x03\x12\x08\n\x04\x46P16\x10\x04\x12\x08\n\x04\x46P32\x10\x05\x12\x08\n\x04\x46P64\x10\x06\x12\n\n\x06SIZE_T\x10\x13\x12\t\n\x05UINT8\x10\x14\x12\x08\n\x04INT8\x10\x15\x12\x0e\n\nLOD_TENSOR\x10\x07\x12\x11\n\rSELECTED_ROWS\x10\x08\x12\x12\n\x0e\x46\x45\x45\x44_MINIBATCH\x10\t\x12\x0e\n\nFETCH_LIST\x10\n\x12\x0f\n\x0bSTEP_SCOPES\x10\x0b\x12\x12\n\x0eLOD_RANK_TABLE\x10\x0c\x12\x14\n\x10LOD_TENSOR_ARRAY\x10\r\x12\x0e\n\nPLACE_LIST\x10\x0e\x12\n\n\x06READER\x10\x0f\x12\x07\n\x03RAW\x10\x11\x12\t\n\x05TUPLE\x10\x12\"b\n\x07VarDesc\x12\x0c\n\x04name\x18\x01 \x02(\t\x12-\n\x04type\x18\x02 \x02(\x0b\x32\x1f.paddle.framework.proto.VarType\x12\x1a\n\x0bpersistable\x18\x03 \x01(\x08:\x05\x66\x61lse\"\xa7\x01\n\tBlockDesc\x12\x0b\n\x03idx\x18\x01 \x02(\x05\x12\x12\n\nparent_idx\x18\x02 \x02(\x05\x12-\n\x04vars\x18\x03 \x03(\x0b\x32\x1f.paddle.framework.proto.VarDesc\x12+\n\x03ops\x18\x04 \x03(\x0b\x32\x1e.paddle.framework.proto.OpDesc\x12\x1d\n\x11\x66orward_block_idx\x18\x05 \x01(\x05:\x02-1\"r\n\x0bProgramDesc\x12\x31\n\x06\x62locks\x18\x01 \x03(\x0b\x32!.paddle.framework.proto.BlockDesc\x12\x30\n\x07version\x18\x02 \x01(\x0b\x32\x1f.paddle.framework.proto.Version*\x94\x01\n\x08\x41ttrType\x12\x07\n\x03INT\x10\x00\x12\t\n\x05\x46LOAT\x10\x01\x12\n\n\x06STRING\x10\x02\x12\x08\n\x04INTS\x10\x03\x12\n\n\x06\x46LOATS\x10\x04\x12\x0b\n\x07STRINGS\x10\x05\x12\x0b\n\x07\x42OOLEAN\x10\x06\x12\x0c\n\x08\x42OOLEANS\x10\x07\x12\t\n\x05\x42LOCK\x10\x08\x12\x08\n\x04LONG\x10\t\x12\n\n\x06\x42LOCKS\x10\n\x12\t\n\x05LONGS\x10\x0b\x42\x02H\x03')
)
_sym_db.RegisterFileDescriptor(DESCRIPTOR)
_ATTRTYPE = _descriptor.EnumDescriptor(
name='AttrType',
full_name='paddle.framework.proto.AttrType',
filename=None,
file=DESCRIPTOR,
values=[
_descriptor.EnumValueDescriptor(
name='INT', index=0, number=0,
options=None,
type=None),
_descriptor.EnumValueDescriptor(
name='FLOAT', index=1, number=1,
options=None,
type=None),
_descriptor.EnumValueDescriptor(
name='STRING', index=2, number=2,
options=None,
type=None),
_descriptor.EnumValueDescriptor(
name='INTS', index=3, number=3,
options=None,
type=None),
_descriptor.EnumValueDescriptor(
name='FLOATS', index=4, number=4,
options=None,
type=None),
_descriptor.EnumValueDescriptor(
name='STRINGS', index=5, number=5,
options=None,
type=None),
_descriptor.EnumValueDescriptor(
name='BOOLEAN', index=6, number=6,
options=None,
type=None),
_descriptor.EnumValueDescriptor(
name='BOOLEANS', index=7, number=7,
options=None,
type=None),
_descriptor.EnumValueDescriptor(
name='BLOCK', index=8, number=8,
options=None,
type=None),
_descriptor.EnumValueDescriptor(
name='LONG', index=9, number=9,
options=None,
type=None),
_descriptor.EnumValueDescriptor(
name='BLOCKS', index=10, number=10,
options=None,
type=None),
_descriptor.EnumValueDescriptor(
name='LONGS', index=11, number=11,
options=None,
type=None),
],
containing_type=None,
options=None,
serialized_start=2511,
serialized_end=2659,
)
_sym_db.RegisterEnumDescriptor(_ATTRTYPE)
AttrType = enum_type_wrapper.EnumTypeWrapper(_ATTRTYPE)
INT = 0
FLOAT = 1
STRING = 2
INTS = 3
FLOATS = 4
STRINGS = 5
BOOLEAN = 6
BOOLEANS = 7
BLOCK = 8
LONG = 9
BLOCKS = 10
LONGS = 11
_VARTYPE_TYPE = _descriptor.EnumDescriptor(
name='Type',
full_name='paddle.framework.proto.VarType.Type',
filename=None,
file=DESCRIPTOR,
values=[
_descriptor.EnumValueDescriptor(
name='BOOL', index=0, number=0,
options=None,
type=None),
_descriptor.EnumValueDescriptor(
name='INT16', index=1, number=1,
options=None,
type=None),
_descriptor.EnumValueDescriptor(
name='INT32', index=2, number=2,
options=None,
type=None),
_descriptor.EnumValueDescriptor(
name='INT64', index=3, number=3,
options=None,
type=None),
_descriptor.EnumValueDescriptor(
name='FP16', index=4, number=4,
options=None,
type=None),
_descriptor.EnumValueDescriptor(
name='FP32', index=5, number=5,
options=None,
type=None),
_descriptor.EnumValueDescriptor(
name='FP64', index=6, number=6,
options=None,
type=None),
_descriptor.EnumValueDescriptor(
name='SIZE_T', index=7, number=19,
options=None,
type=None),
_descriptor.EnumValueDescriptor(
name='UINT8', index=8, number=20,
options=None,
type=None),
_descriptor.EnumValueDescriptor(
name='INT8', index=9, number=21,
options=None,
type=None),
_descriptor.EnumValueDescriptor(
name='LOD_TENSOR', index=10, number=7,
options=None,
type=None),
_descriptor.EnumValueDescriptor(
name='SELECTED_ROWS', index=11, number=8,
options=None,
type=None),
_descriptor.EnumValueDescriptor(
name='FEED_MINIBATCH', index=12, number=9,
options=None,
type=None),
_descriptor.EnumValueDescriptor(
name='FETCH_LIST', index=13, number=10,
options=None,
type=None),
_descriptor.EnumValueDescriptor(
name='STEP_SCOPES', index=14, number=11,
options=None,
type=None),
_descriptor.EnumValueDescriptor(
name='LOD_RANK_TABLE', index=15, number=12,
options=None,
type=None),
_descriptor.EnumValueDescriptor(
name='LOD_TENSOR_ARRAY', index=16, number=13,
options=None,
type=None),
_descriptor.EnumValueDescriptor(
name='PLACE_LIST', index=17, number=14,
options=None,
type=None),
_descriptor.EnumValueDescriptor(
name='READER', index=18, number=15,
options=None,
type=None),
_descriptor.EnumValueDescriptor(
name='RAW', index=19, number=17,
options=None,
type=None),
_descriptor.EnumValueDescriptor(
name='TUPLE', index=20, number=18,
options=None,
type=None),
],
containing_type=None,
options=None,
serialized_start=1832,
serialized_end=2122,
)
_sym_db.RegisterEnumDescriptor(_VARTYPE_TYPE)
_VERSION = _descriptor.Descriptor(
name='Version',
full_name='paddle.framework.proto.Version',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='version', full_name='paddle.framework.proto.Version.version', index=0,
number=1, type=3, cpp_type=2, label=1,
has_default_value=True, default_value=0,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
],
extensions=[
],
nested_types=[],
enum_types=[
],
options=None,
is_extendable=False,
syntax='proto2',
extension_ranges=[],
oneofs=[
],
serialized_start=43,
serialized_end=72,
)
_OPDESC_ATTR = _descriptor.Descriptor(
name='Attr',
full_name='paddle.framework.proto.OpDesc.Attr',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='name', full_name='paddle.framework.proto.OpDesc.Attr.name', index=0,
number=1, type=9, cpp_type=9, label=2,
has_default_value=False, default_value=_b("").decode('utf-8'),
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='type', full_name='paddle.framework.proto.OpDesc.Attr.type', index=1,
number=2, type=14, cpp_type=8, label=2,
has_default_value=False, default_value=0,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='i', full_name='paddle.framework.proto.OpDesc.Attr.i', index=2,
number=3, type=5, cpp_type=1, label=1,
has_default_value=False, default_value=0,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='f', full_name='paddle.framework.proto.OpDesc.Attr.f', index=3,
number=4, type=2, cpp_type=6, label=1,
has_default_value=False, default_value=float(0),
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='s', full_name='paddle.framework.proto.OpDesc.Attr.s', index=4,
number=5, type=9, cpp_type=9, label=1,
has_default_value=False, default_value=_b("").decode('utf-8'),
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='ints', full_name='paddle.framework.proto.OpDesc.Attr.ints', index=5,
number=6, type=5, cpp_type=1, label=3,
has_default_value=False, default_value=[],
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='floats', full_name='paddle.framework.proto.OpDesc.Attr.floats', index=6,
number=7, type=2, cpp_type=6, label=3,
has_default_value=False, default_value=[],
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='strings', full_name='paddle.framework.proto.OpDesc.Attr.strings', index=7,
number=8, type=9, cpp_type=9, label=3,
has_default_value=False, default_value=[],
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='b', full_name='paddle.framework.proto.OpDesc.Attr.b', index=8,
number=10, type=8, cpp_type=7, label=1,
has_default_value=False, default_value=False,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='bools', full_name='paddle.framework.proto.OpDesc.Attr.bools', index=9,
number=11, type=8, cpp_type=7, label=3,
has_default_value=False, default_value=[],
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='block_idx', full_name='paddle.framework.proto.OpDesc.Attr.block_idx', index=10,
number=12, type=5, cpp_type=1, label=1,
has_default_value=False, default_value=0,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='l', full_name='paddle.framework.proto.OpDesc.Attr.l', index=11,
number=13, type=3, cpp_type=2, label=1,
has_default_value=False, default_value=0,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='blocks_idx', full_name='paddle.framework.proto.OpDesc.Attr.blocks_idx', index=12,
number=14, type=5, cpp_type=1, label=3,
has_default_value=False, default_value=[],
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='longs', full_name='paddle.framework.proto.OpDesc.Attr.longs', index=13,
number=15, type=3, cpp_type=2, label=3,
has_default_value=False, default_value=[],
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
],
extensions=[
],
nested_types=[],
enum_types=[
],
options=None,
is_extendable=False,
syntax='proto2',
extension_ranges=[],
oneofs=[
],
serialized_start=283,
serialized_end=522,
)
_OPDESC_VAR = _descriptor.Descriptor(
name='Var',
full_name='paddle.framework.proto.OpDesc.Var',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='parameter', full_name='paddle.framework.proto.OpDesc.Var.parameter', index=0,
number=1, type=9, cpp_type=9, label=2,
has_default_value=False, default_value=_b("").decode('utf-8'),
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='arguments', full_name='paddle.framework.proto.OpDesc.Var.arguments', index=1,
number=2, type=9, cpp_type=9, label=3,
has_default_value=False, default_value=[],
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
],
extensions=[
],
nested_types=[],
enum_types=[
],
options=None,
is_extendable=False,
syntax='proto2',
extension_ranges=[],
oneofs=[
],
serialized_start=524,
serialized_end=567,
)
_OPDESC = _descriptor.Descriptor(
name='OpDesc',
full_name='paddle.framework.proto.OpDesc',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='type', full_name='paddle.framework.proto.OpDesc.type', index=0,
number=3, type=9, cpp_type=9, label=2,
has_default_value=False, default_value=_b("").decode('utf-8'),
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='inputs', full_name='paddle.framework.proto.OpDesc.inputs', index=1,
number=1, type=11, cpp_type=10, label=3,
has_default_value=False, default_value=[],
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='outputs', full_name='paddle.framework.proto.OpDesc.outputs', index=2,
number=2, type=11, cpp_type=10, label=3,
has_default_value=False, default_value=[],
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='attrs', full_name='paddle.framework.proto.OpDesc.attrs', index=3,
number=4, type=11, cpp_type=10, label=3,
has_default_value=False, default_value=[],
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='is_target', full_name='paddle.framework.proto.OpDesc.is_target', index=4,
number=5, type=8, cpp_type=7, label=1,
has_default_value=True, default_value=False,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
],
extensions=[
],
nested_types=[_OPDESC_ATTR, _OPDESC_VAR, ],
enum_types=[
],
options=None,
is_extendable=False,
syntax='proto2',
extension_ranges=[],
oneofs=[
],
serialized_start=75,
serialized_end=567,
)
_OPPROTO_VAR = _descriptor.Descriptor(
name='Var',
full_name='paddle.framework.proto.OpProto.Var',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='name', full_name='paddle.framework.proto.OpProto.Var.name', index=0,
number=1, type=9, cpp_type=9, label=2,
has_default_value=False, default_value=_b("").decode('utf-8'),
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='comment', full_name='paddle.framework.proto.OpProto.Var.comment', index=1,
number=2, type=9, cpp_type=9, label=2,
has_default_value=False, default_value=_b("").decode('utf-8'),
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='duplicable', full_name='paddle.framework.proto.OpProto.Var.duplicable', index=2,
number=3, type=8, cpp_type=7, label=1,
has_default_value=True, default_value=False,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='intermediate', full_name='paddle.framework.proto.OpProto.Var.intermediate', index=3,
number=4, type=8, cpp_type=7, label=1,
has_default_value=True, default_value=False,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='dispensable', full_name='paddle.framework.proto.OpProto.Var.dispensable', index=4,
number=5, type=8, cpp_type=7, label=1,
has_default_value=True, default_value=False,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
],
extensions=[
],
nested_types=[],
enum_types=[
],
options=None,
is_extendable=False,
syntax='proto2',
extension_ranges=[],
oneofs=[
],
serialized_start=772,
serialized_end=892,
)
_OPPROTO_ATTR = _descriptor.Descriptor(
name='Attr',
full_name='paddle.framework.proto.OpProto.Attr',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='name', full_name='paddle.framework.proto.OpProto.Attr.name', index=0,
number=1, type=9, cpp_type=9, label=2,
has_default_value=False, default_value=_b("").decode('utf-8'),
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='type', full_name='paddle.framework.proto.OpProto.Attr.type', index=1,
number=2, type=14, cpp_type=8, label=2,
has_default_value=False, default_value=0,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='comment', full_name='paddle.framework.proto.OpProto.Attr.comment', index=2,
number=3, type=9, cpp_type=9, label=2,
has_default_value=False, default_value=_b("").decode('utf-8'),
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='generated', full_name='paddle.framework.proto.OpProto.Attr.generated', index=3,
number=4, type=8, cpp_type=7, label=1,
has_default_value=True, default_value=False,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
],
extensions=[
],
nested_types=[],
enum_types=[
],
options=None,
is_extendable=False,
syntax='proto2',
extension_ranges=[],
oneofs=[
],
serialized_start=894,
serialized_end=1005,
)
_OPPROTO = _descriptor.Descriptor(
name='OpProto',
full_name='paddle.framework.proto.OpProto',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='type', full_name='paddle.framework.proto.OpProto.type', index=0,
number=1, type=9, cpp_type=9, label=2,
has_default_value=False, default_value=_b("").decode('utf-8'),
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='inputs', full_name='paddle.framework.proto.OpProto.inputs', index=1,
number=2, type=11, cpp_type=10, label=3,
has_default_value=False, default_value=[],
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='outputs', full_name='paddle.framework.proto.OpProto.outputs', index=2,
number=3, type=11, cpp_type=10, label=3,
has_default_value=False, default_value=[],
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='attrs', full_name='paddle.framework.proto.OpProto.attrs', index=3,
number=4, type=11, cpp_type=10, label=3,
has_default_value=False, default_value=[],
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='comment', full_name='paddle.framework.proto.OpProto.comment', index=4,
number=5, type=9, cpp_type=9, label=2,
has_default_value=False, default_value=_b("").decode('utf-8'),
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
],
extensions=[
],
nested_types=[_OPPROTO_VAR, _OPPROTO_ATTR, ],
enum_types=[
],
options=None,
is_extendable=False,
syntax='proto2',
extension_ranges=[],
oneofs=[
],
serialized_start=570,
serialized_end=1005,
)
_VARTYPE_TENSORDESC = _descriptor.Descriptor(
name='TensorDesc',
full_name='paddle.framework.proto.VarType.TensorDesc',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='data_type', full_name='paddle.framework.proto.VarType.TensorDesc.data_type', index=0,
number=1, type=14, cpp_type=8, label=2,
has_default_value=False, default_value=0,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='dims', full_name='paddle.framework.proto.VarType.TensorDesc.dims', index=1,
number=2, type=3, cpp_type=2, label=3,
has_default_value=False, default_value=[],
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
],
extensions=[
],
nested_types=[],
enum_types=[
],
options=None,
is_extendable=False,
syntax='proto2',
extension_ranges=[],
oneofs=[
],
serialized_start=1393,
serialized_end=1476,
)
_VARTYPE_LODTENSORDESC = _descriptor.Descriptor(
name='LoDTensorDesc',
full_name='paddle.framework.proto.VarType.LoDTensorDesc',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='tensor', full_name='paddle.framework.proto.VarType.LoDTensorDesc.tensor', index=0,
number=1, type=11, cpp_type=10, label=2,
has_default_value=False, default_value=None,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='lod_level', full_name='paddle.framework.proto.VarType.LoDTensorDesc.lod_level', index=1,
number=2, type=5, cpp_type=1, label=1,
has_default_value=True, default_value=0,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
],
extensions=[
],
nested_types=[],
enum_types=[
],
options=None,
is_extendable=False,
syntax='proto2',
extension_ranges=[],
oneofs=[
],
serialized_start=1478,
serialized_end=1575,
)
_VARTYPE_LODTENSORARRAYDESC = _descriptor.Descriptor(
name='LoDTensorArrayDesc',
full_name='paddle.framework.proto.VarType.LoDTensorArrayDesc',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='tensor', full_name='paddle.framework.proto.VarType.LoDTensorArrayDesc.tensor', index=0,
number=1, type=11, cpp_type=10, label=2,
has_default_value=False, default_value=None,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='lod_level', full_name='paddle.framework.proto.VarType.LoDTensorArrayDesc.lod_level', index=1,
number=2, type=5, cpp_type=1, label=1,
has_default_value=True, default_value=0,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
],
extensions=[
],
nested_types=[],
enum_types=[
],
options=None,
is_extendable=False,
syntax='proto2',
extension_ranges=[],
oneofs=[
],
serialized_start=1577,
serialized_end=1679,
)
_VARTYPE_READERDESC = _descriptor.Descriptor(
name='ReaderDesc',
full_name='paddle.framework.proto.VarType.ReaderDesc',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='lod_tensor', full_name='paddle.framework.proto.VarType.ReaderDesc.lod_tensor', index=0,
number=1, type=11, cpp_type=10, label=3,
has_default_value=False, default_value=[],
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
],
extensions=[
],
nested_types=[],
enum_types=[
],
options=None,
is_extendable=False,
syntax='proto2',
extension_ranges=[],
oneofs=[
],
serialized_start=1681,
serialized_end=1760,
)
_VARTYPE_TUPLE = _descriptor.Descriptor(
name='Tuple',
full_name='paddle.framework.proto.VarType.Tuple',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='element_type', full_name='paddle.framework.proto.VarType.Tuple.element_type', index=0,
number=1, type=14, cpp_type=8, label=3,
has_default_value=False, default_value=[],
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
],
extensions=[
],
nested_types=[],
enum_types=[
],
options=None,
is_extendable=False,
syntax='proto2',
extension_ranges=[],
oneofs=[
],
serialized_start=1762,
serialized_end=1829,
)
_VARTYPE = _descriptor.Descriptor(
name='VarType',
full_name='paddle.framework.proto.VarType',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='type', full_name='paddle.framework.proto.VarType.type', index=0,
number=1, type=14, cpp_type=8, label=2,
has_default_value=False, default_value=0,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='selected_rows', full_name='paddle.framework.proto.VarType.selected_rows', index=1,
number=2, type=11, cpp_type=10, label=1,
has_default_value=False, default_value=None,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='lod_tensor', full_name='paddle.framework.proto.VarType.lod_tensor', index=2,
number=3, type=11, cpp_type=10, label=1,
has_default_value=False, default_value=None,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='tensor_array', full_name='paddle.framework.proto.VarType.tensor_array', index=3,
number=4, type=11, cpp_type=10, label=1,
has_default_value=False, default_value=None,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='reader', full_name='paddle.framework.proto.VarType.reader', index=4,
number=5, type=11, cpp_type=10, label=1,
has_default_value=False, default_value=None,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='tuple', full_name='paddle.framework.proto.VarType.tuple', index=5,
number=7, type=11, cpp_type=10, label=1,
has_default_value=False, default_value=None,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
],
extensions=[
],
nested_types=[_VARTYPE_TENSORDESC, _VARTYPE_LODTENSORDESC, _VARTYPE_LODTENSORARRAYDESC, _VARTYPE_READERDESC, _VARTYPE_TUPLE, ],
enum_types=[
_VARTYPE_TYPE,
],
options=None,
is_extendable=False,
syntax='proto2',
extension_ranges=[],
oneofs=[
],
serialized_start=1008,
serialized_end=2122,
)
_VARDESC = _descriptor.Descriptor(
name='VarDesc',
full_name='paddle.framework.proto.VarDesc',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='name', full_name='paddle.framework.proto.VarDesc.name', index=0,
number=1, type=9, cpp_type=9, label=2,
has_default_value=False, default_value=_b("").decode('utf-8'),
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='type', full_name='paddle.framework.proto.VarDesc.type', index=1,
number=2, type=11, cpp_type=10, label=2,
has_default_value=False, default_value=None,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='persistable', full_name='paddle.framework.proto.VarDesc.persistable', index=2,
number=3, type=8, cpp_type=7, label=1,
has_default_value=True, default_value=False,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
],
extensions=[
],
nested_types=[],
enum_types=[
],
options=None,
is_extendable=False,
syntax='proto2',
extension_ranges=[],
oneofs=[
],
serialized_start=2124,
serialized_end=2222,
)
_BLOCKDESC = _descriptor.Descriptor(
name='BlockDesc',
full_name='paddle.framework.proto.BlockDesc',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='idx', full_name='paddle.framework.proto.BlockDesc.idx', index=0,
number=1, type=5, cpp_type=1, label=2,
has_default_value=False, default_value=0,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='parent_idx', full_name='paddle.framework.proto.BlockDesc.parent_idx', index=1,
number=2, type=5, cpp_type=1, label=2,
has_default_value=False, default_value=0,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='vars', full_name='paddle.framework.proto.BlockDesc.vars', index=2,
number=3, type=11, cpp_type=10, label=3,
has_default_value=False, default_value=[],
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='ops', full_name='paddle.framework.proto.BlockDesc.ops', index=3,
number=4, type=11, cpp_type=10, label=3,
has_default_value=False, default_value=[],
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='forward_block_idx', full_name='paddle.framework.proto.BlockDesc.forward_block_idx', index=4,
number=5, type=5, cpp_type=1, label=1,
has_default_value=True, default_value=-1,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
],
extensions=[
],
nested_types=[],
enum_types=[
],
options=None,
is_extendable=False,
syntax='proto2',
extension_ranges=[],
oneofs=[
],
serialized_start=2225,
serialized_end=2392,
)
_PROGRAMDESC = _descriptor.Descriptor(
name='ProgramDesc',
full_name='paddle.framework.proto.ProgramDesc',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='blocks', full_name='paddle.framework.proto.ProgramDesc.blocks', index=0,
number=1, type=11, cpp_type=10, label=3,
has_default_value=False, default_value=[],
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='version', full_name='paddle.framework.proto.ProgramDesc.version', index=1,
number=2, type=11, cpp_type=10, label=1,
has_default_value=False, default_value=None,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
],
extensions=[
],
nested_types=[],
enum_types=[
],
options=None,
is_extendable=False,
syntax='proto2',
extension_ranges=[],
oneofs=[
],
serialized_start=2394,
serialized_end=2508,
)
_OPDESC_ATTR.fields_by_name['type'].enum_type = _ATTRTYPE
_OPDESC_ATTR.containing_type = _OPDESC
_OPDESC_VAR.containing_type = _OPDESC
_OPDESC.fields_by_name['inputs'].message_type = _OPDESC_VAR
_OPDESC.fields_by_name['outputs'].message_type = _OPDESC_VAR
_OPDESC.fields_by_name['attrs'].message_type = _OPDESC_ATTR
_OPPROTO_VAR.containing_type = _OPPROTO
_OPPROTO_ATTR.fields_by_name['type'].enum_type = _ATTRTYPE
_OPPROTO_ATTR.containing_type = _OPPROTO
_OPPROTO.fields_by_name['inputs'].message_type = _OPPROTO_VAR
_OPPROTO.fields_by_name['outputs'].message_type = _OPPROTO_VAR
_OPPROTO.fields_by_name['attrs'].message_type = _OPPROTO_ATTR
_VARTYPE_TENSORDESC.fields_by_name['data_type'].enum_type = _VARTYPE_TYPE
_VARTYPE_TENSORDESC.containing_type = _VARTYPE
_VARTYPE_LODTENSORDESC.fields_by_name['tensor'].message_type = _VARTYPE_TENSORDESC
_VARTYPE_LODTENSORDESC.containing_type = _VARTYPE
_VARTYPE_LODTENSORARRAYDESC.fields_by_name['tensor'].message_type = _VARTYPE_TENSORDESC
_VARTYPE_LODTENSORARRAYDESC.containing_type = _VARTYPE
_VARTYPE_READERDESC.fields_by_name['lod_tensor'].message_type = _VARTYPE_LODTENSORDESC
_VARTYPE_READERDESC.containing_type = _VARTYPE
_VARTYPE_TUPLE.fields_by_name['element_type'].enum_type = _VARTYPE_TYPE
_VARTYPE_TUPLE.containing_type = _VARTYPE
_VARTYPE.fields_by_name['type'].enum_type = _VARTYPE_TYPE
_VARTYPE.fields_by_name['selected_rows'].message_type = _VARTYPE_TENSORDESC
_VARTYPE.fields_by_name['lod_tensor'].message_type = _VARTYPE_LODTENSORDESC
_VARTYPE.fields_by_name['tensor_array'].message_type = _VARTYPE_LODTENSORARRAYDESC
_VARTYPE.fields_by_name['reader'].message_type = _VARTYPE_READERDESC
_VARTYPE.fields_by_name['tuple'].message_type = _VARTYPE_TUPLE
_VARTYPE_TYPE.containing_type = _VARTYPE
_VARDESC.fields_by_name['type'].message_type = _VARTYPE
_BLOCKDESC.fields_by_name['vars'].message_type = _VARDESC
_BLOCKDESC.fields_by_name['ops'].message_type = _OPDESC
_PROGRAMDESC.fields_by_name['blocks'].message_type = _BLOCKDESC
_PROGRAMDESC.fields_by_name['version'].message_type = _VERSION
DESCRIPTOR.message_types_by_name['Version'] = _VERSION
DESCRIPTOR.message_types_by_name['OpDesc'] = _OPDESC
DESCRIPTOR.message_types_by_name['OpProto'] = _OPPROTO
DESCRIPTOR.message_types_by_name['VarType'] = _VARTYPE
DESCRIPTOR.message_types_by_name['VarDesc'] = _VARDESC
DESCRIPTOR.message_types_by_name['BlockDesc'] = _BLOCKDESC
DESCRIPTOR.message_types_by_name['ProgramDesc'] = _PROGRAMDESC
DESCRIPTOR.enum_types_by_name['AttrType'] = _ATTRTYPE
Version = _reflection.GeneratedProtocolMessageType('Version', (_message.Message,), dict(
DESCRIPTOR = _VERSION,
__module__ = 'framework_pb2'
# @@protoc_insertion_point(class_scope:paddle.framework.proto.Version)
))
_sym_db.RegisterMessage(Version)
OpDesc = _reflection.GeneratedProtocolMessageType('OpDesc', (_message.Message,), dict(
Attr = _reflection.GeneratedProtocolMessageType('Attr', (_message.Message,), dict(
DESCRIPTOR = _OPDESC_ATTR,
__module__ = 'framework_pb2'
# @@protoc_insertion_point(class_scope:paddle.framework.proto.OpDesc.Attr)
))
,
Var = _reflection.GeneratedProtocolMessageType('Var', (_message.Message,), dict(
DESCRIPTOR = _OPDESC_VAR,
__module__ = 'framework_pb2'
# @@protoc_insertion_point(class_scope:paddle.framework.proto.OpDesc.Var)
))
,
DESCRIPTOR = _OPDESC,
__module__ = 'framework_pb2'
# @@protoc_insertion_point(class_scope:paddle.framework.proto.OpDesc)
))
_sym_db.RegisterMessage(OpDesc)
_sym_db.RegisterMessage(OpDesc.Attr)
_sym_db.RegisterMessage(OpDesc.Var)
OpProto = _reflection.GeneratedProtocolMessageType('OpProto', (_message.Message,), dict(
Var = _reflection.GeneratedProtocolMessageType('Var', (_message.Message,), dict(
DESCRIPTOR = _OPPROTO_VAR,
__module__ = 'framework_pb2'
# @@protoc_insertion_point(class_scope:paddle.framework.proto.OpProto.Var)
))
,
Attr = _reflection.GeneratedProtocolMessageType('Attr', (_message.Message,), dict(
DESCRIPTOR = _OPPROTO_ATTR,
__module__ = 'framework_pb2'
# @@protoc_insertion_point(class_scope:paddle.framework.proto.OpProto.Attr)
))
,
DESCRIPTOR = _OPPROTO,
__module__ = 'framework_pb2'
# @@protoc_insertion_point(class_scope:paddle.framework.proto.OpProto)
))
_sym_db.RegisterMessage(OpProto)
_sym_db.RegisterMessage(OpProto.Var)
_sym_db.RegisterMessage(OpProto.Attr)
VarType = _reflection.GeneratedProtocolMessageType('VarType', (_message.Message,), dict(
TensorDesc = _reflection.GeneratedProtocolMessageType('TensorDesc', (_message.Message,), dict(
DESCRIPTOR = _VARTYPE_TENSORDESC,
__module__ = 'framework_pb2'
# @@protoc_insertion_point(class_scope:paddle.framework.proto.VarType.TensorDesc)
))
,
LoDTensorDesc = _reflection.GeneratedProtocolMessageType('LoDTensorDesc', (_message.Message,), dict(
DESCRIPTOR = _VARTYPE_LODTENSORDESC,
__module__ = 'framework_pb2'
# @@protoc_insertion_point(class_scope:paddle.framework.proto.VarType.LoDTensorDesc)
))
,
LoDTensorArrayDesc = _reflection.GeneratedProtocolMessageType('LoDTensorArrayDesc', (_message.Message,), dict(
DESCRIPTOR = _VARTYPE_LODTENSORARRAYDESC,
__module__ = 'framework_pb2'
# @@protoc_insertion_point(class_scope:paddle.framework.proto.VarType.LoDTensorArrayDesc)
))
,
ReaderDesc = _reflection.GeneratedProtocolMessageType('ReaderDesc', (_message.Message,), dict(
DESCRIPTOR = _VARTYPE_READERDESC,
__module__ = 'framework_pb2'
# @@protoc_insertion_point(class_scope:paddle.framework.proto.VarType.ReaderDesc)
))
,
Tuple = _reflection.GeneratedProtocolMessageType('Tuple', (_message.Message,), dict(
DESCRIPTOR = _VARTYPE_TUPLE,
__module__ = 'framework_pb2'
# @@protoc_insertion_point(class_scope:paddle.framework.proto.VarType.Tuple)
))
,
DESCRIPTOR = _VARTYPE,
__module__ = 'framework_pb2'
# @@protoc_insertion_point(class_scope:paddle.framework.proto.VarType)
))
_sym_db.RegisterMessage(VarType)
_sym_db.RegisterMessage(VarType.TensorDesc)
_sym_db.RegisterMessage(VarType.LoDTensorDesc)
_sym_db.RegisterMessage(VarType.LoDTensorArrayDesc)
_sym_db.RegisterMessage(VarType.ReaderDesc)
_sym_db.RegisterMessage(VarType.Tuple)
VarDesc = _reflection.GeneratedProtocolMessageType('VarDesc', (_message.Message,), dict(
DESCRIPTOR = _VARDESC,
__module__ = 'framework_pb2'
# @@protoc_insertion_point(class_scope:paddle.framework.proto.VarDesc)
))
_sym_db.RegisterMessage(VarDesc)
BlockDesc = _reflection.GeneratedProtocolMessageType('BlockDesc', (_message.Message,), dict(
DESCRIPTOR = _BLOCKDESC,
__module__ = 'framework_pb2'
# @@protoc_insertion_point(class_scope:paddle.framework.proto.BlockDesc)
))
_sym_db.RegisterMessage(BlockDesc)
ProgramDesc = _reflection.GeneratedProtocolMessageType('ProgramDesc', (_message.Message,), dict(
DESCRIPTOR = _PROGRAMDESC,
__module__ = 'framework_pb2'
# @@protoc_insertion_point(class_scope:paddle.framework.proto.ProgramDesc)
))
_sym_db.RegisterMessage(ProgramDesc)
DESCRIPTOR.has_options = True
DESCRIPTOR._options = _descriptor._ParseOptions(descriptor_pb2.FileOptions(), _b('H\003'))
# @@protoc_insertion_point(module_scope)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册