diff --git a/paddle/fluid/framework/ir/CMakeLists.txt b/paddle/fluid/framework/ir/CMakeLists.txt index dc16744fbe24f56011ffa70679572da8c3515b58..662454c4c5737afd7df88ad27ad373e0ba43c22e 100755 --- a/paddle/fluid/framework/ir/CMakeLists.txt +++ b/paddle/fluid/framework/ir/CMakeLists.txt @@ -268,6 +268,8 @@ if(WITH_XPU) xpu DEPS ${XPU_PASS_DEPS}) pass_library(add_activation_xpu_fuse_pass inference DIR xpu DEPS ${XPU_PASS_DEPS}) + pass_library(add_layernorm_xpu_fuse_pass inference DIR xpu DEPS + ${XPU_PASS_DEPS}) pass_library(xpu_delete_cast_op_pass inference DIR xpu DEPS ${XPU_PASS_DEPS}) pass_library(fold_interp_outsize_fuse_pass inference DIR xpu DEPS ${XPU_PASS_DEPS}) diff --git a/paddle/fluid/framework/ir/xpu/add_layernorm_xpu_fuse_pass.cc b/paddle/fluid/framework/ir/xpu/add_layernorm_xpu_fuse_pass.cc new file mode 100644 index 0000000000000000000000000000000000000000..698c0b6c03346bc97e3e76c09327a48ca8f038ef --- /dev/null +++ b/paddle/fluid/framework/ir/xpu/add_layernorm_xpu_fuse_pass.cc @@ -0,0 +1,244 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "glog/logging.h" + +#include "paddle/fluid/framework/ir/fuse_pass_base.h" +#include "paddle/fluid/framework/ir/graph_pattern_detector.h" +#include "paddle/fluid/framework/ir/pass.h" +#include "paddle/fluid/framework/ir/xpu/pass_utils.h" +#include "paddle/fluid/framework/ir/xpu/quant_utils.h" +#include "paddle/fluid/framework/op_version_registry.h" +#include "paddle/fluid/platform/enforce.h" + +namespace phi { +class DenseTensor; +} // namespace phi + +namespace paddle { +namespace framework { +class Scope; +} // namespace framework +} // namespace paddle + +namespace paddle { +namespace framework { +namespace ir { +namespace patterns { + +/* +fuse ele_add + activation block in to xpu_ele_fusion op +For example: +graph: + ele_x + | + elementwise_add -----ele_y + | + layernorm + | + output +------------------------------------------------------ +After the pass is applied: + ele_x + | ele_y + | / + | / + scale---- add_layernorm_fusion ---- bias + / | \ \ + / | \ \ + variance | meam z_add + Output +*/ +struct AddLayernormXPUPattern : public PatternBase { + AddLayernormXPUPattern(PDPattern* pattern, const std::string& name_scope); + // declare operator node's name + PATTERN_DECL_NODE(ele_add); + PATTERN_DECL_NODE(l_norm); + // declare variable node's name + PATTERN_DECL_NODE(ele_x); + PATTERN_DECL_NODE(ele_y); + PATTERN_DECL_NODE(ele_out); + PATTERN_DECL_NODE(norm_bias); + PATTERN_DECL_NODE(norm_scale); + PATTERN_DECL_NODE(norm_mean); + PATTERN_DECL_NODE(norm_variance); + PATTERN_DECL_NODE(norm_out); +}; + +AddLayernormXPUPattern::AddLayernormXPUPattern(PDPattern* pattern, + const std::string& name_scope) + : PatternBase(pattern, name_scope, name_scope) { + auto ele_add = + pattern->NewNode(ele_add_repr())->assert_is_op("elementwise_add"); + auto ele_x = pattern->NewNode(ele_x_repr()) + ->assert_is_op_input("elementwise_add", "X") + ->AsInput(); + auto ele_y = pattern->NewNode(ele_y_repr()) + ->assert_is_op_input("elementwise_add", "Y") + ->AsInput(); + auto ele_out = pattern->NewNode(ele_out_repr()) + ->assert_is_op_output("elementwise_add", "Out") + ->assert_is_op_input("layer_norm", "X") + ->assert_has_n_outputs(1); + ele_add->LinksFrom({ele_x, ele_y}).LinksTo({ele_out}); + auto l_norm = pattern->NewNode(l_norm_repr())->assert_is_op("layer_norm"); + auto norm_bias = pattern->NewNode(norm_bias_repr()) + ->AsInput() + ->assert_is_persistable_var() + ->assert_is_op_input("layer_norm", "Bias"); + auto norm_scale = pattern->NewNode(norm_scale_repr()) + ->AsInput() + ->assert_is_persistable_var() + ->assert_is_op_input("layer_norm", "Scale"); + auto norm_mean = pattern->NewNode(norm_mean_repr()) + ->AsOutput() + ->assert_is_op_output("layer_norm", "Mean"); + auto norm_variance = pattern->NewNode(norm_variance_repr()) + ->AsOutput() + ->assert_is_op_output("layer_norm", "Variance"); + auto norm_out = pattern->NewNode(norm_out_repr()) + ->AsOutput() + ->assert_is_op_output("layer_norm", "Y"); + l_norm->LinksFrom({ele_out, norm_bias, norm_scale}) + .LinksTo({norm_out, norm_mean, norm_variance}); +} + +} // namespace patterns + +namespace { +void setIntermediateOut(OpDesc* desc, + const std::string& out_name, + const std::string& scope_name) { + std::string new_name = scope_name + "/at." + out_name + ".new"; + desc->SetOutput(out_name, {new_name}); +} + +void addIntermediateOut(Node* op_node, + const std::string& out_name, + const std::string& scope_name, + Graph* graph) { + std::string new_name = scope_name + "/at." + out_name + ".new"; + VarDesc out_var(new_name); + out_var.SetPersistable(false); + auto* node_var = graph->CreateVarNode(&out_var); + IR_NODE_LINK_TO(op_node, node_var); +} + +} // namespace + +class AddLayernormXPUFusePass : public FusePassBase { + protected: + void ApplyImpl(ir::Graph* graph) const override; + + private: + void FuseAddLayernorm(ir::Graph* graph) const; + + const std::string name_scope_{"add_layernorm_xpu_fuse_pass"}; +}; + +void AddLayernormXPUFusePass::ApplyImpl(ir::Graph* graph) const { + PADDLE_ENFORCE_NOT_NULL( + graph, platform::errors::PreconditionNotMet("graph should not be null.")); + Init(name_scope_, graph); + + FuseAddLayernorm(graph); +} + +void AddLayernormXPUFusePass::FuseAddLayernorm(ir::Graph* graph) const { + GraphPatternDetector gpd; + patterns::AddLayernormXPUPattern pattern(gpd.mutable_pattern(), name_scope_); + + int found_subgraph_count = 0; + auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, + Graph* graph) { + VLOG(4) << "handle AddLayernormXPUFusePass fuse"; + // declare operator node's name + GET_IR_NODE(ele_add); + GET_IR_NODE(l_norm); + // declare variable node's name + GET_IR_NODE(ele_x); + GET_IR_NODE(ele_y); + GET_IR_NODE(ele_out); + GET_IR_NODE(norm_bias); + GET_IR_NODE(norm_scale); + GET_IR_NODE(norm_mean); + GET_IR_NODE(norm_variance); + GET_IR_NODE(norm_out); + + auto* block = ele_add->Op()->Block(); + auto* scope = param_scope(); + PADDLE_ENFORCE_NOT_NULL( + scope, platform::errors::InvalidArgument("Scope cannot be nullptr.")); + // delete useless node + std::unordered_set delete_nodes; + + float eps = PADDLE_GET_CONST(float, l_norm->Op()->GetAttr("epsilon")); + int begin_norm_axis = + PADDLE_GET_CONST(int, l_norm->Op()->GetAttr("begin_norm_axis")); + auto layer_norm_x_dims = ele_out->Var()->GetShape(); + auto layer_norm_x_mat_dims = + phi::flatten_to_2d(phi::make_ddim(layer_norm_x_dims), begin_norm_axis); + int64_t m = layer_norm_x_mat_dims[0]; + int64_t n = layer_norm_x_mat_dims[1]; + + std::string fused_op_out_name; + fused_op_out_name = norm_out->Name(); + // Generate add_layernorm fused op + framework::OpDesc fused_op_desc(block); + fused_op_desc.SetType("add_layernorm_xpu"); + // set attrs for fused op + fused_op_desc.SetInput("x", {ele_x->Name()}); + fused_op_desc.SetInput("y", {ele_y->Name()}); + fused_op_desc.SetInput("scale", {norm_scale->Name()}); + fused_op_desc.SetInput("bias", {norm_bias->Name()}); + fused_op_desc.SetAttr("m", m); + fused_op_desc.SetAttr("n", n); + fused_op_desc.SetAttr("epsilon", eps); + fused_op_desc.SetOutput("out", {fused_op_out_name}); + setIntermediateOut(&fused_op_desc, "mean", name_scope_); + setIntermediateOut(&fused_op_desc, "variance", name_scope_); + setIntermediateOut(&fused_op_desc, "z_add", name_scope_); + // relink fused op + auto* fused_op = graph->CreateOpNode(&fused_op_desc); + IR_NODE_LINK_TO(ele_x, fused_op); + IR_NODE_LINK_TO(ele_y, fused_op); + IR_NODE_LINK_TO(norm_scale, fused_op); + IR_NODE_LINK_TO(norm_bias, fused_op); + IR_NODE_LINK_TO(fused_op, norm_out); + addIntermediateOut(fused_op, "mean", name_scope_, graph); + addIntermediateOut(fused_op, "variance", name_scope_, graph); + addIntermediateOut(fused_op, "z_add", name_scope_, graph); + + delete_nodes.insert({ele_add, l_norm, ele_out, norm_mean, norm_variance}); + GraphSafeRemoveNodes(graph, delete_nodes); + found_subgraph_count++; + }; + + gpd(graph, handler); + AddStatis(found_subgraph_count); +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +REGISTER_PASS(add_layernorm_xpu_fuse_pass, + paddle::framework::ir::AddLayernormXPUFusePass); + +REGISTER_PASS_CAPABILITY(add_layernorm_xpu_fuse_pass) + .AddCombination( + paddle::framework::compatible::OpVersionComparatorCombination().EQ( + "add_layernorm_xpu", 0)); diff --git a/paddle/fluid/framework/ir/xpu/fc_xpu_fuse_pass.cc b/paddle/fluid/framework/ir/xpu/fc_xpu_fuse_pass.cc index 4d774b62cf8d10a41fb6fc6dd82b4ef65d647ad4..18a573db0c76da45f2f0a3fff96a1ea5256c1401 100644 --- a/paddle/fluid/framework/ir/xpu/fc_xpu_fuse_pass.cc +++ b/paddle/fluid/framework/ir/xpu/fc_xpu_fuse_pass.cc @@ -260,6 +260,7 @@ void FcXPUFusePass::ApplyImpl(ir::Graph* graph) const { "sigmoid", "swish", "relu6", + "leaky_relu", "", }) { found_subgraph_count += diff --git a/paddle/fluid/inference/api/paddle_pass_builder.cc b/paddle/fluid/inference/api/paddle_pass_builder.cc index 3542322bbb915c69f79a785ea350999c28ad1e33..852c471b9d2dfe89475d78a78cdaca0acb48f397 100644 --- a/paddle/fluid/inference/api/paddle_pass_builder.cc +++ b/paddle/fluid/inference/api/paddle_pass_builder.cc @@ -544,6 +544,7 @@ XpuPassStrategy::XpuPassStrategy() : PassStrategy({}) { "conv2d_xpu_fuse_pass", "conv2d_transpose_xpu_fuse_pass", "add_activation_xpu_fuse_pass", + "add_layernorm_xpu_fuse_pass", "yolo_box_xpu_fuse_pass", "link_xpu_op_max_pass", "inplace_op_var_pass", diff --git a/paddle/phi/api/yaml/fused_ops.yaml b/paddle/phi/api/yaml/fused_ops.yaml index 64a5d2bb00aae8e2b2b1a458fec9c5b8afbd22d7..f0fda9e3d2d1e849ef3c6269a08e88ace07c31ac 100644 --- a/paddle/phi/api/yaml/fused_ops.yaml +++ b/paddle/phi/api/yaml/fused_ops.yaml @@ -14,6 +14,15 @@ data_type : x optional : x_max, y_max +- op : add_layernorm_xpu + args : (Tensor x, Tensor y, Tensor scale, Tensor bias, int64_t m, int64_t n, float epsilon) + output : Tensor(out), Tensor(mean), Tensor(variance), Tensor(z_add) + infer_meta : + func : AddLayernormXPUInferMeta + kernel : + func : add_layernorm_xpu + data_type : x + - op : conv2d_transpose_xpu args : (Tensor x, Tensor x_max, Tensor filter, Tensor filter_max, Tensor bias, int[] strides, int[] paddings, int[] output_padding, IntArray output_size, str padding_algorithm, int groups, int[] dilations, str data_format, bool has_bias, bool with_act, str act_type) output : Tensor(out), Tensor(out_max) diff --git a/paddle/phi/backends/xpu/xpu2_op_list.cc b/paddle/phi/backends/xpu/xpu2_op_list.cc index cbd495135db7fabbc28f2d64f290f63f6cc66c57..1e32cdba98df82b05ec89164b7b7fc441a61f5a3 100644 --- a/paddle/phi/backends/xpu/xpu2_op_list.cc +++ b/paddle/phi/backends/xpu/xpu2_op_list.cc @@ -24,6 +24,8 @@ XPUOpMap& get_kl2_ops() { static XPUOpMap s_xpu2_kernels{ {"add_act_xpu", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, + {"add_layernorm_xpu", + XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, {"abs", XPUKernelSet({phi::DataType::FLOAT32})}, {"abs_grad", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, diff --git a/paddle/phi/infermeta/fusion.cc b/paddle/phi/infermeta/fusion.cc index be9506511a9a46a575377bad0766e12e19b014aa..97702d513e4425f08d2c85fe7f31435ac71a3b0a 100644 --- a/paddle/phi/infermeta/fusion.cc +++ b/paddle/phi/infermeta/fusion.cc @@ -92,6 +92,40 @@ void AddActXPUInferMeta(const MetaTensor& x, out_max->set_layout(x.layout()); } +void AddLayernormXPUInferMeta(const MetaTensor& x, + const MetaTensor& y, + const MetaTensor& scale, + const MetaTensor& bias, + int64_t m, + int64_t n, + float epsilon, + MetaTensor* out, + MetaTensor* mean, + MetaTensor* variance, + MetaTensor* z_add) { + int axis = -1; + auto x_dims = x.dims(); + auto y_dims = y.dims(); + if (x_dims != y_dims) { + auto out_dims = BroadCastInferShape(x_dims, y_dims, axis); + out->set_dims(out_dims); + } else { + out->set_dims(x_dims); + } + out->set_dtype(x.dtype()); + out->set_layout(x.layout()); + out->share_lod(x); + mean->set_dims(phi::make_ddim({m})); + mean->set_dtype(DataType::FLOAT32); + mean->set_layout(x.layout()); + variance->set_dims(phi::make_ddim({m})); + variance->set_dtype(DataType::FLOAT32); + variance->set_layout(x.layout()); + z_add->set_dims(phi::make_ddim({m, n})); + z_add->set_dtype(x.dtype()); + z_add->set_layout(x.layout()); +} + inline int ConvOutSize(int input_size, int filter_size, int dilation, diff --git a/paddle/phi/infermeta/fusion.h b/paddle/phi/infermeta/fusion.h index 8fc311ebdd89c4cb1b523395ab3ca7e2143d0a09..921b5b6a021e03f27c32326ce6a47748864345e0 100644 --- a/paddle/phi/infermeta/fusion.h +++ b/paddle/phi/infermeta/fusion.h @@ -30,6 +30,18 @@ void AddActXPUInferMeta(const MetaTensor& x, MetaTensor* out, MetaTensor* out_max); +void AddLayernormXPUInferMeta(const MetaTensor& x, + const MetaTensor& y, + const MetaTensor& scale, + const MetaTensor& bias, + int64_t m, + int64_t n, + float epsilon, + MetaTensor* out, + MetaTensor* mean, + MetaTensor* variance, + MetaTensor* z_add); + void Conv2dXPUInferMeta(const MetaTensor& x, const MetaTensor& x_max, const MetaTensor& filter, diff --git a/paddle/phi/kernels/fusion/xpu/add_layernorm_xpu_kernel.cc b/paddle/phi/kernels/fusion/xpu/add_layernorm_xpu_kernel.cc new file mode 100644 index 0000000000000000000000000000000000000000..66220d1187331360426eb52f4adc6ffb1776a82a --- /dev/null +++ b/paddle/phi/kernels/fusion/xpu/add_layernorm_xpu_kernel.cc @@ -0,0 +1,70 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/backends/xpu/enforce_xpu.h" +#include "paddle/phi/core/kernel_registry.h" + +namespace phi { +namespace fusion { + +template +void AddLayernormXPUKernel(const Context& ctx, + const DenseTensor& x, + const DenseTensor& y, + const DenseTensor& scale, + const DenseTensor& bias, + int64_t m, + int64_t n, + float epsilon, + DenseTensor* out, + DenseTensor* mean, + DenseTensor* variance, + DenseTensor* z_add) { + using XPUType = typename XPUTypeTrait::Type; + + auto* x_data = reinterpret_cast(x.data()); + auto* y_data = reinterpret_cast(y.data()); + const float* scale_data = scale.data(); + const float* bias_data = bias.data(); + + auto* out_data = reinterpret_cast(ctx.template Alloc(out)); + float* mean_data = ctx.template Alloc(mean); + float* variance_data = ctx.template Alloc(variance); + auto* z_add_data = reinterpret_cast(ctx.template Alloc(z_add)); + + int r = xpu::add_layer_norm_fusion( // T + /* baidu::xpu::api::Context* ctx */ ctx.x_context(), + /* const T* x */ x_data, + /* const T* y */ y_data, + /* T* z */ out_data, + /* int64_t m */ m, + /* int64_t n */ n, + /* float epsilon */ epsilon, + /* const float* scale */ scale_data, + /* const float* bias */ bias_data, + /* float* mean */ mean_data, + /* float* variance */ variance_data, + /* T* z_add */ z_add_data); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "add_layernorm_xpu"); +} + +} // namespace fusion +} // namespace phi + +PD_REGISTER_KERNEL(add_layernorm_xpu, + XPU, + ALL_LAYOUT, + phi::fusion::AddLayernormXPUKernel, + float, + phi::dtype::float16) {} diff --git a/test/ir/inference/test_xpu_add_layernorm_fuse_pass.py b/test/ir/inference/test_xpu_add_layernorm_fuse_pass.py new file mode 100644 index 0000000000000000000000000000000000000000..d6b0d386cdd586996819f9df08575a2eabf78d6d --- /dev/null +++ b/test/ir/inference/test_xpu_add_layernorm_fuse_pass.py @@ -0,0 +1,89 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import hypothesis.strategies as st +import numpy as np +from auto_scan_test import PassAutoScanTest +from program_config import OpConfig, ProgramConfig, TensorConfig + + +class TestAddLayernormXPUFusePass(PassAutoScanTest): + def sample_predictor_configs(self, program_config): + config = self.create_inference_config(use_xpu=True) + yield config, ["add_layernorm_xpu"], (1e-3, 1e-3) + + def sample_program_config(self, draw): + batch_size = draw(st.integers(min_value=1, max_value=50)) + x_shape = [batch_size, 16, 128] + y_shape = x_shape + + axis = -1 + + epsilon = draw(st.floats(min_value=0.0000001, max_value=0.001)) + # begin_norm_axis has to be 2 + begin_norm_axis = 2 + # Here we will compose a program + # Still has some risks that the program is invalid or cause bug while running + # Use function `is_program_valid` to filter the invalid programs before running + # Use function `add_skip_pass_case` to ignore the programs even if they cause bug while runing + elementwise_op = OpConfig( + type='elementwise_add', + inputs={'X': ['eltwise_X'], 'Y': ['eltwise_Y']}, + outputs={'Out': ['eltwise_output']}, + axis=axis, + ) + layer_norm_op = OpConfig( + "layer_norm", + inputs={ + "X": ["eltwise_output"], + "Scale": ["layer_norm_scale"], + "Bias": ["layer_norm_bias"], + }, + outputs={ + "Y": ["layer_norm_out"], + "Mean": ["layer_norm_mean"], + "Variance": ["layer_norm_var"], + }, + begin_norm_axis=begin_norm_axis, + epsilon=epsilon, + ) + mini_graph = [elementwise_op, layer_norm_op] + + program_config = ProgramConfig( + ops=mini_graph, + weights={ + "layer_norm_scale": TensorConfig(shape=[x_shape[2]]), + "layer_norm_bias": TensorConfig(shape=[x_shape[2]]), + }, + inputs={ + "eltwise_X": TensorConfig(shape=x_shape), + "eltwise_Y": TensorConfig(shape=y_shape), + }, + outputs=mini_graph[-1].outputs["Y"], + ) + return program_config + + def test(self): + self.run_and_statis( + quant=False, + max_examples=25, + passes=["add_layernorm_xpu_fuse_pass"], + ) + + +if __name__ == "__main__": + np.random.seed(200) + unittest.main()