diff --git a/paddle/fluid/framework/ir/multihead_matmul_fuse_pass.cc b/paddle/fluid/framework/ir/multihead_matmul_fuse_pass.cc index 6236c16d78569f6cd928716770fa808d603a3942..1ea6a69d72559712402a83404711f05b33b8beb0 100644 --- a/paddle/fluid/framework/ir/multihead_matmul_fuse_pass.cc +++ b/paddle/fluid/framework/ir/multihead_matmul_fuse_pass.cc @@ -19,6 +19,7 @@ #include #include "paddle/fluid/framework/ddim.h" #include "paddle/fluid/framework/lod_tensor.h" +#include "paddle/fluid/platform/errors.h" namespace paddle { namespace framework { @@ -425,19 +426,285 @@ PDNode* MultiHeadMatmulPattern::operator()(paddle::framework::ir::PDNode* x) { return transpose2_2_out_var; } +static int BuildFusionV2(Graph* graph, const std::string& name_scope, + Scope* scope) { + GraphPatternDetector gpd; + auto* pattern = gpd.mutable_pattern(); + + // Create pattern. + MultiHeadMatmulPattern multihead_pattern(pattern, name_scope); + + PDNode* x = + pattern->NewNode(patterns::UniqueKey("X"))->assert_var_not_persistable(); + + multihead_pattern(x); + // Create New OpDesc + auto fuse_creater = [&]( + Node* layer_norm_out, Node* mul0, Node* mul1, Node* mul2, Node* mul0_out, + Node* mul1_out, Node* mul2_out, Node* mul0_w, Node* mul1_w, Node* mul2_w, + Node* eltadd0_b, Node* eltadd1_b, Node* eltadd2_b, Node* eltadd_qk_b, + Node* reshape2, Node* reshape2_qkv_out, Node* scale, Node* scale_out) { + auto scale_attr = boost::get(scale->Op()->GetAttr("scale")); + + // mul (B * S * Hidden) x (Hidden * 3 * N * H) = (B * S * 3 * N * H) + // bias (B * S * 3 * N * H) + bias (3 * N * H) + // Transpose (B * S * 3 * N * H) -> (3 * B * N * S * H) + auto* wq_tensor = scope->FindVar(mul0_w->Name())->GetMutable(); + auto* wk_tensor = scope->FindVar(mul1_w->Name())->GetMutable(); + auto* wv_tensor = scope->FindVar(mul2_w->Name())->GetMutable(); + + auto* bq_tensor = + scope->FindVar(eltadd0_b->Name())->GetMutable(); + auto* bk_tensor = + scope->FindVar(eltadd1_b->Name())->GetMutable(); + auto* bv_tensor = + scope->FindVar(eltadd2_b->Name())->GetMutable(); + + auto* wq_data = wq_tensor->mutable_data(platform::CPUPlace()); + auto* wk_data = wk_tensor->mutable_data(platform::CPUPlace()); + auto* wv_data = wv_tensor->mutable_data(platform::CPUPlace()); + auto* bq_data = bq_tensor->mutable_data(platform::CPUPlace()); + auto* bk_data = bk_tensor->mutable_data(platform::CPUPlace()); + auto* bv_data = bv_tensor->mutable_data(platform::CPUPlace()); + + auto combined_w_dims = + framework::make_ddim({wq_tensor->dims()[0], 3, wq_tensor->dims()[1]}); + auto combined_bias_dims = framework::make_ddim({3, bq_tensor->dims()[0]}); + + // create a new var in scope + VarDesc combined_w_desc( + patterns::PDNodeName(name_scope, "multi_head_combined_weight")); + combined_w_desc.SetShape({wq_tensor->dims()[0], 3, wq_tensor->dims()[1]}); + combined_w_desc.SetDataType(wq_tensor->type()); + combined_w_desc.SetLoDLevel(mul0_w->Var()->GetLoDLevel()); + combined_w_desc.SetPersistable(true); + + // create a new var in scope + VarDesc combined_bias_desc( + patterns::PDNodeName(name_scope, "multi_head_combined_bias")); + combined_bias_desc.SetShape({3, bq_tensor->dims()[0]}); + combined_bias_desc.SetDataType(bq_tensor->type()); + combined_bias_desc.SetLoDLevel(eltadd0_b->Var()->GetLoDLevel()); + combined_bias_desc.SetPersistable(true); + + auto* combined_w_node = graph->CreateVarNode(&combined_w_desc); + auto* combined_w_tensor = + scope->Var(combined_w_node->Name())->GetMutable(); + + combined_w_tensor->Resize(combined_w_dims); + auto* combined_w_data = + combined_w_tensor->mutable_data(platform::CPUPlace()); + std::vector w_vec = {wq_data, wk_data, wv_data}; + int dims_h = combined_w_dims[0], dims_w = combined_w_dims[2]; + // Combine the three fc weights together. + for (int i = 0; i < dims_h; i++) { + for (int j = 0; j < 3; j++) { + for (int k = 0; k < dims_w; k++) { + int out_index = i * (3 * dims_w) + j * dims_w + k; + int in_index = i * dims_w + k; + combined_w_data[out_index] = w_vec[j][in_index]; + } + } + } + scope->EraseVars({mul0_w->Name(), mul1_w->Name(), mul2_w->Name()}); + auto* combined_bias_node = graph->CreateVarNode(&combined_bias_desc); + auto* combined_bias_tensor = + scope->Var(combined_bias_node->Name())->GetMutable(); + + combined_bias_tensor->Resize(combined_bias_dims); + auto* combined_bias_data = + combined_bias_tensor->mutable_data(platform::CPUPlace()); + size_t bias_size = bq_tensor->numel(); + memcpy(combined_bias_data, bq_data, sizeof(float) * bias_size); + memcpy(combined_bias_data + bias_size, bk_data, sizeof(float) * bias_size); + memcpy(combined_bias_data + 2 * bias_size, bv_data, + sizeof(float) * bias_size); + + scope->EraseVars({eltadd0_b->Name(), eltadd1_b->Name(), eltadd2_b->Name()}); + + auto reshape_desc = reshape2->Op(); + int head_number = + boost::get>(reshape_desc->GetAttr("shape")).at(2); + + OpDesc multihead_op_desc; + multihead_op_desc.SetType("multihead_matmul"); + + multihead_op_desc.SetInput("Input", {layer_norm_out->Name()}); + multihead_op_desc.SetInput("W", {combined_w_node->Name()}); + multihead_op_desc.SetInput("Bias", {combined_bias_node->Name()}); + multihead_op_desc.SetInput("BiasQK", {eltadd_qk_b->Name()}); + + multihead_op_desc.SetOutput("Out", {reshape2_qkv_out->Name()}); + multihead_op_desc.SetAttr("alpha", scale_attr); + multihead_op_desc.SetAttr("head_number", head_number); + + auto* multihead = graph->CreateOpNode(&multihead_op_desc); + + IR_NODE_LINK_TO(layer_norm_out, multihead); + IR_NODE_LINK_TO(combined_w_node, multihead); + IR_NODE_LINK_TO(combined_bias_node, multihead); + IR_NODE_LINK_TO(eltadd_qk_b, multihead); + + IR_NODE_LINK_TO(multihead, reshape2_qkv_out); + }; + + int fusion_count{0}; + auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, + Graph* g) { + // GET_IR_NODE_FROM_SUBGRAPH(dropout_out, dropout_out, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(layer_norm, layer_norm, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(layer_norm_out, layer_norm_out, + multihead_pattern); + + GET_IR_NODE_FROM_SUBGRAPH(mul0, mul0, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(mul0_out, mul0_out, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(mul0_w, mul0_w, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(reshape2_0, reshape2_0, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(reshape2_0_out, reshape2_0_out, + multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(transpose2_0, transpose2_0, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(transpose2_0_out, transpose2_0_out, + multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(scale, scale, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(scale_out, scale_out, multihead_pattern); + + GET_IR_NODE_FROM_SUBGRAPH(mul1, mul1, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(mul1_out, mul1_out, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(mul1_w, mul1_w, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(reshape2_1, reshape2_1, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(reshape2_1_out, reshape2_1_out, + multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(transpose2_1, transpose2_1, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(transpose2_1_out, transpose2_1_out, + multihead_pattern); + + GET_IR_NODE_FROM_SUBGRAPH(mul2, mul2, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(mul2_out, mul2_out, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(mul2_w, mul2_w, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(reshape2_2, reshape2_2, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(reshape2_2_out, reshape2_2_out, + multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(transpose2_2, transpose2_2, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(transpose2_2_out, transpose2_2_out, + multihead_pattern); + + // nodes need be removed + GET_IR_NODE_FROM_SUBGRAPH(eltadd0, eltadd0, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(eltadd0_b, eltadd0_b, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(eltadd0_out, eltadd0_out, multihead_pattern); + + GET_IR_NODE_FROM_SUBGRAPH(eltadd1, eltadd1, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(eltadd1_b, eltadd1_b, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(eltadd1_out, eltadd1_out, multihead_pattern); + + GET_IR_NODE_FROM_SUBGRAPH(eltadd2, eltadd2, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(eltadd2_b, eltadd2_b, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(eltadd2_out, eltadd2_out, multihead_pattern); + + GET_IR_NODE_FROM_SUBGRAPH(matmul_qk, matmul_qk, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(matmul_qk_out, matmul_qk_out, multihead_pattern); + + GET_IR_NODE_FROM_SUBGRAPH(eltadd_qk, eltadd_qk, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(eltadd_qk_b, eltadd_qk_b, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(eltadd_qk_out, eltadd_qk_out, multihead_pattern); + + GET_IR_NODE_FROM_SUBGRAPH(softmax_qk, softmax_qk, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(softmax_qk_out, softmax_qk_out, + multihead_pattern); + + GET_IR_NODE_FROM_SUBGRAPH(matmul_qkv, matmul_qkv, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(matmul_qkv_out, matmul_qkv_out, + multihead_pattern); + + GET_IR_NODE_FROM_SUBGRAPH(reshape2_qkv, reshape2_qkv, multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(reshape2_qkv_out, reshape2_qkv_out, + multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(transpose2_qkv, transpose2_qkv, + multihead_pattern); + GET_IR_NODE_FROM_SUBGRAPH(transpose2_qkv_out, transpose2_qkv_out, + multihead_pattern); + + fuse_creater(layer_norm_out, mul0, mul1, mul2, mul0_out, mul1_out, mul2_out, + mul0_w, mul1_w, mul2_w, eltadd0_b, eltadd1_b, eltadd2_b, + eltadd_qk_b, reshape2_0, reshape2_qkv_out, scale, scale_out); + + std::unordered_set marked_nodes({eltadd0, + eltadd1, + eltadd2, + eltadd0_b, + eltadd1_b, + eltadd2_b, + eltadd0_out, + eltadd1_out, + eltadd2_out, + reshape2_0, + reshape2_1, + reshape2_2, + reshape2_0_out, + reshape2_1_out, + reshape2_2_out, + transpose2_0, + transpose2_1, + transpose2_2, + transpose2_0_out, + transpose2_1_out, + transpose2_2_out, + matmul_qk, + matmul_qk_out, + eltadd_qk, + eltadd_qk_out, + softmax_qk, + softmax_qk_out, + transpose2_qkv, + transpose2_qkv_out, + matmul_qkv, + matmul_qkv_out, + mul0, + mul1, + mul2, + mul0_out, + mul1_out, + mul2_out, + mul0_w, + mul1_w, + mul2_w, + reshape2_qkv, + scale}); + // Remove unneeded nodes. + GraphSafeRemoveNodes(graph, marked_nodes); + ++fusion_count; + }; + gpd(graph, handler); + + return fusion_count; +} + } // namespace patterns void MultiHeadMatmulFusePass::ApplyImpl(Graph* graph) const { - PADDLE_ENFORCE_NOT_NULL(graph); FusePassBase::Init(name_scope_, graph); int fusion_count = patterns::BuildFusion(graph, name_scope_); AddStatis(fusion_count); } +void MultiHeadMatmulV2FusePass::ApplyImpl(Graph* graph) const { + FusePassBase::Init(name_scope_, graph); + auto* scope = param_scope(); + PADDLE_ENFORCE_NOT_NULL( + scope, + platform::errors::Fatal( + "During the multiheadMatmul pass, The scope should not be null.")); + + patterns::BuildFusionV2(graph, name_scope_, scope); +} + } // namespace ir } // namespace framework } // namespace paddle REGISTER_PASS(multihead_matmul_fuse_pass, paddle::framework::ir::MultiHeadMatmulFusePass); + +REGISTER_PASS(multihead_matmul_fuse_pass_v2, + paddle::framework::ir::MultiHeadMatmulV2FusePass); diff --git a/paddle/fluid/framework/ir/multihead_matmul_fuse_pass.h b/paddle/fluid/framework/ir/multihead_matmul_fuse_pass.h index ab58d9468e5549c1cc1575778efc4a5f1e25c429..d6299c39c739d7ef191eebd5f09f56aceaa9b9c7 100644 --- a/paddle/fluid/framework/ir/multihead_matmul_fuse_pass.h +++ b/paddle/fluid/framework/ir/multihead_matmul_fuse_pass.h @@ -32,8 +32,6 @@ struct MultiHeadMatmulPattern : public PatternBase { PDNode* operator()(PDNode* x); // declare operator node's name - // PATTERN_DECL_NODE(dropout); - // PATTERN_DECL_NODE(dropout_out); PATTERN_DECL_NODE(layer_norm); PATTERN_DECL_NODE(layer_norm_out); PATTERN_DECL_NODE(mul0); @@ -79,8 +77,6 @@ struct MultiHeadMatmulPattern : public PatternBase { PATTERN_DECL_NODE(eltadd_qk_out); PATTERN_DECL_NODE(softmax_qk); PATTERN_DECL_NODE(softmax_qk_out); - // PATTERN_DECL_NODE(dropout_qk); - // PATTERN_DECL_NODE(dropout_qk_out); PATTERN_DECL_NODE(matmul_qkv); PATTERN_DECL_NODE(matmul_qkv_out); @@ -98,6 +94,16 @@ class MultiHeadMatmulFusePass : public FusePassBase { const std::string name_scope_{"multihead_matmul_fuse"}; }; +class MultiHeadMatmulV2FusePass : public FusePassBase { + public: + virtual ~MultiHeadMatmulV2FusePass() {} + + protected: + void ApplyImpl(Graph* graph) const; + + const std::string name_scope_{"multihead_matmul_fuse_v2"}; +}; + } // namespace ir } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/ir/multihead_matmul_fuse_pass_tester.cc b/paddle/fluid/framework/ir/multihead_matmul_fuse_pass_tester.cc index d0a5c8c6fe85b393b3d5507e0584e96be7d024ff..d8a06b037bdefbe8776c9b95b36be80afb988393 100644 --- a/paddle/fluid/framework/ir/multihead_matmul_fuse_pass_tester.cc +++ b/paddle/fluid/framework/ir/multihead_matmul_fuse_pass_tester.cc @@ -17,6 +17,27 @@ namespace paddle { namespace framework { namespace ir { +void AddVarToScope(Scope* param_scope, const std::string& name, + const DDim& dims) { + auto* tensor = param_scope->Var(name)->GetMutable(); + tensor->Resize(dims); + tensor->mutable_data(platform::CPUPlace()); +} + +Scope* CreateParamScope() { + auto param_scope = new Scope(); + AddVarToScope(param_scope, "weights0", {768, 768}); + AddVarToScope(param_scope, "weights1", {768, 768}); + AddVarToScope(param_scope, "weights2", {768, 768}); + + AddVarToScope(param_scope, "bias_0", {768}); + AddVarToScope(param_scope, "bias_1", {768}); + AddVarToScope(param_scope, "bias_2", {768}); + AddVarToScope(param_scope, "biasqk", {768}); + AddVarToScope(param_scope, "weightsl", {768, 768}); + return param_scope; +} + TEST(MultiHeadMatmulFusePass, basic) { // inputs operator output // -------------------------------------------------------------------- @@ -87,7 +108,10 @@ TEST(MultiHeadMatmulFusePass, basic) { layers.mul(reshape_qkv_out, weights_l); std::unique_ptr graph(new ir::Graph(layers.main_program())); - auto pass = PassRegistry::Instance().Get("multihead_matmul_fuse_pass"); + graph->Set("__param_scope__", CreateParamScope()); + + auto pass = PassRegistry::Instance().Get("multihead_matmul_fuse_pass_v2"); + if (pass.get() == nullptr) LOG(INFO) << "asdfasdf"; int num_nodes_before = graph->Nodes().size(); VLOG(3) << DebugString(graph); @@ -96,8 +120,17 @@ TEST(MultiHeadMatmulFusePass, basic) { int num_fused_nodes_after = GetNumOpNodes(graph, "multihead_matmul"); VLOG(3) << DebugString(graph); - PADDLE_ENFORCE_EQ(num_nodes_before, num_nodes_after + 29); - PADDLE_ENFORCE_EQ(num_fused_nodes_after, 1); + PADDLE_ENFORCE_EQ( + num_nodes_before, num_nodes_after + 39, + platform::errors::InvalidArgument( + "After the multihead_matmul pass, The node num in graph " + "should be %d, but the result is %d", + num_nodes_before - 39, num_nodes_after)); + PADDLE_ENFORCE_EQ(num_fused_nodes_after, 1, + platform::errors::InvalidArgument( + "After the multihead_matmul pass, there should be one " + "multihead_matmul op, but the result is %d", + num_fused_nodes_after)); } } // namespace ir @@ -105,3 +138,4 @@ TEST(MultiHeadMatmulFusePass, basic) { } // namespace paddle USE_PASS(multihead_matmul_fuse_pass); +USE_PASS(multihead_matmul_fuse_pass_v2); diff --git a/paddle/fluid/inference/api/paddle_pass_builder.cc b/paddle/fluid/inference/api/paddle_pass_builder.cc index 6f964b77d1e6158150b551c343b76aa24e72cc4b..2c615296fb0d08653d130fc35ca3c6a6898870f7 100644 --- a/paddle/fluid/inference/api/paddle_pass_builder.cc +++ b/paddle/fluid/inference/api/paddle_pass_builder.cc @@ -107,7 +107,7 @@ GpuPassStrategy::GpuPassStrategy() : PassStrategy({}) { "conv_eltwiseadd_affine_channel_fuse_pass", // "conv_bn_fuse_pass", // "conv_eltwiseadd_bn_fuse_pass", // - "multihead_matmul_fuse_pass", + "multihead_matmul_fuse_pass_v2", "fc_fuse_pass", // "fc_elementwise_layernorm_fuse_pass", // #if CUDNN_VERSION >= 7100 // To run conv_fusion, the version of cudnn must be diff --git a/paddle/fluid/operators/fused/multihead_matmul_op.cc b/paddle/fluid/operators/fused/multihead_matmul_op.cc index b82cfb2d81422210addcdd7c3b6955263e769113..ccc90ae368f8a0a8fb976d7595b0bffff0da3292 100644 --- a/paddle/fluid/operators/fused/multihead_matmul_op.cc +++ b/paddle/fluid/operators/fused/multihead_matmul_op.cc @@ -15,126 +15,80 @@ limitations under the License. */ #include #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/detail/safe_ref.h" +#include "paddle/fluid/platform/errors.h" namespace paddle { namespace operators { -class MultiHeadMatMulOp : public framework::OperatorWithKernel { +class MultiHeadMatMulV2Op : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; protected: void InferShape(framework::InferShapeContext *context) const override { - PADDLE_ENFORCE_EQ(context->HasInput("Q"), true, - "Input(Q) of MultiheadOp should not be null."); - PADDLE_ENFORCE_EQ(context->HasInput("K"), true, - "Input(K) of MultiheadOp should not be null."); - PADDLE_ENFORCE_EQ(context->HasInput("V"), true, - "Input(V) of MultiheadOp should not be null."); - PADDLE_ENFORCE_EQ(context->HasInput("BiasQ"), true, - "Input(BiasQ) of MultiheadOp should not be null."); - PADDLE_ENFORCE_EQ(context->HasInput("BiasK"), true, - "Input(BiasQ) of MultiheadOp should not be null."); - PADDLE_ENFORCE_EQ(context->HasInput("BiasV"), true, - "Input(BiasQ) of MultiheadOp should not be null."); - PADDLE_ENFORCE_EQ(context->HasInput("BiasQK"), true, - "Input(BiasQK) of MultiheadOp should not be null."); - PADDLE_ENFORCE_EQ(context->HasOutput("Out"), true, - "Output(Out) of MatMulOp should not be null."); - - auto dim_q = context->GetInputDim("Q"); - PADDLE_ENFORCE_GT(dim_q.size(), 2, - "Multihead input should be at least 3-D tensor."); - - auto dim_k = context->GetInputDim("K"); - PADDLE_ENFORCE_GT(dim_q.size(), 2, - "Multihead input should be at least 3-D tensor."); - - auto dim_v = context->GetInputDim("V"); - PADDLE_ENFORCE_GT(dim_q.size(), 2, - "Multihead input should be at least 3-D tensor."); - - PADDLE_ENFORCE_EQ(dim_q[0], dim_k[0], - "Multihead input should have same batch size"); - PADDLE_ENFORCE_EQ(dim_q[0], dim_v[0], - "Multihead input should have same batch size"); - - PADDLE_ENFORCE_EQ(dim_q[1], dim_k[1], - "Multihead input should have same size"); - PADDLE_ENFORCE_EQ(dim_q[1], dim_v[1], - "Multihead input should have same size"); - - PADDLE_ENFORCE_EQ(dim_q[2], dim_k[2], - "Multihead input should have same size"); - PADDLE_ENFORCE_EQ(dim_q[2], dim_v[2], - "Multihead input should have same size"); - - auto dim_bias_q = context->GetInputDim("BiasQ"); - PADDLE_ENFORCE_GT(dim_bias_q.size(), 0, - "Multihead input should be at least 1-D tensor."); - auto dim_bias_k = context->GetInputDim("BiasK"); - PADDLE_ENFORCE_GT(dim_bias_k.size(), 0, - "Multihead input should be at least 1-D tensor."); - auto dim_bias_v = context->GetInputDim("BiasV"); - PADDLE_ENFORCE_GT(dim_bias_v.size(), 0, - "Multihead input should be at least 1-D tensor."); - - PADDLE_ENFORCE_EQ(dim_bias_q[0], dim_bias_k[0], - "Multihead input bias should have same batch size"); - PADDLE_ENFORCE_EQ(dim_bias_q[0], dim_bias_v[0], - "Multihead input bias should have same batch size"); - - auto dim_bias_qk = context->GetInputDim("BiasQK"); - PADDLE_ENFORCE_GT(dim_bias_qk.size(), 3, - "Multihead input bias qk should be at least 4-D tensor."); - - int b_indx = dim_bias_q.size() - 1; - int indx = dim_q.size() - 1; - PADDLE_ENFORCE_EQ( - dim_bias_q[b_indx], dim_q[indx], + context->HasInput("Input"), true, platform::errors::InvalidArgument( - "bias_q's last dim size should equal to" - " q last dim size, but received bias_q's size is:%d q is:%d", - dim_bias_q[b_indx], dim_q[indx])); + "Input(Input) of MultiHeadMatMul should not be null.")); + PADDLE_ENFORCE_EQ(context->HasInput("W"), true, + platform::errors::InvalidArgument( + "Input(W) of MultiHeadMatMul should not be null.")); PADDLE_ENFORCE_EQ( - dim_bias_k[b_indx], dim_k[indx], + context->HasInput("Bias"), true, platform::errors::InvalidArgument( - "bias_k's last dim size should equal to" - " k last dim size, but received bias_k's size is:%d k is:%d", - dim_bias_k[b_indx], dim_k[indx])); + "Input(Bias) of MultiHeadMatMul should not be null.")); PADDLE_ENFORCE_EQ( - dim_bias_v[b_indx], dim_v[indx], + context->HasInput("BiasQK"), true, platform::errors::InvalidArgument( - "bias_v's last dim size should equal to" - " v last dim size, but received bias_v's size is:%d v is:%d", - dim_bias_v[b_indx], dim_v[indx])); + "Input(BiasQK) of MultiHeadMatMul should not be null.")); + PADDLE_ENFORCE_EQ( + context->HasOutput("Out"), true, + platform::errors::InvalidArgument( + "Output(Out) of MultiHeadMatMul should not be null.")); - PADDLE_ENFORCE_EQ(dim_q[0], dim_bias_qk[0], - platform::errors::InvalidArgument( - "q should have same batch size" - "with bias_qk, but received q's batch size is:%d " - "bias_qk's batch size is:%d", - dim_q[0], dim_bias_qk[0])); + auto dim_w = context->GetInputDim("W"); + PADDLE_ENFORCE_GT( + dim_w.size(), 2, + platform::errors::InvalidArgument( + "Multihead input is expected at least a 3-D tensor, but " + "it's %d-D tensor now.", + dim_w.size())); - int head_number = context->Attrs().Get("head_number"); - PADDLE_ENFORCE_GT(head_number, 1, - "Multihead input head number should be at least 1."); + auto dim_bias_q = context->GetInputDim("Bias"); + PADDLE_ENFORCE_GT( + dim_bias_q.size(), 1, + platform::errors::InvalidArgument( + "Multihead input should be at least 2-D tensor, but it's " + "%d-D tensor now.", + dim_bias_q.size())); + + auto dim_bias_qk = context->GetInputDim("BiasQK"); + PADDLE_ENFORCE_GT( + dim_bias_qk.size(), 3, + platform::errors::InvalidArgument( + "Multihead input bias qk should be at least 4-D tensor, " + "but it's %d-D tensor now.", + dim_bias_qk.size())); - context->SetOutputDim("Out", dim_q); - context->ShareLoD("Q", /*->*/ "Out"); + int head_number = context->Attrs().Get("head_number"); + PADDLE_ENFORCE_GT( + head_number, 1, + platform::errors::InvalidArgument( + "Multihead input head number should be at least 1, but it %d now.", + head_number)); + // modify this + auto dim_input = context->GetInputDim("Input"); + context->SetOutputDim("Out", dim_input); + context->ShareLoD("Input", /*->*/ "Out"); } }; -class MultiHeadMatMulOpMaker : public framework::OpProtoAndCheckerMaker { +class MultiHeadMatMulV2OpMaker : public framework::OpProtoAndCheckerMaker { public: void Make() override { - AddInput("Q", "The first input of MultiHeadMatMul op"); - AddInput("K", "The second input of MMultiHeadMatMul op"); - AddInput("V", "The third input of MultiHeadMatMul op"); - AddInput("BiasQ", "The first bias input of MultiHeadMatMul op"); - AddInput("BiasK", "The second bias input of MultiHeadMatMul op"); - AddInput("BiasV", "The third bias input of MultiHeadMatMul op"); + AddInput("Input", "The input of MultiHeadMatMul op"); + AddInput("W", "The weight input of MultiHeadMatMul op"); + AddInput("Bias", "The bias input of MultiHeadMatMul op"); AddInput("BiasQK", "The QK bias input of MultiHeadMatMul op"); AddOutput("Out", "The output of MultiHeadMatMul op"); AddAttr("transpose_Q", @@ -161,10 +115,6 @@ Not suggest to use in other case except has same structure as ernie. Example of matrix multiplication with head_number of B - X: [B, M, K], Y: [B, K, N] => Out: [B, M, N] -Both the input `Q` and `K` can carry the LoD (Level of Details) information, -or not. But the output only shares the LoD information with input `Q`, because -they are the same. - )DOC"); } }; @@ -173,5 +123,5 @@ they are the same. } // namespace paddle namespace ops = paddle::operators; -REGISTER_OP_WITHOUT_GRADIENT(multihead_matmul, ops::MultiHeadMatMulOp, - ops::MultiHeadMatMulOpMaker); +REGISTER_OP_WITHOUT_GRADIENT(multihead_matmul, ops::MultiHeadMatMulV2Op, + ops::MultiHeadMatMulV2OpMaker); diff --git a/paddle/fluid/operators/fused/multihead_matmul_op.cu b/paddle/fluid/operators/fused/multihead_matmul_op.cu index 9648b62423e94528f5c73069baee9724c4babfe9..2500f66c672733c29d2200e2bdf97597a7cadad4 100644 --- a/paddle/fluid/operators/fused/multihead_matmul_op.cu +++ b/paddle/fluid/operators/fused/multihead_matmul_op.cu @@ -300,7 +300,7 @@ template void MatMulWithHeadQKV(const platform::CUDADeviceContext &context, int head_num, int seq_len, int size_per_head, int batch_size, bool qk_trans, bool v_trans, T *v_buf_, const T *qk_buf_, - T *dst, T *out, T alpha, T beta) { + T *dst, T alpha, T beta) { int m = batch_size * seq_len; int k = head_num * size_per_head; @@ -312,96 +312,199 @@ void MatMulWithHeadQKV(const platform::CUDADeviceContext &context, int head_num, blas.BatchedGEMM(transA, transB, seq_len, size_per_head, seq_len, alpha, qk_buf_, v_buf_, beta, dst, batch_size * head_num, seq_len * seq_len, seq_len * size_per_head); +} - int grid = batch_size * head_num * seq_len; - int block = size_per_head; - transpose<<>>(dst, out, batch_size, seq_len, - head_num, size_per_head); +template +inline __device__ T add_func(T a, T b); + +template <> +__device__ float add_func(float a, float b) { + return a + b; +} + +template <> +__device__ float2 add_func(float2 a, float2 b) { + float2 c; + c.x = a.x + b.x; + c.y = a.y + b.y; + return c; +} + +template <> +__device__ float4 add_func(float4 a, float4 b) { + float4 c; + c.x = a.x + b.x; + c.y = a.y + b.y; + c.z = a.z + b.z; + c.w = a.w + b.w; + return c; } template -void MultiHeadGPUCompute(const platform::CUDADeviceContext &dev_ctx, - int head_num, const framework::DDim &mat_q, - const framework::DDim &mat_k, - const framework::DDim &mat_v, const T *Q, const T *K, - const T *V, const T *bias_q, const T *bias_k, - const T *bias_v, const T *bias_qk, T *out, T alpha, - T beta, bool trans_q, bool trans_k, bool trans_v) { - int seq_len = mat_q[1]; - int size_per_head = (mat_q[2] / head_num); - int batch_size = mat_q[0]; - int buf_size = batch_size * head_num * seq_len * size_per_head; - int qk_buf_size = batch_size * head_num * seq_len * seq_len; - - auto alloc_buf = - memory::Alloc(dev_ctx, (buf_size * 4 + qk_buf_size) * sizeof(T)); - - T *buf = reinterpret_cast(alloc_buf->ptr()); - T *q_buf = buf; - T *k_buf = buf + buf_size; - T *v_buf = buf + 2 * buf_size; - T *qk_buf = buf + 3 * buf_size; - T *dst_buf = buf + 3 * buf_size + qk_buf_size; +__global__ void transpose_qkv_kernel(const int H, const T *input, const T *bias, + T *output) { + // Input: BxSx3xNxH + // Bias: 3xSxB + // Output: 3xBxNxSxH + int n = threadIdx.y; + int s = blockIdx.x; + int b = blockIdx.y; + int m = blockIdx.z; + + const int N = blockDim.y; + const int S = gridDim.x; + const int B = gridDim.y; + + const int NH = N * H; + const int NHS = NH * S; + const int in_offset = n * H + m * NH + s * 3 * NH + b * NHS * 3; + const int bias_offset = m * NH + n * H; + const int out_offset = s * H + n * S * H + b * NHS + m * NHS * B; + + const int i = threadIdx.x; + output[out_offset + i] = + add_func(input[in_offset + i], bias[bias_offset + i]); +} - int m = batch_size * seq_len; - int k = head_num * size_per_head; +void TransQKVWithBias(const int batch, const int seq_len, const int head_size, + const int head_num, const float *input, const float *bias, + float *output, cudaStream_t stream) { + // BxSx3xNxH + 3xNxH -> 3xBxNxSxH + const dim3 grid(seq_len, batch, 3); + if (head_size % 4 == 0) { + const int h = head_size / 4; + const float4 *input4 = reinterpret_cast(input); + const float4 *bias4 = reinterpret_cast(bias); + float4 *output4 = reinterpret_cast(output); + const dim3 block(h, head_num, 1); + + // limit h * head_num to max block size(1024). + PADDLE_ENFORCE_LE(h * head_num, 1024, + platform::errors::InvalidArgument( + "head_num (%d) * head_size (%d) should <= %d", + head_num, head_size, 1024 * 4)); + transpose_qkv_kernel<<>>(h, input4, bias4, + output4); + } else if (head_size % 2 == 0) { + const int h = head_size / 2; + const float2 *input2 = reinterpret_cast(input); + const float2 *bias2 = reinterpret_cast(bias); + float2 *output2 = reinterpret_cast(output); + const dim3 block(h, head_num, 1); + // limit h * head_num to max block size(1024). + PADDLE_ENFORCE_LE(h * head_num, 1024, + platform::errors::InvalidArgument( + "head_num (%d) * head_size (%d) should <= %d", + head_num, head_size, 1024 * 2)); + transpose_qkv_kernel<<>>(h, input2, bias2, + output2); + } else { + const dim3 block(head_size, head_num, 1); + // limit head_size * head_num to max block size(1024). + PADDLE_ENFORCE_LE(head_size * head_num, 1024, + platform::errors::InvalidArgument( + "head_num (%d) * head_size (%d) should <= %d", + head_num, head_size, 1024)); + transpose_qkv_kernel<<>>(head_size, input, + bias, output); + } +} - // Each block process head*size-per_head element, - // have m lines. bias is m lines - auto blas = math::GetBlas(dev_ctx); +template +void MultiHeadGPUComputeV2(const platform::CUDADeviceContext &dev_ctx, + int batch, int seq_len, int head_num, int head_size, + T *qkptr, const T *bias_qk_ptr, T *tptr, T alpha, + T beta) { auto stream = dev_ctx.stream(); - - int grid = m; - PADDLE_ENFORCE_LE(k, 1024, - "Input head_number * size_per_head should <= 1024"); - int block = k <= 1024 ? k : 1024; - add_QKV<<>>(Q, K, V, q_buf, k_buf, v_buf, bias_q, - bias_k, bias_v, batch_size, seq_len, - head_num, size_per_head); - - MatMulWithHeadQK(dev_ctx, head_num, seq_len, size_per_head, batch_size, - trans_q, trans_k, q_buf, k_buf, qk_buf, bias_qk, alpha, - beta); - MatMulWithHeadQKV(dev_ctx, head_num, seq_len, size_per_head, batch_size, - false, trans_v, v_buf, qk_buf, dst_buf, out, T(1.0), - beta); + const int tsize = batch * head_num * seq_len * head_size; + + T *qptr = tptr; + T *kptr = qptr + tsize; + T *vptr = kptr + tsize; + // batch gemm stride, softmaxwithscale. + MatMulWithHeadQK(dev_ctx, head_num, seq_len, head_size, batch, false, true, + qptr, kptr, qkptr, bias_qk_ptr, alpha, beta); + // batch gemm stride, transpose. + MatMulWithHeadQKV(dev_ctx, head_num, seq_len, head_size, batch, false, + false, vptr, qkptr, tptr, T(1.0), beta); } template -class MultiHeadMatMulKernel : public framework::OpKernel { +class MultiHeadMatMulV2Kernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext &context) const override { - auto *q = context.Input("Q"); - auto *k = context.Input("K"); - auto *v = context.Input("V"); - - auto &bias_q = detail::Ref(context.Input("BiasQ"), - "Cannot find BiasQ"); - auto &bias_k = detail::Ref(context.Input("BiasK"), - "Cannot find BiasK"); - auto &bias_v = detail::Ref(context.Input("BiasV"), - "Cannot find BiasV"); + using Tensor = framework::Tensor; + auto *input = context.Input("Input"); + auto *w = context.Input("W"); + auto *bias = context.Input("Bias"); auto &bias_qk = detail::Ref(context.Input("BiasQK"), "Cannot find QK"); - auto *out = context.Output("Out"); - out->mutable_data(context.GetPlace()); + auto *input_d = input->data(); + auto *w_d = w->data(); + auto *bias_d = bias->data(); + auto *bias_qk_d = bias_qk.data(); + auto *output_d = out->mutable_data(context.GetPlace()); T scale = static_cast(context.Attr("alpha")); - bool transpose_q = context.Attr("transpose_Q"); - bool transpose_k = context.Attr("transpose_K"); - bool transpose_v = context.Attr("transpose_V"); int head_number = context.Attr("head_number"); // compute q*k with eltadd auto &device_ctx = context.template device_context(); - - MultiHeadGPUCompute(device_ctx, head_number, q->dims(), k->dims(), - v->dims(), q->data(), k->data(), v->data(), - bias_q.data(), bias_k.data(), bias_v.data(), - bias_qk.data(), out->data(), scale, T(0.0), - transpose_q, transpose_k, transpose_v); + // should be (B * S * hidden) + auto input_dims = input->dims(); + // shouble be (hidden * 3 * all_head_size) + auto w_dims = w->dims(); + int batch = input_dims[0]; + int seq_len = input_dims[1]; + int hidden = input_dims[2]; + + int all_head_size = w_dims[2]; + int head_size = all_head_size / head_number; + + // (B*S, hidden) + const Tensor input_matrix = + framework::ReshapeToMatrix(*input, 2 /*x_num_col_dims */); + // (hidden, 3 * all_head_size) + const Tensor w_matrix = + framework::ReshapeToMatrix(*w, 1 /*y_num_col_dims*/); + + Tensor temp_out_tensor; + auto temp_out_dims = + framework::make_ddim({batch, seq_len, 3, head_number, head_size}); + temp_out_tensor.Resize({batch * seq_len, framework::product(temp_out_dims) / + (batch * seq_len)}); + auto *temp_out_data = temp_out_tensor.mutable_data(context.GetPlace()); + + // (B * S, hidden) * (hidden, 3 * N * H) -> (B * S * 3 * N * H) + auto blas = math::GetBlas(device_ctx); + blas.MatMul(input_matrix, w_matrix, &temp_out_tensor); + + // temp_out_tensor.Resize(temp_out_dims); + + Tensor multihead_temp_tensor; + // B * head_number * S * S * 1 + B * S * 3 * N * H + int scratch_size = batch * head_number * seq_len * seq_len * 1; + multihead_temp_tensor.Resize({scratch_size + temp_out_tensor.numel()}); + auto *multihead_temp_data = + multihead_temp_tensor.mutable_data(context.GetPlace()); + auto *qkptr = multihead_temp_data; + auto *tptr = multihead_temp_data + scratch_size; + + auto stream = device_ctx.stream(); + // Do the transpose with bias. + // BxSx3xNxH => tptr: 3xBxNxSxH. + TransQKVWithBias(batch, seq_len, head_size, head_number, temp_out_data, + bias_d, tptr, stream); + + MultiHeadGPUComputeV2(device_ctx, batch, seq_len, head_number, head_size, + qkptr, bias_qk_d, tptr, scale, T(0.0)); + + int grid = batch * head_number * seq_len; + int block = head_size; + transpose<<>>(tptr, output_d, batch, seq_len, + head_number, head_size); } }; @@ -411,5 +514,4 @@ class MultiHeadMatMulKernel : public framework::OpKernel { namespace ops = paddle::operators; REGISTER_OP_CUDA_KERNEL( multihead_matmul, - ops::MultiHeadMatMulKernel, - ops::MultiHeadMatMulKernel); + ops::MultiHeadMatMulV2Kernel); diff --git a/python/paddle/fluid/tests/unittests/test_fused_multihead_matmul_op.py b/python/paddle/fluid/tests/unittests/test_fused_multihead_matmul_op.py index 9890cbb12220a352b0d626a01587fd3497745543..d78e929fb60a1c8102ce378ef60d5d18b0e7e879 100644 --- a/python/paddle/fluid/tests/unittests/test_fused_multihead_matmul_op.py +++ b/python/paddle/fluid/tests/unittests/test_fused_multihead_matmul_op.py @@ -47,12 +47,21 @@ class TestFusedMultiheadMatmulOp(OpTest): self.config() h = self.seq_len w = self.head_number * self.size_per_head - self.Q = np.random.random((self.batch_size, h, w)).astype("float32") - self.K = np.random.random((self.batch_size, h, w)).astype("float32") - self.V = np.random.random((self.batch_size, h, w)).astype("float32") + self.Input = np.random.random( + (self.batch_size, h, w)).astype("float32") - 0.5 + self.WQ = np.random.random((w, w)).astype("float32") + self.KQ = np.random.random((w, w)).astype("float32") + self.VQ = np.random.random((w, w)).astype("float32") + self.CombinedW = np.hstack((self.WQ, self.KQ, self.VQ)).reshape( + (w, 3, w)) + self.Q = np.dot(self.Input, self.WQ) + self.K = np.dot(self.Input, self.KQ) + self.V = np.dot(self.Input, self.VQ) + self.BiasQ = np.random.random((1, w)).astype("float32") self.BiasK = np.random.random((1, w)).astype("float32") self.BiasV = np.random.random((1, w)).astype("float32") + self.CombinedB = np.vstack((self.BiasQ, self.BiasK, self.BiasV)) self.BiasQK = np.random.random( (self.batch_size, self.head_number, self.seq_len, self.seq_len)).astype("float32") @@ -84,12 +93,9 @@ class TestFusedMultiheadMatmulOp(OpTest): reshape_qkv = np.reshape(transpose_qkv, (self.batch_size, h, w)) self.inputs = { - "Q": self.Q, - "K": self.K, - "V": self.V, - "BiasQ": self.BiasQ, - "BiasK": self.BiasK, - "BiasV": self.BiasV, + "Input": self.Input, + "W": self.CombinedW, + "Bias": self.CombinedB, "BiasQK": self.BiasQK } self.attrs = {