From ed6400e14f3b5c209ae80410237be9b345ca933e Mon Sep 17 00:00:00 2001 From: Channingss Date: Thu, 14 Jan 2021 10:58:54 +0000 Subject: [PATCH] add _rename_or_remove_weight --- .../dygraph/onnx2paddle/opset9/opset.py | 31 ++++++++++--------- 1 file changed, 16 insertions(+), 15 deletions(-) diff --git a/x2paddle/op_mapper/dygraph/onnx2paddle/opset9/opset.py b/x2paddle/op_mapper/dygraph/onnx2paddle/opset9/opset.py index 9a9fe0b..785f075 100644 --- a/x2paddle/op_mapper/dygraph/onnx2paddle/opset9/opset.py +++ b/x2paddle/op_mapper/dygraph/onnx2paddle/opset9/opset.py @@ -42,14 +42,17 @@ def _const_weight_or_none(node, necessary=False): return None -def _rename_or_remove_weight(weights, origin_name, target_name=None): +def _rename_or_remove_weight(weights, origin_name, target_name=None, is_remove=True): if origin_name not in weights: raise KeyError('{} not a key in {}'.format(origin_name, weights)) - if target_name is None: + if is_remove: # remove weight - weights.pop(origin_name) - # rename weight - weights[target_name] = weights.pop(origin_name) + data = weights.pop(origin_name) + else: + data = weights[origin_name] + if target_name is not None: + # rename weight + weights[target_name] = data def _is_static_shape(shape): negtive_dims = 0 @@ -1700,19 +1703,17 @@ class OpSet9(): "dilation": dilations, "groups": num_groups, } - val_w_name = val_w.name - while val_w_name in self.done_weight_list: - val_w_name += "__repeat" - self.done_weight_list.append(val_w_name) + remove_weight = True if val_w.name in self.done_weight_list else False + if remove_weight: + self.done_weight_list.append(val_w.name) #self.weights[op_name + '.weight'] = self.weights[val_w.name] - _rename_or_remove_weight(self.weights, val_w.name, op_name+'.weight') + _rename_or_remove_weight(self.weights, val_w.name, op_name+'.weight', remove_weight) if has_bias: - val_b_name = val_b.name - while val_b_name in self.done_weight_list: - val_b_name += "__repeat" - self.done_weight_list.append(val_b_name) + remove_bias = True if val_b.name in self.done_weight_list else False + if remove_bias: + self.done_weight_list.append(val_b_name) #self.weights[op_name + '.bias'] = self.weights[val_b.name] - _rename_or_remove_weight(self.weights, val_b.name, op_name+'.bias') + _rename_or_remove_weight(self.weights, val_b.name, op_name+'.bias', remove_bias) else: layer_attrs["bias_attr"] = False input_shape = val_x.out_shapes[0] -- GitLab