From 643c94e4926c8d4f457a31243c5ba3dabc78ec49 Mon Sep 17 00:00:00 2001 From: carryyu <569782149@qq.com> Date: Fri, 5 Aug 2022 14:46:46 +0800 Subject: [PATCH] enhance fused_multi_transformer_op(post_layer_norm) (#44789) * add fused_multi_transformer post_layer_norm * add test post_layer_norm --- .../fused/fused_multi_transformer_op.cu | 102 ++++++++++++++---- .../test_fused_multi_transformer_op.py | 55 ++++++++++ 2 files changed, 134 insertions(+), 23 deletions(-) diff --git a/paddle/fluid/operators/fused/fused_multi_transformer_op.cu b/paddle/fluid/operators/fused/fused_multi_transformer_op.cu index a858b31e23..6414954667 100644 --- a/paddle/fluid/operators/fused/fused_multi_transformer_op.cu +++ b/paddle/fluid/operators/fused/fused_multi_transformer_op.cu @@ -1279,9 +1279,12 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel { auto ffn_ln_scales = ctx.MultiInput("FFNLnScale"); auto ffn_ln_biases = ctx.MultiInput("FFNLnBias"); Tensor bias_dropout_residual_out, dropout_mask_out; - auto *bias_dropout_residual_out_data = - bias_dropout_residual_out.mutable_data({bsz, seq_len, dim_embed}, - place); + T *bias_dropout_residual_out_data = nullptr; + if (pre_layer_norm) { + bias_dropout_residual_out_data = + bias_dropout_residual_out.mutable_data({bsz, seq_len, dim_embed}, + place); + } auto *dropout_mask_out_data = dropout_mask_out.mutable_data( {bsz, seq_len, dim_embed}, place); @@ -1333,14 +1336,19 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel { // step1: buf1 --> buf0 // step2: buf0 --> buf1 int layers = qkv_weights.size(); - if (layers & 1) { - // odd, set buf1 as out + if (pre_layer_norm) { + if (layers & 1) { + // odd, set buf1 as out + buf0 = &tmp_out; + buf1 = out; + } else { + // even, set buf0 as out + buf0 = out; + buf1 = &tmp_out; + } + } else { buf0 = &tmp_out; buf1 = out; - } else { - // even, set buf0 as out - buf0 = out; - buf1 = &tmp_out; } for (int i = 0; i < layers; ++i) { @@ -1355,9 +1363,6 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel { buf1->data(), ln_mean_data, ln_var_data); - } else if (!pre_layer_norm) { - PADDLE_THROW(platform::errors::Unimplemented( - "Unimplemented post_layer_norm for now.")); } #ifdef _DEBUG_FUSED_MULTI_TRANSFORMER VLOG(0) << "step1"; @@ -1367,8 +1372,13 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel { const Tensor *qkv_bias = qkv_biases.size() > 0 ? qkv_biases[i] : nullptr; // NOTE: in decoder stage, bias is fused in fmha const Tensor *bias = time_step ? nullptr : qkv_bias; - qkv_compute.ComputeForward( - qkv_weights[i], buf1, bias, &qkv_out, &qkv_out); + if (!pre_layer_norm && i == 0) { + qkv_compute.ComputeForward( + qkv_weights[i], input_x, bias, &qkv_out, &qkv_out); + } else { + qkv_compute.ComputeForward( + qkv_weights[i], buf1, bias, &qkv_out, &qkv_out); + } #ifdef _DEBUG_FUSED_MULTI_TRANSFORMER VLOG(0) << "step2"; #endif @@ -1451,10 +1461,15 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel { VLOG(0) << "step3"; #endif - // step4. out_linear - out_linear_compute.ComputeForward( - out_linear_weights[i], &fmha_out, nullptr, buf1, nullptr); - AllReduce(*buf1, ring_id, dev_ctx); + if (pre_layer_norm) { + out_linear_compute.ComputeForward( + out_linear_weights[i], &fmha_out, nullptr, buf1, nullptr); + AllReduce(*buf1, ring_id, dev_ctx); + } else { + out_linear_compute.ComputeForward( + out_linear_weights[i], &fmha_out, nullptr, buf0, nullptr); + AllReduce(*buf0, ring_id, dev_ctx); + } #ifdef _DEBUG_FUSED_MULTI_TRANSFORMER VLOG(0) << "step4"; #endif @@ -1479,6 +1494,22 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel { ln_mean_data, ln_var_data); } else { + auto *ln_scale_data = ln_scales[i]->data(); + auto *ln_bias_data = ln_biases[i]->data(); + auto *out_linear_bias_data = out_linear_biases[i]->data(); + auto *residual_data = (i == 0 ? x_data : buf1->data()); + fused_dropout_layernorm_helper.LayernormResidualDropoutBias( + dev_ctx, + buf0->data(), + residual_data, + out_linear_bias_data, + ln_scale_data, + ln_bias_data, + buf0->data(), + dropout_mask_out_data, + buf1->data(), + ln_mean_data, + ln_var_data); } #ifdef _DEBUG_FUSED_MULTI_TRANSFORMER VLOG(0) << "step5"; @@ -1504,13 +1535,22 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel { #endif // step8. ffn matmul2 - ffn2_linear_compute.ComputeForward( - ffn2_weights[i], &ffn1_dropout_out, nullptr, buf1, nullptr); + if (pre_layer_norm) { + ffn2_linear_compute.ComputeForward( + ffn2_weights[i], &ffn1_dropout_out, nullptr, buf1, nullptr); + } else { + ffn2_linear_compute.ComputeForward( + ffn2_weights[i], &ffn1_dropout_out, nullptr, buf0, nullptr); + } #ifdef _DEBUG_FUSED_MULTI_TRANSFORMER VLOG(0) << "step8.0"; #endif - AllReduce(*buf1, ring_id, dev_ctx); + if (pre_layer_norm) { + AllReduce(*buf1, ring_id, dev_ctx); + } else { + AllReduce(*buf0, ring_id, dev_ctx); + } #ifdef _DEBUG_FUSED_MULTI_TRANSFORMER VLOG(0) << "step8.1"; #endif @@ -1543,12 +1583,28 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel { dropout_mask_out_data); } } else { + auto *ln_scale_data = ffn_ln_scales[i]->data(); + auto *ln_bias_data = ffn_ln_biases[i]->data(); + ffn2_fused_dropout_helper.LayernormResidualDropoutBias( + dev_ctx, + buf0->data(), + buf1->data(), + ffn2_biases[i]->data(), + ln_scale_data, + ln_bias_data, + buf0->data(), + dropout_mask_out_data, + buf1->data(), + ln_mean_data, + ln_var_data); } #ifdef _DEBUG_FUSED_MULTI_TRANSFORMER VLOG(0) << "step9"; #endif - x_data = buf1->data(); - std::swap(buf0, buf1); + if (pre_layer_norm) { + x_data = buf1->data(); + std::swap(buf0, buf1); + } } } }; diff --git a/python/paddle/fluid/tests/unittests/test_fused_multi_transformer_op.py b/python/paddle/fluid/tests/unittests/test_fused_multi_transformer_op.py index ecfc8a5bc2..65276e9c92 100644 --- a/python/paddle/fluid/tests/unittests/test_fused_multi_transformer_op.py +++ b/python/paddle/fluid/tests/unittests/test_fused_multi_transformer_op.py @@ -548,5 +548,60 @@ class TestFusedMultiTransformerOpGenCacheKVFp16(TestFusedMultiTransformerOp): self.layers = 3 # odd layers +class TestFusedMultiTransformerOpPostLayerNormFp16(TestFusedMultiTransformerOp): + + def config(self): + super().config() + self.x_type = np.float16 + self.layers = 3 # odd layers + self.pre_layer_norm = False + + +class TestFusedMultiTransformerOpCacheKVPostLayerNorm( + TestFusedMultiTransformerOp): + + def config(self): + super().config() + self.has_cache_kv = True + self.query_length = 1 + self.key_length, self.value_length = 1, 1 + self.layers = 3 # odd layers + self.pre_layer_norm = False + + +class TestFusedMultiTransformerOpCacheKVPostLayerNormFp16( + TestFusedMultiTransformerOp): + + def config(self): + super().config() + self.has_cache_kv = True + self.query_length = 1 + self.key_length, self.value_length = 1, 1 + self.x_type = np.float16 + self.pre_layer_norm = False + + +class TestFusedMultiTransformerOpGenCacheKVPostLayerNorm( + TestFusedMultiTransformerOp): + + def config(self): + super().config() + self.has_cache_kv = True + self.gen_cache_kv = True + self.pre_layer_norm = False + + +class TestFusedMultiTransformerOpGenCacheKVPostLayerNormFp16( + TestFusedMultiTransformerOp): + + def config(self): + super().config() + self.has_cache_kv = True + self.gen_cache_kv = True + self.x_type = np.float16 + self.layers = 3 # odd layers + self.pre_layer_norm = False + + if __name__ == "__main__": unittest.main() -- GitLab