提交 816ac6e2 编写于 作者: M Macrobull

re-generate desc proto with python code when debug on

上级 a538420a
...@@ -33,7 +33,7 @@ def main(**kwargs): ...@@ -33,7 +33,7 @@ def main(**kwargs):
from .conversion import convert from .conversion import convert
logger = logging.getLogger('onnx2fluid') logger = logging.getLogger('onnx2fluid')
# debug = kwargs.get('debug', False) debug = kwargs.get('debug', False)
# prepare arguments # prepare arguments
filename = kwargs.pop('model')[0] filename = kwargs.pop('model')[0]
...@@ -49,8 +49,7 @@ def main(**kwargs): ...@@ -49,8 +49,7 @@ def main(**kwargs):
onnx_skip_version_conversion = kwargs.pop('skip_version_conversion', False) onnx_skip_version_conversion = kwargs.pop('skip_version_conversion', False)
# convert # convert
convert( convert(filename,
filename,
save_dir, save_dir,
model_basename=model_basename, model_basename=model_basename,
model_func_name=model_func_name, model_func_name=model_func_name,
...@@ -66,15 +65,16 @@ def main(**kwargs): ...@@ -66,15 +65,16 @@ def main(**kwargs):
from .validation import validate from .validation import validate
logger.info('starting validation on desc ...') logger.info('starting validation on desc ...')
passed &= validate( passed &= validate(shutil.os.path.join(save_dir, '__model__'),
shutil.os.path.join(save_dir, '__model__'), golden_data_filename, golden_data_filename, **kwargs)
**kwargs)
logger.info('starting validation on code ...') logger.info('starting validation on code ...')
passed &= validate( passed &= validate(
shutil.os.path.join(save_dir, model_basename), shutil.os.path.join(save_dir, model_basename),
golden_data_filename, golden_data_filename,
model_func_name=model_func_name, model_func_name=model_func_name,
save_inference_model=
debug, # re-generate desc proto with python code when debug on
**kwargs) **kwargs)
if not passed: if not passed:
...@@ -111,16 +111,14 @@ if __name__ == '__main__': ...@@ -111,16 +111,14 @@ if __name__ == '__main__':
from onnx2fluid.cmdline import main from onnx2fluid.cmdline import main
main( main(model=['../examples/t1.onnx'],
model=['../examples/t1.onnx'],
output_dir='/tmp/export/', output_dir='/tmp/export/',
embed_params=False, embed_params=False,
pedantic=False, pedantic=False,
test_data='../examples/t1.npz', test_data='../examples/t1.npz',
debug=True) debug=True)
main( main(model=['../examples/inception_v2/model.onnx'],
model=['../examples/inception_v2/model.onnx'],
output_dir='/tmp/export/', output_dir='/tmp/export/',
embed_params=True, embed_params=True,
pedantic=False, pedantic=False,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册