diff --git a/paddle/fluid/framework/ir/xpu/multi_encoder_xpu_fuse_pass.cc b/paddle/fluid/framework/ir/xpu/multi_encoder_xpu_fuse_pass.cc index 1053e20150afad5467e22a009babe54dfa3c200f..2b5beed6f57724a28c953569ea063e3c79bf0b47 100644 --- a/paddle/fluid/framework/ir/xpu/multi_encoder_xpu_fuse_pass.cc +++ b/paddle/fluid/framework/ir/xpu/multi_encoder_xpu_fuse_pass.cc @@ -23,26 +23,15 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include "paddle/fluid/framework/ir/xpu/multi_encoder_xpu_fuse_pass.h" #include -#include "paddle/fluid/framework/ir/fuse_pass_base.h" #include "paddle/fluid/framework/ir/graph_pattern_detector.h" -#include "paddle/fluid/framework/ir/pass.h" #include "paddle/fluid/framework/ir/xpu/pass_utils.h" #include "paddle/fluid/framework/ir/xpu/quant_utils.h" #include "paddle/fluid/framework/op_version_registry.h" #include "paddle/fluid/platform/enforce.h" #include "paddle/phi/kernels/concat_kernel.h" -namespace phi { -class DenseTensor; -} // namespace phi - -namespace paddle { -namespace framework { -class Scope; -} // namespace framework -} // namespace paddle - namespace paddle { namespace framework { namespace ir { @@ -515,175 +504,26 @@ SingleEncoderXPUPattern::SingleEncoderXPUPattern( } // namespace patterns -/* -step1: fuse single ops to single_encoder_xpu -step2: fuse mutitl single_encoder_xpu to multi_encoder_xpu - -1. step1 -Origin subgraph: - ------------ input_variable* - | / | \ - | / | \ - | v_matmul q_matmul k_matmul - | | | | - | | | | - | v_add q_add add - | | | | - | | | | - | v_reshape q_reshape k_reshape - | | | | - | | | | - | v_transpose q_transpose k_transpose - | | | | - | | \ / - | | qk_matmul - | | | - | | | - | | qk_add - | | | - | | | - | | qk_softmax - | | | - | | | - | ---------qkv_matmul_0 - | | - | | - | qkv_transpose - | | - | | - | qkv_reshape - | | - | | - | qkv_matmul_1 - | | - | | - | qkv_add_0 - | | - | | - ----------------------qkv_add_1 - | - | - layer_norm_1 - / \ - | | - | qkv_matmul_2 - | | - | | - | qkv_add_2 - | | - | | - | qkv_act - | | - | | - | qkv_matmul_3 - | | - | | - | qkv_add_3 - | | - \ / - qkv_add_4 - | - layer_norm - -Fused subgraph: - single_encoder_xpu - -2. step2 -Origin subgraph: - ... - | - single_encoder_xpu - | - (single_encoder_xpu) - | - (single_encoder_xpu) - | - ... -Fused subgraph: - multi_encoder_xpu -*/ -class MultiEncoderXPUFusePass : public FusePassBase { - protected: - void ApplyImpl(ir::Graph* graph) const override; - - private: - int ApplySingleEncoderXPUFuse(ir::Graph* graph, - const std::string& act_type, - const std::string& matmul_type_0, - const std::string& matmul_type_1, - const std::string& matmul_type_2, - bool norm_before, - bool with_q_scale, - bool with_mask) const; - - bool ApplyMultiEncoderXPUFuse(ir::Graph* graph) const; - - // Mask must be fp32 even if model is fp16 - int CastMask(ir::Graph* graph) const; - - // 1. Transpose q_w, k_w, v_w - // 2. Concat q_w, k_w, v_w - // 3. Generate qkv_w_max tensor - // 4. Quant qkv_w to int16 - void PrepareQKVWeight(Graph* graph, - Scope* scope, - BlockDesc* block, - Node* q_w, - Node* k_w, - Node* v_w, - Node** qkv_w, - Node** qkv_w_max) const; - - // 1. Cast bias to fp32 - // 2. Concat q/k/v bias - void PrepareQKVBias(Graph* graph, - Scope* scope, - BlockDesc* block, - Node* q_bias, - Node* k_bias, - Node* v_bias, - Node** qkv_bias) const; - - const std::string name_scope_{"multi_encoder_xpu_fuse_pass"}; -}; - void MultiEncoderXPUFusePass::ApplyImpl(ir::Graph* graph) const { PADDLE_ENFORCE_NOT_NULL( graph, platform::errors::PreconditionNotMet("graph should not be null.")); Init(name_scope_, graph); - std::vector act_types{"gelu", "relu"}; - std::vector matmul_types_0{"matmul_v2", "matmul", "mul"}; - std::vector matmul_types_1{"matmul_v2", "matmul"}; - std::vector matmul_types_2{"matmul_v2", "matmul"}; - std::vector norm_befores{true, false}; - std::vector with_q_scales{true, false}; - std::vector with_masks{true, false}; + int single_encoder_fused_counts = 0; int multi_encoder_fused_counts = 0; - for (auto act_type : act_types) { - for (auto matmul_type_0 : matmul_types_0) { - for (auto matmul_type_1 : matmul_types_1) { - for (auto matmul_type_2 : matmul_types_2) { - for (auto norm_before : norm_befores) { - for (auto with_q_scale : with_q_scales) { - for (auto with_mask : with_masks) { - single_encoder_fused_counts += - ApplySingleEncoderXPUFuse(graph, - act_type, - matmul_type_0, - matmul_type_1, - matmul_type_2, - norm_before, - with_q_scale, - with_mask); - while (ApplyMultiEncoderXPUFuse(graph)) { - multi_encoder_fused_counts++; - } - } - } - } - } - } + auto pattern_params = GeneratePatternParams(); + for (auto pattern_param : pattern_params) { + single_encoder_fused_counts += + ApplySingleEncoderXPUFuse(graph, + pattern_param.act_type, + pattern_param.matmul_type_0, + pattern_param.matmul_type_1, + pattern_param.matmul_type_2, + pattern_param.norm_before, + pattern_param.with_q_scale, + pattern_param.with_mask); + while (ApplyMultiEncoderXPUFuse(graph)) { + multi_encoder_fused_counts++; } } int cast_mask_counts = CastMask(graph); @@ -1372,6 +1212,13 @@ int MultiEncoderXPUFusePass::CastMask(ir::Graph* graph) const { return cast_counts; } +std::vector MultiEncoderXPUFusePass::GeneratePatternParams() + const { + return std::vector{ + // Params are arranged in alphabetic order + {"gelu", "matmul_v2", "matmul", "matmul_v2", false, false, true}}; +} + } // namespace ir } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/ir/xpu/multi_encoder_xpu_fuse_pass.h b/paddle/fluid/framework/ir/xpu/multi_encoder_xpu_fuse_pass.h new file mode 100644 index 0000000000000000000000000000000000000000..c4d4a0f8cc1747c5af43eaab84ebd25246578617 --- /dev/null +++ b/paddle/fluid/framework/ir/xpu/multi_encoder_xpu_fuse_pass.h @@ -0,0 +1,194 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include "paddle/fluid/framework/ir/fuse_pass_base.h" +#include "paddle/fluid/framework/ir/pass.h" + +namespace phi { +class DenseTensor; +} // namespace phi + +namespace paddle { +namespace framework { +class Scope; +} // namespace framework +} // namespace paddle + +namespace paddle { +namespace framework { +namespace ir { + +/* +step1: fuse single ops to single_encoder_xpu +step2: fuse mutitl single_encoder_xpu to multi_encoder_xpu + +1. step1 +Origin subgraph: + ------------ input_variable* + | / | \ + | / | \ + | v_matmul q_matmul k_matmul + | | | | + | | | | + | v_add q_add add + | | | | + | | | | + | v_reshape q_reshape k_reshape + | | | | + | | | | + | v_transpose q_transpose k_transpose + | | | | + | | \ / + | | qk_matmul + | | | + | | | + | | qk_add + | | | + | | | + | | qk_softmax + | | | + | | | + | ---------qkv_matmul_0 + | | + | | + | qkv_transpose + | | + | | + | qkv_reshape + | | + | | + | qkv_matmul_1 + | | + | | + | qkv_add_0 + | | + | | + ----------------------qkv_add_1 + | + | + layer_norm_1 + / \ + | | + | qkv_matmul_2 + | | + | | + | qkv_add_2 + | | + | | + | qkv_act + | | + | | + | qkv_matmul_3 + | | + | | + | qkv_add_3 + | | + \ / + qkv_add_4 + | + layer_norm + +Fused subgraph: + single_encoder_xpu + +2. step2 +Origin subgraph: + ... + | + single_encoder_xpu + | + (single_encoder_xpu) + | + (single_encoder_xpu) + | + ... +Fused subgraph: + multi_encoder_xpu +*/ + +struct PatternParam { + std::string act_type; // "gelu", "relu" + std::string matmul_type_0; // "matmul_v2", "matmul", "mul" + std::string matmul_type_1; // "matmul_v2", "matmul" + std::string matmul_type_2; // "matmul_v2", "matmul" + bool norm_before; + bool with_q_scale; + bool with_mask; +}; + +class MultiEncoderXPUFusePass : public FusePassBase { + protected: + void ApplyImpl(ir::Graph* graph) const override; + + private: + int ApplySingleEncoderXPUFuse(ir::Graph* graph, + const std::string& act_type, + const std::string& matmul_type_0, + const std::string& matmul_type_1, + const std::string& matmul_type_2, + bool norm_before, + bool with_q_scale, + bool with_mask) const; + + bool ApplyMultiEncoderXPUFuse(ir::Graph* graph) const; + + // Mask must be fp32 even if model is fp16 + int CastMask(ir::Graph* graph) const; + + // 1. Transpose q_w, k_w, v_w + // 2. Concat q_w, k_w, v_w + // 3. Generate qkv_w_max tensor + // 4. Quant qkv_w to int16 + void PrepareQKVWeight(Graph* graph, + Scope* scope, + BlockDesc* block, + Node* q_w, + Node* k_w, + Node* v_w, + Node** qkv_w, + Node** qkv_w_max) const; + + // 1. Cast bias to fp32 + // 2. Concat q/k/v bias + void PrepareQKVBias(Graph* graph, + Scope* scope, + BlockDesc* block, + Node* q_bias, + Node* k_bias, + Node* v_bias, + Node** qkv_bias) const; + + // Iterating all attrs costs too much time. + // Just provide several cases. + std::vector GeneratePatternParams() const; + + const std::string name_scope_{"multi_encoder_xpu_fuse_pass"}; +}; + +} // namespace ir +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/ir/xpu/pass_utils.cc b/paddle/fluid/framework/ir/xpu/pass_utils.cc index d23049eb92c391eb97feba166fbc16c1ad72cad1..cb905235f617871e4977551ad181ebc0698cabf6 100644 --- a/paddle/fluid/framework/ir/xpu/pass_utils.cc +++ b/paddle/fluid/framework/ir/xpu/pass_utils.cc @@ -105,6 +105,12 @@ size_t HashTensor(const phi::DenseTensor& in) { template size_t HashTensor(const phi::DenseTensor& in); template size_t HashTensor(const phi::DenseTensor& in); +std::string GetPrefixWithoutHash(const std::string& name, + const phi::DenseTensor& tensor) { + std::size_t found = name.find("_#"); + return found == std::string::npos ? name : name.substr(0, found); +} + template void PrepareWeight(Graph* graph, Scope* scope, @@ -122,8 +128,9 @@ void PrepareWeight(Graph* graph, size_t dst_hash = HashTensor(dst_tensor); size_t dst_max_hash = HashTensor(dst_max_tensor); - std::string dst_name = src_name + "_" + std::to_string(dst_hash); - std::string dst_max_name = src_name + "_max_" + std::to_string(dst_max_hash); + std::string pre_name = GetPrefixWithoutHash(src_name, *src_tensor); + std::string dst_name = pre_name + "_#" + std::to_string(dst_hash); + std::string dst_max_name = pre_name + "_max_#" + std::to_string(dst_max_hash); *dst = FindNodeWithName(graph, dst_name); if (*dst == nullptr) { // Create dst node @@ -199,7 +206,8 @@ void PrepareBias( phi::DenseTensor dst_tensor; CastToFp32(src_tensor, &dst_tensor); size_t dst_hash = HashTensor(dst_tensor); - std::string dst_name = src_name + "_" + std::to_string(dst_hash); + std::string pre_name = GetPrefixWithoutHash(src_name, *src_tensor); + std::string dst_name = pre_name + "_#" + std::to_string(dst_hash); *dst = FindNodeWithName(graph, dst_name); if (*dst == nullptr) { // Create dst node