未验证 提交 8acd745c 编写于 作者: Z Zhaolong Xing 提交者: GitHub

[Ernie GPU Optim]: Fuse three fc to multihtead matmul (#22486)

* 1. optim multihead matmul: fuse three fc to multihtead matmul

test=develop

* fix conflict
test=develop

* fix comments
test=develop
上级 a8dd425a
...@@ -19,6 +19,7 @@ ...@@ -19,6 +19,7 @@
#include <vector> #include <vector>
#include "paddle/fluid/framework/ddim.h" #include "paddle/fluid/framework/ddim.h"
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/platform/errors.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -425,19 +426,285 @@ PDNode* MultiHeadMatmulPattern::operator()(paddle::framework::ir::PDNode* x) { ...@@ -425,19 +426,285 @@ PDNode* MultiHeadMatmulPattern::operator()(paddle::framework::ir::PDNode* x) {
return transpose2_2_out_var; return transpose2_2_out_var;
} }
static int BuildFusionV2(Graph* graph, const std::string& name_scope,
Scope* scope) {
GraphPatternDetector gpd;
auto* pattern = gpd.mutable_pattern();
// Create pattern.
MultiHeadMatmulPattern multihead_pattern(pattern, name_scope);
PDNode* x =
pattern->NewNode(patterns::UniqueKey("X"))->assert_var_not_persistable();
multihead_pattern(x);
// Create New OpDesc
auto fuse_creater = [&](
Node* layer_norm_out, Node* mul0, Node* mul1, Node* mul2, Node* mul0_out,
Node* mul1_out, Node* mul2_out, Node* mul0_w, Node* mul1_w, Node* mul2_w,
Node* eltadd0_b, Node* eltadd1_b, Node* eltadd2_b, Node* eltadd_qk_b,
Node* reshape2, Node* reshape2_qkv_out, Node* scale, Node* scale_out) {
auto scale_attr = boost::get<float>(scale->Op()->GetAttr("scale"));
// mul (B * S * Hidden) x (Hidden * 3 * N * H) = (B * S * 3 * N * H)
// bias (B * S * 3 * N * H) + bias (3 * N * H)
// Transpose (B * S * 3 * N * H) -> (3 * B * N * S * H)
auto* wq_tensor = scope->FindVar(mul0_w->Name())->GetMutable<LoDTensor>();
auto* wk_tensor = scope->FindVar(mul1_w->Name())->GetMutable<LoDTensor>();
auto* wv_tensor = scope->FindVar(mul2_w->Name())->GetMutable<LoDTensor>();
auto* bq_tensor =
scope->FindVar(eltadd0_b->Name())->GetMutable<LoDTensor>();
auto* bk_tensor =
scope->FindVar(eltadd1_b->Name())->GetMutable<LoDTensor>();
auto* bv_tensor =
scope->FindVar(eltadd2_b->Name())->GetMutable<LoDTensor>();
auto* wq_data = wq_tensor->mutable_data<float>(platform::CPUPlace());
auto* wk_data = wk_tensor->mutable_data<float>(platform::CPUPlace());
auto* wv_data = wv_tensor->mutable_data<float>(platform::CPUPlace());
auto* bq_data = bq_tensor->mutable_data<float>(platform::CPUPlace());
auto* bk_data = bk_tensor->mutable_data<float>(platform::CPUPlace());
auto* bv_data = bv_tensor->mutable_data<float>(platform::CPUPlace());
auto combined_w_dims =
framework::make_ddim({wq_tensor->dims()[0], 3, wq_tensor->dims()[1]});
auto combined_bias_dims = framework::make_ddim({3, bq_tensor->dims()[0]});
// create a new var in scope
VarDesc combined_w_desc(
patterns::PDNodeName(name_scope, "multi_head_combined_weight"));
combined_w_desc.SetShape({wq_tensor->dims()[0], 3, wq_tensor->dims()[1]});
combined_w_desc.SetDataType(wq_tensor->type());
combined_w_desc.SetLoDLevel(mul0_w->Var()->GetLoDLevel());
combined_w_desc.SetPersistable(true);
// create a new var in scope
VarDesc combined_bias_desc(
patterns::PDNodeName(name_scope, "multi_head_combined_bias"));
combined_bias_desc.SetShape({3, bq_tensor->dims()[0]});
combined_bias_desc.SetDataType(bq_tensor->type());
combined_bias_desc.SetLoDLevel(eltadd0_b->Var()->GetLoDLevel());
combined_bias_desc.SetPersistable(true);
auto* combined_w_node = graph->CreateVarNode(&combined_w_desc);
auto* combined_w_tensor =
scope->Var(combined_w_node->Name())->GetMutable<LoDTensor>();
combined_w_tensor->Resize(combined_w_dims);
auto* combined_w_data =
combined_w_tensor->mutable_data<float>(platform::CPUPlace());
std::vector<float*> w_vec = {wq_data, wk_data, wv_data};
int dims_h = combined_w_dims[0], dims_w = combined_w_dims[2];
// Combine the three fc weights together.
for (int i = 0; i < dims_h; i++) {
for (int j = 0; j < 3; j++) {
for (int k = 0; k < dims_w; k++) {
int out_index = i * (3 * dims_w) + j * dims_w + k;
int in_index = i * dims_w + k;
combined_w_data[out_index] = w_vec[j][in_index];
}
}
}
scope->EraseVars({mul0_w->Name(), mul1_w->Name(), mul2_w->Name()});
auto* combined_bias_node = graph->CreateVarNode(&combined_bias_desc);
auto* combined_bias_tensor =
scope->Var(combined_bias_node->Name())->GetMutable<LoDTensor>();
combined_bias_tensor->Resize(combined_bias_dims);
auto* combined_bias_data =
combined_bias_tensor->mutable_data<float>(platform::CPUPlace());
size_t bias_size = bq_tensor->numel();
memcpy(combined_bias_data, bq_data, sizeof(float) * bias_size);
memcpy(combined_bias_data + bias_size, bk_data, sizeof(float) * bias_size);
memcpy(combined_bias_data + 2 * bias_size, bv_data,
sizeof(float) * bias_size);
scope->EraseVars({eltadd0_b->Name(), eltadd1_b->Name(), eltadd2_b->Name()});
auto reshape_desc = reshape2->Op();
int head_number =
boost::get<std::vector<int>>(reshape_desc->GetAttr("shape")).at(2);
OpDesc multihead_op_desc;
multihead_op_desc.SetType("multihead_matmul");
multihead_op_desc.SetInput("Input", {layer_norm_out->Name()});
multihead_op_desc.SetInput("W", {combined_w_node->Name()});
multihead_op_desc.SetInput("Bias", {combined_bias_node->Name()});
multihead_op_desc.SetInput("BiasQK", {eltadd_qk_b->Name()});
multihead_op_desc.SetOutput("Out", {reshape2_qkv_out->Name()});
multihead_op_desc.SetAttr("alpha", scale_attr);
multihead_op_desc.SetAttr("head_number", head_number);
auto* multihead = graph->CreateOpNode(&multihead_op_desc);
IR_NODE_LINK_TO(layer_norm_out, multihead);
IR_NODE_LINK_TO(combined_w_node, multihead);
IR_NODE_LINK_TO(combined_bias_node, multihead);
IR_NODE_LINK_TO(eltadd_qk_b, multihead);
IR_NODE_LINK_TO(multihead, reshape2_qkv_out);
};
int fusion_count{0};
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
// GET_IR_NODE_FROM_SUBGRAPH(dropout_out, dropout_out, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(layer_norm, layer_norm, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(layer_norm_out, layer_norm_out,
multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(mul0, mul0, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(mul0_out, mul0_out, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(mul0_w, mul0_w, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(reshape2_0, reshape2_0, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(reshape2_0_out, reshape2_0_out,
multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(transpose2_0, transpose2_0, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(transpose2_0_out, transpose2_0_out,
multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(scale, scale, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(scale_out, scale_out, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(mul1, mul1, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(mul1_out, mul1_out, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(mul1_w, mul1_w, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(reshape2_1, reshape2_1, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(reshape2_1_out, reshape2_1_out,
multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(transpose2_1, transpose2_1, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(transpose2_1_out, transpose2_1_out,
multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(mul2, mul2, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(mul2_out, mul2_out, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(mul2_w, mul2_w, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(reshape2_2, reshape2_2, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(reshape2_2_out, reshape2_2_out,
multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(transpose2_2, transpose2_2, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(transpose2_2_out, transpose2_2_out,
multihead_pattern);
// nodes need be removed
GET_IR_NODE_FROM_SUBGRAPH(eltadd0, eltadd0, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(eltadd0_b, eltadd0_b, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(eltadd0_out, eltadd0_out, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(eltadd1, eltadd1, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(eltadd1_b, eltadd1_b, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(eltadd1_out, eltadd1_out, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(eltadd2, eltadd2, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(eltadd2_b, eltadd2_b, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(eltadd2_out, eltadd2_out, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(matmul_qk, matmul_qk, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(matmul_qk_out, matmul_qk_out, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(eltadd_qk, eltadd_qk, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(eltadd_qk_b, eltadd_qk_b, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(eltadd_qk_out, eltadd_qk_out, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(softmax_qk, softmax_qk, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(softmax_qk_out, softmax_qk_out,
multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(matmul_qkv, matmul_qkv, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(matmul_qkv_out, matmul_qkv_out,
multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(reshape2_qkv, reshape2_qkv, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(reshape2_qkv_out, reshape2_qkv_out,
multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(transpose2_qkv, transpose2_qkv,
multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH(transpose2_qkv_out, transpose2_qkv_out,
multihead_pattern);
fuse_creater(layer_norm_out, mul0, mul1, mul2, mul0_out, mul1_out, mul2_out,
mul0_w, mul1_w, mul2_w, eltadd0_b, eltadd1_b, eltadd2_b,
eltadd_qk_b, reshape2_0, reshape2_qkv_out, scale, scale_out);
std::unordered_set<const Node*> marked_nodes({eltadd0,
eltadd1,
eltadd2,
eltadd0_b,
eltadd1_b,
eltadd2_b,
eltadd0_out,
eltadd1_out,
eltadd2_out,
reshape2_0,
reshape2_1,
reshape2_2,
reshape2_0_out,
reshape2_1_out,
reshape2_2_out,
transpose2_0,
transpose2_1,
transpose2_2,
transpose2_0_out,
transpose2_1_out,
transpose2_2_out,
matmul_qk,
matmul_qk_out,
eltadd_qk,
eltadd_qk_out,
softmax_qk,
softmax_qk_out,
transpose2_qkv,
transpose2_qkv_out,
matmul_qkv,
matmul_qkv_out,
mul0,
mul1,
mul2,
mul0_out,
mul1_out,
mul2_out,
mul0_w,
mul1_w,
mul2_w,
reshape2_qkv,
scale});
// Remove unneeded nodes.
GraphSafeRemoveNodes(graph, marked_nodes);
++fusion_count;
};
gpd(graph, handler);
return fusion_count;
}
} // namespace patterns } // namespace patterns
void MultiHeadMatmulFusePass::ApplyImpl(Graph* graph) const { void MultiHeadMatmulFusePass::ApplyImpl(Graph* graph) const {
PADDLE_ENFORCE_NOT_NULL(graph);
FusePassBase::Init(name_scope_, graph); FusePassBase::Init(name_scope_, graph);
int fusion_count = patterns::BuildFusion(graph, name_scope_); int fusion_count = patterns::BuildFusion(graph, name_scope_);
AddStatis(fusion_count); AddStatis(fusion_count);
} }
void MultiHeadMatmulV2FusePass::ApplyImpl(Graph* graph) const {
FusePassBase::Init(name_scope_, graph);
auto* scope = param_scope();
PADDLE_ENFORCE_NOT_NULL(
scope,
platform::errors::Fatal(
"During the multiheadMatmul pass, The scope should not be null."));
patterns::BuildFusionV2(graph, name_scope_, scope);
}
} // namespace ir } // namespace ir
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
REGISTER_PASS(multihead_matmul_fuse_pass, REGISTER_PASS(multihead_matmul_fuse_pass,
paddle::framework::ir::MultiHeadMatmulFusePass); paddle::framework::ir::MultiHeadMatmulFusePass);
REGISTER_PASS(multihead_matmul_fuse_pass_v2,
paddle::framework::ir::MultiHeadMatmulV2FusePass);
...@@ -32,8 +32,6 @@ struct MultiHeadMatmulPattern : public PatternBase { ...@@ -32,8 +32,6 @@ struct MultiHeadMatmulPattern : public PatternBase {
PDNode* operator()(PDNode* x); PDNode* operator()(PDNode* x);
// declare operator node's name // declare operator node's name
// PATTERN_DECL_NODE(dropout);
// PATTERN_DECL_NODE(dropout_out);
PATTERN_DECL_NODE(layer_norm); PATTERN_DECL_NODE(layer_norm);
PATTERN_DECL_NODE(layer_norm_out); PATTERN_DECL_NODE(layer_norm_out);
PATTERN_DECL_NODE(mul0); PATTERN_DECL_NODE(mul0);
...@@ -79,8 +77,6 @@ struct MultiHeadMatmulPattern : public PatternBase { ...@@ -79,8 +77,6 @@ struct MultiHeadMatmulPattern : public PatternBase {
PATTERN_DECL_NODE(eltadd_qk_out); PATTERN_DECL_NODE(eltadd_qk_out);
PATTERN_DECL_NODE(softmax_qk); PATTERN_DECL_NODE(softmax_qk);
PATTERN_DECL_NODE(softmax_qk_out); PATTERN_DECL_NODE(softmax_qk_out);
// PATTERN_DECL_NODE(dropout_qk);
// PATTERN_DECL_NODE(dropout_qk_out);
PATTERN_DECL_NODE(matmul_qkv); PATTERN_DECL_NODE(matmul_qkv);
PATTERN_DECL_NODE(matmul_qkv_out); PATTERN_DECL_NODE(matmul_qkv_out);
...@@ -98,6 +94,16 @@ class MultiHeadMatmulFusePass : public FusePassBase { ...@@ -98,6 +94,16 @@ class MultiHeadMatmulFusePass : public FusePassBase {
const std::string name_scope_{"multihead_matmul_fuse"}; const std::string name_scope_{"multihead_matmul_fuse"};
}; };
class MultiHeadMatmulV2FusePass : public FusePassBase {
public:
virtual ~MultiHeadMatmulV2FusePass() {}
protected:
void ApplyImpl(Graph* graph) const;
const std::string name_scope_{"multihead_matmul_fuse_v2"};
};
} // namespace ir } // namespace ir
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -17,6 +17,27 @@ namespace paddle { ...@@ -17,6 +17,27 @@ namespace paddle {
namespace framework { namespace framework {
namespace ir { namespace ir {
void AddVarToScope(Scope* param_scope, const std::string& name,
const DDim& dims) {
auto* tensor = param_scope->Var(name)->GetMutable<LoDTensor>();
tensor->Resize(dims);
tensor->mutable_data<float>(platform::CPUPlace());
}
Scope* CreateParamScope() {
auto param_scope = new Scope();
AddVarToScope(param_scope, "weights0", {768, 768});
AddVarToScope(param_scope, "weights1", {768, 768});
AddVarToScope(param_scope, "weights2", {768, 768});
AddVarToScope(param_scope, "bias_0", {768});
AddVarToScope(param_scope, "bias_1", {768});
AddVarToScope(param_scope, "bias_2", {768});
AddVarToScope(param_scope, "biasqk", {768});
AddVarToScope(param_scope, "weightsl", {768, 768});
return param_scope;
}
TEST(MultiHeadMatmulFusePass, basic) { TEST(MultiHeadMatmulFusePass, basic) {
// inputs operator output // inputs operator output
// -------------------------------------------------------------------- // --------------------------------------------------------------------
...@@ -87,7 +108,10 @@ TEST(MultiHeadMatmulFusePass, basic) { ...@@ -87,7 +108,10 @@ TEST(MultiHeadMatmulFusePass, basic) {
layers.mul(reshape_qkv_out, weights_l); layers.mul(reshape_qkv_out, weights_l);
std::unique_ptr<ir::Graph> graph(new ir::Graph(layers.main_program())); std::unique_ptr<ir::Graph> graph(new ir::Graph(layers.main_program()));
auto pass = PassRegistry::Instance().Get("multihead_matmul_fuse_pass"); graph->Set("__param_scope__", CreateParamScope());
auto pass = PassRegistry::Instance().Get("multihead_matmul_fuse_pass_v2");
if (pass.get() == nullptr) LOG(INFO) << "asdfasdf";
int num_nodes_before = graph->Nodes().size(); int num_nodes_before = graph->Nodes().size();
VLOG(3) << DebugString(graph); VLOG(3) << DebugString(graph);
...@@ -96,8 +120,17 @@ TEST(MultiHeadMatmulFusePass, basic) { ...@@ -96,8 +120,17 @@ TEST(MultiHeadMatmulFusePass, basic) {
int num_fused_nodes_after = GetNumOpNodes(graph, "multihead_matmul"); int num_fused_nodes_after = GetNumOpNodes(graph, "multihead_matmul");
VLOG(3) << DebugString(graph); VLOG(3) << DebugString(graph);
PADDLE_ENFORCE_EQ(num_nodes_before, num_nodes_after + 29); PADDLE_ENFORCE_EQ(
PADDLE_ENFORCE_EQ(num_fused_nodes_after, 1); num_nodes_before, num_nodes_after + 39,
platform::errors::InvalidArgument(
"After the multihead_matmul pass, The node num in graph "
"should be %d, but the result is %d",
num_nodes_before - 39, num_nodes_after));
PADDLE_ENFORCE_EQ(num_fused_nodes_after, 1,
platform::errors::InvalidArgument(
"After the multihead_matmul pass, there should be one "
"multihead_matmul op, but the result is %d",
num_fused_nodes_after));
} }
} // namespace ir } // namespace ir
...@@ -105,3 +138,4 @@ TEST(MultiHeadMatmulFusePass, basic) { ...@@ -105,3 +138,4 @@ TEST(MultiHeadMatmulFusePass, basic) {
} // namespace paddle } // namespace paddle
USE_PASS(multihead_matmul_fuse_pass); USE_PASS(multihead_matmul_fuse_pass);
USE_PASS(multihead_matmul_fuse_pass_v2);
...@@ -107,7 +107,7 @@ GpuPassStrategy::GpuPassStrategy() : PassStrategy({}) { ...@@ -107,7 +107,7 @@ GpuPassStrategy::GpuPassStrategy() : PassStrategy({}) {
"conv_eltwiseadd_affine_channel_fuse_pass", // "conv_eltwiseadd_affine_channel_fuse_pass", //
"conv_bn_fuse_pass", // "conv_bn_fuse_pass", //
"conv_eltwiseadd_bn_fuse_pass", // "conv_eltwiseadd_bn_fuse_pass", //
"multihead_matmul_fuse_pass", "multihead_matmul_fuse_pass_v2",
"fc_fuse_pass", // "fc_fuse_pass", //
"fc_elementwise_layernorm_fuse_pass", // "fc_elementwise_layernorm_fuse_pass", //
#if CUDNN_VERSION >= 7100 // To run conv_fusion, the version of cudnn must be #if CUDNN_VERSION >= 7100 // To run conv_fusion, the version of cudnn must be
......
...@@ -15,126 +15,80 @@ limitations under the License. */ ...@@ -15,126 +15,80 @@ limitations under the License. */
#include <vector> #include <vector>
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/detail/safe_ref.h" #include "paddle/fluid/operators/detail/safe_ref.h"
#include "paddle/fluid/platform/errors.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
class MultiHeadMatMulOp : public framework::OperatorWithKernel { class MultiHeadMatMulV2Op : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
protected: protected:
void InferShape(framework::InferShapeContext *context) const override { void InferShape(framework::InferShapeContext *context) const override {
PADDLE_ENFORCE_EQ(context->HasInput("Q"), true,
"Input(Q) of MultiheadOp should not be null.");
PADDLE_ENFORCE_EQ(context->HasInput("K"), true,
"Input(K) of MultiheadOp should not be null.");
PADDLE_ENFORCE_EQ(context->HasInput("V"), true,
"Input(V) of MultiheadOp should not be null.");
PADDLE_ENFORCE_EQ(context->HasInput("BiasQ"), true,
"Input(BiasQ) of MultiheadOp should not be null.");
PADDLE_ENFORCE_EQ(context->HasInput("BiasK"), true,
"Input(BiasQ) of MultiheadOp should not be null.");
PADDLE_ENFORCE_EQ(context->HasInput("BiasV"), true,
"Input(BiasQ) of MultiheadOp should not be null.");
PADDLE_ENFORCE_EQ(context->HasInput("BiasQK"), true,
"Input(BiasQK) of MultiheadOp should not be null.");
PADDLE_ENFORCE_EQ(context->HasOutput("Out"), true,
"Output(Out) of MatMulOp should not be null.");
auto dim_q = context->GetInputDim("Q");
PADDLE_ENFORCE_GT(dim_q.size(), 2,
"Multihead input should be at least 3-D tensor.");
auto dim_k = context->GetInputDim("K");
PADDLE_ENFORCE_GT(dim_q.size(), 2,
"Multihead input should be at least 3-D tensor.");
auto dim_v = context->GetInputDim("V");
PADDLE_ENFORCE_GT(dim_q.size(), 2,
"Multihead input should be at least 3-D tensor.");
PADDLE_ENFORCE_EQ(dim_q[0], dim_k[0],
"Multihead input should have same batch size");
PADDLE_ENFORCE_EQ(dim_q[0], dim_v[0],
"Multihead input should have same batch size");
PADDLE_ENFORCE_EQ(dim_q[1], dim_k[1],
"Multihead input should have same size");
PADDLE_ENFORCE_EQ(dim_q[1], dim_v[1],
"Multihead input should have same size");
PADDLE_ENFORCE_EQ(dim_q[2], dim_k[2],
"Multihead input should have same size");
PADDLE_ENFORCE_EQ(dim_q[2], dim_v[2],
"Multihead input should have same size");
auto dim_bias_q = context->GetInputDim("BiasQ");
PADDLE_ENFORCE_GT(dim_bias_q.size(), 0,
"Multihead input should be at least 1-D tensor.");
auto dim_bias_k = context->GetInputDim("BiasK");
PADDLE_ENFORCE_GT(dim_bias_k.size(), 0,
"Multihead input should be at least 1-D tensor.");
auto dim_bias_v = context->GetInputDim("BiasV");
PADDLE_ENFORCE_GT(dim_bias_v.size(), 0,
"Multihead input should be at least 1-D tensor.");
PADDLE_ENFORCE_EQ(dim_bias_q[0], dim_bias_k[0],
"Multihead input bias should have same batch size");
PADDLE_ENFORCE_EQ(dim_bias_q[0], dim_bias_v[0],
"Multihead input bias should have same batch size");
auto dim_bias_qk = context->GetInputDim("BiasQK");
PADDLE_ENFORCE_GT(dim_bias_qk.size(), 3,
"Multihead input bias qk should be at least 4-D tensor.");
int b_indx = dim_bias_q.size() - 1;
int indx = dim_q.size() - 1;
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
dim_bias_q[b_indx], dim_q[indx], context->HasInput("Input"), true,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"bias_q's last dim size should equal to" "Input(Input) of MultiHeadMatMul should not be null."));
" q last dim size, but received bias_q's size is:%d q is:%d", PADDLE_ENFORCE_EQ(context->HasInput("W"), true,
dim_bias_q[b_indx], dim_q[indx])); platform::errors::InvalidArgument(
"Input(W) of MultiHeadMatMul should not be null."));
PADDLE_ENFORCE_EQ(
context->HasInput("Bias"), true,
platform::errors::InvalidArgument(
"Input(Bias) of MultiHeadMatMul should not be null."));
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
dim_bias_k[b_indx], dim_k[indx], context->HasInput("BiasQK"), true,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"bias_k's last dim size should equal to" "Input(BiasQK) of MultiHeadMatMul should not be null."));
" k last dim size, but received bias_k's size is:%d k is:%d",
dim_bias_k[b_indx], dim_k[indx]));
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
dim_bias_v[b_indx], dim_v[indx], context->HasOutput("Out"), true,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"bias_v's last dim size should equal to" "Output(Out) of MultiHeadMatMul should not be null."));
" v last dim size, but received bias_v's size is:%d v is:%d",
dim_bias_v[b_indx], dim_v[indx]));
PADDLE_ENFORCE_EQ(dim_q[0], dim_bias_qk[0], auto dim_w = context->GetInputDim("W");
PADDLE_ENFORCE_GT(
dim_w.size(), 2,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"q should have same batch size" "Multihead input is expected at least a 3-D tensor, but "
"with bias_qk, but received q's batch size is:%d " "it's %d-D tensor now.",
"bias_qk's batch size is:%d", dim_w.size()));
dim_q[0], dim_bias_qk[0]));
int head_number = context->Attrs().Get<int>("head_number"); auto dim_bias_q = context->GetInputDim("Bias");
PADDLE_ENFORCE_GT(head_number, 1, PADDLE_ENFORCE_GT(
"Multihead input head number should be at least 1."); dim_bias_q.size(), 1,
platform::errors::InvalidArgument(
"Multihead input should be at least 2-D tensor, but it's "
"%d-D tensor now.",
dim_bias_q.size()));
auto dim_bias_qk = context->GetInputDim("BiasQK");
PADDLE_ENFORCE_GT(
dim_bias_qk.size(), 3,
platform::errors::InvalidArgument(
"Multihead input bias qk should be at least 4-D tensor, "
"but it's %d-D tensor now.",
dim_bias_qk.size()));
context->SetOutputDim("Out", dim_q); int head_number = context->Attrs().Get<int>("head_number");
context->ShareLoD("Q", /*->*/ "Out"); PADDLE_ENFORCE_GT(
head_number, 1,
platform::errors::InvalidArgument(
"Multihead input head number should be at least 1, but it %d now.",
head_number));
// modify this
auto dim_input = context->GetInputDim("Input");
context->SetOutputDim("Out", dim_input);
context->ShareLoD("Input", /*->*/ "Out");
} }
}; };
class MultiHeadMatMulOpMaker : public framework::OpProtoAndCheckerMaker { class MultiHeadMatMulV2OpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
void Make() override { void Make() override {
AddInput("Q", "The first input of MultiHeadMatMul op"); AddInput("Input", "The input of MultiHeadMatMul op");
AddInput("K", "The second input of MMultiHeadMatMul op"); AddInput("W", "The weight input of MultiHeadMatMul op");
AddInput("V", "The third input of MultiHeadMatMul op"); AddInput("Bias", "The bias input of MultiHeadMatMul op");
AddInput("BiasQ", "The first bias input of MultiHeadMatMul op");
AddInput("BiasK", "The second bias input of MultiHeadMatMul op");
AddInput("BiasV", "The third bias input of MultiHeadMatMul op");
AddInput("BiasQK", "The QK bias input of MultiHeadMatMul op"); AddInput("BiasQK", "The QK bias input of MultiHeadMatMul op");
AddOutput("Out", "The output of MultiHeadMatMul op"); AddOutput("Out", "The output of MultiHeadMatMul op");
AddAttr<bool>("transpose_Q", AddAttr<bool>("transpose_Q",
...@@ -161,10 +115,6 @@ Not suggest to use in other case except has same structure as ernie. ...@@ -161,10 +115,6 @@ Not suggest to use in other case except has same structure as ernie.
Example of matrix multiplication with head_number of B Example of matrix multiplication with head_number of B
- X: [B, M, K], Y: [B, K, N] => Out: [B, M, N] - X: [B, M, K], Y: [B, K, N] => Out: [B, M, N]
Both the input `Q` and `K` can carry the LoD (Level of Details) information,
or not. But the output only shares the LoD information with input `Q`, because
they are the same.
)DOC"); )DOC");
} }
}; };
...@@ -173,5 +123,5 @@ they are the same. ...@@ -173,5 +123,5 @@ they are the same.
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_WITHOUT_GRADIENT(multihead_matmul, ops::MultiHeadMatMulOp, REGISTER_OP_WITHOUT_GRADIENT(multihead_matmul, ops::MultiHeadMatMulV2Op,
ops::MultiHeadMatMulOpMaker); ops::MultiHeadMatMulV2OpMaker);
...@@ -300,7 +300,7 @@ template <typename T> ...@@ -300,7 +300,7 @@ template <typename T>
void MatMulWithHeadQKV(const platform::CUDADeviceContext &context, int head_num, void MatMulWithHeadQKV(const platform::CUDADeviceContext &context, int head_num,
int seq_len, int size_per_head, int batch_size, int seq_len, int size_per_head, int batch_size,
bool qk_trans, bool v_trans, T *v_buf_, const T *qk_buf_, bool qk_trans, bool v_trans, T *v_buf_, const T *qk_buf_,
T *dst, T *out, T alpha, T beta) { T *dst, T alpha, T beta) {
int m = batch_size * seq_len; int m = batch_size * seq_len;
int k = head_num * size_per_head; int k = head_num * size_per_head;
...@@ -312,96 +312,199 @@ void MatMulWithHeadQKV(const platform::CUDADeviceContext &context, int head_num, ...@@ -312,96 +312,199 @@ void MatMulWithHeadQKV(const platform::CUDADeviceContext &context, int head_num,
blas.BatchedGEMM(transA, transB, seq_len, size_per_head, seq_len, alpha, blas.BatchedGEMM(transA, transB, seq_len, size_per_head, seq_len, alpha,
qk_buf_, v_buf_, beta, dst, batch_size * head_num, qk_buf_, v_buf_, beta, dst, batch_size * head_num,
seq_len * seq_len, seq_len * size_per_head); seq_len * seq_len, seq_len * size_per_head);
}
int grid = batch_size * head_num * seq_len; template <typename T>
int block = size_per_head; inline __device__ T add_func(T a, T b);
transpose<T><<<grid, block, 0, stream>>>(dst, out, batch_size, seq_len,
head_num, size_per_head); template <>
__device__ float add_func<float>(float a, float b) {
return a + b;
}
template <>
__device__ float2 add_func<float2>(float2 a, float2 b) {
float2 c;
c.x = a.x + b.x;
c.y = a.y + b.y;
return c;
}
template <>
__device__ float4 add_func<float4>(float4 a, float4 b) {
float4 c;
c.x = a.x + b.x;
c.y = a.y + b.y;
c.z = a.z + b.z;
c.w = a.w + b.w;
return c;
} }
template <typename T> template <typename T>
void MultiHeadGPUCompute(const platform::CUDADeviceContext &dev_ctx, __global__ void transpose_qkv_kernel(const int H, const T *input, const T *bias,
int head_num, const framework::DDim &mat_q, T *output) {
const framework::DDim &mat_k, // Input: BxSx3xNxH
const framework::DDim &mat_v, const T *Q, const T *K, // Bias: 3xSxB
const T *V, const T *bias_q, const T *bias_k, // Output: 3xBxNxSxH
const T *bias_v, const T *bias_qk, T *out, T alpha, int n = threadIdx.y;
T beta, bool trans_q, bool trans_k, bool trans_v) { int s = blockIdx.x;
int seq_len = mat_q[1]; int b = blockIdx.y;
int size_per_head = (mat_q[2] / head_num); int m = blockIdx.z;
int batch_size = mat_q[0];
int buf_size = batch_size * head_num * seq_len * size_per_head; const int N = blockDim.y;
int qk_buf_size = batch_size * head_num * seq_len * seq_len; const int S = gridDim.x;
const int B = gridDim.y;
auto alloc_buf =
memory::Alloc(dev_ctx, (buf_size * 4 + qk_buf_size) * sizeof(T)); const int NH = N * H;
const int NHS = NH * S;
T *buf = reinterpret_cast<T *>(alloc_buf->ptr()); const int in_offset = n * H + m * NH + s * 3 * NH + b * NHS * 3;
T *q_buf = buf; const int bias_offset = m * NH + n * H;
T *k_buf = buf + buf_size; const int out_offset = s * H + n * S * H + b * NHS + m * NHS * B;
T *v_buf = buf + 2 * buf_size;
T *qk_buf = buf + 3 * buf_size; const int i = threadIdx.x;
T *dst_buf = buf + 3 * buf_size + qk_buf_size; output[out_offset + i] =
add_func(input[in_offset + i], bias[bias_offset + i]);
}
int m = batch_size * seq_len; void TransQKVWithBias(const int batch, const int seq_len, const int head_size,
int k = head_num * size_per_head; const int head_num, const float *input, const float *bias,
float *output, cudaStream_t stream) {
// BxSx3xNxH + 3xNxH -> 3xBxNxSxH
const dim3 grid(seq_len, batch, 3);
if (head_size % 4 == 0) {
const int h = head_size / 4;
const float4 *input4 = reinterpret_cast<const float4 *>(input);
const float4 *bias4 = reinterpret_cast<const float4 *>(bias);
float4 *output4 = reinterpret_cast<float4 *>(output);
const dim3 block(h, head_num, 1);
// limit h * head_num to max block size(1024).
PADDLE_ENFORCE_LE(h * head_num, 1024,
platform::errors::InvalidArgument(
"head_num (%d) * head_size (%d) should <= %d",
head_num, head_size, 1024 * 4));
transpose_qkv_kernel<float4><<<grid, block, 0, stream>>>(h, input4, bias4,
output4);
} else if (head_size % 2 == 0) {
const int h = head_size / 2;
const float2 *input2 = reinterpret_cast<const float2 *>(input);
const float2 *bias2 = reinterpret_cast<const float2 *>(bias);
float2 *output2 = reinterpret_cast<float2 *>(output);
const dim3 block(h, head_num, 1);
// limit h * head_num to max block size(1024).
PADDLE_ENFORCE_LE(h * head_num, 1024,
platform::errors::InvalidArgument(
"head_num (%d) * head_size (%d) should <= %d",
head_num, head_size, 1024 * 2));
transpose_qkv_kernel<float2><<<grid, block, 0, stream>>>(h, input2, bias2,
output2);
} else {
const dim3 block(head_size, head_num, 1);
// limit head_size * head_num to max block size(1024).
PADDLE_ENFORCE_LE(head_size * head_num, 1024,
platform::errors::InvalidArgument(
"head_num (%d) * head_size (%d) should <= %d",
head_num, head_size, 1024));
transpose_qkv_kernel<float><<<grid, block, 0, stream>>>(head_size, input,
bias, output);
}
}
// Each block process head*size-per_head element, template <typename T>
// have m lines. bias is m lines void MultiHeadGPUComputeV2(const platform::CUDADeviceContext &dev_ctx,
auto blas = math::GetBlas<platform::CUDADeviceContext, T>(dev_ctx); int batch, int seq_len, int head_num, int head_size,
T *qkptr, const T *bias_qk_ptr, T *tptr, T alpha,
T beta) {
auto stream = dev_ctx.stream(); auto stream = dev_ctx.stream();
const int tsize = batch * head_num * seq_len * head_size;
int grid = m;
PADDLE_ENFORCE_LE(k, 1024, T *qptr = tptr;
"Input head_number * size_per_head should <= 1024"); T *kptr = qptr + tsize;
int block = k <= 1024 ? k : 1024; T *vptr = kptr + tsize;
add_QKV<T><<<grid, block, 0, stream>>>(Q, K, V, q_buf, k_buf, v_buf, bias_q, // batch gemm stride, softmaxwithscale.
bias_k, bias_v, batch_size, seq_len, MatMulWithHeadQK<T>(dev_ctx, head_num, seq_len, head_size, batch, false, true,
head_num, size_per_head); qptr, kptr, qkptr, bias_qk_ptr, alpha, beta);
// batch gemm stride, transpose.
MatMulWithHeadQK<T>(dev_ctx, head_num, seq_len, size_per_head, batch_size, MatMulWithHeadQKV<T>(dev_ctx, head_num, seq_len, head_size, batch, false,
trans_q, trans_k, q_buf, k_buf, qk_buf, bias_qk, alpha, false, vptr, qkptr, tptr, T(1.0), beta);
beta);
MatMulWithHeadQKV<T>(dev_ctx, head_num, seq_len, size_per_head, batch_size,
false, trans_v, v_buf, qk_buf, dst_buf, out, T(1.0),
beta);
} }
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
class MultiHeadMatMulKernel : public framework::OpKernel<T> { class MultiHeadMatMulV2Kernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext &context) const override { void Compute(const framework::ExecutionContext &context) const override {
auto *q = context.Input<framework::Tensor>("Q"); using Tensor = framework::Tensor;
auto *k = context.Input<framework::Tensor>("K"); auto *input = context.Input<framework::Tensor>("Input");
auto *v = context.Input<framework::Tensor>("V"); auto *w = context.Input<framework::Tensor>("W");
auto *bias = context.Input<framework::Tensor>("Bias");
auto &bias_q = detail::Ref(context.Input<framework::Tensor>("BiasQ"),
"Cannot find BiasQ");
auto &bias_k = detail::Ref(context.Input<framework::Tensor>("BiasK"),
"Cannot find BiasK");
auto &bias_v = detail::Ref(context.Input<framework::Tensor>("BiasV"),
"Cannot find BiasV");
auto &bias_qk = detail::Ref(context.Input<framework::Tensor>("BiasQK"), auto &bias_qk = detail::Ref(context.Input<framework::Tensor>("BiasQK"),
"Cannot find QK"); "Cannot find QK");
auto *out = context.Output<framework::Tensor>("Out"); auto *out = context.Output<framework::Tensor>("Out");
out->mutable_data<T>(context.GetPlace());
auto *input_d = input->data<T>();
auto *w_d = w->data<T>();
auto *bias_d = bias->data<T>();
auto *bias_qk_d = bias_qk.data<T>();
auto *output_d = out->mutable_data<T>(context.GetPlace());
T scale = static_cast<T>(context.Attr<float>("alpha")); T scale = static_cast<T>(context.Attr<float>("alpha"));
bool transpose_q = context.Attr<bool>("transpose_Q");
bool transpose_k = context.Attr<bool>("transpose_K");
bool transpose_v = context.Attr<bool>("transpose_V");
int head_number = context.Attr<int>("head_number"); int head_number = context.Attr<int>("head_number");
// compute q*k with eltadd // compute q*k with eltadd
auto &device_ctx = context.template device_context<DeviceContext>(); auto &device_ctx = context.template device_context<DeviceContext>();
// should be (B * S * hidden)
MultiHeadGPUCompute<T>(device_ctx, head_number, q->dims(), k->dims(), auto input_dims = input->dims();
v->dims(), q->data<T>(), k->data<T>(), v->data<T>(), // shouble be (hidden * 3 * all_head_size)
bias_q.data<T>(), bias_k.data<T>(), bias_v.data<T>(), auto w_dims = w->dims();
bias_qk.data<T>(), out->data<T>(), scale, T(0.0), int batch = input_dims[0];
transpose_q, transpose_k, transpose_v); int seq_len = input_dims[1];
int hidden = input_dims[2];
int all_head_size = w_dims[2];
int head_size = all_head_size / head_number;
// (B*S, hidden)
const Tensor input_matrix =
framework::ReshapeToMatrix(*input, 2 /*x_num_col_dims */);
// (hidden, 3 * all_head_size)
const Tensor w_matrix =
framework::ReshapeToMatrix(*w, 1 /*y_num_col_dims*/);
Tensor temp_out_tensor;
auto temp_out_dims =
framework::make_ddim({batch, seq_len, 3, head_number, head_size});
temp_out_tensor.Resize({batch * seq_len, framework::product(temp_out_dims) /
(batch * seq_len)});
auto *temp_out_data = temp_out_tensor.mutable_data<T>(context.GetPlace());
// (B * S, hidden) * (hidden, 3 * N * H) -> (B * S * 3 * N * H)
auto blas = math::GetBlas<platform::CUDADeviceContext, T>(device_ctx);
blas.MatMul(input_matrix, w_matrix, &temp_out_tensor);
// temp_out_tensor.Resize(temp_out_dims);
Tensor multihead_temp_tensor;
// B * head_number * S * S * 1 + B * S * 3 * N * H
int scratch_size = batch * head_number * seq_len * seq_len * 1;
multihead_temp_tensor.Resize({scratch_size + temp_out_tensor.numel()});
auto *multihead_temp_data =
multihead_temp_tensor.mutable_data<T>(context.GetPlace());
auto *qkptr = multihead_temp_data;
auto *tptr = multihead_temp_data + scratch_size;
auto stream = device_ctx.stream();
// Do the transpose with bias.
// BxSx3xNxH => tptr: 3xBxNxSxH.
TransQKVWithBias(batch, seq_len, head_size, head_number, temp_out_data,
bias_d, tptr, stream);
MultiHeadGPUComputeV2<T>(device_ctx, batch, seq_len, head_number, head_size,
qkptr, bias_qk_d, tptr, scale, T(0.0));
int grid = batch * head_number * seq_len;
int block = head_size;
transpose<T><<<grid, block, 0, stream>>>(tptr, output_d, batch, seq_len,
head_number, head_size);
} }
}; };
...@@ -411,5 +514,4 @@ class MultiHeadMatMulKernel : public framework::OpKernel<T> { ...@@ -411,5 +514,4 @@ class MultiHeadMatMulKernel : public framework::OpKernel<T> {
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
multihead_matmul, multihead_matmul,
ops::MultiHeadMatMulKernel<paddle::platform::CUDADeviceContext, float>, ops::MultiHeadMatMulV2Kernel<paddle::platform::CUDADeviceContext, float>);
ops::MultiHeadMatMulKernel<paddle::platform::CUDADeviceContext, double>);
...@@ -47,12 +47,21 @@ class TestFusedMultiheadMatmulOp(OpTest): ...@@ -47,12 +47,21 @@ class TestFusedMultiheadMatmulOp(OpTest):
self.config() self.config()
h = self.seq_len h = self.seq_len
w = self.head_number * self.size_per_head w = self.head_number * self.size_per_head
self.Q = np.random.random((self.batch_size, h, w)).astype("float32") self.Input = np.random.random(
self.K = np.random.random((self.batch_size, h, w)).astype("float32") (self.batch_size, h, w)).astype("float32") - 0.5
self.V = np.random.random((self.batch_size, h, w)).astype("float32") self.WQ = np.random.random((w, w)).astype("float32")
self.KQ = np.random.random((w, w)).astype("float32")
self.VQ = np.random.random((w, w)).astype("float32")
self.CombinedW = np.hstack((self.WQ, self.KQ, self.VQ)).reshape(
(w, 3, w))
self.Q = np.dot(self.Input, self.WQ)
self.K = np.dot(self.Input, self.KQ)
self.V = np.dot(self.Input, self.VQ)
self.BiasQ = np.random.random((1, w)).astype("float32") self.BiasQ = np.random.random((1, w)).astype("float32")
self.BiasK = np.random.random((1, w)).astype("float32") self.BiasK = np.random.random((1, w)).astype("float32")
self.BiasV = np.random.random((1, w)).astype("float32") self.BiasV = np.random.random((1, w)).astype("float32")
self.CombinedB = np.vstack((self.BiasQ, self.BiasK, self.BiasV))
self.BiasQK = np.random.random( self.BiasQK = np.random.random(
(self.batch_size, self.head_number, self.seq_len, (self.batch_size, self.head_number, self.seq_len,
self.seq_len)).astype("float32") self.seq_len)).astype("float32")
...@@ -84,12 +93,9 @@ class TestFusedMultiheadMatmulOp(OpTest): ...@@ -84,12 +93,9 @@ class TestFusedMultiheadMatmulOp(OpTest):
reshape_qkv = np.reshape(transpose_qkv, (self.batch_size, h, w)) reshape_qkv = np.reshape(transpose_qkv, (self.batch_size, h, w))
self.inputs = { self.inputs = {
"Q": self.Q, "Input": self.Input,
"K": self.K, "W": self.CombinedW,
"V": self.V, "Bias": self.CombinedB,
"BiasQ": self.BiasQ,
"BiasK": self.BiasK,
"BiasV": self.BiasV,
"BiasQK": self.BiasQK "BiasQK": self.BiasQK
} }
self.attrs = { self.attrs = {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册