diff --git a/paddleslim/common/convert_model.py b/paddleslim/common/convert_model.py index 1b501269a6fbeb406badc48f07a834eafe0aaf47..00e0182017c4d1e60b2a6c28e8db1d109be28666 100644 --- a/paddleslim/common/convert_model.py +++ b/paddleslim/common/convert_model.py @@ -65,6 +65,9 @@ def load_onnx_model(model_path, disable_feedback=False): ) sys.exit(1) + # support distributed convert model + model_idx = paddle.distributed.get_rank( + ) if paddle.distributed.get_world_size() > 1 else 0 try: _logger.info("Now translating model from onnx to paddle.") model = ONNXDecoder(model_path) @@ -73,14 +76,21 @@ def load_onnx_model(model_path, disable_feedback=False): graph_opt = GraphOptimizer(source_frame="onnx") graph_opt.optimize(mapper.paddle_graph) _logger.info("Model optimized.") - onnx2paddle_out_dir = os.path.join(inference_model_path, - 'onnx2paddle') + onnx2paddle_out_dir = os.path.join( + inference_model_path, 'onnx2paddle_{}'.format(model_idx)) mapper.paddle_graph.gen_model(onnx2paddle_out_dir) _logger.info("Successfully exported Paddle static graph model!") if not disable_feedback: ConverterCheck( task="ONNX", time_info=time_info, convert_state="Success").start() + except: + _logger.info( + "[ERROR] x2paddle threw an exception, you can ask for help at: https://github.com/PaddlePaddle/X2Paddle/issues" + ) + sys.exit(1) + + if paddle.distributed.get_rank() == 0: shutil.move( os.path.join(onnx2paddle_out_dir, 'inference_model', 'model.pdmodel'), @@ -89,14 +99,17 @@ def load_onnx_model(model_path, disable_feedback=False): os.path.join(onnx2paddle_out_dir, 'inference_model', 'model.pdiparams'), os.path.join(inference_model_path, 'model.pdiparams')) - except: - _logger.info( - "[ERROR] x2paddle threw an exception, you can ask for help at: https://github.com/PaddlePaddle/X2Paddle/issues" - ) - sys.exit(1) + load_model_path = inference_model_path + else: + load_model_path = os.path.join(onnx2paddle_out_dir, + 'inference_model') paddle.enable_static() val_program, feed_target_names, fetch_targets = paddle.static.load_inference_model( - os.path.join(inference_model_path, 'model'), exe) - _logger.info('Loaded model from: {}'.format(inference_model_path)) + os.path.join(load_model_path, 'model'), exe) + _logger.info('Loaded model from: {}'.format(load_model_path)) + # Clean up the file storage directory + shutil.rmtree( + os.path.join(inference_model_path, 'onnx2paddle_{}'.format( + model_idx))) return val_program, feed_target_names, fetch_targets