From a365024c8190fc3f6199a9b9c6b26032a36f8efa Mon Sep 17 00:00:00 2001 From: minghaoBD <79566150+minghaoBD@users.noreply.github.com> Date: Thu, 1 Dec 2022 15:19:25 +0800 Subject: [PATCH] fuse-mt passes compatible with structured pruning (#48585) * fuse-mt passes compatible with structured pruning --- .../fused_multi_transformer_encoder_pass.cc | 72 ++++++++++--------- .../fused/fused_multi_transformer_op.cc | 21 ------ 2 files changed, 40 insertions(+), 53 deletions(-) diff --git a/paddle/fluid/framework/ir/fused_multi_transformer_encoder_pass.cc b/paddle/fluid/framework/ir/fused_multi_transformer_encoder_pass.cc index 3635613f8c..6f0ef5b755 100644 --- a/paddle/fluid/framework/ir/fused_multi_transformer_encoder_pass.cc +++ b/paddle/fluid/framework/ir/fused_multi_transformer_encoder_pass.cc @@ -1325,17 +1325,6 @@ int FusedMultiTransformerEncoderPass::BuildFusion(Graph* graph, Node* ffn_eltadd0_b, Node* ffn_eltadd1_b, Node* ffn_output) { - auto reshape_desc = reshape2_0->Op(); - int num_head = - PADDLE_GET_CONST(std::vector, reshape_desc->GetAttr("shape")) - .at(2); - int dim_head = - PADDLE_GET_CONST(std::vector, reshape_desc->GetAttr("shape")) - .at(3); - auto* layer_norm_bias_tensor = - scope->FindVar(layer_norm_bias->Name())->GetMutable(); - int dim_embed = layer_norm_bias_tensor->dims()[0]; - auto* matmul0_op = matmul0->Op(); auto* matmul_linear_op = matmul_linear->Op(); auto* ffn_matmul_0_op = ffn_matmul0->Op(); @@ -1364,6 +1353,20 @@ int FusedMultiTransformerEncoderPass::BuildFusion(Graph* graph, auto* bv_tensor = scope->FindVar(eltadd2_b->Name())->GetMutable(); + // NOTE(minghaoBD): to make it compatible with strucutured pruning on + // num_head dimension: + // 1. get dim_head from reshape.shape[3], dim_embed from + // layer_norm_bias.shape[0] + // 2. calculate num_head according to wq_tensor.shape[1] and dim_head + auto reshape_desc = reshape2_0->Op(); + int dim_head = + PADDLE_GET_CONST(std::vector, reshape_desc->GetAttr("shape")) + .at(3); + auto* layer_norm_bias_tensor = + scope->FindVar(layer_norm_bias->Name())->GetMutable(); + int dim_embed = layer_norm_bias_tensor->dims()[0]; + int num_head = wq_tensor->dims()[1] / dim_head; + QKVWeightsBiasProcess(wq_tensor, wk_tensor, wv_tensor, @@ -2195,18 +2198,6 @@ int FusedMultiTransformerEncoderFuseQKVPass::BuildFusion( Node* ffn_eltadd0_b, Node* ffn_eltadd1_b, Node* ffn_output) { - auto reshape_desc = reshape2_0->Op(); - int num_head = - PADDLE_GET_CONST(std::vector, reshape_desc->GetAttr("shape")) - .at(2); - int dim_head = - PADDLE_GET_CONST(std::vector, reshape_desc->GetAttr("shape")) - .at(3) / - 3; // 3 for qkv - auto* layer_norm_bias_tensor = - scope->FindVar(layer_norm_bias->Name())->GetMutable(); - int dim_embed = layer_norm_bias_tensor->dims()[0]; - auto* matmul0_op = matmul0->Op(); auto* matmul_linear_op = matmul_linear->Op(); auto* ffn_matmul_0_op = ffn_matmul0->Op(); @@ -2226,6 +2217,21 @@ int FusedMultiTransformerEncoderFuseQKVPass::BuildFusion( auto* qkv_b_tensor = scope->FindVar(eltadd0_b->Name())->GetMutable(); + // NOTE(minghaoBD): to make it compatible with strucutured pruning on + // num_head dimension: + // 1. get dim_head from reshape.shape[3], dim_embed from + // layer_norm_bias.shape[0] + // 2. calculate num_head according to wqkv_tensor.shape[1]/3 and dim_head + auto reshape_desc = reshape2_0->Op(); + int dim_head = + PADDLE_GET_CONST(std::vector, reshape_desc->GetAttr("shape")) + .at(3) / + 3; // 3 for qkv + auto* layer_norm_bias_tensor = + scope->FindVar(layer_norm_bias->Name())->GetMutable(); + int dim_embed = layer_norm_bias_tensor->dims()[0]; + int num_head = qkv_w_tensor->dims()[1] / 3 / dim_head; + QKVWeightsBiasProcessFuseQKV( qkv_w_tensor, qkv_b_tensor, num_head, dim_head, dim_embed); @@ -2995,15 +3001,6 @@ int MultiDevicesFusedMultiTransformerEncoderFuseQKVPass::BuildFusion( Node* ffn_eltadd0_b, Node* ffn_eltadd1_b, Node* ffn_output) { - auto reshape_desc = reshape2_0->Op(); - int num_head = - PADDLE_GET_CONST(std::vector, reshape_desc->GetAttr("shape")) - .at(2); - int dim_head = - PADDLE_GET_CONST(std::vector, reshape_desc->GetAttr("shape")) - .at(3) / - 3; // 3 for qkv - auto* matmul0_op = matmul0->Op(); auto* matmul_linear_op = matmul_linear->Op(); auto* ffn_matmul_0_op = ffn_matmul0->Op(); @@ -3023,9 +3020,20 @@ int MultiDevicesFusedMultiTransformerEncoderFuseQKVPass::BuildFusion( auto* qkv_b_tensor = scope->FindVar(eltadd0_b->Name())->GetMutable(); + // NOTE(minghaoBD): to make it compatible with strucutured pruning on + // num_head dimension: + // 1. get dim_head from reshape.shape[3], dim_embed from + // layer_norm_bias.shape[0] + // 2. calculate num_head according to wqkv_tensor.shape[1]/3 and dim_head auto* layer_norm_bias_tensor = scope->FindVar(layer_norm_bias->Name())->GetMutable(); int dim_embed = layer_norm_bias_tensor->dims()[0]; + auto reshape_desc = reshape2_0->Op(); + int dim_head = + PADDLE_GET_CONST(std::vector, reshape_desc->GetAttr("shape")) + .at(3) / + 3; // 3 for qkv + int num_head = qkv_w_tensor->dims()[1] / 3 / dim_head; QKVWeightsBiasProcessFuseQKV( qkv_w_tensor, qkv_b_tensor, num_head, dim_head, dim_embed); diff --git a/paddle/fluid/operators/fused/fused_multi_transformer_op.cc b/paddle/fluid/operators/fused/fused_multi_transformer_op.cc index 6a4c3890e5..94a89338a6 100644 --- a/paddle/fluid/operators/fused/fused_multi_transformer_op.cc +++ b/paddle/fluid/operators/fused/fused_multi_transformer_op.cc @@ -93,27 +93,6 @@ class FusedMultiTransformerOp : public framework::OperatorWithKernel { x_dim, y_dim)); - if (ctx->Attrs().Get("ring_id") == -1) { - if (trans_qkvw) { - PADDLE_ENFORCE_EQ(y_dim[1] * y_dim[2], - y_dim[3], - platform::errors::InvalidArgument( - "The dimensions of qkv_weight must be 4" - "(3, num_head, dim_head, dim_embed)," - "and must satisfy the limitations: " - "(num_head * dim_head == dim_embed)")); - - } else { - PADDLE_ENFORCE_EQ(y_dim[2] * y_dim[3], - y_dim[0], - platform::errors::InvalidArgument( - "The dimensions of qkv_weight must be 4" - "(dim_embed, 3, num_head, dim_head)," - "and must satisfy the limitations: " - "(num_head * dim_head == dim_embed)")); - } - } - if (ctx->HasInputs("CacheKV")) { // [2, batch_size, num_head, max_seq_len, head_size] const auto &c_dims = ctx->GetInputsDim("CacheKV"); -- GitLab