diff --git a/paddle/fluid/operators/fused/fused_multi_transformer_op.cu b/paddle/fluid/operators/fused/fused_multi_transformer_op.cu index a858b31e23c8ada64e3d74973ea3197d2c403347..6414954667bfeaf65ad8405e240230837b3aab0e 100644 --- a/paddle/fluid/operators/fused/fused_multi_transformer_op.cu +++ b/paddle/fluid/operators/fused/fused_multi_transformer_op.cu @@ -1279,9 +1279,12 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel { auto ffn_ln_scales = ctx.MultiInput("FFNLnScale"); auto ffn_ln_biases = ctx.MultiInput("FFNLnBias"); Tensor bias_dropout_residual_out, dropout_mask_out; - auto *bias_dropout_residual_out_data = - bias_dropout_residual_out.mutable_data({bsz, seq_len, dim_embed}, - place); + T *bias_dropout_residual_out_data = nullptr; + if (pre_layer_norm) { + bias_dropout_residual_out_data = + bias_dropout_residual_out.mutable_data({bsz, seq_len, dim_embed}, + place); + } auto *dropout_mask_out_data = dropout_mask_out.mutable_data( {bsz, seq_len, dim_embed}, place); @@ -1333,14 +1336,19 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel { // step1: buf1 --> buf0 // step2: buf0 --> buf1 int layers = qkv_weights.size(); - if (layers & 1) { - // odd, set buf1 as out + if (pre_layer_norm) { + if (layers & 1) { + // odd, set buf1 as out + buf0 = &tmp_out; + buf1 = out; + } else { + // even, set buf0 as out + buf0 = out; + buf1 = &tmp_out; + } + } else { buf0 = &tmp_out; buf1 = out; - } else { - // even, set buf0 as out - buf0 = out; - buf1 = &tmp_out; } for (int i = 0; i < layers; ++i) { @@ -1355,9 +1363,6 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel { buf1->data(), ln_mean_data, ln_var_data); - } else if (!pre_layer_norm) { - PADDLE_THROW(platform::errors::Unimplemented( - "Unimplemented post_layer_norm for now.")); } #ifdef _DEBUG_FUSED_MULTI_TRANSFORMER VLOG(0) << "step1"; @@ -1367,8 +1372,13 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel { const Tensor *qkv_bias = qkv_biases.size() > 0 ? qkv_biases[i] : nullptr; // NOTE: in decoder stage, bias is fused in fmha const Tensor *bias = time_step ? nullptr : qkv_bias; - qkv_compute.ComputeForward( - qkv_weights[i], buf1, bias, &qkv_out, &qkv_out); + if (!pre_layer_norm && i == 0) { + qkv_compute.ComputeForward( + qkv_weights[i], input_x, bias, &qkv_out, &qkv_out); + } else { + qkv_compute.ComputeForward( + qkv_weights[i], buf1, bias, &qkv_out, &qkv_out); + } #ifdef _DEBUG_FUSED_MULTI_TRANSFORMER VLOG(0) << "step2"; #endif @@ -1451,10 +1461,15 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel { VLOG(0) << "step3"; #endif - // step4. out_linear - out_linear_compute.ComputeForward( - out_linear_weights[i], &fmha_out, nullptr, buf1, nullptr); - AllReduce(*buf1, ring_id, dev_ctx); + if (pre_layer_norm) { + out_linear_compute.ComputeForward( + out_linear_weights[i], &fmha_out, nullptr, buf1, nullptr); + AllReduce(*buf1, ring_id, dev_ctx); + } else { + out_linear_compute.ComputeForward( + out_linear_weights[i], &fmha_out, nullptr, buf0, nullptr); + AllReduce(*buf0, ring_id, dev_ctx); + } #ifdef _DEBUG_FUSED_MULTI_TRANSFORMER VLOG(0) << "step4"; #endif @@ -1479,6 +1494,22 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel { ln_mean_data, ln_var_data); } else { + auto *ln_scale_data = ln_scales[i]->data(); + auto *ln_bias_data = ln_biases[i]->data(); + auto *out_linear_bias_data = out_linear_biases[i]->data(); + auto *residual_data = (i == 0 ? x_data : buf1->data()); + fused_dropout_layernorm_helper.LayernormResidualDropoutBias( + dev_ctx, + buf0->data(), + residual_data, + out_linear_bias_data, + ln_scale_data, + ln_bias_data, + buf0->data(), + dropout_mask_out_data, + buf1->data(), + ln_mean_data, + ln_var_data); } #ifdef _DEBUG_FUSED_MULTI_TRANSFORMER VLOG(0) << "step5"; @@ -1504,13 +1535,22 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel { #endif // step8. ffn matmul2 - ffn2_linear_compute.ComputeForward( - ffn2_weights[i], &ffn1_dropout_out, nullptr, buf1, nullptr); + if (pre_layer_norm) { + ffn2_linear_compute.ComputeForward( + ffn2_weights[i], &ffn1_dropout_out, nullptr, buf1, nullptr); + } else { + ffn2_linear_compute.ComputeForward( + ffn2_weights[i], &ffn1_dropout_out, nullptr, buf0, nullptr); + } #ifdef _DEBUG_FUSED_MULTI_TRANSFORMER VLOG(0) << "step8.0"; #endif - AllReduce(*buf1, ring_id, dev_ctx); + if (pre_layer_norm) { + AllReduce(*buf1, ring_id, dev_ctx); + } else { + AllReduce(*buf0, ring_id, dev_ctx); + } #ifdef _DEBUG_FUSED_MULTI_TRANSFORMER VLOG(0) << "step8.1"; #endif @@ -1543,12 +1583,28 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel { dropout_mask_out_data); } } else { + auto *ln_scale_data = ffn_ln_scales[i]->data(); + auto *ln_bias_data = ffn_ln_biases[i]->data(); + ffn2_fused_dropout_helper.LayernormResidualDropoutBias( + dev_ctx, + buf0->data(), + buf1->data(), + ffn2_biases[i]->data(), + ln_scale_data, + ln_bias_data, + buf0->data(), + dropout_mask_out_data, + buf1->data(), + ln_mean_data, + ln_var_data); } #ifdef _DEBUG_FUSED_MULTI_TRANSFORMER VLOG(0) << "step9"; #endif - x_data = buf1->data(); - std::swap(buf0, buf1); + if (pre_layer_norm) { + x_data = buf1->data(); + std::swap(buf0, buf1); + } } } }; diff --git a/python/paddle/fluid/tests/unittests/test_fused_multi_transformer_op.py b/python/paddle/fluid/tests/unittests/test_fused_multi_transformer_op.py index ecfc8a5bc292cd0b740cccaa3fb059f75ee891a4..65276e9c92e96aaa08a3fa365942bb41f4a13315 100644 --- a/python/paddle/fluid/tests/unittests/test_fused_multi_transformer_op.py +++ b/python/paddle/fluid/tests/unittests/test_fused_multi_transformer_op.py @@ -548,5 +548,60 @@ class TestFusedMultiTransformerOpGenCacheKVFp16(TestFusedMultiTransformerOp): self.layers = 3 # odd layers +class TestFusedMultiTransformerOpPostLayerNormFp16(TestFusedMultiTransformerOp): + + def config(self): + super().config() + self.x_type = np.float16 + self.layers = 3 # odd layers + self.pre_layer_norm = False + + +class TestFusedMultiTransformerOpCacheKVPostLayerNorm( + TestFusedMultiTransformerOp): + + def config(self): + super().config() + self.has_cache_kv = True + self.query_length = 1 + self.key_length, self.value_length = 1, 1 + self.layers = 3 # odd layers + self.pre_layer_norm = False + + +class TestFusedMultiTransformerOpCacheKVPostLayerNormFp16( + TestFusedMultiTransformerOp): + + def config(self): + super().config() + self.has_cache_kv = True + self.query_length = 1 + self.key_length, self.value_length = 1, 1 + self.x_type = np.float16 + self.pre_layer_norm = False + + +class TestFusedMultiTransformerOpGenCacheKVPostLayerNorm( + TestFusedMultiTransformerOp): + + def config(self): + super().config() + self.has_cache_kv = True + self.gen_cache_kv = True + self.pre_layer_norm = False + + +class TestFusedMultiTransformerOpGenCacheKVPostLayerNormFp16( + TestFusedMultiTransformerOp): + + def config(self): + super().config() + self.has_cache_kv = True + self.gen_cache_kv = True + self.x_type = np.float16 + self.layers = 3 # odd layers + self.pre_layer_norm = False + + if __name__ == "__main__": unittest.main()