未验证 提交 28b179f8 编写于 作者: G Guanghua Yu 提交者: GitHub

fix distribute load onnx model (#1338)

上级 4434a362
...@@ -65,6 +65,9 @@ def load_onnx_model(model_path, disable_feedback=False): ...@@ -65,6 +65,9 @@ def load_onnx_model(model_path, disable_feedback=False):
) )
sys.exit(1) sys.exit(1)
# support distributed convert model
model_idx = paddle.distributed.get_rank(
) if paddle.distributed.get_world_size() > 1 else 0
try: try:
_logger.info("Now translating model from onnx to paddle.") _logger.info("Now translating model from onnx to paddle.")
model = ONNXDecoder(model_path) model = ONNXDecoder(model_path)
...@@ -73,14 +76,21 @@ def load_onnx_model(model_path, disable_feedback=False): ...@@ -73,14 +76,21 @@ def load_onnx_model(model_path, disable_feedback=False):
graph_opt = GraphOptimizer(source_frame="onnx") graph_opt = GraphOptimizer(source_frame="onnx")
graph_opt.optimize(mapper.paddle_graph) graph_opt.optimize(mapper.paddle_graph)
_logger.info("Model optimized.") _logger.info("Model optimized.")
onnx2paddle_out_dir = os.path.join(inference_model_path, onnx2paddle_out_dir = os.path.join(
'onnx2paddle') inference_model_path, 'onnx2paddle_{}'.format(model_idx))
mapper.paddle_graph.gen_model(onnx2paddle_out_dir) mapper.paddle_graph.gen_model(onnx2paddle_out_dir)
_logger.info("Successfully exported Paddle static graph model!") _logger.info("Successfully exported Paddle static graph model!")
if not disable_feedback: if not disable_feedback:
ConverterCheck( ConverterCheck(
task="ONNX", time_info=time_info, task="ONNX", time_info=time_info,
convert_state="Success").start() convert_state="Success").start()
except:
_logger.info(
"[ERROR] x2paddle threw an exception, you can ask for help at: https://github.com/PaddlePaddle/X2Paddle/issues"
)
sys.exit(1)
if paddle.distributed.get_rank() == 0:
shutil.move( shutil.move(
os.path.join(onnx2paddle_out_dir, 'inference_model', os.path.join(onnx2paddle_out_dir, 'inference_model',
'model.pdmodel'), 'model.pdmodel'),
...@@ -89,14 +99,17 @@ def load_onnx_model(model_path, disable_feedback=False): ...@@ -89,14 +99,17 @@ def load_onnx_model(model_path, disable_feedback=False):
os.path.join(onnx2paddle_out_dir, 'inference_model', os.path.join(onnx2paddle_out_dir, 'inference_model',
'model.pdiparams'), 'model.pdiparams'),
os.path.join(inference_model_path, 'model.pdiparams')) os.path.join(inference_model_path, 'model.pdiparams'))
except: load_model_path = inference_model_path
_logger.info( else:
"[ERROR] x2paddle threw an exception, you can ask for help at: https://github.com/PaddlePaddle/X2Paddle/issues" load_model_path = os.path.join(onnx2paddle_out_dir,
) 'inference_model')
sys.exit(1)
paddle.enable_static() paddle.enable_static()
val_program, feed_target_names, fetch_targets = paddle.static.load_inference_model( val_program, feed_target_names, fetch_targets = paddle.static.load_inference_model(
os.path.join(inference_model_path, 'model'), exe) os.path.join(load_model_path, 'model'), exe)
_logger.info('Loaded model from: {}'.format(inference_model_path)) _logger.info('Loaded model from: {}'.format(load_model_path))
# Clean up the file storage directory
shutil.rmtree(
os.path.join(inference_model_path, 'onnx2paddle_{}'.format(
model_idx)))
return val_program, feed_target_names, fetch_targets return val_program, feed_target_names, fetch_targets
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册