未验证 提交 99e2ea67 编写于 作者: W WJJ1995 提交者: GitHub

Solved the same data is used for multiple OPs. (#728)

* add scatter mapper

* solve same param is used for multiple OPs

* Add unique_name support

* Simplified code

* Add PR link

* fixed bug for CI
上级 b492545f
......@@ -45,7 +45,8 @@ def _const_weight_or_none(node, necessary=False):
def _rename_or_remove_weight(weights,
origin_name,
target_name=None,
is_remove=True):
is_remove=True,
rename_mapper=None):
'''
Rename parameters by Paddle's naming rule of parameters.
......@@ -56,13 +57,16 @@ def _rename_or_remove_weight(weights,
{target_name:weights[origin_name]} to weights, and target_name must follow paddle's
naming rule of parameters. Default: None.
is_remove: if is_remove is True, remove origin key-value pair. Default: True.
rename_mapper: Solved the same data is used for multiple OPs, key is old_name, value is new_name.
Returns:
None
'''
if rename_mapper is not None and origin_name in rename_mapper:
origin_name = rename_mapper[origin_name]
is_remove = False
if origin_name not in weights:
raise KeyError('{} not a key in {}'.format(origin_name, weights.keys()))
if is_remove:
# TODO There may be problems when the same data is used as an argument to multiple OPs.
# remove weight
data = weights.pop(origin_name)
else:
......@@ -70,6 +74,7 @@ def _rename_or_remove_weight(weights,
if target_name is not None:
# rename weight
weights[target_name] = data
rename_mapper[origin_name] = target_name
def _is_static_shape(shape):
......@@ -182,6 +187,9 @@ class OpSet9():
self.weights = dict()
self.nn_name2id = dict()
self.done_weight_list = list()
# solve for same data is used as an argument to multiple OPs.
# PR link(wangjunjie06): https://github.com/PaddlePaddle/X2Paddle/pull/728
self.rename_mapper = dict()
@print_mapping_info
def directly_map(self, node, *args, **kwargs):
......@@ -1680,13 +1688,27 @@ class OpSet9():
epsilon = node.get_attr('epsilon', 1e-5)
c = val_x.out_shapes[0][1]
_rename_or_remove_weight(self.weights, val_scale.name,
op_name + '.weight')
_rename_or_remove_weight(self.weights, val_b.name, op_name + '.bias')
_rename_or_remove_weight(self.weights, val_var.name,
op_name + '._variance')
_rename_or_remove_weight(self.weights, val_mean.name,
op_name + '._mean')
# solved the same data is used as an argument to multiple OPs.
_rename_or_remove_weight(
self.weights,
val_scale.name,
op_name + '.weight',
rename_mapper=self.rename_mapper)
_rename_or_remove_weight(
self.weights,
val_b.name,
op_name + '.bias',
rename_mapper=self.rename_mapper)
_rename_or_remove_weight(
self.weights,
val_var.name,
op_name + '._variance',
rename_mapper=self.rename_mapper)
_rename_or_remove_weight(
self.weights,
val_mean.name,
op_name + '._mean',
rename_mapper=self.rename_mapper)
# Attribute: spatial is used in BatchNormalization-1,6,7
spatial = bool(node.get_attr('spatial'))
......@@ -2228,14 +2250,22 @@ class OpSet9():
remove_weight = True if val_w.name in self.done_weight_list else False
if remove_weight:
self.done_weight_list.append(val_w.name)
_rename_or_remove_weight(self.weights, val_w.name, op_name + '.weight',
remove_weight)
_rename_or_remove_weight(
self.weights,
val_w.name,
op_name + '.weight',
remove_weight,
rename_mapper=self.rename_mapper)
if has_bias:
remove_bias = True if val_b.name in self.done_weight_list else False
if remove_bias:
self.done_weight_list.append(val_b_name)
_rename_or_remove_weight(self.weights, val_b.name,
op_name + '.bias', remove_bias)
self.done_weight_list.append(val_b.name)
_rename_or_remove_weight(
self.weights,
val_b.name,
op_name + '.bias',
remove_bias,
rename_mapper=self.rename_mapper)
else:
layer_attrs["bias_attr"] = False
if reduce(lambda x, y: x * y,
......@@ -2355,10 +2385,14 @@ class OpSet9():
_rename_or_remove_weight(
self.weights,
val_w.name,
op_name + '.weight', )
op_name + '.weight',
rename_mapper=self.rename_mapper)
if val_b is not None:
_rename_or_remove_weight(self.weights, val_b.name,
op_name + '.bias')
_rename_or_remove_weight(
self.weights,
val_b.name,
op_name + '.bias',
rename_mapper=self.rename_mapper)
else:
layer_attrs["bias_attr"] = False
self.paddle_graph.add_layer(
......
......@@ -15,7 +15,7 @@
import math
from functools import reduce
import paddle
from paddle.fluid import framework
from paddle.fluid import framework, unique_name
from paddle.fluid.core import VarDesc
from paddle.fluid.initializer import XavierInitializer, MSRAInitializer
from paddle.fluid.data_feeder import check_variable_and_dtype
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册