未验证 提交 a365024c 编写于 作者: M minghaoBD 提交者: GitHub

fuse-mt passes compatible with structured pruning (#48585)

* fuse-mt passes compatible with structured pruning
上级 310f4320
...@@ -1325,17 +1325,6 @@ int FusedMultiTransformerEncoderPass::BuildFusion(Graph* graph, ...@@ -1325,17 +1325,6 @@ int FusedMultiTransformerEncoderPass::BuildFusion(Graph* graph,
Node* ffn_eltadd0_b, Node* ffn_eltadd0_b,
Node* ffn_eltadd1_b, Node* ffn_eltadd1_b,
Node* ffn_output) { Node* ffn_output) {
auto reshape_desc = reshape2_0->Op();
int num_head =
PADDLE_GET_CONST(std::vector<int>, reshape_desc->GetAttr("shape"))
.at(2);
int dim_head =
PADDLE_GET_CONST(std::vector<int>, reshape_desc->GetAttr("shape"))
.at(3);
auto* layer_norm_bias_tensor =
scope->FindVar(layer_norm_bias->Name())->GetMutable<phi::DenseTensor>();
int dim_embed = layer_norm_bias_tensor->dims()[0];
auto* matmul0_op = matmul0->Op(); auto* matmul0_op = matmul0->Op();
auto* matmul_linear_op = matmul_linear->Op(); auto* matmul_linear_op = matmul_linear->Op();
auto* ffn_matmul_0_op = ffn_matmul0->Op(); auto* ffn_matmul_0_op = ffn_matmul0->Op();
...@@ -1364,6 +1353,20 @@ int FusedMultiTransformerEncoderPass::BuildFusion(Graph* graph, ...@@ -1364,6 +1353,20 @@ int FusedMultiTransformerEncoderPass::BuildFusion(Graph* graph,
auto* bv_tensor = auto* bv_tensor =
scope->FindVar(eltadd2_b->Name())->GetMutable<phi::DenseTensor>(); scope->FindVar(eltadd2_b->Name())->GetMutable<phi::DenseTensor>();
// NOTE(minghaoBD): to make it compatible with strucutured pruning on
// num_head dimension:
// 1. get dim_head from reshape.shape[3], dim_embed from
// layer_norm_bias.shape[0]
// 2. calculate num_head according to wq_tensor.shape[1] and dim_head
auto reshape_desc = reshape2_0->Op();
int dim_head =
PADDLE_GET_CONST(std::vector<int>, reshape_desc->GetAttr("shape"))
.at(3);
auto* layer_norm_bias_tensor =
scope->FindVar(layer_norm_bias->Name())->GetMutable<phi::DenseTensor>();
int dim_embed = layer_norm_bias_tensor->dims()[0];
int num_head = wq_tensor->dims()[1] / dim_head;
QKVWeightsBiasProcess(wq_tensor, QKVWeightsBiasProcess(wq_tensor,
wk_tensor, wk_tensor,
wv_tensor, wv_tensor,
...@@ -2195,18 +2198,6 @@ int FusedMultiTransformerEncoderFuseQKVPass::BuildFusion( ...@@ -2195,18 +2198,6 @@ int FusedMultiTransformerEncoderFuseQKVPass::BuildFusion(
Node* ffn_eltadd0_b, Node* ffn_eltadd0_b,
Node* ffn_eltadd1_b, Node* ffn_eltadd1_b,
Node* ffn_output) { Node* ffn_output) {
auto reshape_desc = reshape2_0->Op();
int num_head =
PADDLE_GET_CONST(std::vector<int>, reshape_desc->GetAttr("shape"))
.at(2);
int dim_head =
PADDLE_GET_CONST(std::vector<int>, reshape_desc->GetAttr("shape"))
.at(3) /
3; // 3 for qkv
auto* layer_norm_bias_tensor =
scope->FindVar(layer_norm_bias->Name())->GetMutable<phi::DenseTensor>();
int dim_embed = layer_norm_bias_tensor->dims()[0];
auto* matmul0_op = matmul0->Op(); auto* matmul0_op = matmul0->Op();
auto* matmul_linear_op = matmul_linear->Op(); auto* matmul_linear_op = matmul_linear->Op();
auto* ffn_matmul_0_op = ffn_matmul0->Op(); auto* ffn_matmul_0_op = ffn_matmul0->Op();
...@@ -2226,6 +2217,21 @@ int FusedMultiTransformerEncoderFuseQKVPass::BuildFusion( ...@@ -2226,6 +2217,21 @@ int FusedMultiTransformerEncoderFuseQKVPass::BuildFusion(
auto* qkv_b_tensor = auto* qkv_b_tensor =
scope->FindVar(eltadd0_b->Name())->GetMutable<phi::DenseTensor>(); scope->FindVar(eltadd0_b->Name())->GetMutable<phi::DenseTensor>();
// NOTE(minghaoBD): to make it compatible with strucutured pruning on
// num_head dimension:
// 1. get dim_head from reshape.shape[3], dim_embed from
// layer_norm_bias.shape[0]
// 2. calculate num_head according to wqkv_tensor.shape[1]/3 and dim_head
auto reshape_desc = reshape2_0->Op();
int dim_head =
PADDLE_GET_CONST(std::vector<int>, reshape_desc->GetAttr("shape"))
.at(3) /
3; // 3 for qkv
auto* layer_norm_bias_tensor =
scope->FindVar(layer_norm_bias->Name())->GetMutable<phi::DenseTensor>();
int dim_embed = layer_norm_bias_tensor->dims()[0];
int num_head = qkv_w_tensor->dims()[1] / 3 / dim_head;
QKVWeightsBiasProcessFuseQKV( QKVWeightsBiasProcessFuseQKV(
qkv_w_tensor, qkv_b_tensor, num_head, dim_head, dim_embed); qkv_w_tensor, qkv_b_tensor, num_head, dim_head, dim_embed);
...@@ -2995,15 +3001,6 @@ int MultiDevicesFusedMultiTransformerEncoderFuseQKVPass::BuildFusion( ...@@ -2995,15 +3001,6 @@ int MultiDevicesFusedMultiTransformerEncoderFuseQKVPass::BuildFusion(
Node* ffn_eltadd0_b, Node* ffn_eltadd0_b,
Node* ffn_eltadd1_b, Node* ffn_eltadd1_b,
Node* ffn_output) { Node* ffn_output) {
auto reshape_desc = reshape2_0->Op();
int num_head =
PADDLE_GET_CONST(std::vector<int>, reshape_desc->GetAttr("shape"))
.at(2);
int dim_head =
PADDLE_GET_CONST(std::vector<int>, reshape_desc->GetAttr("shape"))
.at(3) /
3; // 3 for qkv
auto* matmul0_op = matmul0->Op(); auto* matmul0_op = matmul0->Op();
auto* matmul_linear_op = matmul_linear->Op(); auto* matmul_linear_op = matmul_linear->Op();
auto* ffn_matmul_0_op = ffn_matmul0->Op(); auto* ffn_matmul_0_op = ffn_matmul0->Op();
...@@ -3023,9 +3020,20 @@ int MultiDevicesFusedMultiTransformerEncoderFuseQKVPass::BuildFusion( ...@@ -3023,9 +3020,20 @@ int MultiDevicesFusedMultiTransformerEncoderFuseQKVPass::BuildFusion(
auto* qkv_b_tensor = auto* qkv_b_tensor =
scope->FindVar(eltadd0_b->Name())->GetMutable<phi::DenseTensor>(); scope->FindVar(eltadd0_b->Name())->GetMutable<phi::DenseTensor>();
// NOTE(minghaoBD): to make it compatible with strucutured pruning on
// num_head dimension:
// 1. get dim_head from reshape.shape[3], dim_embed from
// layer_norm_bias.shape[0]
// 2. calculate num_head according to wqkv_tensor.shape[1]/3 and dim_head
auto* layer_norm_bias_tensor = auto* layer_norm_bias_tensor =
scope->FindVar(layer_norm_bias->Name())->GetMutable<phi::DenseTensor>(); scope->FindVar(layer_norm_bias->Name())->GetMutable<phi::DenseTensor>();
int dim_embed = layer_norm_bias_tensor->dims()[0]; int dim_embed = layer_norm_bias_tensor->dims()[0];
auto reshape_desc = reshape2_0->Op();
int dim_head =
PADDLE_GET_CONST(std::vector<int>, reshape_desc->GetAttr("shape"))
.at(3) /
3; // 3 for qkv
int num_head = qkv_w_tensor->dims()[1] / 3 / dim_head;
QKVWeightsBiasProcessFuseQKV( QKVWeightsBiasProcessFuseQKV(
qkv_w_tensor, qkv_b_tensor, num_head, dim_head, dim_embed); qkv_w_tensor, qkv_b_tensor, num_head, dim_head, dim_embed);
......
...@@ -93,27 +93,6 @@ class FusedMultiTransformerOp : public framework::OperatorWithKernel { ...@@ -93,27 +93,6 @@ class FusedMultiTransformerOp : public framework::OperatorWithKernel {
x_dim, x_dim,
y_dim)); y_dim));
if (ctx->Attrs().Get<int>("ring_id") == -1) {
if (trans_qkvw) {
PADDLE_ENFORCE_EQ(y_dim[1] * y_dim[2],
y_dim[3],
platform::errors::InvalidArgument(
"The dimensions of qkv_weight must be 4"
"(3, num_head, dim_head, dim_embed),"
"and must satisfy the limitations: "
"(num_head * dim_head == dim_embed)"));
} else {
PADDLE_ENFORCE_EQ(y_dim[2] * y_dim[3],
y_dim[0],
platform::errors::InvalidArgument(
"The dimensions of qkv_weight must be 4"
"(dim_embed, 3, num_head, dim_head),"
"and must satisfy the limitations: "
"(num_head * dim_head == dim_embed)"));
}
}
if (ctx->HasInputs("CacheKV")) { if (ctx->HasInputs("CacheKV")) {
// [2, batch_size, num_head, max_seq_len, head_size] // [2, batch_size, num_head, max_seq_len, head_size]
const auto &c_dims = ctx->GetInputsDim("CacheKV"); const auto &c_dims = ctx->GetInputsDim("CacheKV");
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册