提交 acbb1cac 编写于 作者: W wjj19950828

Simplified code

上级 53f8175d
...@@ -45,7 +45,8 @@ def _const_weight_or_none(node, necessary=False): ...@@ -45,7 +45,8 @@ def _const_weight_or_none(node, necessary=False):
def _rename_or_remove_weight(weights, def _rename_or_remove_weight(weights,
origin_name, origin_name,
target_name=None, target_name=None,
is_remove=True): is_remove=True,
rename_mapper=None):
''' '''
Rename parameters by Paddle's naming rule of parameters. Rename parameters by Paddle's naming rule of parameters.
...@@ -56,9 +57,13 @@ def _rename_or_remove_weight(weights, ...@@ -56,9 +57,13 @@ def _rename_or_remove_weight(weights,
{target_name:weights[origin_name]} to weights, and target_name must follow paddle's {target_name:weights[origin_name]} to weights, and target_name must follow paddle's
naming rule of parameters. Default: None. naming rule of parameters. Default: None.
is_remove: if is_remove is True, remove origin key-value pair. Default: True. is_remove: if is_remove is True, remove origin key-value pair. Default: True.
rename_mapper: Solved the same data is used for multiple OPs, key is old_name, value is new_name.
Returns: Returns:
None None
''' '''
if origin_name in rename_mapper:
origin_name = rename_mapper[origin_name]
is_remove = False
if origin_name not in weights: if origin_name not in weights:
raise KeyError('{} not a key in {}'.format(origin_name, weights.keys())) raise KeyError('{} not a key in {}'.format(origin_name, weights.keys()))
if is_remove: if is_remove:
...@@ -69,6 +74,7 @@ def _rename_or_remove_weight(weights, ...@@ -69,6 +74,7 @@ def _rename_or_remove_weight(weights,
if target_name is not None: if target_name is not None:
# rename weight # rename weight
weights[target_name] = data weights[target_name] = data
rename_mapper[origin_name] = target_name
def _is_static_shape(shape): def _is_static_shape(shape):
...@@ -1682,38 +1688,26 @@ class OpSet9(): ...@@ -1682,38 +1688,26 @@ class OpSet9():
c = val_x.out_shapes[0][1] c = val_x.out_shapes[0][1]
# solved the same data is used as an argument to multiple OPs. # solved the same data is used as an argument to multiple OPs.
if val_scale.name in self.rename_mapper: _rename_or_remove_weight(
new_name = self.rename_mapper[val_scale.name] self.weights,
_rename_or_remove_weight(self.weights, new_name, val_scale.name,
op_name + '.weight', False) op_name + '.weight',
else: rename_mapper=self.rename_mapper)
_rename_or_remove_weight(self.weights, val_scale.name, _rename_or_remove_weight(
op_name + '.weight') self.weights,
self.rename_mapper[val_scale.name] = op_name + '.weight' val_b.name,
if val_b.name in self.rename_mapper: op_name + '.bias',
new_name = self.rename_mapper[val_b.name] rename_mapper=self.rename_mapper)
_rename_or_remove_weight(self.weights, new_name, op_name + '.bias', _rename_or_remove_weight(
False) self.weights,
else: val_var.name,
_rename_or_remove_weight(self.weights, val_b.name, op_name + '._variance',
op_name + '.bias') rename_mapper=self.rename_mapper)
self.rename_mapper[val_b.name] = op_name + '.bias' _rename_or_remove_weight(
if val_var.name in self.rename_mapper: self.weights,
new_name = self.rename_mapper[val_var.name] val_mean.name,
_rename_or_remove_weight(self.weights, new_name, op_name + '._mean',
op_name + '._variance', False) rename_mapper=self.rename_mapper)
else:
_rename_or_remove_weight(self.weights, val_var.name,
op_name + '._variance')
self.rename_mapper[val_var.name] = op_name + '._variance'
if val_mean.name in self.rename_mapper:
new_name = self.rename_mapper[val_mean.name]
_rename_or_remove_weight(self.weights, new_name, op_name + '._mean',
False)
else:
_rename_or_remove_weight(self.weights, val_mean.name,
op_name + '._mean')
self.rename_mapper[val_mean.name] = op_name + '._mean'
# Attribute: spatial is used in BatchNormalization-1,6,7 # Attribute: spatial is used in BatchNormalization-1,6,7
spatial = bool(node.get_attr('spatial')) spatial = bool(node.get_attr('spatial'))
...@@ -2255,14 +2249,22 @@ class OpSet9(): ...@@ -2255,14 +2249,22 @@ class OpSet9():
remove_weight = True if val_w.name in self.done_weight_list else False remove_weight = True if val_w.name in self.done_weight_list else False
if remove_weight: if remove_weight:
self.done_weight_list.append(val_w.name) self.done_weight_list.append(val_w.name)
_rename_or_remove_weight(self.weights, val_w.name, op_name + '.weight', _rename_or_remove_weight(
remove_weight) self.weights,
val_w.name,
op_name + '.weight',
remove_weight,
rename_mapper=self.rename_mapper)
if has_bias: if has_bias:
remove_bias = True if val_b.name in self.done_weight_list else False remove_bias = True if val_b.name in self.done_weight_list else False
if remove_bias: if remove_bias:
self.done_weight_list.append(val_b_name) self.done_weight_list.append(val_b.name)
_rename_or_remove_weight(self.weights, val_b.name, _rename_or_remove_weight(
op_name + '.bias', remove_bias) self.weights,
val_b.name,
op_name + '.bias',
remove_bias,
rename_mapper=self.rename_mapper)
else: else:
layer_attrs["bias_attr"] = False layer_attrs["bias_attr"] = False
if reduce(lambda x, y: x * y, if reduce(lambda x, y: x * y,
...@@ -2382,10 +2384,14 @@ class OpSet9(): ...@@ -2382,10 +2384,14 @@ class OpSet9():
_rename_or_remove_weight( _rename_or_remove_weight(
self.weights, self.weights,
val_w.name, val_w.name,
op_name + '.weight', ) op_name + '.weight',
rename_mapper=self.rename_mapper)
if val_b is not None: if val_b is not None:
_rename_or_remove_weight(self.weights, val_b.name, _rename_or_remove_weight(
op_name + '.bias') self.weights,
val_b.name,
op_name + '.bias',
rename_mapper=self.rename_mapper)
else: else:
layer_attrs["bias_attr"] = False layer_attrs["bias_attr"] = False
self.paddle_graph.add_layer( self.paddle_graph.add_layer(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册