diff --git a/x2paddle/op_mapper/onnx2paddle/opset9/opset.py b/x2paddle/op_mapper/onnx2paddle/opset9/opset.py index 6fb18a03ecd1f224d6ac86f2036c3900f20365c2..bcde0a103cbd431b29ad42dc60865ac3a4072a1b 100755 --- a/x2paddle/op_mapper/onnx2paddle/opset9/opset.py +++ b/x2paddle/op_mapper/onnx2paddle/opset9/opset.py @@ -62,7 +62,6 @@ def _rename_or_remove_weight(weights, if origin_name not in weights: raise KeyError('{} not a key in {}'.format(origin_name, weights.keys())) if is_remove: - # TODO There may be problems when the same data is used as an argument to multiple OPs. # remove weight data = weights.pop(origin_name) else: @@ -182,6 +181,8 @@ class OpSet9(): self.weights = dict() self.nn_name2id = dict() self.done_weight_list = list() + # solve for same data is used as an argument to multiple OPs. + self.rename_mapper = dict() @print_mapping_info def directly_map(self, node, *args, **kwargs): @@ -1680,13 +1681,39 @@ class OpSet9(): epsilon = node.get_attr('epsilon', 1e-5) c = val_x.out_shapes[0][1] - _rename_or_remove_weight(self.weights, val_scale.name, - op_name + '.weight') - _rename_or_remove_weight(self.weights, val_b.name, op_name + '.bias') - _rename_or_remove_weight(self.weights, val_var.name, - op_name + '._variance') - _rename_or_remove_weight(self.weights, val_mean.name, - op_name + '._mean') + # solved the same data is used as an argument to multiple OPs. + if val_scale.name in self.rename_mapper: + new_name = self.rename_mapper[val_scale.name] + _rename_or_remove_weight(self.weights, new_name, + op_name + '.weight', False) + else: + _rename_or_remove_weight(self.weights, val_scale.name, + op_name + '.weight') + self.rename_mapper[val_scale.name] = op_name + '.weight' + if val_b.name in self.rename_mapper: + new_name = self.rename_mapper[val_b.name] + _rename_or_remove_weight(self.weights, new_name, op_name + '.bias', + False) + else: + _rename_or_remove_weight(self.weights, val_b.name, + op_name + '.bias') + self.rename_mapper[val_b.name] = op_name + '.bias' + if val_var.name in self.rename_mapper: + new_name = self.rename_mapper[val_var.name] + _rename_or_remove_weight(self.weights, new_name, + op_name + '._variance', False) + else: + _rename_or_remove_weight(self.weights, val_var.name, + op_name + '._variance') + self.rename_mapper[val_var.name] = op_name + '._variance' + if val_mean.name in self.rename_mapper: + new_name = self.rename_mapper[val_mean.name] + _rename_or_remove_weight(self.weights, new_name, op_name + '._mean', + False) + else: + _rename_or_remove_weight(self.weights, val_mean.name, + op_name + '._mean') + self.rename_mapper[val_mean.name] = op_name + '._mean' # Attribute: spatial is used in BatchNormalization-1,6,7 spatial = bool(node.get_attr('spatial'))