diff --git a/x2paddle/convert.py b/x2paddle/convert.py index fe0b09e9fdedc7648c97fba316b38b0a6fe69ea1..717bd5fc43138bac1d826b458281f6e97017112c 100644 --- a/x2paddle/convert.py +++ b/x2paddle/convert.py @@ -195,9 +195,14 @@ def onnx2paddle(model_path, save_dir, params_merge=False): def paddle2onnx(model_path, save_dir, opset_version=10): from x2paddle.decoder.paddle_decoder import PaddleDecoder from x2paddle.op_mapper.paddle2onnx.paddle_op_mapper import PaddleOpMapper + import paddle.fluid as fluid model = PaddleDecoder(model_path, '__model__', '__params__') mapper = PaddleOpMapper() - mapper.convert(model.program, save_dir, opset_number=opset_version) + mapper.convert( + model.program, + save_dir, + scope=fluid.global_scope(), + opset_version=opset_version) def main(): @@ -264,7 +269,7 @@ def main(): elif args.framework == "paddle2onnx": assert args.model is not None, "--model should be defined while translating paddle model to onnx" - paddle2onnx(args.model, args.save_dir, args.onnx_opset) + paddle2onnx(args.model, args.save_dir, opset_version=args.onnx_opset) else: raise Exception( diff --git a/x2paddle/op_mapper/paddle2onnx/opset9/opset.py b/x2paddle/op_mapper/paddle2onnx/opset9/opset.py index 850e3d0430d8bcb05c209560765b29d61a022edf..0ab8e69ccc1577cd03579a0e13f65a35899d7f78 100644 --- a/x2paddle/op_mapper/paddle2onnx/opset9/opset.py +++ b/x2paddle/op_mapper/paddle2onnx/opset9/opset.py @@ -59,7 +59,7 @@ class OpSet9(object): 'Constant', inputs=[], outputs=[name], value=tensor) return node - def convert_weights(self, program): + def convert_weights(self, program, scope=None): var_names = program.global_block().vars nodes = list() for name in var_names: @@ -68,7 +68,7 @@ class OpSet9(object): continue if not var.persistable: continue - weight = np.array(fluid.global_scope().find_var(name).get_tensor()) + weight = np.array(scope.find_var(name).get_tensor()) tensor = helper.make_tensor( name=name, dims=var.shape, diff --git a/x2paddle/op_mapper/paddle2onnx/paddle_op_mapper.py b/x2paddle/op_mapper/paddle2onnx/paddle_op_mapper.py index 1ce2ec5e7093d6d5302e673c4400fe0a87d66583..f167dfdd73b05aae2036c2ab4001c7c6a838d267 100644 --- a/x2paddle/op_mapper/paddle2onnx/paddle_op_mapper.py +++ b/x2paddle/op_mapper/paddle2onnx/paddle_op_mapper.py @@ -33,9 +33,9 @@ class PaddleOpMapper(object): self.name_counter = dict() self.op_set = None - def convert(self, program, save_dir, opset_number=10): - self.op_set = self.create_opset(opset_number) - weight_nodes = self.op_set.convert_weights(program) + def convert(self, program, save_dir, scope=None, opset_version=10): + self.op_set = self.create_opset(opset_version) + weight_nodes = self.op_set.convert_weights(program, scope=scope) op_nodes = list() input_nodes = list() output_nodes = list() @@ -77,7 +77,7 @@ class PaddleOpMapper(object): initializer=[], inputs=input_nodes, outputs=output_nodes) - opset_imports = [helper.make_opsetid("", opset_number)] + opset_imports = [helper.make_opsetid("", opset_version)] model = helper.make_model( graph, producer_name='X2Paddle', opset_imports=opset_imports) onnx.checker.check_model(model) @@ -89,20 +89,20 @@ class PaddleOpMapper(object): print("\nTranslated model saved in {}".format( os.path.join(save_dir, 'x2paddle_model.onnx'))) - def create_opset(self, opset_number): + def create_opset(self, opset_version=10): run_opset = self.default_opset opset = '' - if opset_number in self.support_opsets: - run_opset = opset_number + if opset_version in self.support_opsets: + run_opset = opset_version else: - for support_opset_number in self.support_opsets: - if support_opset_number < opset_number: - run_opset = support_opset_number + for support_opset_version in self.support_opsets: + if support_opset_version < opset_version: + run_opset = support_opset_version else: break print( 'Now, onnx2paddle support convert onnx model opset_verison {},' 'opset_verison of your onnx model is {}, automatically treated as op_set: {}.' - .format(self.support_opsets, opset_number, run_opset)) + .format(self.support_opsets, opset_version, run_opset)) opset = 'OpSet' + str(run_opset) return eval(opset)()