未验证 提交 703e2c98 编写于 作者: J Jason 提交者: GitHub

Merge pull request #777 from PaddlePaddle/fixed_LN_fuser

Fixed LN fuser pass
......@@ -56,11 +56,7 @@ class LayerNormFuser(FuseBase):
shape=[1],
fill_value=0.5)
self.pattern.add_layer(
"paddle.full",
inputs={},
outputs=[gen_name(3)],
shape=[1],
fill_value=9.999999747378752e-06)
"paddle.full", inputs={}, outputs=[gen_name(3)], shape=[1])
self.pattern.add_layer(
"paddle.mean",
inputs={"x": "layernorm-input-0"},
......@@ -122,6 +118,7 @@ class LayerNormFuser(FuseBase):
layer_inputs = list()
layer_inputs_ids = list()
param_name = list()
fill_value_list = list()
for layer_id, layer in matches.items():
if layer.kernel == "paddle.mean":
layer_inputs.append(layer.inputs)
......@@ -130,6 +127,8 @@ class LayerNormFuser(FuseBase):
param_name.append(layer.outputs[0])
if layer.kernel == "paddle.add":
output_name = layer.outputs[0]
if layer.kernel == "paddle.full":
fill_value_list.append(layer.attrs["fill_value"])
param = parameters[param_name[0]]
c = param.shape[0]
weight_param = parameters.pop(param_name[0])
......@@ -141,5 +140,6 @@ class LayerNormFuser(FuseBase):
"paddle.nn.LayerNorm",
inputs=layer_inputs[0],
outputs=[output_name],
normalized_shape=[c])
normalized_shape=[c],
epsilon=fill_value_list[-1])
return new_layer, layer_inputs_ids[0]
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册