提交 acbb1cac 编写于 作者: W wjj19950828

Simplified code

上级 53f8175d
......@@ -45,7 +45,8 @@ def _const_weight_or_none(node, necessary=False):
def _rename_or_remove_weight(weights,
origin_name,
target_name=None,
is_remove=True):
is_remove=True,
rename_mapper=None):
'''
Rename parameters by Paddle's naming rule of parameters.
......@@ -56,9 +57,13 @@ def _rename_or_remove_weight(weights,
{target_name:weights[origin_name]} to weights, and target_name must follow paddle's
naming rule of parameters. Default: None.
is_remove: if is_remove is True, remove origin key-value pair. Default: True.
rename_mapper: Solved the same data is used for multiple OPs, key is old_name, value is new_name.
Returns:
None
'''
if origin_name in rename_mapper:
origin_name = rename_mapper[origin_name]
is_remove = False
if origin_name not in weights:
raise KeyError('{} not a key in {}'.format(origin_name, weights.keys()))
if is_remove:
......@@ -69,6 +74,7 @@ def _rename_or_remove_weight(weights,
if target_name is not None:
# rename weight
weights[target_name] = data
rename_mapper[origin_name] = target_name
def _is_static_shape(shape):
......@@ -1682,38 +1688,26 @@ class OpSet9():
c = val_x.out_shapes[0][1]
# solved the same data is used as an argument to multiple OPs.
if val_scale.name in self.rename_mapper:
new_name = self.rename_mapper[val_scale.name]
_rename_or_remove_weight(self.weights, new_name,
op_name + '.weight', False)
else:
_rename_or_remove_weight(self.weights, val_scale.name,
op_name + '.weight')
self.rename_mapper[val_scale.name] = op_name + '.weight'
if val_b.name in self.rename_mapper:
new_name = self.rename_mapper[val_b.name]
_rename_or_remove_weight(self.weights, new_name, op_name + '.bias',
False)
else:
_rename_or_remove_weight(self.weights, val_b.name,
op_name + '.bias')
self.rename_mapper[val_b.name] = op_name + '.bias'
if val_var.name in self.rename_mapper:
new_name = self.rename_mapper[val_var.name]
_rename_or_remove_weight(self.weights, new_name,
op_name + '._variance', False)
else:
_rename_or_remove_weight(self.weights, val_var.name,
op_name + '._variance')
self.rename_mapper[val_var.name] = op_name + '._variance'
if val_mean.name in self.rename_mapper:
new_name = self.rename_mapper[val_mean.name]
_rename_or_remove_weight(self.weights, new_name, op_name + '._mean',
False)
else:
_rename_or_remove_weight(self.weights, val_mean.name,
op_name + '._mean')
self.rename_mapper[val_mean.name] = op_name + '._mean'
_rename_or_remove_weight(
self.weights,
val_scale.name,
op_name + '.weight',
rename_mapper=self.rename_mapper)
_rename_or_remove_weight(
self.weights,
val_b.name,
op_name + '.bias',
rename_mapper=self.rename_mapper)
_rename_or_remove_weight(
self.weights,
val_var.name,
op_name + '._variance',
rename_mapper=self.rename_mapper)
_rename_or_remove_weight(
self.weights,
val_mean.name,
op_name + '._mean',
rename_mapper=self.rename_mapper)
# Attribute: spatial is used in BatchNormalization-1,6,7
spatial = bool(node.get_attr('spatial'))
......@@ -2255,14 +2249,22 @@ class OpSet9():
remove_weight = True if val_w.name in self.done_weight_list else False
if remove_weight:
self.done_weight_list.append(val_w.name)
_rename_or_remove_weight(self.weights, val_w.name, op_name + '.weight',
remove_weight)
_rename_or_remove_weight(
self.weights,
val_w.name,
op_name + '.weight',
remove_weight,
rename_mapper=self.rename_mapper)
if has_bias:
remove_bias = True if val_b.name in self.done_weight_list else False
if remove_bias:
self.done_weight_list.append(val_b_name)
_rename_or_remove_weight(self.weights, val_b.name,
op_name + '.bias', remove_bias)
self.done_weight_list.append(val_b.name)
_rename_or_remove_weight(
self.weights,
val_b.name,
op_name + '.bias',
remove_bias,
rename_mapper=self.rename_mapper)
else:
layer_attrs["bias_attr"] = False
if reduce(lambda x, y: x * y,
......@@ -2382,10 +2384,14 @@ class OpSet9():
_rename_or_remove_weight(
self.weights,
val_w.name,
op_name + '.weight', )
op_name + '.weight',
rename_mapper=self.rename_mapper)
if val_b is not None:
_rename_or_remove_weight(self.weights, val_b.name,
op_name + '.bias')
_rename_or_remove_weight(
self.weights,
val_b.name,
op_name + '.bias',
rename_mapper=self.rename_mapper)
else:
layer_attrs["bias_attr"] = False
self.paddle_graph.add_layer(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册