未验证 提交 28b179f8 编写于 作者: G Guanghua Yu 提交者: GitHub

fix distribute load onnx model (#1338)

上级 4434a362
......@@ -65,6 +65,9 @@ def load_onnx_model(model_path, disable_feedback=False):
)
sys.exit(1)
# support distributed convert model
model_idx = paddle.distributed.get_rank(
) if paddle.distributed.get_world_size() > 1 else 0
try:
_logger.info("Now translating model from onnx to paddle.")
model = ONNXDecoder(model_path)
......@@ -73,14 +76,21 @@ def load_onnx_model(model_path, disable_feedback=False):
graph_opt = GraphOptimizer(source_frame="onnx")
graph_opt.optimize(mapper.paddle_graph)
_logger.info("Model optimized.")
onnx2paddle_out_dir = os.path.join(inference_model_path,
'onnx2paddle')
onnx2paddle_out_dir = os.path.join(
inference_model_path, 'onnx2paddle_{}'.format(model_idx))
mapper.paddle_graph.gen_model(onnx2paddle_out_dir)
_logger.info("Successfully exported Paddle static graph model!")
if not disable_feedback:
ConverterCheck(
task="ONNX", time_info=time_info,
convert_state="Success").start()
except:
_logger.info(
"[ERROR] x2paddle threw an exception, you can ask for help at: https://github.com/PaddlePaddle/X2Paddle/issues"
)
sys.exit(1)
if paddle.distributed.get_rank() == 0:
shutil.move(
os.path.join(onnx2paddle_out_dir, 'inference_model',
'model.pdmodel'),
......@@ -89,14 +99,17 @@ def load_onnx_model(model_path, disable_feedback=False):
os.path.join(onnx2paddle_out_dir, 'inference_model',
'model.pdiparams'),
os.path.join(inference_model_path, 'model.pdiparams'))
except:
_logger.info(
"[ERROR] x2paddle threw an exception, you can ask for help at: https://github.com/PaddlePaddle/X2Paddle/issues"
)
sys.exit(1)
load_model_path = inference_model_path
else:
load_model_path = os.path.join(onnx2paddle_out_dir,
'inference_model')
paddle.enable_static()
val_program, feed_target_names, fetch_targets = paddle.static.load_inference_model(
os.path.join(inference_model_path, 'model'), exe)
_logger.info('Loaded model from: {}'.format(inference_model_path))
os.path.join(load_model_path, 'model'), exe)
_logger.info('Loaded model from: {}'.format(load_model_path))
# Clean up the file storage directory
shutil.rmtree(
os.path.join(inference_model_path, 'onnx2paddle_{}'.format(
model_idx)))
return val_program, feed_target_names, fetch_targets
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册