diff --git a/x2paddle/convert.py b/x2paddle/convert.py index 57b9c5464b462c945ad9efec9f698fd960d892a0..717bd5fc43138bac1d826b458281f6e97017112c 100644 --- a/x2paddle/convert.py +++ b/x2paddle/convert.py @@ -202,7 +202,7 @@ def paddle2onnx(model_path, save_dir, opset_version=10): model.program, save_dir, scope=fluid.global_scope(), - opset_number=opset_number) + opset_version=opset_version) def main(): @@ -269,7 +269,7 @@ def main(): elif args.framework == "paddle2onnx": assert args.model is not None, "--model should be defined while translating paddle model to onnx" - paddle2onnx(args.model, args.save_dir, args.onnx_opset) + paddle2onnx(args.model, args.save_dir, opset_version=args.onnx_opset) else: raise Exception( diff --git a/x2paddle/op_mapper/paddle2onnx/paddle_op_mapper.py b/x2paddle/op_mapper/paddle2onnx/paddle_op_mapper.py index eb2765afacb18e7aa0b5fa4cfd1649d3f4fd1288..f167dfdd73b05aae2036c2ab4001c7c6a838d267 100644 --- a/x2paddle/op_mapper/paddle2onnx/paddle_op_mapper.py +++ b/x2paddle/op_mapper/paddle2onnx/paddle_op_mapper.py @@ -33,8 +33,8 @@ class PaddleOpMapper(object): self.name_counter = dict() self.op_set = None - def convert(self, program, save_dir, scope=None, opset_number=10): - self.op_set = self.create_opset(opset_number) + def convert(self, program, save_dir, scope=None, opset_version=10): + self.op_set = self.create_opset(opset_version) weight_nodes = self.op_set.convert_weights(program, scope=scope) op_nodes = list() input_nodes = list() @@ -77,7 +77,7 @@ class PaddleOpMapper(object): initializer=[], inputs=input_nodes, outputs=output_nodes) - opset_imports = [helper.make_opsetid("", opset_number)] + opset_imports = [helper.make_opsetid("", opset_version)] model = helper.make_model( graph, producer_name='X2Paddle', opset_imports=opset_imports) onnx.checker.check_model(model) @@ -89,20 +89,20 @@ class PaddleOpMapper(object): print("\nTranslated model saved in {}".format( os.path.join(save_dir, 'x2paddle_model.onnx'))) - def create_opset(self, opset_number): + def create_opset(self, opset_version=10): run_opset = self.default_opset opset = '' - if opset_number in self.support_opsets: - run_opset = opset_number + if opset_version in self.support_opsets: + run_opset = opset_version else: - for support_opset_number in self.support_opsets: - if support_opset_number < opset_number: - run_opset = support_opset_number + for support_opset_version in self.support_opsets: + if support_opset_version < opset_version: + run_opset = support_opset_version else: break print( 'Now, onnx2paddle support convert onnx model opset_verison {},' 'opset_verison of your onnx model is {}, automatically treated as op_set: {}.' - .format(self.support_opsets, opset_number, run_opset)) + .format(self.support_opsets, opset_version, run_opset)) opset = 'OpSet' + str(run_opset) return eval(opset)()