未验证 提交 c8aa6405 编写于 作者: Z zhupengyang 提交者: GitHub

[XPU] fix dropout pass; add multi_encoder_xpu_fuse_pass & multi_encoder_xpu kernel (#50499)

上级 df207283
...@@ -213,10 +213,17 @@ endif() ...@@ -213,10 +213,17 @@ endif()
if(WITH_XPU) if(WITH_XPU)
cc_library( cc_library(
quant_utils xpu_quant_utils
SRCS xpu/quant_utils.cc SRCS xpu/quant_utils.cc
DEPS pass) DEPS pass)
pass_library(fc_xpu_fuse_pass inference DIR xpu DEPS quant_utils) cc_library(
xpu_pass_utils
SRCS xpu/pass_utils.cc
DEPS pass)
set(XPU_PASS_DEPS xpu_quant_utils xpu_pass_utils)
pass_library(fc_xpu_fuse_pass inference DIR xpu DEPS ${XPU_PASS_DEPS})
pass_library(multi_encoder_xpu_fuse_pass inference DIR xpu DEPS
${XPU_PASS_DEPS})
endif() endif()
cc_library( cc_library(
......
...@@ -25,71 +25,52 @@ namespace paddle { ...@@ -25,71 +25,52 @@ namespace paddle {
namespace framework { namespace framework {
namespace ir { namespace ir {
#define GET_IR_NODE(node__) GET_IR_NODE_FROM_SUBGRAPH(node__, node__, pattern); #define GET_IR_NODE(node_) GET_IR_NODE_FROM_SUBGRAPH(node_, node_, pattern)
#define GET_NODES \
GET_IR_NODE(any_op_out); \
GET_IR_NODE(dropout_op); \
GET_IR_NODE(dropout_op_out); \
GET_IR_NODE(dropout_op_outmask); \
GET_IR_NODE(any_op2);
void DeleteDropoutOpPass::ApplyImpl(ir::Graph* graph) const { void DeleteDropoutOpPass::ApplyImpl(ir::Graph* graph) const {
const std::string pattern_name = "delete_dropout_op_pattern"; const std::string pattern_name = "delete_dropout_op_pattern";
FusePassBase::Init(pattern_name, graph); FusePassBase::Init(pattern_name, graph);
GraphPatternDetector gpd; GraphPatternDetector gpd;
patterns::DeleteDropoutOpPattern pattern(gpd.mutable_pattern(), pattern_name); patterns::DeleteDropoutOpPattern pattern(gpd.mutable_pattern(), pattern_name);
pattern(); pattern();
int found_subgraph_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) { Graph* g) {
GET_NODES; GET_IR_NODE(dropout_op_x);
IR_NODE_LINK_TO(any_op_out, any_op2); GET_IR_NODE(dropout_op);
std::string any_op_out_name = any_op_out->Var()->Name(); GET_IR_NODE(dropout_op_out);
std::string dropout_op_out_name = dropout_op_out->Var()->Name(); GET_IR_NODE(dropout_op_mask);
// any_op2 // link dropout_op_out to pre_op
auto* any_op2_desc = any_op2->Op(); auto dropout_op_x_name = dropout_op_x->Var()->Name();
auto var_map = any_op2_desc->Inputs(); auto dropout_op_out_name = dropout_op_out->Var()->Name();
std::string arg_name = ""; auto pre_ops = dropout_op_x->inputs;
for (auto& name_m : var_map) { if (pre_ops.empty()) return;
if (std::find(name_m.second.begin(), auto pre_op_desc = pre_ops[0]->Op();
name_m.second.end(), auto pre_op_outs = pre_op_desc->Outputs();
dropout_op_out_name) != name_m.second.end()) { for (auto& out_var : pre_op_outs) {
arg_name = name_m.first; auto names = out_var.second;
} for (size_t i = 0; i < names.size(); i++) {
} if (names[i] == dropout_op_x_name) {
if (arg_name.size() == 0) { names[i] = dropout_op_out_name;
LOG(INFO) << "Delete dropout op pass: can not find the input " pre_op_desc->SetOutput(out_var.first, names);
<< dropout_op_out_name; break;
return;
}
// modify the any_op2's inputs
for (auto& name_m : var_map) {
if (std::find(name_m.second.begin(),
name_m.second.end(),
dropout_op_out_name) != name_m.second.end()) {
std::vector<std::string> new_inputs;
for (auto& i_n : name_m.second) {
if (i_n != dropout_op_out_name) {
new_inputs.push_back(i_n);
}
} }
new_inputs.push_back(any_op_out_name);
any_op2_desc->SetInput(name_m.first, new_inputs);
any_op2_desc->Flush();
} }
} }
any_op2_desc->Flush(); IR_NODE_LINK_TO(pre_ops[0], dropout_op_out);
// Delete the unneeded nodes. // delete useless node
GraphSafeRemoveNodes(graph, std::unordered_set<const Node*> delete_nodes{
{dropout_op, dropout_op_out, dropout_op_outmask}); dropout_op_x, dropout_op, dropout_op_mask};
GraphSafeRemoveNodes(graph, delete_nodes);
found_subgraph_count++;
}; };
gpd(graph, handler); gpd(graph, handler);
AddStatis(found_subgraph_count);
} }
DeleteDropoutOpXPass::DeleteDropoutOpXPass() { DeleteDropoutOpXPass::DeleteDropoutOpXPass() {
...@@ -279,6 +260,10 @@ void DeleteDropoutOpXPass::ReplaceOutputVar(Node* op, ...@@ -279,6 +260,10 @@ void DeleteDropoutOpXPass::ReplaceOutputVar(Node* op,
REGISTER_PASS(delete_dropout_op_pass, REGISTER_PASS(delete_dropout_op_pass,
paddle::framework::ir::DeleteDropoutOpPass); paddle::framework::ir::DeleteDropoutOpPass);
REGISTER_PASS_CAPABILITY(delete_dropout_op_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination().EQ(
"dropout", 0));
REGISTER_PASS(delete_dropout_op_x_pass, REGISTER_PASS(delete_dropout_op_x_pass,
paddle::framework::ir::DeleteDropoutOpXPass); paddle::framework::ir::DeleteDropoutOpXPass);
......
...@@ -3034,26 +3034,19 @@ PDNode *patterns::TransposeFlattenConcat::operator()( ...@@ -3034,26 +3034,19 @@ PDNode *patterns::TransposeFlattenConcat::operator()(
} }
void patterns::DeleteDropoutOpPattern::operator()() { void patterns::DeleteDropoutOpPattern::operator()() {
auto any_op_out = pattern->NewNode(any_op_out_repr()) auto dropout_op_x = pattern->NewNode(dropout_op_x_repr())
->assert_is_op_input("dropout", "X") ->assert_is_op_input("dropout", "X")
->AsInput(); ->AsInput();
auto dropout_op = pattern->NewNode(dropout_op_repr())
auto dropout_op = ->assert_is_op("dropout")
pattern->NewNode(dropout_op_repr())->assert_is_op("dropout"); ->assert_op_attr("dropout_implementation",
std::string("upscale_in_train"));
auto dropout_op_out = pattern->NewNode(dropout_op_out_repr()) auto dropout_op_out = pattern->NewNode(dropout_op_out_repr())
->assert_is_op_output("dropout", "Out") ->assert_is_op_output("dropout", "Out");
->AsIntermediate(); auto dropout_op_mask = pattern->NewNode(dropout_op_mask_repr())
->assert_is_op_output("dropout", "Mask");
auto dropout_op_outmask = pattern->NewNode(dropout_op_outmask_repr()) dropout_op->LinksFrom({dropout_op_x})
->assert_is_op_output("dropout", "Mask") .LinksTo({dropout_op_out, dropout_op_mask});
->AsOutput();
auto any_op2 = pattern->NewNode(any_op2_repr())->assert_is_op()->AsOutput();
dropout_op->LinksFrom({any_op_out});
dropout_op_out->LinksFrom({dropout_op});
dropout_op_outmask->LinksFrom({dropout_op});
any_op2->LinksFrom({dropout_op_out});
} }
void patterns::DeleteQuantOpFuse::operator()(PDNode *input_act_node, void patterns::DeleteQuantOpFuse::operator()(PDNode *input_act_node,
......
...@@ -1763,11 +1763,10 @@ struct DeleteDropoutOpPattern : public PatternBase { ...@@ -1763,11 +1763,10 @@ struct DeleteDropoutOpPattern : public PatternBase {
void operator()(); void operator()();
PATTERN_DECL_NODE(any_op_out); PATTERN_DECL_NODE(dropout_op_x);
PATTERN_DECL_NODE(dropout_op); PATTERN_DECL_NODE(dropout_op);
PATTERN_DECL_NODE(dropout_op_out); PATTERN_DECL_NODE(dropout_op_out);
PATTERN_DECL_NODE(dropout_op_outmask); PATTERN_DECL_NODE(dropout_op_mask);
PATTERN_DECL_NODE(any_op2);
}; };
struct DeleteQuantDequantOpPattern : public PatternBase { struct DeleteQuantDequantOpPattern : public PatternBase {
......
...@@ -176,15 +176,6 @@ class FcXPUFusePass : public FusePassBase { ...@@ -176,15 +176,6 @@ class FcXPUFusePass : public FusePassBase {
const std::string& act_type) const; const std::string& act_type) const;
const std::string name_scope_{"fc_xpu_fuse_pass"}; const std::string name_scope_{"fc_xpu_fuse_pass"};
const std::map<std::string, int> act_map_{{"", 0},
{"relu", 1},
{"sigmoid", 2},
{"tanh", 3},
{"gelu", 4},
{"leaky_relu", 5},
{"hard_swish", 14},
{"hard_sigmoid", 15},
{"relu6", 17}};
}; };
void FcXPUFusePass::ApplyImpl(ir::Graph* graph) const { void FcXPUFusePass::ApplyImpl(ir::Graph* graph) const {
...@@ -246,17 +237,13 @@ void FcXPUFusePass::ApplyImpl(ir::Graph* graph, ...@@ -246,17 +237,13 @@ void FcXPUFusePass::ApplyImpl(ir::Graph* graph,
mul_w_max_var->SetPersistable(true); mul_w_max_var->SetPersistable(true);
auto mul_w_max_tensor = auto mul_w_max_tensor =
scope->Var(mul_w_max_name)->GetMutable<phi::DenseTensor>(); scope->Var(mul_w_max_name)->GetMutable<phi::DenseTensor>();
auto* xpu_ctx = static_cast<phi::XPUContext*>(
platform::DeviceContextPool::Instance().Get(phi::XPUPlace()));
int max_ptr_size = xpu_ctx->x_context()->max_ptr_size();
bool transpose_w = false; bool transpose_w = false;
if (mul_type == "matmul") { if (mul_type == "matmul") {
transpose_w = PADDLE_GET_CONST(bool, mul->Op()->GetAttr("transpose_Y")); transpose_w = PADDLE_GET_CONST(bool, mul->Op()->GetAttr("transpose_Y"));
} else if (mul_type == "matmul_v2") { } else if (mul_type == "matmul_v2") {
transpose_w = PADDLE_GET_CONST(bool, mul->Op()->GetAttr("trans_y")); transpose_w = PADDLE_GET_CONST(bool, mul->Op()->GetAttr("trans_y"));
} }
QuantWeight<int16_t>( QuantWeight<int16_t>(mul_w_tensor, mul_w_max_tensor, !transpose_w);
mul_w_tensor, mul_w_max_tensor, !transpose_w, max_ptr_size);
} }
// Generate fc_xpu op // Generate fc_xpu op
...@@ -288,7 +275,7 @@ void FcXPUFusePass::ApplyImpl(ir::Graph* graph, ...@@ -288,7 +275,7 @@ void FcXPUFusePass::ApplyImpl(ir::Graph* graph,
fc_xpu_op_desc.SetAttr("act_type", 0); fc_xpu_op_desc.SetAttr("act_type", 0);
fc_xpu_op_desc.SetAttr("act_alpha", 0.f); fc_xpu_op_desc.SetAttr("act_alpha", 0.f);
if (act) { if (act) {
fc_xpu_op_desc.SetAttr("act_type", act_map_.at(act_type)); fc_xpu_op_desc.SetAttr("act_type", ConvertActivationType(act_type));
if (act_type == "leaky_relu") { if (act_type == "leaky_relu") {
fc_xpu_op_desc.SetAttr( fc_xpu_op_desc.SetAttr(
"act_alpha", PADDLE_GET_CONST(float, act->Op()->GetAttr("alpha"))); "act_alpha", PADDLE_GET_CONST(float, act->Op()->GetAttr("alpha")));
......
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <string>
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/framework/ir/xpu/pass_utils.h"
#include "paddle/fluid/framework/ir/xpu/quant_utils.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/platform/enforce.h"
namespace phi {
class DenseTensor;
} // namespace phi
namespace paddle {
namespace framework {
class Scope;
} // namespace framework
} // namespace paddle
namespace paddle {
namespace framework {
namespace ir {
namespace patterns {
struct SingleEncoderXPUPattern : public PatternBase {
SingleEncoderXPUPattern(PDPattern* pattern,
const std::string& name_scope,
const std::string& act_type,
const std::string& matmul_type_0,
const std::string& matmul_type_1,
const std::string& matmul_type_2,
bool norm_before,
bool with_mask);
// declare operator node's name
// If norm_before, use ln_0 & ln_1.
// If not norm_before, use ln_1 & ln_2.
PATTERN_DECL_NODE(ln_0);
PATTERN_DECL_NODE(ln_1);
PATTERN_DECL_NODE(ln_2);
PATTERN_DECL_NODE(q_matmul);
PATTERN_DECL_NODE(q_add);
PATTERN_DECL_NODE(q_reshape);
PATTERN_DECL_NODE(q_transpose);
PATTERN_DECL_NODE(k_matmul);
PATTERN_DECL_NODE(k_add);
PATTERN_DECL_NODE(k_reshape);
PATTERN_DECL_NODE(k_transpose);
PATTERN_DECL_NODE(v_matmul);
PATTERN_DECL_NODE(v_add);
PATTERN_DECL_NODE(v_reshape);
PATTERN_DECL_NODE(v_transpose);
PATTERN_DECL_NODE(qk_matmul);
PATTERN_DECL_NODE(qk_add);
PATTERN_DECL_NODE(qk_softmax);
PATTERN_DECL_NODE(qkv_matmul_0);
PATTERN_DECL_NODE(qkv_transpose);
PATTERN_DECL_NODE(qkv_reshape);
PATTERN_DECL_NODE(qkv_matmul_1);
PATTERN_DECL_NODE(qkv_add_0);
PATTERN_DECL_NODE(qkv_add_1);
PATTERN_DECL_NODE(qkv_matmul_2);
PATTERN_DECL_NODE(qkv_add_2);
PATTERN_DECL_NODE(qkv_act);
PATTERN_DECL_NODE(qkv_matmul_3);
PATTERN_DECL_NODE(qkv_add_3);
PATTERN_DECL_NODE(qkv_add_4);
// declare variable node's name
PATTERN_DECL_NODE(ln_0_x);
PATTERN_DECL_NODE(ln_0_bias);
PATTERN_DECL_NODE(ln_0_scale);
PATTERN_DECL_NODE(ln_0_out);
PATTERN_DECL_NODE(ln_0_mean);
PATTERN_DECL_NODE(ln_0_variance);
PATTERN_DECL_NODE(q_matmul_w);
PATTERN_DECL_NODE(q_matmul_out);
PATTERN_DECL_NODE(q_add_bias);
PATTERN_DECL_NODE(q_add_out);
PATTERN_DECL_NODE(q_reshape_out);
PATTERN_DECL_NODE(q_reshape_xshape);
PATTERN_DECL_NODE(q_transpose_out);
PATTERN_DECL_NODE(q_transpose_xshape);
PATTERN_DECL_NODE(k_matmul_w);
PATTERN_DECL_NODE(k_matmul_out);
PATTERN_DECL_NODE(k_add_bias);
PATTERN_DECL_NODE(k_add_out);
PATTERN_DECL_NODE(k_reshape_out);
PATTERN_DECL_NODE(k_reshape_xshape);
PATTERN_DECL_NODE(k_transpose_out);
PATTERN_DECL_NODE(k_transpose_xshape);
PATTERN_DECL_NODE(v_matmul_w);
PATTERN_DECL_NODE(v_matmul_out);
PATTERN_DECL_NODE(v_add_bias);
PATTERN_DECL_NODE(v_add_out);
PATTERN_DECL_NODE(v_reshape_out);
PATTERN_DECL_NODE(v_reshape_xshape);
PATTERN_DECL_NODE(v_transpose_out);
PATTERN_DECL_NODE(v_transpose_xshape);
PATTERN_DECL_NODE(qk_matmul_out);
PATTERN_DECL_NODE(qk_add_mask);
PATTERN_DECL_NODE(qk_add_out);
PATTERN_DECL_NODE(qk_softmax_out);
PATTERN_DECL_NODE(qkv_matmul_0_out);
PATTERN_DECL_NODE(qkv_transpose_out);
PATTERN_DECL_NODE(qkv_transpose_xshape);
PATTERN_DECL_NODE(qkv_reshape_out);
PATTERN_DECL_NODE(qkv_reshape_xshape);
PATTERN_DECL_NODE(qkv_matmul_1_w);
PATTERN_DECL_NODE(qkv_matmul_1_out);
PATTERN_DECL_NODE(qkv_add_0_bias);
PATTERN_DECL_NODE(qkv_add_0_out);
PATTERN_DECL_NODE(qkv_add_1_out);
PATTERN_DECL_NODE(ln_1_bias);
PATTERN_DECL_NODE(ln_1_scale);
PATTERN_DECL_NODE(ln_1_out);
PATTERN_DECL_NODE(ln_1_mean);
PATTERN_DECL_NODE(ln_1_variance);
PATTERN_DECL_NODE(qkv_matmul_2_w);
PATTERN_DECL_NODE(qkv_matmul_2_out);
PATTERN_DECL_NODE(qkv_add_2_bias);
PATTERN_DECL_NODE(qkv_add_2_out);
PATTERN_DECL_NODE(qkv_act_out);
PATTERN_DECL_NODE(qkv_matmul_3_w);
PATTERN_DECL_NODE(qkv_matmul_3_out);
PATTERN_DECL_NODE(qkv_add_3_bias);
PATTERN_DECL_NODE(qkv_add_3_out);
PATTERN_DECL_NODE(qkv_add_4_out);
PATTERN_DECL_NODE(ln_2_x);
PATTERN_DECL_NODE(ln_2_bias);
PATTERN_DECL_NODE(ln_2_scale);
PATTERN_DECL_NODE(ln_2_out);
PATTERN_DECL_NODE(ln_2_mean);
PATTERN_DECL_NODE(ln_2_variance);
private:
std::string act_type_;
std::string matmul_type_0_;
std::string matmul_type_1_;
std::string matmul_type_2_;
bool norm_before_{true};
bool with_mask_{true};
};
SingleEncoderXPUPattern::SingleEncoderXPUPattern(
PDPattern* pattern,
const std::string& name_scope,
const std::string& act_type,
const std::string& matmul_type_0,
const std::string& matmul_type_1,
const std::string& matmul_type_2,
bool norm_before,
bool with_mask)
: PatternBase(pattern, name_scope, name_scope),
act_type_(act_type),
matmul_type_0_(matmul_type_0),
matmul_type_1_(matmul_type_1),
matmul_type_2_(matmul_type_2),
norm_before_(norm_before),
with_mask_(with_mask) {
// layer_norm 0
PDNode* ln_0_x = pattern->NewNode(ln_0_x_repr());
PDNode* ln_0_out = nullptr;
if (norm_before_) {
ln_0_x->assert_is_op_input("layer_norm", "X")->assert_var_not_persistable();
auto* ln_0_bias = pattern->NewNode(ln_0_bias_repr())
->assert_is_op_input("layer_norm", "Bias")
->assert_is_persistable_var();
auto* ln_0_scale = pattern->NewNode(ln_0_scale_repr())
->assert_is_op_input("layer_norm", "Scale")
->assert_is_persistable_var();
auto* ln_0 = pattern->NewNode(ln_0_repr())->assert_is_op("layer_norm");
ln_0_out = pattern->NewNode(ln_0_out_repr())
->assert_is_op_output("layer_norm", "Y")
->assert_var_not_persistable();
auto* ln_0_mean = pattern->NewNode(ln_0_mean_repr())
->assert_is_op_output("layer_norm", "Mean")
->assert_var_not_persistable();
auto* ln_0_variance = pattern->NewNode(ln_0_variance_repr())
->assert_is_op_output("layer_norm", "Variance")
->assert_var_not_persistable();
ln_0->LinksFrom({ln_0_x, ln_0_bias, ln_0_scale})
.LinksTo({ln_0_out, ln_0_mean, ln_0_variance});
}
// q: matmul + add + reshape + transpose
auto q_matmul_w = pattern->NewNode(q_matmul_w_repr())
->assert_is_op_input(matmul_type_0_, "Y")
->assert_is_persistable_var();
auto* q_matmul =
pattern->NewNode(q_matmul_repr())->assert_is_op(matmul_type_0_);
auto* q_matmul_out = pattern->NewNode(q_matmul_out_repr())
->assert_is_op_output(matmul_type_0_, "Out")
->assert_var_not_persistable();
auto q_add_bias = pattern->NewNode(q_add_bias_repr())
->assert_is_op_input("elementwise_add", "Y")
->assert_is_persistable_var();
auto* q_add = pattern->NewNode(q_add_repr())->assert_is_op("elementwise_add");
auto* q_add_out = pattern->NewNode(q_add_out_repr())
->assert_is_op_output("elementwise_add", "Out")
->assert_var_not_persistable();
auto* q_reshape =
pattern->NewNode(q_reshape_repr())->assert_is_op("reshape2");
auto* q_reshape_out = pattern->NewNode(q_reshape_out_repr())
->assert_is_op_output("reshape2", "Out")
->assert_var_not_persistable();
auto* q_reshape_xshape = pattern->NewNode(q_reshape_xshape_repr())
->assert_is_op_output("reshape2", "XShape")
->assert_var_not_persistable();
auto* q_transpose =
pattern->NewNode(q_transpose_repr())->assert_is_op("transpose2");
auto* q_transpose_out = pattern->NewNode(q_transpose_out_repr())
->assert_is_op_output("transpose2", "Out")
->assert_is_op_input(matmul_type_1_, "X")
->assert_var_not_persistable();
auto* q_transpose_xshape = pattern->NewNode(q_transpose_xshape_repr())
->assert_is_op_output("transpose2", "XShape")
->assert_var_not_persistable();
// k: matmul + add + reshape + transpose
auto k_matmul_w = pattern->NewNode(k_matmul_w_repr())
->assert_is_op_input(matmul_type_0_, "Y")
->assert_is_persistable_var();
auto* k_matmul =
pattern->NewNode(k_matmul_repr())->assert_is_op(matmul_type_0_);
auto* k_matmul_out = pattern->NewNode(k_matmul_out_repr())
->assert_is_op_output(matmul_type_0_, "Out")
->assert_var_not_persistable();
auto k_add_bias = pattern->NewNode(k_add_bias_repr())
->assert_is_op_input("elementwise_add", "Y")
->assert_is_persistable_var();
auto* k_add = pattern->NewNode(k_add_repr())->assert_is_op("elementwise_add");
auto* k_add_out = pattern->NewNode(k_add_out_repr())
->assert_is_op_output("elementwise_add", "Out")
->assert_var_not_persistable();
auto* k_reshape =
pattern->NewNode(k_reshape_repr())->assert_is_op("reshape2");
auto* k_reshape_out = pattern->NewNode(k_reshape_out_repr())
->assert_is_op_output("reshape2", "Out")
->assert_var_not_persistable();
auto* k_reshape_xshape = pattern->NewNode(k_reshape_xshape_repr())
->assert_is_op_output("reshape2", "XShape")
->assert_var_not_persistable();
auto* k_transpose =
pattern->NewNode(k_transpose_repr())->assert_is_op("transpose2");
auto* k_transpose_out = pattern->NewNode(k_transpose_out_repr())
->assert_is_op_output("transpose2", "Out")
->assert_is_op_input(matmul_type_1_, "Y")
->assert_var_not_persistable();
auto* k_transpose_xshape = pattern->NewNode(k_transpose_xshape_repr())
->assert_is_op_output("transpose2", "XShape")
->assert_var_not_persistable();
// qk: matmul + add + softmax
auto* qk_matmul =
pattern->NewNode(qk_matmul_repr())->assert_is_op(matmul_type_1_);
auto* qk_matmul_out = pattern->NewNode(qk_matmul_out_repr())
->assert_is_op_output(matmul_type_1_, "Out")
->assert_var_not_persistable();
PDNode* qk_add_out = nullptr;
if (with_mask_) {
auto qk_add_mask = pattern->NewNode(qk_add_mask_repr())
->assert_is_op_input("elementwise_add", "Y")
->assert_var_not_persistable();
auto* qk_add =
pattern->NewNode(qk_add_repr())->assert_is_op("elementwise_add");
qk_add_out = pattern->NewNode(qk_add_out_repr())
->assert_is_op_output("elementwise_add", "Out")
->assert_var_not_persistable();
qk_add->LinksFrom({qk_matmul_out, qk_add_mask}).LinksTo({qk_add_out});
}
auto* qk_softmax =
pattern->NewNode(qk_softmax_repr())->assert_is_op("softmax");
auto* qk_softmax_out = pattern->NewNode(qk_softmax_out_repr())
->assert_is_op_output("softmax", "Out")
->assert_is_op_input(matmul_type_2_, "X")
->assert_var_not_persistable();
// v: matmul + add + reshape + transpose
auto v_matmul_w = pattern->NewNode(v_matmul_w_repr())
->assert_is_op_input(matmul_type_0_, "Y")
->assert_is_persistable_var();
auto* v_matmul =
pattern->NewNode(v_matmul_repr())->assert_is_op(matmul_type_0_);
auto* v_matmul_out = pattern->NewNode(v_matmul_out_repr())
->assert_is_op_output(matmul_type_0_, "Out")
->assert_var_not_persistable();
auto v_add_bias = pattern->NewNode(v_add_bias_repr())
->assert_is_op_input("elementwise_add", "Y")
->assert_is_persistable_var();
auto* v_add = pattern->NewNode(v_add_repr())->assert_is_op("elementwise_add");
auto* v_add_out = pattern->NewNode(v_add_out_repr())
->assert_is_op_output("elementwise_add", "Out")
->assert_var_not_persistable();
auto* v_reshape =
pattern->NewNode(v_reshape_repr())->assert_is_op("reshape2");
auto* v_reshape_out = pattern->NewNode(v_reshape_out_repr())
->assert_is_op_output("reshape2", "Out")
->assert_var_not_persistable();
auto* v_reshape_xshape = pattern->NewNode(v_reshape_xshape_repr())
->assert_is_op_output("reshape2", "XShape")
->assert_var_not_persistable();
auto* v_transpose =
pattern->NewNode(v_transpose_repr())->assert_is_op("transpose2");
auto* v_transpose_out = pattern->NewNode(v_transpose_out_repr())
->assert_is_op_output("transpose2", "Out")
->assert_is_op_input(matmul_type_2_, "Y")
->assert_var_not_persistable();
auto* v_transpose_xshape = pattern->NewNode(v_transpose_xshape_repr())
->assert_is_op_output("transpose2", "XShape")
->assert_var_not_persistable();
// qkv
auto* qkv_matmul_0 =
pattern->NewNode(qkv_matmul_0_repr())->assert_is_op(matmul_type_2_);
auto* qkv_matmul_0_out = pattern->NewNode(qkv_matmul_0_out_repr())
->assert_is_op_output(matmul_type_2_, "Out")
->assert_var_not_persistable();
auto* qkv_transpose =
pattern->NewNode(qkv_transpose_repr())->assert_is_op("transpose2");
auto* qkv_transpose_out = pattern->NewNode(qkv_transpose_out_repr())
->assert_is_op_output("transpose2", "Out")
->assert_var_not_persistable();
auto* qkv_transpose_xshape = pattern->NewNode(qkv_transpose_xshape_repr())
->assert_is_op_output("transpose2", "XShape")
->assert_var_not_persistable();
auto* qkv_reshape =
pattern->NewNode(qkv_reshape_repr())->assert_is_op("reshape2");
auto* qkv_reshape_out = pattern->NewNode(qkv_reshape_out_repr())
->assert_is_op_output("reshape2", "Out")
->assert_var_not_persistable();
auto* qkv_reshape_xshape = pattern->NewNode(qkv_reshape_xshape_repr())
->assert_is_op_output("reshape2", "XShape")
->assert_var_not_persistable();
auto qkv_matmul_1_w = pattern->NewNode(qkv_matmul_1_w_repr())
->assert_is_op_input(matmul_type_0_, "Y")
->assert_is_persistable_var();
auto* qkv_matmul_1 =
pattern->NewNode(qkv_matmul_1_repr())->assert_is_op(matmul_type_0_);
auto* qkv_matmul_1_out = pattern->NewNode(qkv_matmul_1_out_repr())
->assert_is_op_output(matmul_type_0_, "Out")
->assert_var_not_persistable();
auto qkv_add_0_bias = pattern->NewNode(qkv_add_0_bias_repr())
->assert_is_op_input("elementwise_add", "Y")
->assert_is_persistable_var();
auto* qkv_add_0 =
pattern->NewNode(qkv_add_0_repr())->assert_is_op("elementwise_add");
auto* qkv_add_0_out = pattern->NewNode(qkv_add_0_out_repr())
->assert_is_op_output("elementwise_add", "Out")
->assert_var_not_persistable();
auto* qkv_add_1 =
pattern->NewNode(qkv_add_1_repr())->assert_is_op("elementwise_add");
auto* qkv_add_1_out = pattern->NewNode(qkv_add_1_out_repr())
->assert_is_op_output("elementwise_add", "Out")
->assert_is_op_input("layer_norm", "X")
->assert_var_not_persistable();
auto* ln_1_bias = pattern->NewNode(ln_1_bias_repr())
->assert_is_op_input("layer_norm", "Bias")
->assert_is_persistable_var();
auto* ln_1_scale = pattern->NewNode(ln_1_scale_repr())
->assert_is_op_input("layer_norm", "Scale")
->assert_is_persistable_var();
auto* ln_1 = pattern->NewNode(ln_1_repr())->assert_is_op("layer_norm");
auto* ln_1_out = pattern->NewNode(ln_1_out_repr())
->assert_is_op_output("layer_norm", "Y")
->assert_var_not_persistable();
auto* ln_1_mean = pattern->NewNode(ln_1_mean_repr())
->assert_is_op_output("layer_norm", "Mean")
->assert_var_not_persistable();
auto* ln_1_variance = pattern->NewNode(ln_1_variance_repr())
->assert_is_op_output("layer_norm", "Variance")
->assert_var_not_persistable();
auto qkv_matmul_2_w = pattern->NewNode(qkv_matmul_2_w_repr())
->assert_is_op_input(matmul_type_0_, "Y")
->assert_is_persistable_var();
auto* qkv_matmul_2 =
pattern->NewNode(qkv_matmul_2_repr())->assert_is_op(matmul_type_0_);
auto* qkv_matmul_2_out = pattern->NewNode(qkv_matmul_2_out_repr())
->assert_is_op_output(matmul_type_0_, "Out")
->assert_var_not_persistable();
auto qkv_add_2_bias = pattern->NewNode(qkv_add_2_bias_repr())
->assert_is_op_input("elementwise_add", "Y")
->assert_is_persistable_var();
auto* qkv_add_2 =
pattern->NewNode(qkv_add_2_repr())->assert_is_op("elementwise_add");
auto* qkv_add_2_out = pattern->NewNode(qkv_add_2_out_repr())
->assert_is_op_output("elementwise_add", "Out")
->assert_var_not_persistable();
auto* qkv_act = pattern->NewNode(qkv_act_repr())->assert_is_op(act_type_);
auto* qkv_act_out = pattern->NewNode(qkv_act_out_repr())
->assert_is_op_output(act_type_, "Out")
->assert_var_not_persistable();
auto qkv_matmul_3_w = pattern->NewNode(qkv_matmul_3_w_repr())
->assert_is_op_input(matmul_type_0_, "Y")
->assert_is_persistable_var();
auto* qkv_matmul_3 =
pattern->NewNode(qkv_matmul_3_repr())->assert_is_op(matmul_type_0_);
auto* qkv_matmul_3_out = pattern->NewNode(qkv_matmul_3_out_repr())
->assert_is_op_output(matmul_type_0_, "Out")
->assert_var_not_persistable();
auto qkv_add_3_bias = pattern->NewNode(qkv_add_3_bias_repr())
->assert_is_op_input("elementwise_add", "Y")
->assert_is_persistable_var();
auto* qkv_add_3 =
pattern->NewNode(qkv_add_3_repr())->assert_is_op("elementwise_add");
auto* qkv_add_3_out = pattern->NewNode(qkv_add_3_out_repr())
->assert_is_op_output("elementwise_add", "Out")
->assert_var_not_persistable();
auto* qkv_add_4 =
pattern->NewNode(qkv_add_4_repr())->assert_is_op("elementwise_add");
auto* qkv_add_4_out = pattern->NewNode(qkv_add_4_out_repr())
->assert_is_op_output("elementwise_add", "Out")
->assert_var_not_persistable();
PDNode* ln_2_out = nullptr;
if (!norm_before_) {
auto* ln_2_bias = pattern->NewNode(ln_2_bias_repr())
->assert_is_op_input("layer_norm", "Bias")
->assert_is_persistable_var();
auto* ln_2_scale = pattern->NewNode(ln_2_scale_repr())
->assert_is_op_input("layer_norm", "Scale")
->assert_is_persistable_var();
auto* ln_2 = pattern->NewNode(ln_2_repr())->assert_is_op("layer_norm");
ln_2_out = pattern->NewNode(ln_2_out_repr())
->assert_is_op_output("layer_norm", "Y")
->assert_var_not_persistable();
auto* ln_2_mean = pattern->NewNode(ln_2_mean_repr())
->assert_is_op_output("layer_norm", "Mean")
->assert_var_not_persistable();
auto* ln_2_variance = pattern->NewNode(ln_2_variance_repr())
->assert_is_op_output("layer_norm", "Variance")
->assert_var_not_persistable();
ln_2->LinksFrom({qkv_add_4_out, ln_2_bias, ln_2_scale})
.LinksTo({ln_2_out, ln_2_mean, ln_2_variance});
}
// link nodes
PDNode* q_matmul_x = ln_0_x;
if (norm_before_) q_matmul_x = ln_0_out;
q_matmul->LinksFrom({q_matmul_x, q_matmul_w}).LinksTo({q_matmul_out});
q_add->LinksFrom({q_matmul_out, q_add_bias}).LinksTo({q_add_out});
q_reshape->LinksFrom({q_add_out}).LinksTo({q_reshape_out, q_reshape_xshape});
q_transpose->LinksFrom({q_reshape_out})
.LinksTo({q_transpose_out, q_transpose_xshape});
k_matmul->LinksFrom({q_matmul_x, k_matmul_w}).LinksTo({k_matmul_out});
k_add->LinksFrom({k_matmul_out, k_add_bias}).LinksTo({k_add_out});
k_reshape->LinksFrom({k_add_out}).LinksTo({k_reshape_out, k_reshape_xshape});
k_transpose->LinksFrom({k_reshape_out})
.LinksTo({k_transpose_out, k_transpose_xshape});
qk_matmul->LinksFrom({q_transpose_out, k_transpose_out})
.LinksTo({qk_matmul_out});
PDNode* qk_softmax_x = qk_matmul_out;
if (with_mask_) qk_softmax_x = qk_add_out;
qk_softmax->LinksFrom({qk_softmax_x}).LinksTo({qk_softmax_out});
v_matmul->LinksFrom({q_matmul_x, v_matmul_w}).LinksTo({v_matmul_out});
v_add->LinksFrom({v_matmul_out, v_add_bias}).LinksTo({v_add_out});
v_reshape->LinksFrom({v_add_out}).LinksTo({v_reshape_out, v_reshape_xshape});
v_transpose->LinksFrom({v_reshape_out})
.LinksTo({v_transpose_out, v_transpose_xshape});
qkv_matmul_0->LinksFrom({qk_softmax_out, v_transpose_out})
.LinksTo({qkv_matmul_0_out});
qkv_transpose->LinksFrom({qkv_matmul_0_out})
.LinksTo({qkv_transpose_out, qkv_transpose_xshape});
qkv_reshape->LinksFrom({qkv_transpose_out})
.LinksTo({qkv_reshape_out, qkv_reshape_xshape});
qkv_matmul_1->LinksFrom({qkv_reshape_out, qkv_matmul_1_w})
.LinksTo({qkv_matmul_1_out});
qkv_add_0->LinksFrom({qkv_matmul_1_out, qkv_add_0_bias})
.LinksTo({qkv_add_0_out});
qkv_add_1->LinksFrom({qkv_add_0_out, q_matmul_x}).LinksTo({qkv_add_1_out});
ln_1->LinksFrom({qkv_add_1_out, ln_1_bias, ln_1_scale})
.LinksTo({ln_1_out, ln_1_mean, ln_1_variance});
qkv_matmul_2->LinksFrom({ln_1_out, qkv_matmul_2_w})
.LinksTo({qkv_matmul_2_out});
qkv_add_2->LinksFrom({qkv_matmul_2_out, qkv_add_2_bias})
.LinksTo({qkv_add_2_out});
qkv_act->LinksFrom({qkv_add_2_out}).LinksTo({qkv_act_out});
qkv_matmul_3->LinksFrom({qkv_act_out, qkv_matmul_3_w})
.LinksTo({qkv_matmul_3_out});
qkv_add_3->LinksFrom({qkv_matmul_3_out, qkv_add_3_bias})
.LinksTo({qkv_add_3_out});
if (norm_before_) {
qkv_add_4->LinksFrom({qkv_add_3_out, qkv_add_1_out})
.LinksTo({qkv_add_4_out});
} else {
qkv_add_4->LinksFrom({qkv_add_3_out, ln_1_out}).LinksTo({qkv_add_4_out});
}
}
} // namespace patterns
/*
step1: fuse single ops to single_encoder_xpu
step2: fuse mutitl single_encoder_xpu to multi_encoder_xpu
1. step1
Origin subgraph:
------------ input_variable*
| / | \
| / | \
| v_matmul q_matmul k_matmul
| | | |
| | | |
| v_add q_add add
| | | |
| | | |
| v_reshape q_reshape k_reshape
| | | |
| | | |
| v_transpose q_transpose k_transpose
| | | |
| | \ /
| | qk_matmul
| | |
| | |
| | qk_add
| | |
| | |
| | qk_softmax
| | |
| | |
| ---------qkv_matmul_0
| |
| |
| qkv_transpose
| |
| |
| qkv_reshape
| |
| |
| qkv_matmul_1
| |
| |
| qkv_add_0
| |
| |
----------------------qkv_add_1
|
|
layer_norm_1
/ \
| |
| qkv_matmul_2
| |
| |
| qkv_add_2
| |
| |
| qkv_act
| |
| |
| qkv_matmul_3
| |
| |
| qkv_add_3
| |
\ /
qkv_add_4
|
layer_norm
Fused subgraph:
single_encoder_xpu
2. step2
Origin subgraph:
...
|
single_encoder_xpu
|
(single_encoder_xpu)
|
(single_encoder_xpu)
|
...
Fused subgraph:
multi_encoder_xpu
*/
class MultiEncoderXPUFusePass : public FusePassBase {
protected:
void ApplyImpl(ir::Graph* graph) const override;
private:
int ApplySingleEncoderXPUFuse(ir::Graph* graph,
const std::string& act_type,
const std::string& matmul_type_0,
const std::string& matmul_type_1,
const std::string& matmul_type_2,
bool norm_before,
bool with_mask) const;
bool ApplyMultiEncoderXPUFuse(ir::Graph* graph) const;
// 1. Transpose q_w, k_w, v_w
// 2. Concat q_w, k_w, v_w
// 3. Generate qkv_w_max tensor
// 4. Quant qkv_w to int16
void PrepareQKVWeight(const phi::DenseTensor& q_w,
const phi::DenseTensor& k_w,
const phi::DenseTensor& v_w,
phi::DenseTensor* qkv_w,
phi::DenseTensor* qkv_w_max) const;
void ConcatQKVBias(const phi::DenseTensor& q_bias,
const phi::DenseTensor& k_bias,
const phi::DenseTensor& v_bias,
phi::DenseTensor* qkv_bias) const;
const std::string name_scope_{"multi_encoder_xpu_fuse_pass"};
};
void MultiEncoderXPUFusePass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::PreconditionNotMet("graph should not be null."));
Init(name_scope_, graph);
std::vector<std::string> act_types{"gelu", "relu"};
std::vector<std::string> matmul_types_0{"mul", "matmul", "matmul_v2"};
std::vector<std::string> matmul_types_1{"matmul", "matmul_v2"};
std::vector<std::string> matmul_types_2{"matmul", "matmul_v2"};
std::vector<bool> norm_befores{true, false};
std::vector<bool> with_masks{true, false};
int single_encoder_fused_counts = 0;
int multi_encoder_fused_counts = 0;
for (auto act_type : act_types) {
for (auto matmul_type_0 : matmul_types_0) {
for (auto matmul_type_1 : matmul_types_1) {
for (auto matmul_type_2 : matmul_types_2) {
for (auto norm_before : norm_befores) {
for (auto with_mask : with_masks) {
single_encoder_fused_counts +=
ApplySingleEncoderXPUFuse(graph,
act_type,
matmul_type_0,
matmul_type_1,
matmul_type_2,
norm_before,
with_mask);
while (ApplyMultiEncoderXPUFuse(graph)) {
multi_encoder_fused_counts++;
}
}
}
}
}
}
}
AddStatis(single_encoder_fused_counts);
AddStatis(multi_encoder_fused_counts);
}
void MultiEncoderXPUFusePass::PrepareQKVWeight(
const phi::DenseTensor& q_w,
const phi::DenseTensor& k_w,
const phi::DenseTensor& v_w,
phi::DenseTensor* qkv_w,
phi::DenseTensor* qkv_w_max) const {
// Transpose
phi::DenseTensor q_w_trans;
phi::DenseTensor k_w_trans;
phi::DenseTensor v_w_trans;
Transpose2D<float>(q_w, &q_w_trans);
Transpose2D<float>(k_w, &k_w_trans);
Transpose2D<float>(v_w, &v_w_trans);
// Concat
auto q_w_trans_dims = q_w_trans.dims();
auto k_w_trans_dims = k_w_trans.dims();
auto v_w_trans_dims = v_w_trans.dims();
qkv_w->Resize(DDim({q_w_trans_dims[0] + k_w_trans_dims[0] + v_w_trans_dims[0],
q_w_trans_dims[1]}));
qkv_w->set_type(q_w.type());
auto* dev_ctx = static_cast<phi::CPUContext*>(
platform::DeviceContextPool::Instance().Get(phi::CPUPlace()));
int size = q_w.numel();
auto* qkv_w_data = dev_ctx->Alloc<float>(qkv_w);
memcpy(qkv_w_data, q_w_trans.data(), size * sizeof(float));
qkv_w_data += size;
memcpy(qkv_w_data, k_w_trans.data(), size * sizeof(float));
qkv_w_data += size;
memcpy(qkv_w_data, v_w_trans.data(), size * sizeof(float));
// Quant to int16
QuantWeight<int16_t>(qkv_w, qkv_w_max, false);
}
void MultiEncoderXPUFusePass::ConcatQKVBias(const phi::DenseTensor& q_bias,
const phi::DenseTensor& k_bias,
const phi::DenseTensor& v_bias,
phi::DenseTensor* qkv_bias) const {
int q_bias_size = q_bias.numel();
qkv_bias->Resize(DDim({q_bias_size * 3}));
qkv_bias->set_type(q_bias.type());
auto* dev_ctx = static_cast<phi::CPUContext*>(
platform::DeviceContextPool::Instance().Get(phi::CPUPlace()));
auto* qkv_bias_data = dev_ctx->Alloc<float>(qkv_bias);
memcpy(qkv_bias_data, q_bias.data(), q_bias_size * sizeof(float));
qkv_bias_data += q_bias_size;
memcpy(qkv_bias_data, k_bias.data(), q_bias_size * sizeof(float));
qkv_bias_data += q_bias_size;
memcpy(qkv_bias_data, v_bias.data(), q_bias_size * sizeof(float));
}
int MultiEncoderXPUFusePass::ApplySingleEncoderXPUFuse(
ir::Graph* graph,
const std::string& act_type,
const std::string& matmul_type_0,
const std::string& matmul_type_1,
const std::string& matmul_type_2,
bool norm_before,
bool with_mask) const {
GraphPatternDetector gpd;
patterns::SingleEncoderXPUPattern pattern(gpd.mutable_pattern(),
name_scope_,
act_type,
matmul_type_0,
matmul_type_1,
matmul_type_2,
norm_before,
with_mask);
int found_subgraph_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* graph) {
// VLOG(4) << "handle MultiEncoderXPUFusePass fuse, step1";
GET_IR_NODE(ln_0);
GET_IR_NODE(ln_1);
GET_IR_NODE(ln_2);
GET_IR_NODE(q_matmul);
GET_IR_NODE(q_add);
GET_IR_NODE(q_reshape);
GET_IR_NODE(q_transpose);
GET_IR_NODE(k_matmul);
GET_IR_NODE(k_add);
GET_IR_NODE(k_reshape);
GET_IR_NODE(k_transpose);
GET_IR_NODE(v_matmul);
GET_IR_NODE(v_add);
GET_IR_NODE(v_reshape);
GET_IR_NODE(v_transpose);
GET_IR_NODE(qk_matmul);
GET_IR_NODE(qk_add);
GET_IR_NODE(qk_softmax);
GET_IR_NODE(qkv_matmul_0);
GET_IR_NODE(qkv_transpose);
GET_IR_NODE(qkv_reshape);
GET_IR_NODE(qkv_matmul_1);
GET_IR_NODE(qkv_add_0);
GET_IR_NODE(qkv_add_1);
GET_IR_NODE(qkv_matmul_2);
GET_IR_NODE(qkv_add_2);
GET_IR_NODE(qkv_act);
GET_IR_NODE(qkv_matmul_3);
GET_IR_NODE(qkv_add_3);
GET_IR_NODE(qkv_add_4);
GET_IR_NODE(ln_0_x);
GET_IR_NODE(ln_0_bias);
GET_IR_NODE(ln_0_scale);
GET_IR_NODE(ln_0_out);
GET_IR_NODE(ln_0_mean);
GET_IR_NODE(ln_0_variance);
GET_IR_NODE(q_matmul_w);
GET_IR_NODE(q_matmul_out);
GET_IR_NODE(q_add_bias);
GET_IR_NODE(q_add_out);
GET_IR_NODE(q_reshape_out);
GET_IR_NODE(q_reshape_xshape);
GET_IR_NODE(q_transpose_out);
GET_IR_NODE(q_transpose_xshape);
GET_IR_NODE(k_matmul_w);
GET_IR_NODE(k_matmul_out);
GET_IR_NODE(k_add_bias);
GET_IR_NODE(k_add_out);
GET_IR_NODE(k_reshape_out);
GET_IR_NODE(k_reshape_xshape);
GET_IR_NODE(k_transpose_out);
GET_IR_NODE(k_transpose_xshape);
GET_IR_NODE(v_matmul_w);
GET_IR_NODE(v_matmul_out);
GET_IR_NODE(v_add_bias);
GET_IR_NODE(v_add_out);
GET_IR_NODE(v_reshape_out);
GET_IR_NODE(v_reshape_xshape);
GET_IR_NODE(v_transpose_out);
GET_IR_NODE(v_transpose_xshape);
GET_IR_NODE(qk_matmul_out);
GET_IR_NODE(qk_add_mask);
GET_IR_NODE(qk_add_out);
GET_IR_NODE(qk_softmax_out);
GET_IR_NODE(qkv_matmul_0_out);
GET_IR_NODE(qkv_transpose_out);
GET_IR_NODE(qkv_transpose_xshape);
GET_IR_NODE(qkv_reshape_out);
GET_IR_NODE(qkv_reshape_xshape);
GET_IR_NODE(qkv_matmul_1_w);
GET_IR_NODE(qkv_matmul_1_out);
GET_IR_NODE(qkv_add_0_bias);
GET_IR_NODE(qkv_add_0_out);
GET_IR_NODE(qkv_add_1_out);
GET_IR_NODE(ln_1_bias);
GET_IR_NODE(ln_1_scale);
GET_IR_NODE(ln_1_out);
GET_IR_NODE(ln_1_mean);
GET_IR_NODE(ln_1_variance);
GET_IR_NODE(qkv_matmul_2_w);
GET_IR_NODE(qkv_matmul_2_out);
GET_IR_NODE(qkv_add_2_bias);
GET_IR_NODE(qkv_add_2_out);
GET_IR_NODE(qkv_act_out);
GET_IR_NODE(qkv_matmul_3_w);
GET_IR_NODE(qkv_matmul_3_out);
GET_IR_NODE(qkv_add_3_bias);
GET_IR_NODE(qkv_add_3_out);
GET_IR_NODE(qkv_add_4_out);
GET_IR_NODE(ln_2_x);
GET_IR_NODE(ln_2_bias);
GET_IR_NODE(ln_2_scale);
GET_IR_NODE(ln_2_out);
GET_IR_NODE(ln_2_mean);
GET_IR_NODE(ln_2_variance);
auto* block = q_matmul->Op()->Block();
auto* scope = param_scope();
// Prepare q,k,v weight
std::string q_w_name = q_matmul_w->Name();
std::string k_w_name = k_matmul_w->Name();
std::string v_w_name = v_matmul_w->Name();
std::string qkv_w_name = q_w_name + "_" + k_w_name + "_" + v_w_name;
VarDesc qkv_w_desc(qkv_w_name);
qkv_w_desc.SetPersistable(true);
auto* qkv_w = graph->CreateVarNode(&qkv_w_desc);
auto* qkv_w_var = block->Var(qkv_w_name);
qkv_w_var->SetPersistable(true);
std::string qkv_w_max_name = qkv_w_name + "_max";
VarDesc qkv_w_max_desc(qkv_w_max_name);
qkv_w_max_desc.SetPersistable(true);
auto* qkv_w_max = graph->CreateVarNode(&qkv_w_max_desc);
auto* qkv_w_max_var = block->Var(qkv_w_max_name);
qkv_w_max_var->SetPersistable(true);
PrepareQKVWeight(
scope->FindVar(q_w_name)->Get<phi::DenseTensor>(),
scope->FindVar(k_w_name)->Get<phi::DenseTensor>(),
scope->FindVar(v_w_name)->Get<phi::DenseTensor>(),
scope->Var(qkv_w_name)->GetMutable<phi::DenseTensor>(),
scope->Var(qkv_w_max_name)->GetMutable<phi::DenseTensor>());
// Prepare qkv_matmul_1_w, qkv_matmul_2_w, qkv_matmul_3_w
#define PREPARE_QKV_MATMUL_W(idx_) \
std::string qkv_matmul_##idx_##_w_name = qkv_matmul_##idx_##_w->Name(); \
std::string qkv_matmul_##idx_##_w_max_name = \
qkv_matmul_##idx_##_w_name + "_max"; \
VarDesc qkv_matmul_##idx_##_w_max_desc(qkv_matmul_##idx_##_w_max_name); \
qkv_matmul_##idx_##_w_max_desc.SetPersistable(true); \
auto qkv_matmul_##idx_##_w_max = \
graph->CreateVarNode(&qkv_matmul_##idx_##_w_max_desc); \
auto qkv_matmul_##idx_##_w_max_var = \
block->Var(qkv_matmul_##idx_##_w_max_name); \
qkv_matmul_##idx_##_w_max_var->SetPersistable(true); \
auto qkv_matmul_##idx_##_w_max_tensor = \
scope->Var(qkv_matmul_##idx_##_w_max_name) \
->GetMutable<phi::DenseTensor>(); \
auto qkv_matmul_##idx_##_w_tensor = \
scope->Var(qkv_matmul_##idx_##_w_name)->GetMutable<phi::DenseTensor>(); \
QuantWeight<int16_t>( \
qkv_matmul_##idx_##_w_tensor, qkv_matmul_##idx_##_w_max_tensor, true);
PREPARE_QKV_MATMUL_W(1);
PREPARE_QKV_MATMUL_W(2);
PREPARE_QKV_MATMUL_W(3);
#undef PREPARE_QKV_MATMUL_W
// Concat q_add_bias, k_add_bias, v_add_bias
std::string q_add_bias_name = q_add_bias->Name();
std::string k_add_bias_name = k_add_bias->Name();
std::string v_add_bias_name = v_add_bias->Name();
std::string qkv_add_bias_name =
q_add_bias_name + "_" + k_add_bias_name + "_" + v_add_bias_name;
VarDesc qkv_add_bias_desc(qkv_add_bias_name);
qkv_add_bias_desc.SetPersistable(true);
auto* qkv_add_bias = graph->CreateVarNode(&qkv_add_bias_desc);
auto* qkv_add_bias_var = block->Var(qkv_add_bias_name);
qkv_add_bias_var->SetPersistable(true);
ConcatQKVBias(
scope->FindVar(q_add_bias_name)->Get<phi::DenseTensor>(),
scope->FindVar(k_add_bias_name)->Get<phi::DenseTensor>(),
scope->FindVar(v_add_bias_name)->Get<phi::DenseTensor>(),
scope->Var(qkv_add_bias_name)->GetMutable<phi::DenseTensor>());
// Generate single_encoder_xpu op
framework::OpDesc op_desc(block);
op_desc.SetType("single_encoder_xpu");
op_desc.SetInput("x", {ln_0_x->Name()});
op_desc.SetInput("fc_weight",
{qkv_w_name,
qkv_matmul_1_w_name,
qkv_matmul_2_w_name,
qkv_matmul_3_w_name});
op_desc.SetInput("fc_weight_max",
{qkv_w_max_name,
qkv_matmul_1_w_max_name,
qkv_matmul_2_w_max_name,
qkv_matmul_3_w_max_name});
op_desc.SetInput("fc_bias",
{qkv_add_bias_name,
qkv_add_0_bias->Name(),
qkv_add_2_bias->Name(),
qkv_add_3_bias->Name()});
if (norm_before) {
op_desc.SetInput("ln_scale", {ln_0_scale->Name(), ln_1_scale->Name()});
op_desc.SetInput("ln_bias", {ln_0_bias->Name(), ln_1_bias->Name()});
} else {
op_desc.SetInput("ln_scale", {ln_1_scale->Name(), ln_2_scale->Name()});
op_desc.SetInput("ln_bias", {ln_1_bias->Name(), ln_2_bias->Name()});
}
if (with_mask) {
op_desc.SetInput("mask", {qk_add_mask->Name()});
}
op_desc.SetAttr("norm_before", norm_before);
op_desc.SetAttr("hidden_dim",
static_cast<int>(q_matmul_w->Var()->GetShape()[0]));
auto q_reshape_shape =
PADDLE_GET_CONST(std::vector<int>, q_reshape->Op()->GetAttr("shape"));
op_desc.SetAttr("head_num", q_reshape_shape[2]);
op_desc.SetAttr("size_per_head", q_reshape_shape[3]);
auto qkv_matmul_2_w_shape = qkv_matmul_2_w->Var()->GetShape();
op_desc.SetAttr(
"ffn_hidden_dim_scale",
static_cast<int>(qkv_matmul_2_w_shape[1] / qkv_matmul_2_w_shape[0]));
op_desc.SetAttr("act_type", ConvertActivationType(act_type));
op_desc.SetAttr("relative_type", static_cast<int>(0));
if (norm_before) {
op_desc.SetOutput("out", {qkv_add_4_out->Name()});
} else {
op_desc.SetOutput("out", {ln_2_out->Name()});
}
auto* single_encoder_xpu = graph->CreateOpNode(&op_desc);
// Link nodes
SAFE_IR_NODE_LINK_TO(ln_0_x, single_encoder_xpu);
SAFE_IR_NODE_LINK_TO(qkv_w, single_encoder_xpu);
SAFE_IR_NODE_LINK_TO(qkv_w_max, single_encoder_xpu);
SAFE_IR_NODE_LINK_TO(qkv_matmul_1_w, single_encoder_xpu);
SAFE_IR_NODE_LINK_TO(qkv_matmul_1_w_max, single_encoder_xpu);
SAFE_IR_NODE_LINK_TO(qkv_matmul_2_w, single_encoder_xpu);
SAFE_IR_NODE_LINK_TO(qkv_matmul_2_w_max, single_encoder_xpu);
SAFE_IR_NODE_LINK_TO(qkv_matmul_3_w, single_encoder_xpu);
SAFE_IR_NODE_LINK_TO(qkv_matmul_3_w_max, single_encoder_xpu);
SAFE_IR_NODE_LINK_TO(qkv_add_bias, single_encoder_xpu);
SAFE_IR_NODE_LINK_TO(qkv_add_0_bias, single_encoder_xpu);
SAFE_IR_NODE_LINK_TO(qkv_add_2_bias, single_encoder_xpu);
SAFE_IR_NODE_LINK_TO(qkv_add_3_bias, single_encoder_xpu);
SAFE_IR_NODE_LINK_TO(ln_0_scale, single_encoder_xpu);
SAFE_IR_NODE_LINK_TO(ln_0_bias, single_encoder_xpu);
SAFE_IR_NODE_LINK_TO(ln_1_scale, single_encoder_xpu);
SAFE_IR_NODE_LINK_TO(ln_1_bias, single_encoder_xpu);
SAFE_IR_NODE_LINK_TO(ln_2_scale, single_encoder_xpu);
SAFE_IR_NODE_LINK_TO(ln_2_bias, single_encoder_xpu);
SAFE_IR_NODE_LINK_TO(qk_add_mask, single_encoder_xpu);
if (norm_before) {
SAFE_IR_NODE_LINK_TO(single_encoder_xpu, qkv_add_4_out);
} else {
SAFE_IR_NODE_LINK_TO(single_encoder_xpu, ln_2_out);
}
// Delete nodes
std::unordered_set<const Node*> delete_nodes{ln_1,
q_matmul,
q_add,
q_reshape,
q_transpose,
k_matmul,
k_add,
k_reshape,
k_transpose,
v_matmul,
v_add,
v_reshape,
v_transpose,
qk_matmul,
qk_softmax,
qkv_matmul_0,
qkv_transpose,
qkv_reshape,
qkv_matmul_1,
qkv_add_0,
qkv_add_1,
qkv_matmul_2,
qkv_add_2,
qkv_act,
qkv_matmul_3,
qkv_add_3,
qkv_add_4,
q_matmul_w,
q_matmul_out,
q_add_out,
q_reshape_out,
q_reshape_xshape,
q_transpose_out,
q_transpose_xshape,
k_matmul_w,
k_matmul_out,
k_add_out,
k_reshape_out,
k_reshape_xshape,
k_transpose_out,
k_transpose_xshape,
v_matmul_w,
v_matmul_out,
v_add_out,
v_reshape_out,
v_reshape_xshape,
v_transpose_out,
v_transpose_xshape,
qk_matmul_out,
qk_softmax_out,
qkv_matmul_0_out,
qkv_transpose_out,
qkv_transpose_xshape,
qkv_reshape_out,
qkv_reshape_xshape,
qkv_matmul_1_out,
qkv_add_0_out,
qkv_add_1_out,
ln_1_out,
ln_1_mean,
ln_1_variance,
qkv_matmul_2_out,
qkv_add_2_out,
qkv_act_out,
qkv_matmul_3_out,
qkv_add_3_out};
if (norm_before) {
delete_nodes.insert(ln_0);
delete_nodes.insert(ln_0_mean);
delete_nodes.insert(ln_0_variance);
delete_nodes.insert(ln_0_out);
} else {
delete_nodes.insert(qkv_add_4_out);
delete_nodes.insert(ln_2);
delete_nodes.insert(ln_2_mean);
delete_nodes.insert(ln_2_variance);
}
if (with_mask) {
delete_nodes.insert(qk_add);
delete_nodes.insert(qk_add_out);
}
GraphSafeRemoveNodes(graph, delete_nodes);
found_subgraph_count++;
};
gpd(graph, handler);
return found_subgraph_count;
}
static std::vector<Node*> GetSingleEncoders(ir::Graph* graph) {
std::vector<Node*> single_encoders;
for (auto* node : graph->Nodes()) {
// Find first singld_encoder_xpu
if (node->IsVar() || node->Op()->Type() != "single_encoder_xpu") continue;
bool is_first_encoder = true;
for (auto* in_node : node->inputs) {
if (in_node->Var()->Persistable()) continue;
if (in_node->inputs[0]->Op()->Type() == "single_encoder_xpu") {
is_first_encoder = false;
break;
}
}
if (!is_first_encoder) continue;
// Add continuous single_encoder_xpu
single_encoders.push_back(node);
while (true) {
auto next_ops = single_encoders.back()->outputs[0]->outputs;
if (next_ops.empty()) break;
auto next_op_type = next_ops[0]->Op()->Type();
if (next_op_type != "single_encoder_xpu") break;
single_encoders.push_back(next_ops[0]);
}
break;
}
return single_encoders;
}
bool MultiEncoderXPUFusePass::ApplyMultiEncoderXPUFuse(ir::Graph* graph) const {
auto single_encoders = GetSingleEncoders(graph);
if (single_encoders.empty()) return false;
// Prepare inputs/outputs names/nodes
std::string x_name = single_encoders[0]->Op()->Inputs().at("x")[0];
std::vector<std::string> arg_names{
"fc_weight", "fc_weight_max", "fc_bias", "ln_scale", "ln_bias"};
std::map<std::string, std::vector<std::string>> arg_names_map;
std::string mask_name = single_encoders[0]->Op()->Inputs().count("mask") > 0
? single_encoders[0]->Op()->Inputs().at("mask")[0]
: "";
std::string out_name = single_encoders.back()->Op()->Outputs().at("out")[0];
std::vector<Node*> in_nodes;
for (auto* in_node : single_encoders[0]->inputs) {
if (in_node->Var()->Name() == x_name ||
in_node->Var()->Name() == mask_name) {
in_nodes.push_back(in_node);
}
}
for (auto* single_encoder : single_encoders) {
auto single_encoder_in_nodes = single_encoder->inputs;
for (auto arg_name : arg_names) {
auto var_names = single_encoder->Op()->Inputs().at(arg_name);
for (auto var_name : var_names) {
arg_names_map[arg_name].push_back(var_name);
for (auto in_node : single_encoder_in_nodes) {
if (in_node->Var()->Name() == var_name) {
in_nodes.push_back(in_node);
}
}
}
}
}
std::vector<Node*> out_nodes;
for (auto* out_node : single_encoders.back()->outputs) {
if (out_node->Var()->Name() == out_name) {
out_nodes.push_back(out_node);
break;
}
}
auto* block = single_encoders[0]->Op()->Block();
auto* scope = param_scope();
// Create x_fp16 variable/mode/tensor
std::string x_fp16_name = x_name + "_fp16";
VarDesc x_fp16_desc(x_fp16_name);
auto* x_fp16 = graph->CreateVarNode(&x_fp16_desc);
block->Var(x_fp16_name);
scope->Var(x_fp16_name)->GetMutable<phi::DenseTensor>();
out_nodes.push_back(x_fp16);
// Create out_fp16 variable/mode/tensor
std::string out_fp16_name = out_name + "_fp16";
VarDesc out_fp16_desc(out_fp16_name);
auto* out_fp16 = graph->CreateVarNode(&out_fp16_desc);
block->Var(out_fp16_name);
scope->Var(out_fp16_name)->GetMutable<phi::DenseTensor>();
out_nodes.push_back(out_fp16);
// Generate multi_encoder_xpu op
framework::OpDesc op_desc(block);
op_desc.SetType("multi_encoder_xpu");
op_desc.SetInput("x", {x_name});
for (auto arg_name : arg_names) {
op_desc.SetInput(arg_name, arg_names_map[arg_name]);
}
if (!mask_name.empty()) {
op_desc.SetInput("mask", {mask_name});
}
op_desc.SetAttr("layer_num", static_cast<int>(single_encoders.size()));
op_desc.SetAttr(
"norm_before",
PADDLE_GET_CONST(bool, single_encoders[0]->Op()->GetAttr("norm_before")));
for (auto attr_name : {"hidden_dim",
"head_num",
"size_per_head",
"ffn_hidden_dim_scale",
"act_type",
"relative_type"}) {
op_desc.SetAttr(
attr_name,
PADDLE_GET_CONST(int, single_encoders[0]->Op()->GetAttr(attr_name)));
}
op_desc.SetAttr("slice_idx", static_cast<int>(-1));
op_desc.SetOutput("out", {out_name});
op_desc.SetOutput("x_fp16", {x_fp16_name});
op_desc.SetOutput("out_fp16", {out_fp16_name});
auto* multi_encoder_xpu = graph->CreateOpNode(&op_desc);
for (auto* in_node : in_nodes) {
IR_NODE_LINK_TO(in_node, multi_encoder_xpu);
}
for (auto* out_node : out_nodes) {
IR_NODE_LINK_TO(multi_encoder_xpu, out_node);
}
// delete useless node
std::unordered_set<const Node*> delete_nodes(single_encoders.begin(),
single_encoders.end());
for (int i = 0; i < static_cast<int>(single_encoders.size()) - 1; i++) {
std::string out_name = single_encoders[i]->Op()->Outputs().at("out")[0];
for (auto* out_node : single_encoders[i]->outputs) {
if (out_node->Var()->Name() != out_name) {
delete_nodes.insert(out_node);
}
}
}
GraphSafeRemoveNodes(graph, delete_nodes);
return true;
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(multi_encoder_xpu_fuse_pass,
paddle::framework::ir::MultiEncoderXPUFusePass);
REGISTER_PASS_CAPABILITY(multi_encoder_xpu_fuse_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination().EQ(
"multi_encoder_xpu", 0));
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/xpu/pass_utils.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/phi/kernels/xpu/xpu_api_wrapper.h"
namespace paddle {
namespace framework {
namespace ir {
int ConvertActivationType(std::string act_type) {
if (act_type == "") {
return static_cast<int>(xpu::Activation_t::LINEAR);
} else if (act_type == "relu") {
return static_cast<int>(xpu::Activation_t::RELU);
} else if (act_type == "sigmoid") {
return static_cast<int>(xpu::Activation_t::SIGMOID);
} else if (act_type == "tanh") {
return static_cast<int>(xpu::Activation_t::TANH);
} else if (act_type == "gelu") {
return static_cast<int>(xpu::Activation_t::GELU);
} else if (act_type == "leaky_relu") {
return static_cast<int>(xpu::Activation_t::LEAKY_RELU);
} else if (act_type == "exp") {
return static_cast<int>(xpu::Activation_t::EXP);
} else if (act_type == "hard_swish") {
return static_cast<int>(xpu::Activation_t::HARD_SWISH);
} else if (act_type == "hard_sigmoid") {
return static_cast<int>(xpu::Activation_t::HARD_SIGMOID);
} else if (act_type == "swish") {
return static_cast<int>(xpu::Activation_t::SWISH);
} else if (act_type == "relu6") {
return static_cast<int>(xpu::Activation_t::RELU6);
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"Not support convert activation_type(%s).", act_type));
}
return -1;
}
} // namespace ir
} // namespace framework
} // namespace paddle
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
// limitations under the License. // limitations under the License.
#pragma once #pragma once
#include <string>
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -42,6 +43,8 @@ namespace ir { ...@@ -42,6 +43,8 @@ namespace ir {
IR_NODE_LINK_TO(a, b) \ IR_NODE_LINK_TO(a, b) \
} }
int ConvertActivationType(std::string act_type);
} // namespace ir } // namespace ir
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -16,20 +16,34 @@ ...@@ -16,20 +16,34 @@
#include <vector> #include <vector>
#include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/device_context.h"
#include "paddle/phi/core/enforce.h" #include "paddle/phi/core/enforce.h"
#include "paddle/phi/kernels/funcs/math_function.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
namespace ir { namespace ir {
template <typename T> template <typename T>
static void Transpose(const T* in, T* out, int h, int w) { void Transpose2D(const phi::DenseTensor& in, phi::DenseTensor* out) {
for (int h1 = 0; h1 < w; ++h1) { auto in_dims = in.dims();
for (int w1 = 0; w1 < h; ++w1) { PADDLE_ENFORCE_EQ(
out[h1 * h + w1] = in[w1 * w + h1]; in_dims.size(),
} 2,
} platform::errors::InvalidArgument(
"In dims rank should be 2, but received in dims size is [%d].",
in_dims.size()));
out->Resize({in_dims[1], in_dims[0]});
out->set_type(in.type());
auto* dev_ctx = static_cast<phi::CPUContext*>(
platform::DeviceContextPool::Instance().Get(phi::CPUPlace()));
dev_ctx->Alloc<T>(out);
std::vector<int> axis{1, 0};
phi::funcs::Transpose<phi::CPUContext, T, 2> trans2d;
trans2d(*dev_ctx, in, out, axis);
} }
template void Transpose2D<float>(const phi::DenseTensor& in,
phi::DenseTensor* out);
static float FindMaxAbs(const float* data, int len) { static float FindMaxAbs(const float* data, int len) {
float max_f = 0.0f; float max_f = 0.0f;
for (int i = 0; i < len; ++i) { for (int i = 0; i < len; ++i) {
...@@ -136,25 +150,20 @@ void QuantFP32ToIntX<int16_t>(const float* src_ptr, ...@@ -136,25 +150,20 @@ void QuantFP32ToIntX<int16_t>(const float* src_ptr,
template <typename T> template <typename T>
void QuantWeight(phi::DenseTensor* weight, void QuantWeight(phi::DenseTensor* weight,
phi::DenseTensor* weight_max, phi::DenseTensor* weight_max,
bool transpose, bool transpose) {
int max_ptr_size) {
// Transpose // Transpose
auto* weight_data = weight->data<float>(); auto* weight_data = weight->data<float>();
auto dims = weight->dims(); phi::DenseTensor weight_trans;
auto size = weight->numel();
std::vector<float> transpose_data(weight_data, weight_data + size);
if (transpose) { if (transpose) {
PADDLE_ENFORCE_EQ( Transpose2D<float>(*weight, &weight_trans);
dims.size(), weight_data = weight_trans.data<float>();
2, weight->Resize(weight_trans.dims());
platform::errors::InvalidArgument(
"Only support 2D weight, but received weight rank is [%d].",
dims.size()));
Transpose(weight_data, transpose_data.data(), dims[0], dims[1]);
weight->Resize({dims[1], dims[0]});
} }
weight_data = transpose_data.data();
// Find max // Find max
auto* xpu_ctx = static_cast<phi::XPUContext*>(
platform::DeviceContextPool::Instance().Get(phi::XPUPlace()));
int max_ptr_size = xpu_ctx->x_context()->max_ptr_size();
int size = weight->numel();
float max_val = FindMaxAbs(weight_data, size); float max_val = FindMaxAbs(weight_data, size);
std::vector<float> max_vec(max_ptr_size, max_val); std::vector<float> max_vec(max_ptr_size, max_val);
weight_max->set_type(paddle::experimental::CppTypeToDataType<float>::Type()); weight_max->set_type(paddle::experimental::CppTypeToDataType<float>::Type());
...@@ -173,8 +182,7 @@ void QuantWeight(phi::DenseTensor* weight, ...@@ -173,8 +182,7 @@ void QuantWeight(phi::DenseTensor* weight,
template void QuantWeight<int16_t>(phi::DenseTensor* weight, template void QuantWeight<int16_t>(phi::DenseTensor* weight,
phi::DenseTensor* weight_max, phi::DenseTensor* weight_max,
bool transpose, bool transpose);
int max_ptr_size);
} // namespace ir } // namespace ir
} // namespace framework } // namespace framework
......
...@@ -19,14 +19,16 @@ namespace paddle { ...@@ -19,14 +19,16 @@ namespace paddle {
namespace framework { namespace framework {
namespace ir { namespace ir {
template <typename T>
void Transpose2D(const phi::DenseTensor& in, phi::DenseTensor* out);
// 1. Quant weight from fp32 to int16/int31 // 1. Quant weight from fp32 to int16/int31
// 2. Weight data is in-place update. // 2. Weight data is in-place update.
// 3. Generate weight max tensor // 3. Generate weight max tensor
template <typename T> template <typename T>
void QuantWeight(phi::DenseTensor* weight, void QuantWeight(phi::DenseTensor* weight,
phi::DenseTensor* weight_max, phi::DenseTensor* weight_max,
bool transpose, bool transpose);
int max_ptr_size);
} // namespace ir } // namespace ir
} // namespace framework } // namespace framework
......
...@@ -2745,8 +2745,6 @@ void OperatorWithKernel::ParseMultiInputDataType( ...@@ -2745,8 +2745,6 @@ void OperatorWithKernel::ParseMultiInputDataType(
const phi::DenseTensor* t = nullptr; const phi::DenseTensor* t = nullptr;
if (var->IsType<phi::DenseTensor>()) { if (var->IsType<phi::DenseTensor>()) {
t = &var->Get<phi::DenseTensor>(); t = &var->Get<phi::DenseTensor>();
} else if (var->IsType<phi::DenseTensor>()) {
t = &var->Get<phi::DenseTensor>();
} else if (var->IsType<phi::SelectedRows>()) { } else if (var->IsType<phi::SelectedRows>()) {
t = &(var->Get<phi::SelectedRows>().value()); t = &(var->Get<phi::SelectedRows>().value());
} else if (var->IsType<phi::SparseCooTensor>()) { } else if (var->IsType<phi::SparseCooTensor>()) {
...@@ -2866,8 +2864,6 @@ phi::DenseTensor* OperatorWithKernel::GetTensorFormInputSafely( ...@@ -2866,8 +2864,6 @@ phi::DenseTensor* OperatorWithKernel::GetTensorFormInputSafely(
phi::DenseTensor* t = nullptr; phi::DenseTensor* t = nullptr;
if (var->IsType<phi::DenseTensor>()) { if (var->IsType<phi::DenseTensor>()) {
t = var->GetMutable<phi::DenseTensor>(); t = var->GetMutable<phi::DenseTensor>();
} else if (var->IsType<phi::DenseTensor>()) {
t = var->GetMutable<phi::DenseTensor>();
} else if (var->IsType<phi::SelectedRows>()) { } else if (var->IsType<phi::SelectedRows>()) {
t = var->GetMutable<phi::SelectedRows>()->mutable_value(); t = var->GetMutable<phi::SelectedRows>()->mutable_value();
} else { } else {
......
...@@ -517,7 +517,7 @@ void CpuPassStrategy::EraseFcMkldnnPasses() { ...@@ -517,7 +517,7 @@ void CpuPassStrategy::EraseFcMkldnnPasses() {
XpuPassStrategy::XpuPassStrategy() : PassStrategy({}) { XpuPassStrategy::XpuPassStrategy() : PassStrategy({}) {
passes_.assign({ passes_.assign({
"delete_dropout_op_pass", "delete_dropout_op_pass",
// "multi_encoder_xpu_fuse_pass", "multi_encoder_xpu_fuse_pass",
// "embedding_with_eltwise_add_xpu_fuse_pass", // "embedding_with_eltwise_add_xpu_fuse_pass",
"fc_xpu_fuse_pass", "fc_xpu_fuse_pass",
// "multi_encoder_slice_link_xpu_fuse_pass", // "multi_encoder_slice_link_xpu_fuse_pass",
......
...@@ -5,8 +5,19 @@ ...@@ -5,8 +5,19 @@
func : FcXPUInferMeta func : FcXPUInferMeta
kernel : kernel :
func : fc_xpu func : fc_xpu
data_type : x
optional : bias optional : bias
- op : multi_encoder_xpu
args : (Tensor x, Tensor[] fc_weight, Tensor[] fc_weight_max, Tensor[] fc_bias, Tensor[] ln_scale, Tensor[] ln_bias, Tensor mask, int layer_num, bool norm_before, int hidden_dim, int head_num, int size_per_head, int ffn_hidden_dim_scale, int act_type, int relative_type, int slice_idx)
output : Tensor(out), Tensor(x_fp16), Tensor(out_fp16)
infer_meta :
func : MultiEncoderXPUInferMeta
kernel :
func : multi_encoder_xpu
data_type : x
optional : mask, x_fp16, out_fp16
- op : share_buffer - op : share_buffer
args : (Tensor[] x, bool[] share_dims_and_dtype={}) args : (Tensor[] x, bool[] share_dims_and_dtype={})
output : Tensor[](out){x.size()}, Tensor[](xout){x.size()} output : Tensor[](out){x.size()}, Tensor[](xout){x.size()}
......
...@@ -421,6 +421,7 @@ XPUOpMap& get_kl2_ops() { ...@@ -421,6 +421,7 @@ XPUOpMap& get_kl2_ops() {
phi::DataType::FLOAT16, phi::DataType::FLOAT16,
phi::DataType::INT32, phi::DataType::INT32,
phi::DataType::INT64})}, phi::DataType::INT64})},
{"multi_encoder_xpu", XPUKernelSet({phi::DataType::FLOAT32})},
{"nearest_interp_v2", XPUKernelSet({phi::DataType::FLOAT32})}, {"nearest_interp_v2", XPUKernelSet({phi::DataType::FLOAT32})},
{"nearest_interp_v2_grad", XPUKernelSet({phi::DataType::FLOAT32})}, {"nearest_interp_v2_grad", XPUKernelSet({phi::DataType::FLOAT32})},
{"not_equal", {"not_equal",
......
...@@ -42,4 +42,41 @@ void FcXPUInferMeta(const MetaTensor& x, ...@@ -42,4 +42,41 @@ void FcXPUInferMeta(const MetaTensor& x,
out->set_layout(x.layout()); out->set_layout(x.layout());
} }
void MultiEncoderXPUInferMeta(
const MetaTensor& x,
const std::vector<const MetaTensor*>& fc_weight,
const std::vector<const MetaTensor*>& fc_weight_max,
const std::vector<const MetaTensor*>& fc_bias,
const std::vector<const MetaTensor*>& ln_scale,
const std::vector<const MetaTensor*>& ln_bias,
const MetaTensor& mask,
int layer_num,
bool norm_before,
int hidden_dim,
int head_num,
int size_per_head,
int ffn_hidden_dim_scale,
int act_type,
int relative_type,
int slice_idx,
MetaTensor* out,
MetaTensor* x_fp16,
MetaTensor* out_fp16) {
auto x_dims = x.dims();
x_fp16->set_dims(x_dims);
x_fp16->set_dtype(DataType::FLOAT16);
x_fp16->set_layout(x.layout());
out->set_dtype(x.dtype());
out->set_layout(x.layout());
out_fp16->set_dtype(DataType::FLOAT16);
out_fp16->set_layout(x.layout());
if (slice_idx == -1) {
out->set_dims(x_dims);
out_fp16->set_dims(x_dims);
} else {
out->set_dims({x_dims[0], x_dims[2]});
out_fp16->set_dims({x_dims[0], x_dims[2]});
}
}
} // namespace phi } // namespace phi
...@@ -34,4 +34,25 @@ void FcXPUInferMeta(const MetaTensor& x, ...@@ -34,4 +34,25 @@ void FcXPUInferMeta(const MetaTensor& x,
float act_alpha, float act_alpha,
MetaTensor* out); MetaTensor* out);
void MultiEncoderXPUInferMeta(
const MetaTensor& x,
const std::vector<const MetaTensor*>& fc_weight,
const std::vector<const MetaTensor*>& fc_weight_max,
const std::vector<const MetaTensor*>& fc_bias,
const std::vector<const MetaTensor*>& ln_scale,
const std::vector<const MetaTensor*>& ln_bias,
const MetaTensor& mask,
int layer_num,
bool norm_before,
int hidden_dim,
int head_num,
int size_per_head,
int ffn_hidden_dim_scale,
int act_type,
int relative_type,
int slice_idx,
MetaTensor* out,
MetaTensor* x_fp16,
MetaTensor* out_fp16);
} // namespace phi } // namespace phi
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, sofint16_tare
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/backends/xpu/enforce_xpu.h"
#include "paddle/phi/common/float16.h"
#include "paddle/phi/core/kernel_registry.h"
namespace phi {
namespace fusion {
template <typename T, typename Context>
void MultiEncoderXPUKernel(const Context& ctx,
const DenseTensor& x,
const std::vector<const DenseTensor*>& fc_weight,
const std::vector<const DenseTensor*>& fc_weight_max,
const std::vector<const DenseTensor*>& fc_bias,
const std::vector<const DenseTensor*>& ln_scale,
const std::vector<const DenseTensor*>& ln_bias,
const paddle::optional<DenseTensor>& mask,
int layer_num,
bool norm_before,
int hidden_dim,
int head_num,
int size_per_head,
int ffn_hidden_dim_scale,
int act_type,
int relative_type,
int slice_idx,
DenseTensor* out,
DenseTensor* x_fp16,
DenseTensor* out_fp16) {
using float16 = typename XPUTypeTrait<phi::dtype::float16>::Type;
// XPU2 only support fp16 input/output.
float16* x_fp16_data = reinterpret_cast<float16*>(
ctx.template Alloc<phi::dtype::float16>(x_fp16));
int r_cast_x = xpu::cast_v2<float, float16>(
ctx.x_context(), x.data<T>(), x_fp16_data, x.numel());
PADDLE_ENFORCE_XDNN_SUCCESS(r_cast_x,
"multi_encoder_xpu(cast x from fp32 to fp16)");
float16* out_fp16_data = reinterpret_cast<float16*>(
ctx.template Alloc<phi::dtype::float16>(out_fp16));
// q,k,v weight are fused.
// Each encoder's weight should be: w0, null, null, w3, w4, w5
std::vector<const float*> fc_input_max_data;
std::vector<const int16_t*> fc_weight_data;
std::vector<const float*> fc_weight_max_data;
std::vector<const float*> fc_bias_data;
for (size_t i = 0; i < fc_weight.size(); i++) {
fc_weight_data.push_back(fc_weight[i]->data<int16_t>());
fc_weight_max_data.push_back(fc_weight_max[i]->data<float>());
fc_bias_data.push_back(fc_bias[i]->data<float>());
if (i % 4 == 0) {
fc_weight_data.push_back(nullptr);
fc_weight_data.push_back(nullptr);
fc_weight_max_data.push_back(nullptr);
fc_weight_max_data.push_back(nullptr);
fc_bias_data.push_back(nullptr);
fc_bias_data.push_back(nullptr);
}
}
std::vector<const float*> ln_scale_data;
std::vector<const float*> ln_bias_data;
for (size_t i = 0; i < ln_scale.size(); i++) {
ln_scale_data.push_back(ln_scale[i]->data<float>());
ln_bias_data.push_back(ln_bias[i]->data<float>());
}
const T* mask_data =
mask.get_ptr() == nullptr ? nullptr : mask.get_ptr()->data<T>();
xpu::Activation_t qkv_act(static_cast<xpu::Activation_t::act_enum>(act_type));
int batch = x.dims()[0];
int max_seqlen = x.dims()[1];
// matmul_size * layer_num
std::vector<xpu::QuantType> quant_types(8 * layer_num,
xpu::QuantType::NOT_QUANT);
if (mask_data) {
auto mask_dims = mask.get_ptr()->dims();
std::vector<int> mask_shape(mask_dims.Get(),
mask_dims.Get() + mask_dims.size());
xpu::QKVAttnParam qkv_attn_param(batch,
max_seqlen,
head_num,
size_per_head,
mask_shape,
qkv_act,
slice_idx,
true,
hidden_dim,
norm_before,
false);
qkv_attn_param.quant_type_.assign(quant_types.begin(), quant_types.end());
qkv_attn_param.scale_of_hidden_units = ffn_hidden_dim_scale;
int r =
xpu::transformer_encoder<float16, int16_t, int16_t>(ctx.x_context(),
x_fp16_data,
fc_weight_data,
out_fp16_data,
fc_input_max_data,
fc_weight_max_data,
fc_bias_data,
ln_scale_data,
ln_bias_data,
qkv_attn_param,
mask_data);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "multi_encoder_xpu");
} else {
// When no mask input, like VIT, create LOD to act as vsl.
std::vector<int> lod;
for (int i = 0; i < batch + 1; i++) {
lod.push_back(i * max_seqlen);
}
xpu::VectorParam<int> query_lod = {
lod.data(), static_cast<int>(lod.size()), nullptr};
// No need to pad, no matter slice or not
xpu::QKVAttnParam qkv_attn_param(query_lod,
head_num,
size_per_head,
qkv_act,
slice_idx,
true,
-1,
hidden_dim,
norm_before,
false);
qkv_attn_param.quant_type_.assign(quant_types.begin(), quant_types.end());
qkv_attn_param.scale_of_hidden_units = ffn_hidden_dim_scale;
int r =
xpu::transformer_encoder<float16, int16_t, int16_t>(ctx.x_context(),
x_fp16_data,
fc_weight_data,
out_fp16_data,
fc_input_max_data,
fc_weight_max_data,
fc_bias_data,
ln_scale_data,
ln_bias_data,
qkv_attn_param);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "multi_encoder_xpu");
}
int r_cast_out = xpu::cast_v2<float16, float>(
ctx.x_context(), out_fp16_data, ctx.template Alloc<T>(out), out->numel());
PADDLE_ENFORCE_XDNN_SUCCESS(r_cast_out,
"multi_encoder_xpu(cast out from fp16 to fp32)");
}
} // namespace fusion
} // namespace phi
PD_REGISTER_KERNEL(multi_encoder_xpu,
XPU,
ALL_LAYOUT,
phi::fusion::MultiEncoderXPUKernel,
float) {}
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import hypothesis.strategies as st
from auto_scan_test import PassAutoScanTest
from program_config import OpConfig, ProgramConfig, TensorConfig
class TestDeleteDropoutOpPass(PassAutoScanTest):
def sample_predictor_configs(self, program_config):
config = self.create_inference_config(use_xpu=True)
yield config, ["elementwise_add", "relu", "relu6"], (1e-5, 1e-5)
def sample_program_config(self, draw):
add_op = OpConfig(
"elementwise_add",
inputs={"X": ["add_x"], "Y": ["add_y"]},
outputs={"Out": ["add_out"]},
axis=-1,
)
dropout_op = OpConfig(
"dropout",
inputs={"X": ["add_out"]},
outputs={"Out": ["dropout_out"], "Mask": ["dropout_mask"]},
dropout_implementation="upscale_in_train",
dropout_prob=0.1,
fix_seed=False,
is_test=True,
seed=0,
)
relu_op = OpConfig(
"relu",
inputs={"X": ["dropout_out"]},
outputs={"Out": ["relu_out"]},
)
relu6_op = OpConfig(
"relu6",
inputs={"X": ["dropout_out"]},
outputs={"Out": ["relu6_out"]},
)
ops = [add_op, dropout_op, relu_op, relu6_op]
add_x_shape = draw(
st.lists(
st.integers(min_value=1, max_value=4), min_size=2, max_size=4
)
)
program_config = ProgramConfig(
ops=ops,
weights={},
inputs={
"add_x": TensorConfig(shape=add_x_shape),
"add_y": TensorConfig(shape=add_x_shape),
},
outputs=["relu_out", "relu6_out"],
)
return program_config
def test(self):
self.run_and_statis(
quant=False,
max_examples=1,
min_success_num=1,
passes=["delete_dropout_op_pass"],
)
if __name__ == "__main__":
unittest.main()
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import hypothesis.strategies as st
import numpy as np
from auto_scan_test import PassAutoScanTest
from program_config import OpConfig, ProgramConfig, TensorConfig
class TestMultiEncoderXPUFusePass(PassAutoScanTest):
def sample_predictor_configs(self, program_config):
config = self.create_inference_config(use_xpu=True)
yield config, ["multi_encoder_xpu"], (1e-1, 1e-1)
def sample_program_config(self, draw):
# q: matmul+add+reshape+transpose
q_matmul_op = OpConfig(
"matmul_v2",
inputs={"X": ["q_matmul_x"], "Y": ["q_matmul_w"]},
outputs={"Out": ["q_matmul_out"]},
trans_x=False,
trans_y=False,
)
q_add_op = OpConfig(
"elementwise_add",
inputs={"X": ["q_matmul_out"], "Y": ["q_add_bias"]},
outputs={"Out": ["q_add_out"]},
axis=2,
)
q_reshape_op = OpConfig(
"reshape2",
inputs={"X": ["q_add_out"]},
outputs={"Out": ["q_reshape_out"], "XShape": ["q_reshape_xshape"]},
shape=[0, 0, 12, 64],
)
q_transpose_op = OpConfig(
"transpose2",
inputs={"X": ["q_reshape_out"]},
outputs={
"Out": ["q_transpose_out"],
"XShape": ["q_transpose_xshape"],
},
axis=[0, 2, 1, 3],
)
# k: matmul+add+reshape+transpose
k_matmul_op = OpConfig(
"matmul_v2",
inputs={"X": ["q_matmul_x"], "Y": ["k_matmul_w"]},
outputs={"Out": ["k_matmul_out"]},
trans_x=False,
trans_y=False,
)
k_add_op = OpConfig(
"elementwise_add",
inputs={"X": ["k_matmul_out"], "Y": ["k_add_bias"]},
outputs={"Out": ["k_add_out"]},
axis=2,
)
k_reshape_op = OpConfig(
"reshape2",
inputs={"X": ["k_add_out"]},
outputs={"Out": ["k_reshape_out"], "XShape": ["k_reshape_xshape"]},
shape=[0, 0, 12, 64],
)
k_transpose_op = OpConfig(
"transpose2",
inputs={"X": ["k_reshape_out"]},
outputs={
"Out": ["k_transpose_out"],
"XShape": ["k_transpose_xshape"],
},
axis=[0, 2, 1, 3],
)
# v: matmul+add+reshape+transpose
v_matmul_op = OpConfig(
"matmul_v2",
inputs={"X": ["q_matmul_x"], "Y": ["v_matmul_w"]},
outputs={"Out": ["v_matmul_out"]},
trans_x=False,
trans_y=False,
)
v_add_op = OpConfig(
"elementwise_add",
inputs={"X": ["v_matmul_out"], "Y": ["v_add_bias"]},
outputs={"Out": ["v_add_out"]},
axis=2,
)
v_reshape_op = OpConfig(
"reshape2",
inputs={"X": ["v_add_out"]},
outputs={"Out": ["v_reshape_out"], "XShape": ["v_reshape_xshape"]},
shape=[0, 0, 12, 64],
)
v_transpose_op = OpConfig(
"transpose2",
inputs={"X": ["v_reshape_out"]},
outputs={
"Out": ["v_transpose_out"],
"XShape": ["v_transpose_xshape"],
},
axis=[0, 2, 1, 3],
)
# qk: matmul+add+softmax
qk_matmul_op = OpConfig(
"matmul",
inputs={"X": ["q_transpose_out"], "Y": ["k_transpose_out"]},
outputs={"Out": ["qk_matmul_out"]},
alpha=0.125,
transpose_X=False,
transpose_Y=True,
)
qk_add_op = OpConfig(
"elementwise_add",
inputs={"X": ["qk_matmul_out"], "Y": ["qk_add_mask"]},
outputs={"Out": ["qk_add_out"]},
axis=-1,
)
qk_softmax_op = OpConfig(
"softmax",
inputs={"X": ["qk_add_out"]},
outputs={"Out": ["qk_softmax_out"]},
axis=-1,
)
# qkv
qkv_matmul_0_op = OpConfig(
"matmul_v2",
inputs={"X": ["qk_softmax_out"], "Y": ["v_transpose_out"]},
outputs={"Out": ["qkv_matmul_0_out"]},
trans_x=False,
trans_y=False,
)
qkv_transpose_op = OpConfig(
"transpose2",
inputs={"X": ["qkv_matmul_0_out"]},
outputs={
"Out": ["qkv_transpose_out"],
"XShape": ["qkv_transpose_xshape"],
},
axis=[0, 2, 1, 3],
)
qkv_reshape_op = OpConfig(
"reshape2",
inputs={"X": ["qkv_transpose_out"]},
outputs={
"Out": ["qkv_reshape_out"],
"XShape": ["qkv_reshape_xshape"],
},
shape=[0, 0, 768],
)
qkv_matmul_1_op = OpConfig(
"matmul_v2",
inputs={"X": ["qkv_reshape_out"], "Y": ["qkv_matmul_1_w"]},
outputs={"Out": ["qkv_matmul_1_out"]},
trans_x=False,
trans_y=False,
)
qkv_add_0_op = OpConfig(
"elementwise_add",
inputs={"X": ["qkv_matmul_1_out"], "Y": ["qkv_add_0_bias"]},
outputs={"Out": ["qkv_add_0_out"]},
axis=2,
)
qkv_add_1_op = OpConfig(
"elementwise_add",
inputs={"X": ["qkv_add_0_out"], "Y": ["q_matmul_x"]},
outputs={"Out": ["qkv_add_1_out"]},
axis=-1,
)
ln_1_op = OpConfig(
"layer_norm",
inputs={
"X": ["qkv_add_1_out"],
"Bias": ["ln_1_bias"],
"Scale": ["ln_1_scale"],
},
outputs={
"Y": ["ln_1_out"],
"Mean": ["ln_1_mean"],
"Variance": ["ln_1_variance"],
},
begin_norm_axis=2,
epsilon=1e-14,
)
qkv_matmul_2_op = OpConfig(
"matmul_v2",
inputs={"X": ["ln_1_out"], "Y": ["qkv_matmul_2_w"]},
outputs={"Out": ["qkv_matmul_2_out"]},
trans_x=False,
trans_y=False,
)
qkv_add_2_op = OpConfig(
"elementwise_add",
inputs={"X": ["qkv_matmul_2_out"], "Y": ["qkv_add_2_bias"]},
outputs={"Out": ["qkv_add_2_out"]},
axis=2,
)
qkv_act_op = OpConfig(
"gelu",
inputs={"X": ["qkv_add_2_out"]},
outputs={"Out": ["qkv_act_out"]},
approximate=False,
)
qkv_matmul_3_op = OpConfig(
"matmul_v2",
inputs={"X": ["qkv_act_out"], "Y": ["qkv_matmul_3_w"]},
outputs={"Out": ["qkv_matmul_3_out"]},
trans_x=False,
trans_y=False,
)
qkv_add_3_op = OpConfig(
"elementwise_add",
inputs={"X": ["qkv_matmul_3_out"], "Y": ["qkv_add_3_bias"]},
outputs={"Out": ["qkv_add_3_out"]},
axis=2,
)
qkv_add_4_op = OpConfig(
"elementwise_add",
inputs={"X": ["ln_1_out"], "Y": ["qkv_add_3_out"]},
outputs={"Out": ["qkv_add_4_out"]},
axis=-1,
)
ln_2_op = OpConfig(
"layer_norm",
inputs={
"X": ["qkv_add_4_out"],
"Bias": ["ln_2_bias"],
"Scale": ["ln_2_scale"],
},
outputs={
"Y": ["ln_2_out"],
"Mean": ["ln_2_mean"],
"Variance": ["ln_2_variance"],
},
begin_norm_axis=2,
epsilon=1e-14,
)
ops = [
q_matmul_op,
q_add_op,
q_reshape_op,
q_transpose_op,
k_matmul_op,
k_add_op,
k_reshape_op,
k_transpose_op,
v_matmul_op,
v_add_op,
v_reshape_op,
v_transpose_op,
qk_matmul_op,
qk_add_op,
qk_softmax_op,
qkv_matmul_0_op,
qkv_transpose_op,
qkv_reshape_op,
qkv_matmul_1_op,
qkv_add_0_op,
qkv_add_1_op,
ln_1_op,
qkv_matmul_2_op,
qkv_add_2_op,
qkv_act_op,
qkv_matmul_3_op,
qkv_add_3_op,
qkv_add_4_op,
ln_2_op,
]
q_matmul_x_shape = draw(
st.lists(
st.integers(min_value=3, max_value=10), min_size=3, max_size=3
)
)
q_matmul_x_shape[2] = 768
q_matmul_w_shape = [q_matmul_x_shape[2], q_matmul_x_shape[2]]
q_add_bias_shape = [q_matmul_x_shape[2]]
qk_add_mask_shape = [q_matmul_x_shape[0], 1, 1, q_matmul_x_shape[1]]
qkv_matmul_2_w_shape = [q_matmul_x_shape[2], 3072]
qkv_add_2_bias_shape = [qkv_matmul_2_w_shape[1]]
qkv_matmul_3_w_shape = [3072, q_matmul_x_shape[2]]
qkv_add_3_bias_shape = [qkv_matmul_3_w_shape[1]]
ln_1_bias_shape = [q_matmul_x_shape[2]]
# def generate_q_matmul_w():
# return np.random.random(x_shape).astype(np.float32)
program_config = ProgramConfig(
ops=ops,
weights={
"q_matmul_w": TensorConfig(shape=q_matmul_w_shape),
"q_add_bias": TensorConfig(shape=q_add_bias_shape),
"k_matmul_w": TensorConfig(shape=q_matmul_w_shape),
"k_add_bias": TensorConfig(shape=q_add_bias_shape),
"v_matmul_w": TensorConfig(shape=q_matmul_w_shape),
"v_add_bias": TensorConfig(shape=q_add_bias_shape),
"qkv_matmul_1_w": TensorConfig(shape=q_matmul_w_shape),
"qkv_add_0_bias": TensorConfig(shape=q_add_bias_shape),
"qkv_matmul_2_w": TensorConfig(shape=qkv_matmul_2_w_shape),
"qkv_add_2_bias": TensorConfig(shape=qkv_add_2_bias_shape),
"qkv_matmul_3_w": TensorConfig(shape=qkv_matmul_3_w_shape),
"qkv_add_3_bias": TensorConfig(shape=qkv_add_3_bias_shape),
"ln_1_bias": TensorConfig(shape=ln_1_bias_shape),
"ln_1_scale": TensorConfig(shape=ln_1_bias_shape),
"ln_2_bias": TensorConfig(shape=ln_1_bias_shape),
"ln_2_scale": TensorConfig(shape=ln_1_bias_shape),
},
inputs={
"q_matmul_x": TensorConfig(shape=q_matmul_x_shape),
"qk_add_mask": TensorConfig(shape=qk_add_mask_shape),
},
outputs=["ln_2_out"],
)
return program_config
def test(self):
self.run_and_statis(
quant=False,
max_examples=2,
min_success_num=2,
passes=["multi_encoder_xpu_fuse_pass"],
)
if __name__ == "__main__":
np.random.seed(200)
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册