未验证 提交 99e2ea67 编写于 作者: W WJJ1995 提交者: GitHub

Solved the same data is used for multiple OPs. (#728)

* add scatter mapper

* solve same param is used for multiple OPs

* Add unique_name support

* Simplified code

* Add PR link

* fixed bug for CI
上级 b492545f
...@@ -45,7 +45,8 @@ def _const_weight_or_none(node, necessary=False): ...@@ -45,7 +45,8 @@ def _const_weight_or_none(node, necessary=False):
def _rename_or_remove_weight(weights, def _rename_or_remove_weight(weights,
origin_name, origin_name,
target_name=None, target_name=None,
is_remove=True): is_remove=True,
rename_mapper=None):
''' '''
Rename parameters by Paddle's naming rule of parameters. Rename parameters by Paddle's naming rule of parameters.
...@@ -56,13 +57,16 @@ def _rename_or_remove_weight(weights, ...@@ -56,13 +57,16 @@ def _rename_or_remove_weight(weights,
{target_name:weights[origin_name]} to weights, and target_name must follow paddle's {target_name:weights[origin_name]} to weights, and target_name must follow paddle's
naming rule of parameters. Default: None. naming rule of parameters. Default: None.
is_remove: if is_remove is True, remove origin key-value pair. Default: True. is_remove: if is_remove is True, remove origin key-value pair. Default: True.
rename_mapper: Solved the same data is used for multiple OPs, key is old_name, value is new_name.
Returns: Returns:
None None
''' '''
if rename_mapper is not None and origin_name in rename_mapper:
origin_name = rename_mapper[origin_name]
is_remove = False
if origin_name not in weights: if origin_name not in weights:
raise KeyError('{} not a key in {}'.format(origin_name, weights.keys())) raise KeyError('{} not a key in {}'.format(origin_name, weights.keys()))
if is_remove: if is_remove:
# TODO There may be problems when the same data is used as an argument to multiple OPs.
# remove weight # remove weight
data = weights.pop(origin_name) data = weights.pop(origin_name)
else: else:
...@@ -70,6 +74,7 @@ def _rename_or_remove_weight(weights, ...@@ -70,6 +74,7 @@ def _rename_or_remove_weight(weights,
if target_name is not None: if target_name is not None:
# rename weight # rename weight
weights[target_name] = data weights[target_name] = data
rename_mapper[origin_name] = target_name
def _is_static_shape(shape): def _is_static_shape(shape):
...@@ -182,6 +187,9 @@ class OpSet9(): ...@@ -182,6 +187,9 @@ class OpSet9():
self.weights = dict() self.weights = dict()
self.nn_name2id = dict() self.nn_name2id = dict()
self.done_weight_list = list() self.done_weight_list = list()
# solve for same data is used as an argument to multiple OPs.
# PR link(wangjunjie06): https://github.com/PaddlePaddle/X2Paddle/pull/728
self.rename_mapper = dict()
@print_mapping_info @print_mapping_info
def directly_map(self, node, *args, **kwargs): def directly_map(self, node, *args, **kwargs):
...@@ -1680,13 +1688,27 @@ class OpSet9(): ...@@ -1680,13 +1688,27 @@ class OpSet9():
epsilon = node.get_attr('epsilon', 1e-5) epsilon = node.get_attr('epsilon', 1e-5)
c = val_x.out_shapes[0][1] c = val_x.out_shapes[0][1]
_rename_or_remove_weight(self.weights, val_scale.name, # solved the same data is used as an argument to multiple OPs.
op_name + '.weight') _rename_or_remove_weight(
_rename_or_remove_weight(self.weights, val_b.name, op_name + '.bias') self.weights,
_rename_or_remove_weight(self.weights, val_var.name, val_scale.name,
op_name + '._variance') op_name + '.weight',
_rename_or_remove_weight(self.weights, val_mean.name, rename_mapper=self.rename_mapper)
op_name + '._mean') _rename_or_remove_weight(
self.weights,
val_b.name,
op_name + '.bias',
rename_mapper=self.rename_mapper)
_rename_or_remove_weight(
self.weights,
val_var.name,
op_name + '._variance',
rename_mapper=self.rename_mapper)
_rename_or_remove_weight(
self.weights,
val_mean.name,
op_name + '._mean',
rename_mapper=self.rename_mapper)
# Attribute: spatial is used in BatchNormalization-1,6,7 # Attribute: spatial is used in BatchNormalization-1,6,7
spatial = bool(node.get_attr('spatial')) spatial = bool(node.get_attr('spatial'))
...@@ -2228,14 +2250,22 @@ class OpSet9(): ...@@ -2228,14 +2250,22 @@ class OpSet9():
remove_weight = True if val_w.name in self.done_weight_list else False remove_weight = True if val_w.name in self.done_weight_list else False
if remove_weight: if remove_weight:
self.done_weight_list.append(val_w.name) self.done_weight_list.append(val_w.name)
_rename_or_remove_weight(self.weights, val_w.name, op_name + '.weight', _rename_or_remove_weight(
remove_weight) self.weights,
val_w.name,
op_name + '.weight',
remove_weight,
rename_mapper=self.rename_mapper)
if has_bias: if has_bias:
remove_bias = True if val_b.name in self.done_weight_list else False remove_bias = True if val_b.name in self.done_weight_list else False
if remove_bias: if remove_bias:
self.done_weight_list.append(val_b_name) self.done_weight_list.append(val_b.name)
_rename_or_remove_weight(self.weights, val_b.name, _rename_or_remove_weight(
op_name + '.bias', remove_bias) self.weights,
val_b.name,
op_name + '.bias',
remove_bias,
rename_mapper=self.rename_mapper)
else: else:
layer_attrs["bias_attr"] = False layer_attrs["bias_attr"] = False
if reduce(lambda x, y: x * y, if reduce(lambda x, y: x * y,
...@@ -2355,10 +2385,14 @@ class OpSet9(): ...@@ -2355,10 +2385,14 @@ class OpSet9():
_rename_or_remove_weight( _rename_or_remove_weight(
self.weights, self.weights,
val_w.name, val_w.name,
op_name + '.weight', ) op_name + '.weight',
rename_mapper=self.rename_mapper)
if val_b is not None: if val_b is not None:
_rename_or_remove_weight(self.weights, val_b.name, _rename_or_remove_weight(
op_name + '.bias') self.weights,
val_b.name,
op_name + '.bias',
rename_mapper=self.rename_mapper)
else: else:
layer_attrs["bias_attr"] = False layer_attrs["bias_attr"] = False
self.paddle_graph.add_layer( self.paddle_graph.add_layer(
......
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
import math import math
from functools import reduce from functools import reduce
import paddle import paddle
from paddle.fluid import framework from paddle.fluid import framework, unique_name
from paddle.fluid.core import VarDesc from paddle.fluid.core import VarDesc
from paddle.fluid.initializer import XavierInitializer, MSRAInitializer from paddle.fluid.initializer import XavierInitializer, MSRAInitializer
from paddle.fluid.data_feeder import check_variable_and_dtype from paddle.fluid.data_feeder import check_variable_and_dtype
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册