From acbb1cac795ae6112785ee77fbc1f03b8f528ce6 Mon Sep 17 00:00:00 2001 From: wjj19950828 Date: Mon, 17 Jan 2022 17:08:05 +0800 Subject: [PATCH] Simplified code --- .../op_mapper/onnx2paddle/opset9/opset.py | 88 ++++++++++--------- 1 file changed, 47 insertions(+), 41 deletions(-) diff --git a/x2paddle/op_mapper/onnx2paddle/opset9/opset.py b/x2paddle/op_mapper/onnx2paddle/opset9/opset.py index bcde0a1..9c02f34 100755 --- a/x2paddle/op_mapper/onnx2paddle/opset9/opset.py +++ b/x2paddle/op_mapper/onnx2paddle/opset9/opset.py @@ -45,7 +45,8 @@ def _const_weight_or_none(node, necessary=False): def _rename_or_remove_weight(weights, origin_name, target_name=None, - is_remove=True): + is_remove=True, + rename_mapper=None): ''' Rename parameters by Paddle's naming rule of parameters. @@ -56,9 +57,13 @@ def _rename_or_remove_weight(weights, {target_name:weights[origin_name]} to weights, and target_name must follow paddle's naming rule of parameters. Default: None. is_remove: if is_remove is True, remove origin key-value pair. Default: True. + rename_mapper: Solved the same data is used for multiple OPs, key is old_name, value is new_name. Returns: None ''' + if origin_name in rename_mapper: + origin_name = rename_mapper[origin_name] + is_remove = False if origin_name not in weights: raise KeyError('{} not a key in {}'.format(origin_name, weights.keys())) if is_remove: @@ -69,6 +74,7 @@ def _rename_or_remove_weight(weights, if target_name is not None: # rename weight weights[target_name] = data + rename_mapper[origin_name] = target_name def _is_static_shape(shape): @@ -1682,38 +1688,26 @@ class OpSet9(): c = val_x.out_shapes[0][1] # solved the same data is used as an argument to multiple OPs. - if val_scale.name in self.rename_mapper: - new_name = self.rename_mapper[val_scale.name] - _rename_or_remove_weight(self.weights, new_name, - op_name + '.weight', False) - else: - _rename_or_remove_weight(self.weights, val_scale.name, - op_name + '.weight') - self.rename_mapper[val_scale.name] = op_name + '.weight' - if val_b.name in self.rename_mapper: - new_name = self.rename_mapper[val_b.name] - _rename_or_remove_weight(self.weights, new_name, op_name + '.bias', - False) - else: - _rename_or_remove_weight(self.weights, val_b.name, - op_name + '.bias') - self.rename_mapper[val_b.name] = op_name + '.bias' - if val_var.name in self.rename_mapper: - new_name = self.rename_mapper[val_var.name] - _rename_or_remove_weight(self.weights, new_name, - op_name + '._variance', False) - else: - _rename_or_remove_weight(self.weights, val_var.name, - op_name + '._variance') - self.rename_mapper[val_var.name] = op_name + '._variance' - if val_mean.name in self.rename_mapper: - new_name = self.rename_mapper[val_mean.name] - _rename_or_remove_weight(self.weights, new_name, op_name + '._mean', - False) - else: - _rename_or_remove_weight(self.weights, val_mean.name, - op_name + '._mean') - self.rename_mapper[val_mean.name] = op_name + '._mean' + _rename_or_remove_weight( + self.weights, + val_scale.name, + op_name + '.weight', + rename_mapper=self.rename_mapper) + _rename_or_remove_weight( + self.weights, + val_b.name, + op_name + '.bias', + rename_mapper=self.rename_mapper) + _rename_or_remove_weight( + self.weights, + val_var.name, + op_name + '._variance', + rename_mapper=self.rename_mapper) + _rename_or_remove_weight( + self.weights, + val_mean.name, + op_name + '._mean', + rename_mapper=self.rename_mapper) # Attribute: spatial is used in BatchNormalization-1,6,7 spatial = bool(node.get_attr('spatial')) @@ -2255,14 +2249,22 @@ class OpSet9(): remove_weight = True if val_w.name in self.done_weight_list else False if remove_weight: self.done_weight_list.append(val_w.name) - _rename_or_remove_weight(self.weights, val_w.name, op_name + '.weight', - remove_weight) + _rename_or_remove_weight( + self.weights, + val_w.name, + op_name + '.weight', + remove_weight, + rename_mapper=self.rename_mapper) if has_bias: remove_bias = True if val_b.name in self.done_weight_list else False if remove_bias: - self.done_weight_list.append(val_b_name) - _rename_or_remove_weight(self.weights, val_b.name, - op_name + '.bias', remove_bias) + self.done_weight_list.append(val_b.name) + _rename_or_remove_weight( + self.weights, + val_b.name, + op_name + '.bias', + remove_bias, + rename_mapper=self.rename_mapper) else: layer_attrs["bias_attr"] = False if reduce(lambda x, y: x * y, @@ -2382,10 +2384,14 @@ class OpSet9(): _rename_or_remove_weight( self.weights, val_w.name, - op_name + '.weight', ) + op_name + '.weight', + rename_mapper=self.rename_mapper) if val_b is not None: - _rename_or_remove_weight(self.weights, val_b.name, - op_name + '.bias') + _rename_or_remove_weight( + self.weights, + val_b.name, + op_name + '.bias', + rename_mapper=self.rename_mapper) else: layer_attrs["bias_attr"] = False self.paddle_graph.add_layer( -- GitLab