提交 816ac6e2 编写于 作者: M Macrobull

re-generate desc proto with python code when debug on

上级 a538420a
...@@ -33,7 +33,7 @@ def main(**kwargs): ...@@ -33,7 +33,7 @@ def main(**kwargs):
from .conversion import convert from .conversion import convert
logger = logging.getLogger('onnx2fluid') logger = logging.getLogger('onnx2fluid')
# debug = kwargs.get('debug', False) debug = kwargs.get('debug', False)
# prepare arguments # prepare arguments
filename = kwargs.pop('model')[0] filename = kwargs.pop('model')[0]
...@@ -49,15 +49,14 @@ def main(**kwargs): ...@@ -49,15 +49,14 @@ def main(**kwargs):
onnx_skip_version_conversion = kwargs.pop('skip_version_conversion', False) onnx_skip_version_conversion = kwargs.pop('skip_version_conversion', False)
# convert # convert
convert( convert(filename,
filename, save_dir,
save_dir, model_basename=model_basename,
model_basename=model_basename, model_func_name=model_func_name,
model_func_name=model_func_name, onnx_opset_version=onnx_opset_version,
onnx_opset_version=onnx_opset_version, onnx_opset_pedantic=onnx_opset_pedantic,
onnx_opset_pedantic=onnx_opset_pedantic, onnx_skip_version_conversion=onnx_skip_version_conversion,
onnx_skip_version_conversion=onnx_skip_version_conversion, **kwargs)
**kwargs)
# validate # validate
passed = True passed = True
...@@ -66,15 +65,16 @@ def main(**kwargs): ...@@ -66,15 +65,16 @@ def main(**kwargs):
from .validation import validate from .validation import validate
logger.info('starting validation on desc ...') logger.info('starting validation on desc ...')
passed &= validate( passed &= validate(shutil.os.path.join(save_dir, '__model__'),
shutil.os.path.join(save_dir, '__model__'), golden_data_filename, golden_data_filename, **kwargs)
**kwargs)
logger.info('starting validation on code ...') logger.info('starting validation on code ...')
passed &= validate( passed &= validate(
shutil.os.path.join(save_dir, model_basename), shutil.os.path.join(save_dir, model_basename),
golden_data_filename, golden_data_filename,
model_func_name=model_func_name, model_func_name=model_func_name,
save_inference_model=
debug, # re-generate desc proto with python code when debug on
**kwargs) **kwargs)
if not passed: if not passed:
...@@ -111,19 +111,17 @@ if __name__ == '__main__': ...@@ -111,19 +111,17 @@ if __name__ == '__main__':
from onnx2fluid.cmdline import main from onnx2fluid.cmdline import main
main( main(model=['../examples/t1.onnx'],
model=['../examples/t1.onnx'], output_dir='/tmp/export/',
output_dir='/tmp/export/', embed_params=False,
embed_params=False, pedantic=False,
pedantic=False, test_data='../examples/t1.npz',
test_data='../examples/t1.npz', debug=True)
debug=True)
main(model=['../examples/inception_v2/model.onnx'],
main( output_dir='/tmp/export/',
model=['../examples/inception_v2/model.onnx'], embed_params=True,
output_dir='/tmp/export/', pedantic=False,
embed_params=True, skip_version_conversion=False,
pedantic=False, test_data='../examples/inception_v2/test_data_set_2.npz',
skip_version_conversion=False, debug=True)
test_data='../examples/inception_v2/test_data_set_2.npz',
debug=True)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册