未验证 提交 bdae1084 编写于 作者: J Jason 提交者: GitHub

Merge pull request #18 from MacroBull/master

Bugfixes
......@@ -11,15 +11,32 @@ import numpy as np
from collections import OrderedDict as Dict
def _make_var_name(name):
"""
make a valid variable name in Python code
"""
if name == '':
return '_'
if name[0].isdigit():
return 'var_' + name
for s in ' *?\\/-:':
name = name.replace(s, '_')
if name.startswith('_'):
name = 'var' + name
return name
fn = sys.argv[1]
input_names = sys.argv[2].split(':')
output_name = sys.argv[3].split(':')
data = np.load(fn)
data = np.load(fn, encoding='bytes')
input_data = data['inputs']
output_data = data['outputs']
inputs = Dict(zip(input_names, [input_data]))
outputs = Dict(zip(output_name, [output_data]))
inputs = Dict(zip(map(_make_var_name, input_names), [input_data]))
outputs = Dict(zip(map(_make_var_name, output_name), [output_data]))
np.savez(fn, inputs=inputs, outputs=outputs) # overwrite
......@@ -16,6 +16,23 @@ import onnx.numpy_helper as numpy_helper
from collections import OrderedDict as Dict
from glob import glob
def _make_var_name(name):
"""
make a valid variable name in Python code
"""
if name == '':
return '_'
if name[0].isdigit():
return 'var_' + name
for s in ' *?\\/-:':
name = name.replace(s, '_')
if name.startswith('_'):
name = 'var' + name
return name
data_dir = os.path.dirname(sys.argv[1])
input_names = sys.argv[2].split(':')
output_name = sys.argv[3].split(':')
......@@ -36,7 +53,7 @@ for fn in glob(os.path.join(data_dir, 'output_*.pb')):
tensor.ParseFromString(f.read())
outputs.append(numpy_helper.to_array(tensor))
inputs = Dict(zip(input_names, inputs))
outputs = Dict(zip(output_name, outputs))
inputs = Dict(zip(map(_make_var_name, input_names), inputs))
outputs = Dict(zip(map(_make_var_name, output_name), outputs))
np.savez(data_dir, inputs=inputs, outputs=outputs)
#! /usr/bin/env sh
get_url="aria2c -c -s8 -x8"
# setopt SH_WORD_SPLIT # if zsh
base_url="https://s3.amazonaws.com/download.onnx/models/opset_9/"
flags="-e -o /tmp/export/"
convert_flags="-e -o /tmp/export/"
validate_flags1="/tmp/export/model.py"
validate_flags2="/tmp/export/__model__"
# alias http_get="wget -c" # if no aria2
alias http_get="aria2c -c -s8 -x8"
# alias python="python3" # if ...
bvlc_alexnet()
{
......@@ -10,21 +17,24 @@ bvlc_alexnet()
fn_tar="$bn_tar.tar.gz"
fn_model="$bn_tar/model.onnx"
$get_url "$base_url$fn_tar"
http_get "$base_url$fn_tar"
echo "extracting ..."
tar xf "$fn_tar"
for npz in $bn_tar/*.npz
python -m onnx2fluid $convert_flags "$fn_model"
for npz in "$bn_tar"/*.npz
do
echo "converting $npz ..."
python convert_data_npz_0.py "$npz" "data_0" "prob_1"
python -m onnx2fluid $flags "$fn_model" -t $npz
python convert_data_npz_0.py "$npz" data_0 prob_1
python -m onnx2fluid.validation $validate_flags1 -t "$npz"
python -m onnx2fluid.validation $validate_flags2 -t "$npz"
done
for pb_dir in $bn_tar/*/
for pb_dir in "$bn_tar"/*/
do
echo "converting $pb_dir ..."
python convert_data_pb_0.py "$pb_dir" "data_0" "prob_1"
python -m onnx2fluid $flags "$fn_model" -t $(dirname "$pb_dir/x").npz
python convert_data_pb_0.py "$pb_dir" data_0 prob_1
python -m onnx2fluid.validation $validate_flags1 -t $(dirname "$pb_dir/x").npz
python -m onnx2fluid.validation $validate_flags2 -t $(dirname "$pb_dir/x").npz
done
}
......@@ -34,15 +44,17 @@ bvlc_googlenet()
fn_tar="$bn_tar.tar.gz"
fn_model="$bn_tar/model.onnx"
$get_url "$base_url$fn_tar"
http_get "$base_url$fn_tar"
echo "extracting ..."
tar xf "$fn_tar"
for pb_dir in $bn_tar/*/
python -m onnx2fluid $convert_flags "$fn_model"
for pb_dir in "$bn_tar"/*/
do
echo "converting $pb_dir"
python convert_data_pb_0.py "$pb_dir" "data_0" "prob_1"
python -m onnx2fluid $flags "$fn_model" -t $(dirname "$pb_dir/x").npz
python convert_data_pb_0.py "$pb_dir" data_0 prob_1
python -m onnx2fluid.validation $validate_flags1 -t $(dirname "$pb_dir/x").npz
python -m onnx2fluid.validation $validate_flags2 -t $(dirname "$pb_dir/x").npz
done
}
......@@ -52,15 +64,17 @@ bvlc_reference_caffenet()
fn_tar="$bn_tar.tar.gz"
fn_model="$bn_tar/model.onnx"
$get_url "$base_url$fn_tar"
http_get "$base_url$fn_tar"
echo "extracting ..."
tar xf "$fn_tar"
for pb_dir in $bn_tar/*/
python -m onnx2fluid $convert_flags "$fn_model"
for pb_dir in "$bn_tar"/*/
do
echo "converting $pb_dir"
python convert_data_pb_0.py "$pb_dir" "data_0" "prob_1"
python -m onnx2fluid $flags "$fn_model" -t $(dirname "$pb_dir/x").npz
python convert_data_pb_0.py "$pb_dir" data_0 prob_1
python -m onnx2fluid.validation $validate_flags1 -t $(dirname "$pb_dir/x").npz
python -m onnx2fluid.validation $validate_flags2 -t $(dirname "$pb_dir/x").npz
done
}
......@@ -70,15 +84,17 @@ bvlc_reference_rcnn_ilsvrc13()
fn_tar="$bn_tar.tar.gz"
fn_model="$bn_tar/model.onnx"
$get_url "$base_url$fn_tar"
http_get "$base_url$fn_tar"
echo "extracting ..."
tar xf "$fn_tar"
for pb_dir in $bn_tar/*/
python -m onnx2fluid $convert_flags "$fn_model"
for pb_dir in "$bn_tar"/*/
do
echo "converting $pb_dir"
python convert_data_pb_0.py "$pb_dir" "data_0" "fc_rcnn_1"
python -m onnx2fluid $flags "$fn_model" -t $(dirname "$pb_dir/x").npz
python convert_data_pb_0.py "$pb_dir" data_0 fc-rcnn_1
python -m onnx2fluid.validation $validate_flags1 -t $(dirname "$pb_dir/x").npz
python -m onnx2fluid.validation $validate_flags2 -t $(dirname "$pb_dir/x").npz
done
}
......@@ -88,22 +104,24 @@ inception_v1()
fn_tar="$bn_tar.tar.gz"
fn_model="$bn_tar/model.onnx"
$get_url "$base_url$fn_tar"
http_get "$base_url$fn_tar"
echo "extracting ..."
tar xf "$fn_tar"
for npz in $bn_tar/*.npz
python -m onnx2fluid $convert_flags "$fn_model"
for npz in "$bn_tar"/*.npz
do
echo "converting $npz ..."
python convert_data_npz_0.py "$npz" "data_0" "prob_1"
python -m onnx2fluid $flags "$fn_model" -t $npz
python convert_data_npz_0.py "$npz" data_0 prob_1
python -m onnx2fluid.validation $validate_flags1 -t "$npz"
python -m onnx2fluid.validation $validate_flags2 -t "$npz"
done
for pb_dir in $bn_tar/*/
for pb_dir in "$bn_tar"/*/
do
echo "converting $pb_dir"
python convert_data_pb_0.py "$pb_dir" "data_0" "prob_1"
python -m onnx2fluid $flags "$fn_model" -t $(dirname "$pb_dir/x").npz
echo "converting $pb_dir ..."
python convert_data_pb_0.py "$pb_dir" data_0 prob_1
python -m onnx2fluid.validation $validate_flags1 -t $(dirname "$pb_dir/x").npz
python -m onnx2fluid.validation $validate_flags2 -t $(dirname "$pb_dir/x").npz
done
}
......@@ -113,22 +131,24 @@ inception_v2()
fn_tar="$bn_tar.tar.gz"
fn_model="$bn_tar/model.onnx"
$get_url "$base_url$fn_tar"
http_get "$base_url$fn_tar"
echo "extracting ..."
tar xf "$fn_tar"
for npz in $bn_tar/*.npz
python -m onnx2fluid $convert_flags "$fn_model"
for npz in "$bn_tar"/*.npz
do
echo "converting $npz ..."
python convert_data_npz_0.py "$npz" "data_0" "prob_1"
python -m onnx2fluid $flags "$fn_model" -t $npz
python convert_data_npz_0.py "$npz" data_0 prob_1
python -m onnx2fluid.validation $validate_flags1 -t "$npz"
python -m onnx2fluid.validation $validate_flags2 -t "$npz"
done
for pb_dir in $bn_tar/*/
for pb_dir in "$bn_tar"/*/
do
echo "converting $pb_dir"
python convert_data_pb_0.py "$pb_dir" "data_0" "prob_1"
python -m onnx2fluid $flags "$fn_model" -t $(dirname "$pb_dir/x").npz
echo "converting $pb_dir ..."
python convert_data_pb_0.py "$pb_dir" data_0 prob_1
python -m onnx2fluid.validation $validate_flags1 -t $(dirname "$pb_dir/x").npz
python -m onnx2fluid.validation $validate_flags2 -t $(dirname "$pb_dir/x").npz
done
}
......@@ -138,22 +158,24 @@ resnet50()
fn_tar="$bn_tar.tar.gz"
fn_model="$bn_tar/model.onnx"
$get_url "$base_url$fn_tar"
http_get "$base_url$fn_tar"
echo "extracting ..."
tar xf "$fn_tar"
for npz in $bn_tar/*.npz
python -m onnx2fluid $convert_flags "$fn_model"
for npz in "$bn_tar"/*.npz
do
echo "converting $npz ..."
python convert_data_npz_0.py "$npz" "gpu_0/data_0" "gpu_0/softmaxout_1"
python -m onnx2fluid $flags "$fn_model" -t $npz
python convert_data_npz_0.py "$npz" gpu_0/data_0 gpu_0/softmaxout_1
python -m onnx2fluid.validation $validate_flags1 -t "$npz"
python -m onnx2fluid.validation $validate_flags2 -t "$npz"
done
for pb_dir in $bn_tar/*/
for pb_dir in "$bn_tar"/*/
do
echo "converting $pb_dir"
python convert_data_pb_0.py "$pb_dir" "gpu_0/data_0" "gpu_0/softmaxout_1"
python -m onnx2fluid $flags "$fn_model" -t $(dirname "$pb_dir/x").npz
echo "converting $pb_dir ..."
python convert_data_pb_0.py "$pb_dir" gpu_0/data_0 gpu_0/softmaxout_1
python -m onnx2fluid.validation $validate_flags1 -t $(dirname "$pb_dir/x").npz
python -m onnx2fluid.validation $validate_flags2 -t $(dirname "$pb_dir/x").npz
done
}
......@@ -163,15 +185,17 @@ shufflenet()
fn_tar="$bn_tar.tar.gz"
fn_model="$bn_tar/model.onnx"
$get_url "$base_url$fn_tar"
http_get "$base_url$fn_tar"
echo "extracting ..."
tar xf "$fn_tar"
for pb_dir in $bn_tar/*/
python -m onnx2fluid $convert_flags "$fn_model"
for pb_dir in "$bn_tar"/*/
do
echo "converting $pb_dir"
python convert_data_pb_0.py "$pb_dir" "gpu_0/data_0" "gpu_0/softmaxout_1"
python -m onnx2fluid $flags "$fn_model" -t $(dirname "$pb_dir/x").npz
echo "converting $pb_dir ..."
python convert_data_pb_0.py "$pb_dir" gpu_0/data_0 gpu_0/softmaxout_1
python -m onnx2fluid.validation $validate_flags1 -t $(dirname "$pb_dir/x").npz
python -m onnx2fluid.validation $validate_flags2 -t $(dirname "$pb_dir/x").npz
done
}
......@@ -181,15 +205,17 @@ squeezenet()
fn_tar="$bn_tar.tar.gz"
fn_model="$bn_tar/model.onnx"
$get_url "$base_url$fn_tar"
http_get "$base_url$fn_tar"
echo "extracting ..."
tar xf "$fn_tar"
for pb_dir in $bn_tar/*/
python -m onnx2fluid $convert_flags "$fn_model"
for pb_dir in "$bn_tar"/*/
do
echo "converting $pb_dir"
python convert_data_pb_0.py "$pb_dir" "data_0" "softmaxout_1"
python -m onnx2fluid $flags "$fn_model" -t $(dirname "$pb_dir/x").npz
python convert_data_pb_0.py "$pb_dir" data_0 softmaxout_1
python -m onnx2fluid.validation $validate_flags1 -t $(dirname "$pb_dir/x").npz
python -m onnx2fluid.validation $validate_flags2 -t $(dirname "$pb_dir/x").npz
done
}
......@@ -199,15 +225,17 @@ tiny_yolov2()
fn_tar="$bn_tar.tar.gz"
fn_model="$bn_tar/model.onnx"
$get_url "https://onnxzoo.blob.core.windows.net/models/opset_8/tiny_yolov2/$fn_tar"
http_get "https://onnxzoo.blob.core.windows.net/models/opset_8/tiny_yolov2/$fn_tar"
echo "extracting ..."
tar xf "$fn_tar"
for pb_dir in $bn_tar/*/
python -m onnx2fluid $convert_flags "$fn_model" -xy
for pb_dir in "$bn_tar"/*/
do
echo "converting $pb_dir"
python convert_data_pb_0.py "$pb_dir" "image" "grid"
python -m onnx2fluid $flags "$fn_model" -t $(dirname "$pb_dir/x").npz -x
python convert_data_pb_0.py "$pb_dir" image grid
python -m onnx2fluid.validation $validate_flags1 -t $(dirname "$pb_dir/x").npz
python -m onnx2fluid.validation $validate_flags2 -t $(dirname "$pb_dir/x").npz
done
}
......@@ -217,15 +245,17 @@ vgg19()
fn_tar="$bn_tar.tar.gz"
fn_model="$bn_tar/model.onnx"
$get_url "$base_url$fn_tar"
http_get "$base_url$fn_tar"
echo "extracting ..."
tar xf "$fn_tar"
for pb_dir in $bn_tar/*/
python -m onnx2fluid $convert_flags "$fn_model"
for pb_dir in "$bn_tar"/*/
do
echo "converting $pb_dir"
python convert_data_pb_0.py "$pb_dir" "data_0" "prob_1"
python -m onnx2fluid $flags "$fn_model" -t $(dirname "$pb_dir/x").npz
python convert_data_pb_0.py "$pb_dir" data_0 prob_1
python -m onnx2fluid.validation $validate_flags1 -t $(dirname "$pb_dir/x").npz
python -m onnx2fluid.validation $validate_flags2 -t $(dirname "$pb_dir/x").npz
done
}
......@@ -235,15 +265,17 @@ zfnet512()
fn_tar="$bn_tar.tar.gz"
fn_model="$bn_tar/model.onnx"
$get_url "$base_url$fn_tar"
http_get "$base_url$fn_tar"
echo "extracting ..."
tar xf "$fn_tar"
for pb_dir in $bn_tar/*/
python -m onnx2fluid $convert_flags "$fn_model"
for pb_dir in "$bn_tar"/*/
do
echo "converting $pb_dir"
python convert_data_pb_0.py "$pb_dir" "gpu_0/data_0" "gpu_0/softmax_1"
python -m onnx2fluid $flags "$fn_model" -t $(dirname "$pb_dir/x").npz
python convert_data_pb_0.py "$pb_dir" gpu_0/data_0 gpu_0/softmax_1
python -m onnx2fluid.validation $validate_flags1 -t $(dirname "$pb_dir/x").npz
python -m onnx2fluid.validation $validate_flags2 -t $(dirname "$pb_dir/x").npz
done
}
......
......@@ -69,11 +69,27 @@ parser.add_argument(
dest='pedantic',
help='process non-standard ONNX ops, this may lead to fails',
)
parser.add_argument(
'--skip-version-conversion',
'-y',
action='store_true',
default=False,
help='skip ONNX op version conversion, workaround for RumtimeErrors',
)
parser.add_argument(
'--archive',
'-z',
nargs='?',
type=str,
default=None,
const='',
help='compress outputs to ZIP file if conversion successed',
)
parser.add_argument(
'--precision',
'-p',
type=int,
default=4,
default=3,
help='assertion decimal for validation',
)
args = parser.parse_args()
......
......@@ -16,10 +16,10 @@ from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
# import logging, shutil, zipfile
import logging
import shutil
import zipfile
import logging, shutil, zipfile
#import logging
#import shutil
#import zipfile
__all__ = [
'main',
......@@ -49,12 +49,14 @@ def main(**kwargs):
basepath, _ = shutil.os.path.splitext(filename)
save_dir = kwargs.get('output_dir', '')
# model.onnx -> model/
save_dir = shutil.os.path.dirname(save_dir) if save_dir else basepath
save_dir = (save_dir.rstrip('/') if save_dir else basepath) + '/'
model_basename = DEFAULT_MODEL_MODULE + '.py'
model_func_name = DEFAULT_MODEL_FUNC
embed_params = kwargs.get('embed_params', False)
onnx_opset_version = DEFAULT_ONNX_OPSET_VERSION
onnx_opset_pedantic = kwargs.get('pedantic', True)
onnx_skip_version_conversion = kwargs.get('skip_version_conversion', False)
archive = kwargs.get('archive', None)
# convert
convert(
......@@ -65,6 +67,7 @@ def main(**kwargs):
embed_params=embed_params,
onnx_opset_version=onnx_opset_version,
onnx_opset_pedantic=onnx_opset_pedantic,
onnx_skip_version_conversion=onnx_skip_version_conversion,
debug=debug)
# validate
......@@ -81,13 +84,13 @@ def main(**kwargs):
# in fact fluid can not fully clear the context
# continuous validation may be inaccurate
precision = 10**-kwargs.get('precision', 4)
decimal = kwargs.get('precision', 3)
logger.info('starting validation on desc ...')
passed &= validate(
shutil.os.path.join(save_dir, '__model__'),
golden_data_filename,
precision=precision,
decimal=decimal,
)
logger.info('starting validation on code ...')
......@@ -95,7 +98,7 @@ def main(**kwargs):
shutil.os.path.join(save_dir, model_basename),
golden_data_filename,
model_func_name=model_func_name,
precision=precision,
decimal=decimal,
save_inference_model=debug, # this overwrite desc file for test
)
......@@ -104,13 +107,21 @@ def main(**kwargs):
return
# create zip file
fn_zip = save_dir.rstrip('/') + '.zip'
logger.info('compressing file to %s ...', fn_zip)
fz = zipfile.ZipFile(fn_zip, 'w', compression=zipfile.ZIP_LZMA)
for fn in shutil.os.listdir(save_dir):
fz.write(shutil.os.path.join(save_dir, fn), arcname=fn)
fz.close()
logger.info('compressing done')
if archive is not None:
if archive == '':
archive = save_dir.rstrip('/') + '.zip'
logger.info('compressing file to %s ...', archive)
shutil.sys.stderr.write('\n')
shutil.sys.stderr.flush()
file_list = shutil.os.listdir(save_dir)
fz = zipfile.ZipFile(archive, 'w', compression=zipfile.ZIP_LZMA)
for idx, fn in enumerate(file_list):
shutil.sys.stderr.write('\033[F\033[2K')
logger.info('file {}/{}: {}'.format(idx + 1, len(file_list), fn))
shutil.sys.stderr.flush()
fz.write(shutil.os.path.join(save_dir, fn), arcname=fn)
fz.close()
logger.info('compressing done')
if __name__ == '__main__':
......@@ -120,17 +131,19 @@ if __name__ == '__main__':
level=logging.DEBUG,
)
# main(model=['../examples/t5.onnx'],
# output_dir='/tmp/export/',
# embed_params=False,
# pedantic=False,
# test_data='../examples/t5.npz',
# debug=True)
main(
model=['../examples/t1.onnx'],
output_dir='/tmp/export/',
embed_params=False,
pedantic=False,
test_data='../examples/t1.npz',
debug=True)
main(
model=['../examples/inception_v2/model.onnx'],
output_dir='/tmp/export/',
embed_params=True,
pedantic=False,
skip_version_conversion=False,
test_data='../examples/inception_v2/test_data_set_2.npz',
debug=True)
......@@ -8,9 +8,9 @@ Created on Mon Feb 25 09:50:35 2019
from __future__ import division
# import logging, shutil
import logging
import shutil
import logging, shutil
#import logging
#import shutil
__all__ = [
'convert',
......@@ -24,6 +24,7 @@ def convert(onnx_model_filename,
embed_params=False,
onnx_opset_version=9,
onnx_opset_pedantic=True,
onnx_skip_version_conversion=False,
debug=False):
"""
convert an ONNX model to Paddle fluid Python code and desc pb
......@@ -60,12 +61,13 @@ def convert(onnx_model_filename,
try:
logger.info('checking model ...')
check_model(onnx_model)
logger.debug('using opset version: %d', onnx_opset_version)
if onnx_opset_pedantic: # WORKAROUND: RuntimeError: No Adapter For OP
onnx_model = convert_version(onnx_model, onnx_opset_version)
else: # TODO: add new argument for this option
if onnx_skip_version_conversion: # WORKAROUND: RuntimeError: No Adapter For OP
logger.debug('assumed opset version: %d', onnx_opset_version)
logger.warning(
'opset conversion skipped for onnx_opset_pedantic is OFF')
else:
logger.debug('using opset version: %d', onnx_opset_version)
onnx_model = convert_version(onnx_model, onnx_opset_version)
onnx_model = polish_model(onnx_model)
except ValidationError as e:
if onnx_opset_pedantic:
......@@ -152,16 +154,15 @@ def convert(onnx_model_filename,
logger.info(
'weight %s is shared between ops, more disk space will be consumed',
name)
logger.debug(
'saving weight %s with size of %d, in %d bytes, as %s ...',
name, weight.size, weight.nbytes, var_names)
logger.debug('saving weight %s(%s[%d], %dB) as %s ...', name,
weight.dtype, weight.size, weight.nbytes, var_names)
for var_name in var_names: # multiple references
fluid_writer.write_weight(
weight, shutil.os.path.join(save_dir, var_name))
else:
logger.debug(
'saving weight %s with size of %d, in %d bytes, to %s ...',
name, weight.size, weight.nbytes, make_var_name(name))
logger.debug('saving weight %s(%s[%d], %dB) to %s ...', name,
weight.dtype, weight.size, weight.nbytes,
make_var_name(name))
fluid_writer.write_weight(
weight, shutil.os.path.join(save_dir, make_var_name(name)))
fluid_writer.emit_param(fluid_program, name, value_info)
......@@ -262,6 +263,13 @@ if __name__ == '__main__':
dest='pedantic',
help='process non-standard ONNX ops, this may lead to fails',
)
parser.add_argument(
'--skip-version-conversion',
'-y',
action='store_true',
default=False,
help='skip ONNX op version conversion, workaround for RumtimeErrors',
)
args = parser.parse_args()
logging_format = '[%(levelname)8s]%(name)s::%(funcName)s:%(lineno)04d: %(message)s'
......@@ -273,10 +281,12 @@ if __name__ == '__main__':
save_dir = args.output_dir
embed_params = args.embed_params
pedantic = args.pedantic
skip_version_conversion = args.skip_version_conversion
convert(
model_filename,
save_dir,
embed_params=embed_params,
onnx_opset_pedantic=pedantic,
onnx_skip_version_conversion=skip_version_conversion,
debug=debug)
......@@ -26,6 +26,7 @@ __all__ = [
'node_attrs',
'node_topo',
'node_iter',
'tensor_dtype',
'tensor_shape',
'graph_ops',
'graph_weights',
......@@ -92,13 +93,12 @@ def get_attribute_value2(attr):
return value
def node_attrs(node):
def tensor_dtype(tensor):
"""
convert ONNX node attributes to dict
get ONNX tensor in np.dtype
"""
return {attr.name: get_attribute_value2(attr)
for attr in node.attribute} # dict
return TENSOR_TYPE_TO_NP_TYPE[tensor.type.tensor_type.elem_type]
def tensor_shape(tensor):
......@@ -109,6 +109,15 @@ def tensor_shape(tensor):
return [dim.dim_value for dim in tensor.type.tensor_type.shape.dim]
def node_attrs(node):
"""
convert ONNX node attributes to dict
"""
return {attr.name: get_attribute_value2(attr)
for attr in node.attribute} # dict
def node_topo(nodes, topo='default'):
"""
build indices with given topology to an ONNX node graph
......@@ -237,21 +246,21 @@ def inferred_model_value_info(model):
value_info = Dict()
for item in graph.value_info:
value_info[item.name] = dict(
dtype=TENSOR_TYPE_TO_NP_TYPE[item.type.tensor_type.elem_type],
dtype=tensor_dtype(item),
shape=tensor_shape(item),
external=False,
)
for item in graph.input:
assert item.name not in value_info
value_info[item.name] = dict(
dtype=TENSOR_TYPE_TO_NP_TYPE[item.type.tensor_type.elem_type],
dtype=tensor_dtype(item),
shape=tensor_shape(item),
external=True,
)
for item in graph.output:
# assert item.name not in value_info, 'bypass-model not supported'
value_info[item.name] = dict(
dtype=TENSOR_TYPE_TO_NP_TYPE[item.type.tensor_type.elem_type],
dtype=tensor_dtype(item),
shape=tensor_shape(item),
external=True,
)
......@@ -373,9 +382,9 @@ def optimize_model_strip_initializer(model, keep_input_only=True):
elif not keep_input_only and name in output_refs:
ret_initializers.add().CopyFrom(initializer)
else:
logger.debug('initializer %s(%s[%d]) stripped', name,
TENSOR_TYPE_TO_NP_TYPE[initializer.data_type],
len(initializer.raw_data))
dtype = TENSOR_TYPE_TO_NP_TYPE[initializer.data_type]
logger.debug('initializer %s(%s[%d]) stripped', name, dtype,
len(initializer.raw_data) // dtype.itemsize)
# strip inputs
ret.graph.ClearField('input')
......@@ -385,10 +394,8 @@ def optimize_model_strip_initializer(model, keep_input_only=True):
if name in input_refs or name in out_names:
ret_inputs.add().CopyFrom(item)
else:
logger.debug(
'input %s(%s%s) stripped', name,
TENSOR_TYPE_TO_NP_TYPE[item.type.tensor_type.elem_type],
tensor_shape(item))
logger.debug('input %s(%s%s) stripped', name, tensor_dtype(item),
tensor_shape(item))
return ret
......
此差异已折叠。
......@@ -6,11 +6,12 @@ Created on Fri Mar 22 12:17:19 2019
@author: Macrobull
"""
# import importlib, logging, os, sys
import importlib
import logging
import os
import sys
import importlib, logging, os, sys
#import importlib
#import logging
#import os
#import sys
def _flatten_dict(obj, out=None):
......@@ -36,7 +37,7 @@ def _ensure_list(obj):
def validate(fluid_model_filename,
golden_data_filename,
model_func_name='inference',
precision=1e-4,
decimal=3,
save_inference_model=False):
"""
inferece the converted Paddle fluid model, validate with given golden data
......@@ -90,16 +91,17 @@ def validate(fluid_model_filename,
# load data
logger.info('using golden data %s', golden_data_filename)
if golden_data_filename.endswith('.npz'):
test_data = np.load(golden_data_filename)
test_data = np.load(golden_data_filename, encoding='bytes')
input_data = test_data['inputs'].tolist()
output_data = test_data['outputs'].tolist()
else:
test_data = np.load(golden_data_filename).tolist()
input_data = input_data['inputs']
output_data = output_data['outputs']
test_data = np.load(golden_data_filename, encoding='bytes').tolist()
input_data = test_data['inputs']
output_data = test_data['outputs']
input_data = _flatten_dict(input_data)
output_data = _flatten_dict(output_data)
logger.info('found %d I/O golden data, starting test ...', len(test_data))
logger.info('found %d I/O golden data, starting test ...',
len(input_data) + len(output_data))
# DEBUG: reload test for python code
if basename.endswith('.py') and save_inference_model:
......@@ -123,7 +125,7 @@ def validate(fluid_model_filename,
for (name, truth), output in zip(output_data.items(), outputs):
logger.info('testing output {} ...'.format(name))
try:
np.testing.assert_almost_equal(output, truth, decimal=precision)
np.testing.assert_almost_equal(output, truth, decimal=decimal)
except AssertionError as e:
passed = False
logger.error('failed: %s\n', e)
......@@ -164,7 +166,7 @@ if __name__ == '__main__':
'--precision',
'-p',
type=int,
default=4,
default=3,
help='assertion decimal for validation',
)
args = parser.parse_args()
......@@ -176,10 +178,10 @@ if __name__ == '__main__':
debug = args.debug
fluid_model_filename = args.model[0]
golden_data_filename = args.test_data
precision = args.precision
decimal = args.precision
validate(
fluid_model_filename,
golden_data_filename,
precision=precision,
decimal=decimal,
save_inference_model=debug)
......@@ -8,9 +8,9 @@ Created on Sun Feb 24 20:44:43 2019
from __future__ import division
# import logging, os
import logging
import os
import logging, os
#import logging
#import os
import numpy as np
logger = logging.getLogger(__name__)
......@@ -215,10 +215,6 @@ class Program(object):
var_desc.persistable = persistable
var_desc.type.type = framework_pb2.VarType.LOD_TENSOR
# REMOVEIT: WORKAROUND: Netron: null.tensor error
tensor_desc = var_desc.type.lod_tensor.tensor
tensor_desc.data_type = self.Dtype(dummy_dtype) # required
if value_info and 'dtype' in value_info:
tensor_desc = var_desc.type.lod_tensor.tensor
tensor_desc.data_type = self.Dtype(value_info['dtype']) # required
......@@ -230,6 +226,9 @@ class Program(object):
not persistable)
if remove_batch:
tensor_desc.dims[0] = -1
else: # REMOVEIT: WORKAROUND: Netron: null.tensor error
tensor_desc = var_desc.type.lod_tensor.tensor
tensor_desc.data_type = self.Dtype(dummy_dtype) # required
self.var_descs.append(var_desc)
......@@ -329,7 +328,7 @@ class Writer(object):
else:
var_name = make_var_name(name)
attr_name = make_attr_name(name)
prog.Code('# parameter: {}'.format(name))
prog.Code('# parameter {}: {}'.format(name, var_name))
prog.Code('{} = ParamAttr(name={})' # , trainable=True
.format(attr_name, repr(var_name)))
prog.Code(
......@@ -356,13 +355,13 @@ class Writer(object):
if remove_batch:
shape = shape[1:]
prog.Code('# input: {}'.format(name))
prog.Code('# input {}: {}'.format(name, var_name))
prog.Code((
'{} = layers.data(name={}, shape={}, dtype={}, '
'append_batch_size={})' # , stop_gradient=True
).format(
var_name,
repr(name),
repr(var_name),
shape,
repr(value_info['dtype'].name),
remove_batch,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册