未验证 提交 e2cdd4a3 编写于 作者: Z zhupengyang 提交者: GitHub

[xpu] optimize multi_encoder_xpu_fuse_pass performance (#51346)

上级 e6ca78c2
......@@ -23,26 +23,15 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/xpu/multi_encoder_xpu_fuse_pass.h"
#include <string>
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/framework/ir/xpu/pass_utils.h"
#include "paddle/fluid/framework/ir/xpu/quant_utils.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/phi/kernels/concat_kernel.h"
namespace phi {
class DenseTensor;
} // namespace phi
namespace paddle {
namespace framework {
class Scope;
} // namespace framework
} // namespace paddle
namespace paddle {
namespace framework {
namespace ir {
......@@ -515,175 +504,26 @@ SingleEncoderXPUPattern::SingleEncoderXPUPattern(
} // namespace patterns
/*
step1: fuse single ops to single_encoder_xpu
step2: fuse mutitl single_encoder_xpu to multi_encoder_xpu
1. step1
Origin subgraph:
------------ input_variable*
| / | \
| / | \
| v_matmul q_matmul k_matmul
| | | |
| | | |
| v_add q_add add
| | | |
| | | |
| v_reshape q_reshape k_reshape
| | | |
| | | |
| v_transpose q_transpose k_transpose
| | | |
| | \ /
| | qk_matmul
| | |
| | |
| | qk_add
| | |
| | |
| | qk_softmax
| | |
| | |
| ---------qkv_matmul_0
| |
| |
| qkv_transpose
| |
| |
| qkv_reshape
| |
| |
| qkv_matmul_1
| |
| |
| qkv_add_0
| |
| |
----------------------qkv_add_1
|
|
layer_norm_1
/ \
| |
| qkv_matmul_2
| |
| |
| qkv_add_2
| |
| |
| qkv_act
| |
| |
| qkv_matmul_3
| |
| |
| qkv_add_3
| |
\ /
qkv_add_4
|
layer_norm
Fused subgraph:
single_encoder_xpu
2. step2
Origin subgraph:
...
|
single_encoder_xpu
|
(single_encoder_xpu)
|
(single_encoder_xpu)
|
...
Fused subgraph:
multi_encoder_xpu
*/
class MultiEncoderXPUFusePass : public FusePassBase {
protected:
void ApplyImpl(ir::Graph* graph) const override;
private:
int ApplySingleEncoderXPUFuse(ir::Graph* graph,
const std::string& act_type,
const std::string& matmul_type_0,
const std::string& matmul_type_1,
const std::string& matmul_type_2,
bool norm_before,
bool with_q_scale,
bool with_mask) const;
bool ApplyMultiEncoderXPUFuse(ir::Graph* graph) const;
// Mask must be fp32 even if model is fp16
int CastMask(ir::Graph* graph) const;
// 1. Transpose q_w, k_w, v_w
// 2. Concat q_w, k_w, v_w
// 3. Generate qkv_w_max tensor
// 4. Quant qkv_w to int16
void PrepareQKVWeight(Graph* graph,
Scope* scope,
BlockDesc* block,
Node* q_w,
Node* k_w,
Node* v_w,
Node** qkv_w,
Node** qkv_w_max) const;
// 1. Cast bias to fp32
// 2. Concat q/k/v bias
void PrepareQKVBias(Graph* graph,
Scope* scope,
BlockDesc* block,
Node* q_bias,
Node* k_bias,
Node* v_bias,
Node** qkv_bias) const;
const std::string name_scope_{"multi_encoder_xpu_fuse_pass"};
};
void MultiEncoderXPUFusePass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::PreconditionNotMet("graph should not be null."));
Init(name_scope_, graph);
std::vector<std::string> act_types{"gelu", "relu"};
std::vector<std::string> matmul_types_0{"matmul_v2", "matmul", "mul"};
std::vector<std::string> matmul_types_1{"matmul_v2", "matmul"};
std::vector<std::string> matmul_types_2{"matmul_v2", "matmul"};
std::vector<bool> norm_befores{true, false};
std::vector<bool> with_q_scales{true, false};
std::vector<bool> with_masks{true, false};
int single_encoder_fused_counts = 0;
int multi_encoder_fused_counts = 0;
for (auto act_type : act_types) {
for (auto matmul_type_0 : matmul_types_0) {
for (auto matmul_type_1 : matmul_types_1) {
for (auto matmul_type_2 : matmul_types_2) {
for (auto norm_before : norm_befores) {
for (auto with_q_scale : with_q_scales) {
for (auto with_mask : with_masks) {
single_encoder_fused_counts +=
ApplySingleEncoderXPUFuse(graph,
act_type,
matmul_type_0,
matmul_type_1,
matmul_type_2,
norm_before,
with_q_scale,
with_mask);
while (ApplyMultiEncoderXPUFuse(graph)) {
multi_encoder_fused_counts++;
}
}
}
}
}
}
auto pattern_params = GeneratePatternParams();
for (auto pattern_param : pattern_params) {
single_encoder_fused_counts +=
ApplySingleEncoderXPUFuse(graph,
pattern_param.act_type,
pattern_param.matmul_type_0,
pattern_param.matmul_type_1,
pattern_param.matmul_type_2,
pattern_param.norm_before,
pattern_param.with_q_scale,
pattern_param.with_mask);
while (ApplyMultiEncoderXPUFuse(graph)) {
multi_encoder_fused_counts++;
}
}
int cast_mask_counts = CastMask(graph);
......@@ -1372,6 +1212,13 @@ int MultiEncoderXPUFusePass::CastMask(ir::Graph* graph) const {
return cast_counts;
}
std::vector<PatternParam> MultiEncoderXPUFusePass::GeneratePatternParams()
const {
return std::vector<PatternParam>{
// Params are arranged in alphabetic order
{"gelu", "matmul_v2", "matmul", "matmul_v2", false, false, true}};
}
} // namespace ir
} // namespace framework
} // namespace paddle
......
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/pass.h"
namespace phi {
class DenseTensor;
} // namespace phi
namespace paddle {
namespace framework {
class Scope;
} // namespace framework
} // namespace paddle
namespace paddle {
namespace framework {
namespace ir {
/*
step1: fuse single ops to single_encoder_xpu
step2: fuse mutitl single_encoder_xpu to multi_encoder_xpu
1. step1
Origin subgraph:
------------ input_variable*
| / | \
| / | \
| v_matmul q_matmul k_matmul
| | | |
| | | |
| v_add q_add add
| | | |
| | | |
| v_reshape q_reshape k_reshape
| | | |
| | | |
| v_transpose q_transpose k_transpose
| | | |
| | \ /
| | qk_matmul
| | |
| | |
| | qk_add
| | |
| | |
| | qk_softmax
| | |
| | |
| ---------qkv_matmul_0
| |
| |
| qkv_transpose
| |
| |
| qkv_reshape
| |
| |
| qkv_matmul_1
| |
| |
| qkv_add_0
| |
| |
----------------------qkv_add_1
|
|
layer_norm_1
/ \
| |
| qkv_matmul_2
| |
| |
| qkv_add_2
| |
| |
| qkv_act
| |
| |
| qkv_matmul_3
| |
| |
| qkv_add_3
| |
\ /
qkv_add_4
|
layer_norm
Fused subgraph:
single_encoder_xpu
2. step2
Origin subgraph:
...
|
single_encoder_xpu
|
(single_encoder_xpu)
|
(single_encoder_xpu)
|
...
Fused subgraph:
multi_encoder_xpu
*/
struct PatternParam {
std::string act_type; // "gelu", "relu"
std::string matmul_type_0; // "matmul_v2", "matmul", "mul"
std::string matmul_type_1; // "matmul_v2", "matmul"
std::string matmul_type_2; // "matmul_v2", "matmul"
bool norm_before;
bool with_q_scale;
bool with_mask;
};
class MultiEncoderXPUFusePass : public FusePassBase {
protected:
void ApplyImpl(ir::Graph* graph) const override;
private:
int ApplySingleEncoderXPUFuse(ir::Graph* graph,
const std::string& act_type,
const std::string& matmul_type_0,
const std::string& matmul_type_1,
const std::string& matmul_type_2,
bool norm_before,
bool with_q_scale,
bool with_mask) const;
bool ApplyMultiEncoderXPUFuse(ir::Graph* graph) const;
// Mask must be fp32 even if model is fp16
int CastMask(ir::Graph* graph) const;
// 1. Transpose q_w, k_w, v_w
// 2. Concat q_w, k_w, v_w
// 3. Generate qkv_w_max tensor
// 4. Quant qkv_w to int16
void PrepareQKVWeight(Graph* graph,
Scope* scope,
BlockDesc* block,
Node* q_w,
Node* k_w,
Node* v_w,
Node** qkv_w,
Node** qkv_w_max) const;
// 1. Cast bias to fp32
// 2. Concat q/k/v bias
void PrepareQKVBias(Graph* graph,
Scope* scope,
BlockDesc* block,
Node* q_bias,
Node* k_bias,
Node* v_bias,
Node** qkv_bias) const;
// Iterating all attrs costs too much time.
// Just provide several cases.
std::vector<PatternParam> GeneratePatternParams() const;
const std::string name_scope_{"multi_encoder_xpu_fuse_pass"};
};
} // namespace ir
} // namespace framework
} // namespace paddle
......@@ -105,6 +105,12 @@ size_t HashTensor(const phi::DenseTensor& in) {
template size_t HashTensor<int16_t>(const phi::DenseTensor& in);
template size_t HashTensor<float>(const phi::DenseTensor& in);
std::string GetPrefixWithoutHash(const std::string& name,
const phi::DenseTensor& tensor) {
std::size_t found = name.find("_#");
return found == std::string::npos ? name : name.substr(0, found);
}
template <typename T>
void PrepareWeight(Graph* graph,
Scope* scope,
......@@ -122,8 +128,9 @@ void PrepareWeight(Graph* graph,
size_t dst_hash = HashTensor<T>(dst_tensor);
size_t dst_max_hash = HashTensor<float>(dst_max_tensor);
std::string dst_name = src_name + "_" + std::to_string(dst_hash);
std::string dst_max_name = src_name + "_max_" + std::to_string(dst_max_hash);
std::string pre_name = GetPrefixWithoutHash(src_name, *src_tensor);
std::string dst_name = pre_name + "_#" + std::to_string(dst_hash);
std::string dst_max_name = pre_name + "_max_#" + std::to_string(dst_max_hash);
*dst = FindNodeWithName(graph, dst_name);
if (*dst == nullptr) {
// Create dst node
......@@ -199,7 +206,8 @@ void PrepareBias(
phi::DenseTensor dst_tensor;
CastToFp32(src_tensor, &dst_tensor);
size_t dst_hash = HashTensor<float>(dst_tensor);
std::string dst_name = src_name + "_" + std::to_string(dst_hash);
std::string pre_name = GetPrefixWithoutHash(src_name, *src_tensor);
std::string dst_name = pre_name + "_#" + std::to_string(dst_hash);
*dst = FindNodeWithName(graph, dst_name);
if (*dst == nullptr) {
// Create dst node
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册