diff --git a/x2paddle/convert.py b/x2paddle/convert.py index 70c6db8063e338dd68bbd17e5a22d68231dbb7df..458f9108ac800ad1918a454356215c9e0f63c818 100644 --- a/x2paddle/convert.py +++ b/x2paddle/convert.py @@ -75,6 +75,12 @@ def arg_parser(): action="store_true", default=False, help="define input shape for tf model") + parser.add_argument( + "--onnx_opset", + "-oo", + type=int, + default=10, + help="when paddle2onnx set onnx opset version to export") parser.add_argument( "--params_merge", "-pm", @@ -186,12 +192,12 @@ def onnx2paddle(model_path, save_dir, params_merge=False): print("Paddle model and code generated.") -def paddle2onnx(model_path, save_dir): +def paddle2onnx(model_path, save_dir, opset): from x2paddle.decoder.paddle_decoder import PaddleDecoder from x2paddle.op_mapper.paddle_op_mapper import PaddleOpMapper model = PaddleDecoder(model_path, '__model__', '__params__') mapper = PaddleOpMapper() - mapper.convert(model.program, save_dir) + mapper.convert(model.program, save_dir, opset) def main(): @@ -258,7 +264,7 @@ def main(): elif args.framework == "paddle2onnx": assert args.model is not None, "--model should be defined while translating paddle model to onnx" - paddle2onnx(args.model, args.save_dir) + paddle2onnx(args.model, args.save_dir, args.onnx_opset) else: raise Exception( diff --git a/x2paddle/op_mapper/paddle_op_mapper.py b/x2paddle/op_mapper/paddle_op_mapper.py index 0ba7ad682528b4062dea381964835271f0177432..329629bb4ac3e3daf62c08faabe4ebaa51cf5c88 100644 --- a/x2paddle/op_mapper/paddle_op_mapper.py +++ b/x2paddle/op_mapper/paddle_op_mapper.py @@ -37,12 +37,11 @@ class PaddleOpMapper(object): self.name_counter = dict() - def convert(self, program, save_dir): + def convert(self, program, save_dir, opset=10): weight_nodes = self.convert_weights(program) op_nodes = list() input_nodes = list() output_nodes = list() - unsupported_ops = set() print("Translating PaddlePaddle to ONNX...\n") @@ -81,7 +80,9 @@ class PaddleOpMapper(object): initializer=[], inputs=input_nodes, outputs=output_nodes) - model = helper.make_model(graph, producer_name='X2Paddle') + opset_imports = [helper.make_opsetid("", opset)] + model = helper.make_model( + graph, producer_name='X2Paddle', opset_imports=opset_imports) onnx.checker.check_model(model) if not os.path.isdir(save_dir):