diff --git a/paddle/fluid/operators/fused/fused_multi_transformer_op.cu b/paddle/fluid/operators/fused/fused_multi_transformer_op.cu index f52bc2a7f54d180762d4d1698b60ea2fb7fba267..f56baef1d26726605a38782294cdba8a29c968ff 100644 --- a/paddle/fluid/operators/fused/fused_multi_transformer_op.cu +++ b/paddle/fluid/operators/fused/fused_multi_transformer_op.cu @@ -14,6 +14,517 @@ limitations under the License. */ namespace paddle { namespace operators { +#if CUDA_VERSION >= 11060 // Use cublasLt to fuse FFN operation. + +template +class FusedMultiTransformerOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &ctx) const override { + using U = LayerNormParamType; + auto &dev_ctx = ctx.cuda_device_context(); + + auto *time_step = ctx.Input("TimeStep"); + // 0. input + auto *input_x = ctx.Input("X"); + const auto input_x_dims = input_x->dims(); + int bsz = input_x_dims[0]; + int seq_len = input_x_dims[1]; + int dim_embed = input_x_dims[2]; + int bsz_seq = bsz * seq_len; + + // 1. layer norm + const auto pre_layer_norm = ctx.Attr("pre_layer_norm"); + const float epsilon = ctx.Attr("epsilon"); + auto ln_scales = ctx.MultiInput("LnScale"); + auto ln_biases = ctx.MultiInput("LnBias"); + + auto ln_compute = AttnLayerNorm(dev_ctx, epsilon, bsz_seq, dim_embed); + Tensor ln_mean, ln_var; + ln_mean.Resize({{bsz_seq}}); + auto *ln_mean_data = + dev_ctx.Alloc(&ln_mean, ln_mean.numel() * sizeof(U)); + ln_var.Resize({{bsz_seq}}); + auto *ln_var_data = dev_ctx.Alloc(&ln_var, ln_var.numel() * sizeof(U)); + + // 2. qkv + // x: qkv's input [batch_size, seq_len, dim_embed] + // y: qkv's weight: [3, num_head, dim_head, dim_embed] + auto qkv_weights = ctx.MultiInput("QKVW"); + auto qkv_biases = ctx.MultiInput("QKVBias"); + const bool trans_qkvw = ctx.Attr("trans_qkvw"); + const auto qkv_w_dims = qkv_weights[0]->dims(); + int num_head = trans_qkvw ? qkv_w_dims[1] : qkv_w_dims[2]; + int dim_head = trans_qkvw ? qkv_w_dims[2] : qkv_w_dims[3]; + int hidden_size = num_head * dim_head; + int output_size = 3 * hidden_size; + int input_size = dim_embed; + + bool compute_bias = qkv_biases.size() > 0 && time_step == nullptr; + // (transA, transB, compute_bias) = (false, trans_qkvw, false) + + // Since we fused QKVBias into QKVBiasAddTransposeSplit kernel, here we set + // compute_bias as false. + auto qkv_compute = AttnMatMul(dev_ctx, + false, + trans_qkvw, + bsz_seq, + output_size, + input_size, + /*compute_bias=*/false); + + Tensor qkv_out; + qkv_out.Resize({{bsz, seq_len, 3, num_head, dim_head}}); + auto *qkv_out_data = + dev_ctx.Alloc(&qkv_out, qkv_out.numel() * sizeof(T)); + + // 3. fmha + AttnDropoutParam attn_param( + true, "upscale_in_train", 0.0, true, true, 0, nullptr); + auto fmha_compute = + FMHARef(dev_ctx, bsz, seq_len, num_head, dim_head, attn_param); + auto *src_mask = ctx.Input("SrcMask"); + auto cache_kvs = ctx.MultiInput("CacheKV"); + auto cache_kv_outs = ctx.MultiOutput("CacheKVOut"); + // auto *time_step = ctx.Input("TimeStep"); + auto pre_caches = ctx.MultiInput("PreCaches"); + int cache_offset = 0; + if (pre_caches.size() > 0) { + cache_offset = pre_caches[0]->dims()[3]; + } + + auto out_seq_len = seq_len; + if (time_step) { + PADDLE_ENFORCE_EQ(time_step->place(), + platform::CPUPlace(), + platform::errors::PreconditionNotMet( + "The place of input(TimeStep) must be CPUPlace.")); + // cache_seq_len + int time_step_value = time_step->data()[0]; + PADDLE_ENFORCE_GT(time_step_value, + 0, + platform::errors::PreconditionNotMet( + "The value of time_step must > 0, but now is %d", + time_step_value)); + PADDLE_ENFORCE_EQ( + seq_len, + 1, + platform::errors::PreconditionNotMet( + "In decode stage, the seq_len of input must be 1, but now is %d", + seq_len)); + out_seq_len += time_step_value; + } else { + out_seq_len += cache_offset; + } + + Tensor q_transpose_out, kv_transpose_out, qk_out; + q_transpose_out.Resize({{bsz, num_head, seq_len, dim_head}}); + auto *q_transpose_out_data = + dev_ctx.Alloc(&q_transpose_out, q_transpose_out.numel() * sizeof(T)); + + kv_transpose_out.Resize({{2, bsz, num_head, seq_len, dim_head}}); + auto *kv_transpose_out_data = dev_ctx.Alloc( + &kv_transpose_out, kv_transpose_out.numel() * sizeof(T)); + + qk_out.Resize({{bsz, num_head, seq_len, out_seq_len}}); + auto *qk_out_data = dev_ctx.Alloc(&qk_out, qk_out.numel() * sizeof(T)); + + Tensor src_mask_out; + if (cache_offset > 0) { + src_mask_out.Resize({{bsz, num_head, seq_len, out_seq_len}}); + auto *src_mask_out_data = + dev_ctx.Alloc(&src_mask_out, src_mask_out.numel() * sizeof(T)); + } + + // [2, bs, num_head, cache_seq_len + seq_len, head_dim] + Tensor pre_cache_kv_out; + if (cache_offset > 0) { + pre_cache_kv_out.Resize( + {{2, bsz, num_head, seq_len + cache_offset, dim_head}}); + auto *pre_cache_kv_out_data = dev_ctx.Alloc( + &pre_cache_kv_out, pre_cache_kv_out.numel() * sizeof(T)); + } + + Tensor softmax_out; + Tensor attn_dropout_mask_out, attn_dropout_out; + Tensor qktv_out, fmha_out; + softmax_out.Resize({{bsz, num_head, seq_len, out_seq_len}}); + auto *softmax_out_data = + dev_ctx.Alloc(&softmax_out, softmax_out.numel() * sizeof(T)); + + attn_dropout_mask_out.Resize({{bsz, num_head, seq_len, out_seq_len}}); + auto *attn_dropout_mask_out_data = dev_ctx.Alloc( + &attn_dropout_mask_out, attn_dropout_mask_out.numel() * sizeof(T)); + attn_dropout_out.Resize({{bsz, num_head, seq_len, out_seq_len}}); + auto *attn_dropout_data_data = dev_ctx.Alloc( + &attn_dropout_out, attn_dropout_out.numel() * sizeof(T)); + + qktv_out.Resize({{bsz, num_head, seq_len, dim_head}}); + auto *qktv_out_data = + dev_ctx.Alloc(&qktv_out, qktv_out.numel() * sizeof(T)); + fmha_out.Resize({{bsz, seq_len, num_head, dim_head}}); + auto *fmha_out_data = + dev_ctx.Alloc(&fmha_out, fmha_out.numel() * sizeof(T)); + + // 4. out_linear + auto out_linear_weights = ctx.MultiInput("OutLinearW"); + auto out_linear_biases = ctx.MultiInput("OutLinearBias"); + int ring_id = ctx.Attr("ring_id"); + // (transA, transB, compute_bias) = (false, false, false) + auto out_linear_compute = AttnMatMul( + dev_ctx, false, false, bsz_seq, dim_embed, hidden_size, false); + + // 5. ln(residual + bias) + DropoutParam dropout_param2(true, 0, true, true, 0.0, nullptr, 0); + FusedDropoutLayerNormHelper fused_dropout_layernorm_helper( + dev_ctx, bsz_seq, dim_embed, dropout_param2, epsilon); + auto ffn_ln_scales = ctx.MultiInput("FFNLnScale"); + auto ffn_ln_biases = ctx.MultiInput("FFNLnBias"); + Tensor bias_dropout_residual_out, dropout_mask_out; + T *bias_dropout_residual_out_data = nullptr; + if (pre_layer_norm) { + bias_dropout_residual_out.Resize({{bsz, seq_len, dim_embed}}); + bias_dropout_residual_out_data = + dev_ctx.Alloc(&bias_dropout_residual_out, + bias_dropout_residual_out.numel() * sizeof(T)); + } + dropout_mask_out.Resize({{bsz, seq_len, dim_embed}}); + auto *dropout_mask_out_data = dev_ctx.Alloc( + &dropout_mask_out, dropout_mask_out.numel() * sizeof(uint8_t)); + + // 6. ffn1 matmul + bias_add + gelu. + auto ffn1_weights = ctx.MultiInput("FFN1Weight"); + auto ffn1_biases = ctx.MultiInput("FFN1Bias"); + auto ffn1_weight_dim = ffn1_weights[0]->dims(); + + int dim_ffn = ffn1_weight_dim[1]; + + Tensor ffn1_out; + ffn1_out.Resize({{bsz_seq, dim_ffn}}); + auto *ffn1_out_data = + dev_ctx.Alloc(&ffn1_out, ffn1_out.numel() * sizeof(T)); + + auto ffn1_linear_bias_gelu = CublasFusedMLP(dev_ctx); + const phi::DDim ffn1_input_shape({bsz_seq, dim_ffn}); + ffn1_linear_bias_gelu.Setup( + ffn1_input_shape, ffn1_weight_dim, false, false); + + // 8. ffn2 matmul + bias_add + residual. + auto ffn2_weights = ctx.MultiInput("FFN2Weight"); + auto ffn2_biases = ctx.MultiInput("FFN2Bias"); + + auto ffn2_linear_bias_residual = CublasFusedMLP(dev_ctx); + ffn2_linear_bias_residual.Setup( + ffn1_out.dims(), ffn2_weights[0]->dims(), false, false); + + // 9. ffn2 residual bias + DropoutParam ffn2_dropout_param(true, 0, true, true, 0.0, nullptr, 0); + FusedDropoutLayerNormHelper ffn2_fused_dropout_helper( + dev_ctx, bsz_seq, dim_embed, ffn2_dropout_param, epsilon); + + // calc + auto *out = ctx.Output("Out"); + auto *from_data = dev_ctx.Alloc(out, out->numel() * sizeof(T)); + Tensor *from_tensor = out; + Tensor tmp_out; + tmp_out.Resize({{bsz, seq_len, dim_embed}}); + auto *tmp_out_data = + dev_ctx.Alloc(&tmp_out, tmp_out.numel() * sizeof(T)); + + auto *x_data = input_x->data(); + Tensor *buf0 = nullptr; + Tensor *buf1 = nullptr; + + // step0: x --> buf1 + // step1: buf1 --> buf0 + // step2: buf0 --> buf1 + int layers = qkv_weights.size(); + if (pre_layer_norm) { + if (layers & 1) { + // odd, set buf1 as out + buf0 = &tmp_out; + buf1 = out; + } else { + // even, set buf0 as out + buf0 = out; + buf1 = &tmp_out; + } + } else { + buf0 = &tmp_out; + buf1 = out; + } + + for (int i = 0; i < layers; ++i) { + // step1. layer_norm + if (i == 0 && pre_layer_norm) { + auto *ln_scale_data = ln_scales[i]->data(); + auto *ln_bias_data = ln_biases[i]->data(); + // TODO(wangxi): can remove mean var in inference + ln_compute.ComputeForward(x_data, + ln_scale_data, + ln_bias_data, + buf1->data(), + ln_mean_data, + ln_var_data); + } +#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER + VLOG(0) << "step1"; +#endif + + // step2. qkv + const Tensor *qkv_bias = qkv_biases.size() > 0 ? qkv_biases[i] : nullptr; + // NOTE: in decoder stage, bias is fused in fmha + const Tensor *bias = time_step ? nullptr : qkv_bias; + if (!pre_layer_norm && i == 0) { + qkv_compute.ComputeForward( + qkv_weights[i], input_x, bias, &qkv_out, &qkv_out); + } else { + qkv_compute.ComputeForward( + qkv_weights[i], buf1, bias, &qkv_out, &qkv_out); + } +#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER + VLOG(0) << "step2"; +#endif + + // step3. fmha + const Tensor *cache_kv = cache_kvs.size() > 0 ? cache_kvs[i] : nullptr; + Tensor *cache_kv_out = cache_kv ? cache_kv_outs[i] : nullptr; + + if (time_step) { // generation decoder stage + // [2, batch_size, num_head, max_seq_len, head_size] + int max_seq_len = cache_kv->dims()[3]; + fmha(dev_ctx, + qkv_out, + *qkv_bias, + *src_mask, + cache_kv_out, + &fmha_out, + bsz, + max_seq_len, + num_head, + dim_head, + time_step->data()[0], + 1. / sqrt(dim_head)); + } else if (cache_kv_out) { // generation context stage + const Tensor *pre_cache_kv_tensor = + pre_caches.size() > 0 ? pre_caches[i] : nullptr; + Tensor *pre_cache_kv_out_tmp = + cache_offset > 0 ? &pre_cache_kv_out : nullptr; + Tensor *src_mask_tmp = cache_offset > 0 ? &src_mask_out : nullptr; + qkv_bias_add_transpose_split(dev_ctx, + q_transpose_out_data, + kv_transpose_out_data, + qkv_out_data, + qkv_bias->data(), + bsz, + num_head, + seq_len, + dim_head, + compute_bias); + fmha_compute.ComputeForwardWithoutTranspose(qkv_out, + pre_cache_kv_tensor, + src_mask, + &q_transpose_out, + &kv_transpose_out, + pre_cache_kv_out_tmp, + &qk_out, + src_mask_tmp, + &softmax_out, + &attn_dropout_mask_out, + &attn_dropout_out, + &qktv_out, + &fmha_out); + + const T *k_ptr = nullptr; + const T *v_ptr = nullptr; + + if (cache_offset > 0) { + // [2, bsz, num_head, cache_offset + seq_len, head_dim] + const T *kv_data = pre_cache_kv_out.data(); + k_ptr = kv_data; + int64_t k_size = bsz * num_head * (seq_len + cache_offset) * dim_head; + v_ptr = k_ptr + k_size; + } else { + // [3, bsz, num_head, seq_len, head_dim] + int64_t k_size = bsz * seq_len * num_head * dim_head; + const T *q_ptr = q_transpose_out_data; + k_ptr = kv_transpose_out_data; + v_ptr = k_ptr + k_size; + } + + // [2, bsz, num_head, max_seq_len, head_dim] + int max_seq_len = cache_kv_out->dims()[3]; + T *cache_kv_data = cache_kv_out->data(); + int64_t cache_k_size = bsz * num_head * max_seq_len * dim_head; + + T *cache_k_ptr = cache_kv_data; + T *cache_v_ptr = cache_kv_data + cache_k_size; + + const int seq_len_tmp = seq_len + cache_offset; + write_cache_kv(dev_ctx, + cache_k_ptr, + cache_v_ptr, + k_ptr, + v_ptr, + bsz, + num_head, + seq_len_tmp, + max_seq_len, + dim_head); + } else { // not generation + // TODO(wangxi): can remove dropout in inference + qkv_bias_add_transpose_split(dev_ctx, + q_transpose_out_data, + kv_transpose_out_data, + qkv_out_data, + qkv_bias->data(), + bsz, + num_head, + seq_len, + dim_head, + compute_bias); + fmha_compute.ComputeForwardWithoutTranspose(qkv_out, + cache_kv, + src_mask, + &q_transpose_out, + &kv_transpose_out, + cache_kv_out, + &qk_out, + nullptr, + &softmax_out, + &attn_dropout_mask_out, + &attn_dropout_out, + &qktv_out, + &fmha_out); + } +#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER + VLOG(0) << "step3"; +#endif + + if (pre_layer_norm) { + out_linear_compute.ComputeForward( + out_linear_weights[i], &fmha_out, nullptr, buf1, nullptr); + AllReduce(*buf1, ring_id, buf1->numel(), dev_ctx); + } else { + out_linear_compute.ComputeForward( + out_linear_weights[i], &fmha_out, nullptr, buf0, nullptr); + AllReduce(*buf0, ring_id, buf0->numel(), dev_ctx); + } +#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER + VLOG(0) << "step4"; +#endif + + // step5. ln(residual + dropout(input + bias)) + if (pre_layer_norm) { + auto *ln_scale_data = ffn_ln_scales[i]->data(); + auto *ln_bias_data = ffn_ln_biases[i]->data(); + auto *out_linear_bias_data = out_linear_biases[i]->data(); + + // inplace + fused_dropout_layernorm_helper.LayernormResidualDropoutBias( + dev_ctx, + buf1->data(), + x_data, + out_linear_bias_data, + ln_scale_data, + ln_bias_data, + bias_dropout_residual_out_data, + dropout_mask_out_data, + buf1->data(), + ln_mean_data, + ln_var_data); + } else { + auto *ln_scale_data = ln_scales[i]->data(); + auto *ln_bias_data = ln_biases[i]->data(); + auto *out_linear_bias_data = out_linear_biases[i]->data(); + auto *residual_data = (i == 0 ? x_data : buf1->data()); + fused_dropout_layernorm_helper.LayernormResidualDropoutBias( + dev_ctx, + buf0->data(), + residual_data, + out_linear_bias_data, + ln_scale_data, + ln_bias_data, + buf0->data(), + dropout_mask_out_data, + buf1->data(), + ln_mean_data, + ln_var_data); + } + +#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER + VLOG(0) << "step5"; +#endif + // step6. ffn1 matmul + bias_add + gelu. + + ffn1_linear_bias_gelu.ComputeForward( + buf1, ffn1_weights[i], ffn1_biases[i], nullptr, &ffn1_out, "gelu"); + +#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER + VLOG(0) << "step6"; +#endif + + // step7. ffn2 matmul + bias_add + residual. + if (pre_layer_norm) { + ffn2_linear_bias_residual.ComputeForward(&ffn1_out, + ffn2_weights[i], + ffn2_biases[i], + &bias_dropout_residual_out, + buf1, + "none"); + + } else { + ffn2_linear_bias_residual.ComputeForward( + &ffn1_out, ffn2_weights[i], ffn2_biases[i], buf1, buf0, "none"); + } + + if (pre_layer_norm) { + AllReduce(*buf1, ring_id, buf1->numel(), dev_ctx); + } else { + AllReduce(*buf0, ring_id, buf0->numel(), dev_ctx); + } + +#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER + VLOG(0) << "step7"; +#endif + + // step8. layer norm or do nothing(because bias_add + residual has been + // fused into cublasFusedMLP. ) + if (pre_layer_norm) { + if (i < layers - 1) { + auto *ln_scale_data = ln_scales[i + 1]->data(); + auto *ln_bias_data = ln_biases[i + 1]->data(); + ffn2_fused_dropout_helper.LayerNorm(dev_ctx, + buf1->data(), + ln_scale_data, + ln_bias_data, + buf0->data(), + ln_mean_data, + ln_var_data); + } + } else { + auto *ln_scale_data = ffn_ln_scales[i]->data(); + auto *ln_bias_data = ffn_ln_biases[i]->data(); + ffn2_fused_dropout_helper.LayerNorm(dev_ctx, + buf0->data(), + ln_scale_data, + ln_bias_data, + buf1->data(), + ln_mean_data, + ln_var_data); + } +#ifdef _DEBUG_FUSED_MULTI_TRANSFORMER + VLOG(0) << "step8"; +#endif + if (pre_layer_norm) { + x_data = buf1->data(); + std::swap(buf0, buf1); + } + } + } +}; + +#else + template class FusedMultiTransformerOpKernel : public framework::OpKernel { public: @@ -550,6 +1061,8 @@ class FusedMultiTransformerOpKernel : public framework::OpKernel { } }; +#endif // CUDA_VERSION >= 11060 + } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/fused/fused_multi_transformer_op.cu.h b/paddle/fluid/operators/fused/fused_multi_transformer_op.cu.h index 79fc561698989babed35887840362d51f4343603..c36ee69723e45261f30efd2fe6bae3e719bbdd07 100644 --- a/paddle/fluid/operators/fused/fused_multi_transformer_op.cu.h +++ b/paddle/fluid/operators/fused/fused_multi_transformer_op.cu.h @@ -26,7 +26,9 @@ limitations under the License. */ #include "paddle/fluid/operators/fused/attn_gemm.h" #include "paddle/fluid/operators/fused/fmha_ref.h" #include "paddle/fluid/operators/fused/fused_dropout_helper.h" +#include "paddle/fluid/operators/fused/fused_gemm_epilogue_op.h" #include "paddle/fluid/platform/device/gpu/gpu_dnn.h" +#include "paddle/fluid/platform/dynload/cublasLt.h" #include "paddle/phi/api/include/tensor.h" #include "paddle/phi/backends/gpu/gpu_device_function.h" #include "paddle/phi/kernels/funcs/math_function.h" @@ -37,6 +39,8 @@ limitations under the License. */ #include "paddle/fluid/platform/device/gpu/nccl_helper.h" #endif +DECLARE_bool(gemm_use_half_precision_compute_type); + namespace paddle { namespace operators { @@ -1336,10 +1340,10 @@ inline cudaError_t GetNumBlocks(int64_t n, int *num_blocks) { constexpr int kBlockSize = 128; constexpr int kNumWaves = 16; - const int device_id = paddle::platform::GetCurrentDeviceId(); - const int sm_count = paddle::platform::GetGPUMultiProcessors(device_id); + const int device_id = phi::backends::gpu::GetCurrentDeviceId(); + const int sm_count = phi::backends::gpu::GetGPUMultiProcessors(device_id); const int max_thread_per_multiprocessor = - paddle::platform::GetGPUMultiProcessors(device_id); + phi::backends::gpu::GetGPUMultiProcessors(device_id); *num_blocks = std::max(1, @@ -1400,6 +1404,249 @@ void qkv_bias_add_transpose_split(const phi::GPUContext &dev_ctx, } } +#if CUDA_VERSION >= 11060 +// Only Used in Inference +template +class CublasFusedMLP { + public: + // (m, n, k) = bsz_seq, hidden_feature, in_feature + explicit CublasFusedMLP(const phi::GPUContext &dev_ctx) : dev_ctx_(dev_ctx) { + // Set Math Type + cudaDataType_t mat_type = CUDA_R_32F; + cudaDataType_t scale_type = CUDA_R_32F; + cublasComputeType_t compute_type = CUBLAS_COMPUTE_32F; + + if (std::is_same::value) { + mat_type = CUDA_R_16F; + if (FLAGS_gemm_use_half_precision_compute_type) { + compute_type = CUBLAS_COMPUTE_16F; + scale_type = CUDA_R_16F; + } + } + if (std::is_same::value) { + mat_type = CUDA_R_16BF; + } + if (std::is_same::value) { + mat_type = CUDA_R_64F; + scale_type = CUDA_R_64F; + compute_type = CUBLAS_COMPUTE_64F; + } + + // Just for init. + PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatmulDescCreate( + &operation_desc_, compute_type, scale_type)); + PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatrixLayoutCreate( + &x_desc_, mat_type, 1, 1, 1)); + PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatrixLayoutCreate( + &w_desc_, mat_type, 1, 1, 1)); + PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatrixLayoutCreate( + &out_desc_, mat_type, 1, 1, 1)); + } + + ~CublasFusedMLP() { + PADDLE_ENFORCE_GPU_SUCCESS( + platform::dynload::cublasLtMatmulDescDestroy(operation_desc_)); + PADDLE_ENFORCE_GPU_SUCCESS( + platform::dynload::cublasLtMatrixLayoutDestroy(x_desc_)); + PADDLE_ENFORCE_GPU_SUCCESS( + platform::dynload::cublasLtMatrixLayoutDestroy(w_desc_)); + PADDLE_ENFORCE_GPU_SUCCESS( + platform::dynload::cublasLtMatrixLayoutDestroy(out_desc_)); + } + + // Change to use tensor's shape. + void Setup(const phi::DDim &x_shape, + const phi::DDim &w_shape, + bool trans_x, + bool trans_w) { + int64_t M = trans_x ? x_shape[1] : x_shape[0]; + int64_t K = trans_w ? w_shape[1] : w_shape[0]; + int64_t N = trans_w ? w_shape[0] : w_shape[1]; + + cublasOperation_t cublas_transA = trans_x ? CUBLAS_OP_T : CUBLAS_OP_N; + cublasOperation_t cublas_transB = trans_w ? CUBLAS_OP_T : CUBLAS_OP_N; + PADDLE_ENFORCE_GPU_SUCCESS( + platform::dynload::cublasLtMatmulDescSetAttribute( + operation_desc_, + CUBLASLT_MATMUL_DESC_TRANSB, + &cublas_transA, + sizeof(cublas_transA))); + PADDLE_ENFORCE_GPU_SUCCESS( + platform::dynload::cublasLtMatmulDescSetAttribute( + operation_desc_, + CUBLASLT_MATMUL_DESC_TRANSA, + &cublas_transB, + sizeof(cublas_transB))); + + /* + cublas use col major: x(M, K) matmul w(K, N) = out(M, N) equals to w_t(N, K) + * x_t(K, M) = out(N, M) + */ + SetCublasMatrixLayout_(x_desc_, cublas_transA, K, M); + SetCublasMatrixLayout_(w_desc_, cublas_transB, N, K); + SetCublasMatrixLayout_(out_desc_, CUBLAS_OP_N, N, M); + } + + void ComputeForward(const phi::DenseTensor *input, + const phi::DenseTensor *weight, + const phi::DenseTensor *bias, + phi::DenseTensor *residual, + phi::DenseTensor *output, + const std::string &activation) { + // here: (transa, transb): nt, input * weight. + // (M * K) * (K * N) + cublasLtHandle_t lt_handle = dev_ctx_.cublaslt_handle(); + size_t workspace_size = static_cast(16) * 1024 * 1024; + cudaStream_t stream = dev_ctx_.stream(); + memory::allocation::AllocationPtr workspace = + memory::Alloc(dev_ctx_.GetPlace(), + workspace_size, + phi::Stream(reinterpret_cast(stream))); + + const bool add_residual = (residual == nullptr) ? false : true; + const bool add_bias = (bias == nullptr) ? false : true; + if (add_bias) { + SetCublasBiasPtr_(bias); + } + + // Set cublasLt epilogue. + cublasLtEpilogue_t epiloque_func = GetEpilogueType_(activation, add_bias); + PADDLE_ENFORCE_GPU_SUCCESS( + platform::dynload::cublasLtMatmulDescSetAttribute( + operation_desc_, + CUBLASLT_MATMUL_DESC_EPILOGUE, + &epiloque_func, + sizeof(epiloque_func))); + + const auto *x_data = input->data(); + const auto *w_data = weight->data(); + auto *residual_data = + add_residual ? residual->data() : output->data(); + auto *out_data = output->data(); + + // if add_residual, we compute result + 1.0 * residual, else result + 0.0 * + // out. + double alpha64 = 1.0, beta64 = add_residual ? 1.0 : 0.0; + float alpha32 = 1.0f, beta32 = add_residual ? 1.0f : 0.0f; + void *alpha = nullptr, *beta = nullptr; + if (std::is_same::value) { + alpha = &alpha64; + beta = &beta64; + } else { + alpha = &alpha32; + beta = &beta32; + } + + auto algo = GemmEpilogueAlgoCache::Instance().GetGemmAlgo(lt_handle, + operation_desc_, + w_desc_, + x_desc_, + out_desc_, + alpha, + beta, + w_data, + x_data, + out_data, + stream, + workspace->ptr(), + workspace_size); + + PADDLE_ENFORCE_GPU_SUCCESS( + platform::dynload::cublasLtMatmul(lt_handle, + operation_desc_, + alpha, + w_data, + w_desc_, + x_data, + x_desc_, + beta, + residual_data, + out_desc_, + out_data, + out_desc_, + algo /*algo*/, + workspace->ptr() /*workspace*/, + workspace_size, + stream)); + } + + private: + static cublasLtEpilogue_t GetEpilogueType_(const std::string &activation, + const bool add_bias) { + if (activation == "relu") { + if (add_bias) { + return CUBLASLT_EPILOGUE_RELU_BIAS; + } else { + return CUBLASLT_EPILOGUE_RELU; + } + } else if (activation == "gelu") { + if (add_bias) { + return CUBLASLT_EPILOGUE_GELU_BIAS; + } else { + return CUBLASLT_EPILOGUE_GELU; + } + } else if (activation == "none") { + if (add_bias) { + return CUBLASLT_EPILOGUE_BIAS; + } else { + return CUBLASLT_EPILOGUE_DEFAULT; + } + } else { + PADDLE_ENFORCE_EQ( + true, + false, + platform::errors::InvalidArgument( + "The activation attribute of fused_gemm_epilogue op should be" + " one of {\"none\", \"relu\", \"gelu\"}. But received %s." + "But received activation=%s.", + activation)); + } + } + + void SetCublasMatrixLayout_(cublasLtMatrixLayout_t layout_desc, + cublasOperation_t cublas_trans, + const size_t cublas_m, + const size_t cublas_n) { + PADDLE_ENFORCE_GPU_SUCCESS( + platform::dynload::cublasLtMatrixLayoutSetAttribute( + layout_desc, + CUBLASLT_MATRIX_LAYOUT_ROWS, + cublas_trans == CUBLAS_OP_N ? &cublas_m : &cublas_n, + sizeof(cublas_m))); + PADDLE_ENFORCE_GPU_SUCCESS( + platform::dynload::cublasLtMatrixLayoutSetAttribute( + layout_desc, + CUBLASLT_MATRIX_LAYOUT_COLS, + cublas_trans == CUBLAS_OP_N ? &cublas_n : &cublas_m, + sizeof(cublas_m))); + const size_t cublas_ld = cublas_trans == CUBLAS_OP_N ? cublas_m : cublas_n; + PADDLE_ENFORCE_GPU_SUCCESS( + platform::dynload::cublasLtMatrixLayoutSetAttribute( + layout_desc, + CUBLASLT_MATRIX_LAYOUT_LD, + &cublas_ld, + sizeof(cublas_ld))); + } + + void SetCublasBiasPtr_(const phi::DenseTensor *bias) { + const T *bias_data = bias->data(); + PADDLE_ENFORCE_GPU_SUCCESS( + platform::dynload::cublasLtMatmulDescSetAttribute( + operation_desc_, + CUBLASLT_MATMUL_DESC_BIAS_POINTER, + &bias_data, + sizeof(bias_data))); + } + + const phi::GPUContext &dev_ctx_; + cublasLtMatmulDesc_t operation_desc_; + cublasLtMatrixLayout_t x_desc_; + cublasLtMatrixLayout_t w_desc_; + cublasLtMatrixLayout_t out_desc_; +}; + +#endif // PADDLE_FLUID_OPERATORS_FUSED_FUSED_MULTI_TRANSFORMER_OP_CU_H_ + } // namespace } // namespace operators