提交 a2346f04 编写于 作者: W wjj19950828

solve same param is used for multiple OPs

上级 e33dba30
......@@ -62,7 +62,6 @@ def _rename_or_remove_weight(weights,
if origin_name not in weights:
raise KeyError('{} not a key in {}'.format(origin_name, weights.keys()))
if is_remove:
# TODO There may be problems when the same data is used as an argument to multiple OPs.
# remove weight
data = weights.pop(origin_name)
else:
......@@ -182,6 +181,8 @@ class OpSet9():
self.weights = dict()
self.nn_name2id = dict()
self.done_weight_list = list()
# solve for same data is used as an argument to multiple OPs.
self.rename_mapper = dict()
@print_mapping_info
def directly_map(self, node, *args, **kwargs):
......@@ -1680,13 +1681,39 @@ class OpSet9():
epsilon = node.get_attr('epsilon', 1e-5)
c = val_x.out_shapes[0][1]
_rename_or_remove_weight(self.weights, val_scale.name,
op_name + '.weight')
_rename_or_remove_weight(self.weights, val_b.name, op_name + '.bias')
_rename_or_remove_weight(self.weights, val_var.name,
op_name + '._variance')
_rename_or_remove_weight(self.weights, val_mean.name,
op_name + '._mean')
# solved the same data is used as an argument to multiple OPs.
if val_scale.name in self.rename_mapper:
new_name = self.rename_mapper[val_scale.name]
_rename_or_remove_weight(self.weights, new_name,
op_name + '.weight', False)
else:
_rename_or_remove_weight(self.weights, val_scale.name,
op_name + '.weight')
self.rename_mapper[val_scale.name] = op_name + '.weight'
if val_b.name in self.rename_mapper:
new_name = self.rename_mapper[val_b.name]
_rename_or_remove_weight(self.weights, new_name, op_name + '.bias',
False)
else:
_rename_or_remove_weight(self.weights, val_b.name,
op_name + '.bias')
self.rename_mapper[val_b.name] = op_name + '.bias'
if val_var.name in self.rename_mapper:
new_name = self.rename_mapper[val_var.name]
_rename_or_remove_weight(self.weights, new_name,
op_name + '._variance', False)
else:
_rename_or_remove_weight(self.weights, val_var.name,
op_name + '._variance')
self.rename_mapper[val_var.name] = op_name + '._variance'
if val_mean.name in self.rename_mapper:
new_name = self.rename_mapper[val_mean.name]
_rename_or_remove_weight(self.weights, new_name, op_name + '._mean',
False)
else:
_rename_or_remove_weight(self.weights, val_mean.name,
op_name + '._mean')
self.rename_mapper[val_mean.name] = op_name + '._mean'
# Attribute: spatial is used in BatchNormalization-1,6,7
spatial = bool(node.get_attr('spatial'))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册