未验证 提交 703e2c98 编写于 作者: J Jason 提交者: GitHub

Merge pull request #777 from PaddlePaddle/fixed_LN_fuser

Fixed LN fuser pass
...@@ -56,11 +56,7 @@ class LayerNormFuser(FuseBase): ...@@ -56,11 +56,7 @@ class LayerNormFuser(FuseBase):
shape=[1], shape=[1],
fill_value=0.5) fill_value=0.5)
self.pattern.add_layer( self.pattern.add_layer(
"paddle.full", "paddle.full", inputs={}, outputs=[gen_name(3)], shape=[1])
inputs={},
outputs=[gen_name(3)],
shape=[1],
fill_value=9.999999747378752e-06)
self.pattern.add_layer( self.pattern.add_layer(
"paddle.mean", "paddle.mean",
inputs={"x": "layernorm-input-0"}, inputs={"x": "layernorm-input-0"},
...@@ -122,6 +118,7 @@ class LayerNormFuser(FuseBase): ...@@ -122,6 +118,7 @@ class LayerNormFuser(FuseBase):
layer_inputs = list() layer_inputs = list()
layer_inputs_ids = list() layer_inputs_ids = list()
param_name = list() param_name = list()
fill_value_list = list()
for layer_id, layer in matches.items(): for layer_id, layer in matches.items():
if layer.kernel == "paddle.mean": if layer.kernel == "paddle.mean":
layer_inputs.append(layer.inputs) layer_inputs.append(layer.inputs)
...@@ -130,6 +127,8 @@ class LayerNormFuser(FuseBase): ...@@ -130,6 +127,8 @@ class LayerNormFuser(FuseBase):
param_name.append(layer.outputs[0]) param_name.append(layer.outputs[0])
if layer.kernel == "paddle.add": if layer.kernel == "paddle.add":
output_name = layer.outputs[0] output_name = layer.outputs[0]
if layer.kernel == "paddle.full":
fill_value_list.append(layer.attrs["fill_value"])
param = parameters[param_name[0]] param = parameters[param_name[0]]
c = param.shape[0] c = param.shape[0]
weight_param = parameters.pop(param_name[0]) weight_param = parameters.pop(param_name[0])
...@@ -141,5 +140,6 @@ class LayerNormFuser(FuseBase): ...@@ -141,5 +140,6 @@ class LayerNormFuser(FuseBase):
"paddle.nn.LayerNorm", "paddle.nn.LayerNorm",
inputs=layer_inputs[0], inputs=layer_inputs[0],
outputs=[output_name], outputs=[output_name],
normalized_shape=[c]) normalized_shape=[c],
epsilon=fill_value_list[-1])
return new_layer, layer_inputs_ids[0] return new_layer, layer_inputs_ids[0]
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册