From 99e2ea67607c160161955cc5b0891fc4b6a9754b Mon Sep 17 00:00:00 2001 From: WJJ1995 Date: Mon, 17 Jan 2022 19:24:16 +0800 Subject: [PATCH] Solved the same data is used for multiple OPs. (#728) * add scatter mapper * solve same param is used for multiple OPs * Add unique_name support * Simplified code * Add PR link * fixed bug for CI --- .../op_mapper/onnx2paddle/opset9/opset.py | 68 ++++++++++++++----- .../pytorch/torch2paddle/nn_init.py | 2 +- 2 files changed, 52 insertions(+), 18 deletions(-) diff --git a/x2paddle/op_mapper/onnx2paddle/opset9/opset.py b/x2paddle/op_mapper/onnx2paddle/opset9/opset.py index 6fb18a0..3dec5b5 100755 --- a/x2paddle/op_mapper/onnx2paddle/opset9/opset.py +++ b/x2paddle/op_mapper/onnx2paddle/opset9/opset.py @@ -45,7 +45,8 @@ def _const_weight_or_none(node, necessary=False): def _rename_or_remove_weight(weights, origin_name, target_name=None, - is_remove=True): + is_remove=True, + rename_mapper=None): ''' Rename parameters by Paddle's naming rule of parameters. @@ -56,13 +57,16 @@ def _rename_or_remove_weight(weights, {target_name:weights[origin_name]} to weights, and target_name must follow paddle's naming rule of parameters. Default: None. is_remove: if is_remove is True, remove origin key-value pair. Default: True. + rename_mapper: Solved the same data is used for multiple OPs, key is old_name, value is new_name. Returns: None ''' + if rename_mapper is not None and origin_name in rename_mapper: + origin_name = rename_mapper[origin_name] + is_remove = False if origin_name not in weights: raise KeyError('{} not a key in {}'.format(origin_name, weights.keys())) if is_remove: - # TODO There may be problems when the same data is used as an argument to multiple OPs. # remove weight data = weights.pop(origin_name) else: @@ -70,6 +74,7 @@ def _rename_or_remove_weight(weights, if target_name is not None: # rename weight weights[target_name] = data + rename_mapper[origin_name] = target_name def _is_static_shape(shape): @@ -182,6 +187,9 @@ class OpSet9(): self.weights = dict() self.nn_name2id = dict() self.done_weight_list = list() + # solve for same data is used as an argument to multiple OPs. + # PR link(wangjunjie06): https://github.com/PaddlePaddle/X2Paddle/pull/728 + self.rename_mapper = dict() @print_mapping_info def directly_map(self, node, *args, **kwargs): @@ -1680,13 +1688,27 @@ class OpSet9(): epsilon = node.get_attr('epsilon', 1e-5) c = val_x.out_shapes[0][1] - _rename_or_remove_weight(self.weights, val_scale.name, - op_name + '.weight') - _rename_or_remove_weight(self.weights, val_b.name, op_name + '.bias') - _rename_or_remove_weight(self.weights, val_var.name, - op_name + '._variance') - _rename_or_remove_weight(self.weights, val_mean.name, - op_name + '._mean') + # solved the same data is used as an argument to multiple OPs. + _rename_or_remove_weight( + self.weights, + val_scale.name, + op_name + '.weight', + rename_mapper=self.rename_mapper) + _rename_or_remove_weight( + self.weights, + val_b.name, + op_name + '.bias', + rename_mapper=self.rename_mapper) + _rename_or_remove_weight( + self.weights, + val_var.name, + op_name + '._variance', + rename_mapper=self.rename_mapper) + _rename_or_remove_weight( + self.weights, + val_mean.name, + op_name + '._mean', + rename_mapper=self.rename_mapper) # Attribute: spatial is used in BatchNormalization-1,6,7 spatial = bool(node.get_attr('spatial')) @@ -2228,14 +2250,22 @@ class OpSet9(): remove_weight = True if val_w.name in self.done_weight_list else False if remove_weight: self.done_weight_list.append(val_w.name) - _rename_or_remove_weight(self.weights, val_w.name, op_name + '.weight', - remove_weight) + _rename_or_remove_weight( + self.weights, + val_w.name, + op_name + '.weight', + remove_weight, + rename_mapper=self.rename_mapper) if has_bias: remove_bias = True if val_b.name in self.done_weight_list else False if remove_bias: - self.done_weight_list.append(val_b_name) - _rename_or_remove_weight(self.weights, val_b.name, - op_name + '.bias', remove_bias) + self.done_weight_list.append(val_b.name) + _rename_or_remove_weight( + self.weights, + val_b.name, + op_name + '.bias', + remove_bias, + rename_mapper=self.rename_mapper) else: layer_attrs["bias_attr"] = False if reduce(lambda x, y: x * y, @@ -2355,10 +2385,14 @@ class OpSet9(): _rename_or_remove_weight( self.weights, val_w.name, - op_name + '.weight', ) + op_name + '.weight', + rename_mapper=self.rename_mapper) if val_b is not None: - _rename_or_remove_weight(self.weights, val_b.name, - op_name + '.bias') + _rename_or_remove_weight( + self.weights, + val_b.name, + op_name + '.bias', + rename_mapper=self.rename_mapper) else: layer_attrs["bias_attr"] = False self.paddle_graph.add_layer( diff --git a/x2paddle/project_convertor/pytorch/torch2paddle/nn_init.py b/x2paddle/project_convertor/pytorch/torch2paddle/nn_init.py index 289ce19..842eeed 100644 --- a/x2paddle/project_convertor/pytorch/torch2paddle/nn_init.py +++ b/x2paddle/project_convertor/pytorch/torch2paddle/nn_init.py @@ -15,7 +15,7 @@ import math from functools import reduce import paddle -from paddle.fluid import framework +from paddle.fluid import framework, unique_name from paddle.fluid.core import VarDesc from paddle.fluid.initializer import XavierInitializer, MSRAInitializer from paddle.fluid.data_feeder import check_variable_and_dtype -- GitLab