From 1c79e06757f7b810e0c9f30c1db63bba960489f7 Mon Sep 17 00:00:00 2001 From: Channingss Date: Fri, 3 Jul 2020 09:38:01 +0000 Subject: [PATCH] set asymmetric as defalut mode of Resize --- x2paddle/convert.py | 12 +++++++++--- x2paddle/op_mapper/paddle_op_mapper.py | 7 ++++--- 2 files changed, 13 insertions(+), 6 deletions(-) diff --git a/x2paddle/convert.py b/x2paddle/convert.py index 70c6db8..458f910 100644 --- a/x2paddle/convert.py +++ b/x2paddle/convert.py @@ -75,6 +75,12 @@ def arg_parser(): action="store_true", default=False, help="define input shape for tf model") + parser.add_argument( + "--onnx_opset", + "-oo", + type=int, + default=10, + help="when paddle2onnx set onnx opset version to export") parser.add_argument( "--params_merge", "-pm", @@ -186,12 +192,12 @@ def onnx2paddle(model_path, save_dir, params_merge=False): print("Paddle model and code generated.") -def paddle2onnx(model_path, save_dir): +def paddle2onnx(model_path, save_dir, opset): from x2paddle.decoder.paddle_decoder import PaddleDecoder from x2paddle.op_mapper.paddle_op_mapper import PaddleOpMapper model = PaddleDecoder(model_path, '__model__', '__params__') mapper = PaddleOpMapper() - mapper.convert(model.program, save_dir) + mapper.convert(model.program, save_dir, opset) def main(): @@ -258,7 +264,7 @@ def main(): elif args.framework == "paddle2onnx": assert args.model is not None, "--model should be defined while translating paddle model to onnx" - paddle2onnx(args.model, args.save_dir) + paddle2onnx(args.model, args.save_dir, args.onnx_opset) else: raise Exception( diff --git a/x2paddle/op_mapper/paddle_op_mapper.py b/x2paddle/op_mapper/paddle_op_mapper.py index 0ba7ad6..329629b 100644 --- a/x2paddle/op_mapper/paddle_op_mapper.py +++ b/x2paddle/op_mapper/paddle_op_mapper.py @@ -37,12 +37,11 @@ class PaddleOpMapper(object): self.name_counter = dict() - def convert(self, program, save_dir): + def convert(self, program, save_dir, opset=10): weight_nodes = self.convert_weights(program) op_nodes = list() input_nodes = list() output_nodes = list() - unsupported_ops = set() print("Translating PaddlePaddle to ONNX...\n") @@ -81,7 +80,9 @@ class PaddleOpMapper(object): initializer=[], inputs=input_nodes, outputs=output_nodes) - model = helper.make_model(graph, producer_name='X2Paddle') + opset_imports = [helper.make_opsetid("", opset)] + model = helper.make_model( + graph, producer_name='X2Paddle', opset_imports=opset_imports) onnx.checker.check_model(model) if not os.path.isdir(save_dir): -- GitLab