提交 d8b7d67c 编写于 作者: W wjj19950828

Add gelu pass

上级 2c1d8e82
......@@ -64,9 +64,7 @@ class GeluFuser(FuseBase):
"y": gen_name(0)},
outputs=[gen_name(3)])
self.pattern.add_layer(
"paddle.erf",
inputs={"x": gen_name(3)},
outputs=[gen_name(4)])
"paddle.erf", inputs={"x": gen_name(3)}, outputs=[gen_name(4)])
self.pattern.add_layer(
"paddle.add",
inputs={"x": gen_name(4),
......@@ -88,7 +86,7 @@ class GeluFuser(FuseBase):
new_layer, new_layer_id = self.gen_new_layer(parameters, matches)
graph.layers[new_layer_id] = new_layer
matches.pop(new_layer_id)
def gen_new_layer(self, parameters, matches):
layer_id_list = list(matches.keys())
layer_id_list.sort(key=int)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册