diff --git a/x2paddle/optimizer/fusion/onnx_gelu_fuser.py b/x2paddle/optimizer/fusion/onnx_gelu_fuser.py index 11b7b798c1134a95ce0d9b3539088c823f8438bd..e8b48bbf8dc980fff2dfcc2fd23fd26e8259ab78 100644 --- a/x2paddle/optimizer/fusion/onnx_gelu_fuser.py +++ b/x2paddle/optimizer/fusion/onnx_gelu_fuser.py @@ -64,9 +64,7 @@ class GeluFuser(FuseBase): "y": gen_name(0)}, outputs=[gen_name(3)]) self.pattern.add_layer( - "paddle.erf", - inputs={"x": gen_name(3)}, - outputs=[gen_name(4)]) + "paddle.erf", inputs={"x": gen_name(3)}, outputs=[gen_name(4)]) self.pattern.add_layer( "paddle.add", inputs={"x": gen_name(4), @@ -88,7 +86,7 @@ class GeluFuser(FuseBase): new_layer, new_layer_id = self.gen_new_layer(parameters, matches) graph.layers[new_layer_id] = new_layer matches.pop(new_layer_id) - + def gen_new_layer(self, parameters, matches): layer_id_list = list(matches.keys()) layer_id_list.sort(key=int)