From d8b7d67c8e108244ff9046a9268eff176242d589 Mon Sep 17 00:00:00 2001 From: wjj19950828 Date: Tue, 7 Jun 2022 21:55:09 +0800 Subject: [PATCH] Add gelu pass --- x2paddle/optimizer/fusion/onnx_gelu_fuser.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/x2paddle/optimizer/fusion/onnx_gelu_fuser.py b/x2paddle/optimizer/fusion/onnx_gelu_fuser.py index 11b7b79..e8b48bb 100644 --- a/x2paddle/optimizer/fusion/onnx_gelu_fuser.py +++ b/x2paddle/optimizer/fusion/onnx_gelu_fuser.py @@ -64,9 +64,7 @@ class GeluFuser(FuseBase): "y": gen_name(0)}, outputs=[gen_name(3)]) self.pattern.add_layer( - "paddle.erf", - inputs={"x": gen_name(3)}, - outputs=[gen_name(4)]) + "paddle.erf", inputs={"x": gen_name(3)}, outputs=[gen_name(4)]) self.pattern.add_layer( "paddle.add", inputs={"x": gen_name(4), @@ -88,7 +86,7 @@ class GeluFuser(FuseBase): new_layer, new_layer_id = self.gen_new_layer(parameters, matches) graph.layers[new_layer_id] = new_layer matches.pop(new_layer_id) - + def gen_new_layer(self, parameters, matches): layer_id_list = list(matches.keys()) layer_id_list.sort(key=int) -- GitLab