From 2c1d8e8262902cbd211862d4c0a08621f980c015 Mon Sep 17 00:00:00 2001 From: wjj19950828 Date: Tue, 7 Jun 2022 21:52:59 +0800 Subject: [PATCH] Add gelu pass --- x2paddle/optimizer/fusion/__init__.py | 2 + .../optimizer/fusion/onnx_gelu_fuse_pass.py | 33 ++++++ x2paddle/optimizer/fusion/onnx_gelu_fuser.py | 110 ++++++++++++++++++ x2paddle/optimizer/optimizer.py | 2 +- 4 files changed, 146 insertions(+), 1 deletion(-) create mode 100644 x2paddle/optimizer/fusion/onnx_gelu_fuse_pass.py create mode 100644 x2paddle/optimizer/fusion/onnx_gelu_fuser.py diff --git a/x2paddle/optimizer/fusion/__init__.py b/x2paddle/optimizer/fusion/__init__.py index 86f78e5..3483374 100644 --- a/x2paddle/optimizer/fusion/__init__.py +++ b/x2paddle/optimizer/fusion/__init__.py @@ -40,3 +40,5 @@ from .trace_fc_fuser import TraceFcFuser from .trace_fc_fuse_pass import TraceFcFusePass from .onnx_layernorm_fuser import LayerNormFuser from .onnx_layernorm_fuse_pass import LayerNormFusePass +from .onnx_gelu_fuser import GeluFuser +from .onnx_gelu_fuse_pass import GeluFusePass diff --git a/x2paddle/optimizer/fusion/onnx_gelu_fuse_pass.py b/x2paddle/optimizer/fusion/onnx_gelu_fuse_pass.py new file mode 100644 index 0000000..8143068 --- /dev/null +++ b/x2paddle/optimizer/fusion/onnx_gelu_fuse_pass.py @@ -0,0 +1,33 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from x2paddle.optimizer.pass_ import Pass +from x2paddle.optimizer.fusion import GeluFuser +from x2paddle.optimizer.pass_manager import pass_register + + +@pass_register +class GeluFusePass(Pass): + name = "onnx_gelu_fuse_pass" + + def __init__(self): + Pass.__init__(self) + + def apply(self, graph): + fuser = GeluFuser() + fuser.operate(graph, match_kind="edge") + + +# register gelu pass +onnx_gelu_fuse_pass = GeluFusePass() diff --git a/x2paddle/optimizer/fusion/onnx_gelu_fuser.py b/x2paddle/optimizer/fusion/onnx_gelu_fuser.py new file mode 100644 index 0000000..11b7b79 --- /dev/null +++ b/x2paddle/optimizer/fusion/onnx_gelu_fuser.py @@ -0,0 +1,110 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import numpy as np +from collections import OrderedDict +from x2paddle.optimizer.pattern_matcher import FuseBase +from x2paddle.core.program import PaddleGraph, PaddleLayer +from x2paddle.core.util import * + + +class GeluFuser(FuseBase): + def __init__(self): + super(GeluFuser, self).__init__() + + def build_pattern(self): + """ + code describe: + x2paddle_332 = paddle.full(dtype='float32', shape=[1], fill_value=1.4142135381698608) + x2paddle_335 = paddle.full(dtype='float32', shape=[1], fill_value=1.0) + x2paddle_338 = paddle.full(dtype='float32', shape=[1], fill_value=0.5) + x2paddle_333 = paddle.divide(x=x2paddle_331, y=x2paddle_332) + x2paddle_334 = paddle.erf(x=x2paddle_333) + x2paddle_336 = paddle.add(x=x2paddle_334, y=x2paddle_335) + x2paddle_337 = paddle.multiply(x=x2paddle_331, y=x2paddle_336) + x2paddle_339 = paddle.multiply(x=x2paddle_337, y=x2paddle_338) + """ + + def gen_name(id): + return "x" + str(id) + + self.pattern.add_layer( + "paddle.full", + inputs={}, + outputs=[gen_name(0)], + shape=[1], + fill_value=1.4142135381698608) + self.pattern.add_layer( + "paddle.full", + inputs={}, + outputs=[gen_name(1)], + shape=[1], + fill_value=1.0) + self.pattern.add_layer( + "paddle.full", + inputs={}, + outputs=[gen_name(2)], + shape=[1], + fill_value=0.5) + self.pattern.add_layer( + "paddle.divide", + inputs={"x": "gelu-input-0", + "y": gen_name(0)}, + outputs=[gen_name(3)]) + self.pattern.add_layer( + "paddle.erf", + inputs={"x": gen_name(3)}, + outputs=[gen_name(4)]) + self.pattern.add_layer( + "paddle.add", + inputs={"x": gen_name(4), + "y": gen_name(1)}, + outputs=[gen_name(5)]) + self.pattern.add_layer( + "paddle.multiply", + inputs={"x": "gelu-input-0", + "y": gen_name(5)}, + outputs=[gen_name(6)]) + self.pattern.add_layer( + "paddle.multiply", + inputs={"x": gen_name(6), + "y": gen_name(2)}, + outputs=[gen_name(7)]) + self.pattern.build(inputs={"input-0": "gelu-input-0", }) + + def insert_new_layer(self, graph, parameters, matches): + new_layer, new_layer_id = self.gen_new_layer(parameters, matches) + graph.layers[new_layer_id] = new_layer + matches.pop(new_layer_id) + + def gen_new_layer(self, parameters, matches): + layer_id_list = list(matches.keys()) + layer_id_list.sort(key=int) + layer_inputs = list() + layer_inputs_ids = list() + fill_value_list = list() + for layer_id, layer in matches.items(): + if layer.kernel == "paddle.divide": + layer_inputs.append(layer.inputs["x"]) + layer_inputs_ids.append(layer_id) + if layer.kernel == "paddle.multiply": + output_name = layer.outputs[0] + new_layer = PaddleLayer( + layer_id_list[0], + "paddle.nn.GELU", + inputs={"x": layer_inputs[0]}, + outputs=[output_name], + approximate=False) + return new_layer, layer_inputs_ids[0] diff --git a/x2paddle/optimizer/optimizer.py b/x2paddle/optimizer/optimizer.py index 49c6238..57eff57 100644 --- a/x2paddle/optimizer/optimizer.py +++ b/x2paddle/optimizer/optimizer.py @@ -37,7 +37,7 @@ class GraphOptimizer(object): "prelu_fuse_pass", "transpose_eliminate_pass" ] elif source_frame == "onnx": - self.passes = ["onnx_layernorm_fuse_pass"] + self.passes = ["onnx_layernorm_fuse_pass", "onnx_gelu_fuse_pass"] else: self.passes = [] -- GitLab