diff --git a/paddle/fluid/operators/fused/attn_gemm.h b/paddle/fluid/operators/fused/attn_gemm.h index c0157c8cb04dda07850960e6b63aae4a65a9c12d..28d9454c2babc118b99f7ee016a8b3da34cb811a 100644 --- a/paddle/fluid/operators/fused/attn_gemm.h +++ b/paddle/fluid/operators/fused/attn_gemm.h @@ -14,12 +14,13 @@ limitations under the License. */ #pragma once -#include "paddle/fluid/operators/kernel_primitives/kernel_primitives.h" +#include "paddle/fluid/operators/fused/fused_gemm_epilogue_op.h" #include "paddle/fluid/operators/reduce_ops/reduce_op.cu.h" #include "paddle/fluid/platform/float16.h" #include "paddle/phi/kernels/funcs/blas/blas.h" #include "paddle/phi/kernels/funcs/broadcast_function.h" #include "paddle/phi/kernels/funcs/elementwise_functor.h" +#include "paddle/phi/kernels/primitive/kernel_primitives.h" namespace paddle { namespace operators { @@ -44,13 +45,43 @@ class AttnMatMul { input_size_(input_size), compute_bias_(compute_bias) {} - ~AttnMatMul() {} - void ComputeForward(const phi::DenseTensor* weight, const phi::DenseTensor* input, const phi::DenseTensor* bias, phi::DenseTensor* output, - phi::DenseTensor* bias_out) { + phi::DenseTensor* bias_out, + bool fused = false) { + VLOG(6) << "input.shape={" << input->dims() << "}, weight.shape={" + << weight->dims() << "}, output.shape={" << output->dims() + << "}, batch_size=" << bsz_seq_ << ", output_size=" << output_size_ + << ", input_size=" << input_size_ << ", transA=" << transA_ + << ", transB=" << transB_ << ", compute_bias=" << compute_bias_ + << ", fused=" << fused; + +#if defined(PADDLE_WITH_CUDA) && CUDA_VERSION >= 11060 + if (compute_bias_ && fused) { + PADDLE_ENFORCE_EQ( + !output || output == bias_out, + true, + phi::errors::InvalidArgument( + "The output (= input * weight) is expected to be nullptr or the " + "same as bias_out when fused is true.")); + ComputeFusedGemmEpilogueForward(dev_ctx_, + input, + weight, + bias, + bsz_seq_, // M + output_size_, // N + input_size_, // K + transA_, + transB_, + "none", + bias_out, + nullptr); + return; + } +#endif + // Note: for blas.GEMM API in Paddle, it treats all inputs as row-major. // here: (transa, transb): nt, input * weight. CBLAS_TRANSPOSE transA = transA_ ? CblasTrans : CblasNoTrans; @@ -85,7 +116,29 @@ class AttnMatMul { phi::DenseTensor* d_input, phi::DenseTensor* d_weight, phi::DenseTensor* d_bias, - bool use_addto = false) { + bool use_addto = false, + bool fused = false) { +#if defined(PADDLE_WITH_CUDA) && CUDA_VERSION >= 11060 + if (compute_bias_ && fused) { + ComputeFusedGemmEpilogueBackward(dev_ctx_, + d_output, + input, + weight, + nullptr, + bsz_seq_, // M + output_size_, // N + input_size_, // K + transA_, + transB_, + "none", + d_input, + d_weight, + d_bias, + use_addto); + return; + } +#endif + T alpha = static_cast(1.0); T beta_dA = use_addto ? static_cast(1.0) : static_cast(0.0); T beta_dB = static_cast(0.0); diff --git a/paddle/fluid/operators/fused/fmha_ref.h b/paddle/fluid/operators/fused/fmha_ref.h index fd8f6c6014a6a63f29589b8c149b6f1ec7db901a..1d83c7a62b1d94031c0b6bbde81dd5504a056068 100644 --- a/paddle/fluid/operators/fused/fmha_ref.h +++ b/paddle/fluid/operators/fused/fmha_ref.h @@ -15,6 +15,7 @@ limitations under the License. */ #pragma once #include "paddle/fluid/operators/fused/fused_softmax_mask.cu.h" +#include "paddle/phi/kernels/funcs/blas/blas.h" #include "paddle/phi/kernels/funcs/broadcast_function.h" #include "paddle/phi/kernels/funcs/concat_and_split_functor.h" #include "paddle/phi/kernels/funcs/dropout_impl.cu.h" diff --git a/paddle/fluid/operators/fused/fused_gate_attention_op.cu b/paddle/fluid/operators/fused/fused_gate_attention_op.cu index b4fb69a58226ca1cf304e8f633f83db1e33b8899..ca7b70b220f380aa5df98b0fea3a49fd02bbb907 100644 --- a/paddle/fluid/operators/fused/fused_gate_attention_op.cu +++ b/paddle/fluid/operators/fused/fused_gate_attention_op.cu @@ -209,7 +209,8 @@ void ComputeGatingLinearForward(const framework::ExecutionContext &ctx, const GateAttentionConfig &config, const phi::DenseTensor *query, const phi::DenseTensor *fmha_out, - phi::DenseTensor *gate_out) { + phi::DenseTensor *gate_bias_out, + bool use_fused_matmul_bias) { auto *gate_weight = ctx.Input("GateWeight"); auto *gate_bias = ctx.Input("GateBias"); @@ -220,14 +221,18 @@ void ComputeGatingLinearForward(const framework::ExecutionContext &ctx, int m = config.batch_size * config.seq_len_m * config.seq_len_r; int n = config.num_heads * config.head_dim; int k = config.q_dim; - auto gate_attn_compute = + auto gate_linear = AttnMatMul(ctx.cuda_device_context(), false, false, m, n, k, true); - gate_attn_compute.ComputeForward( - gate_weight, query, gate_bias, gate_out, gate_out); + gate_linear.ComputeForward(gate_weight, + query, + gate_bias, + gate_bias_out, + gate_bias_out, + use_fused_matmul_bias); // gate_out = sigmoid(gate_out) * fmha_out - std::vector ins = {gate_out, fmha_out}; - std::vector outs = {gate_out}; + std::vector ins = {gate_bias_out, fmha_out}; + std::vector outs = {gate_bias_out}; phi::funcs::ElementwiseKernel( ctx.cuda_device_context(), ins, &outs, SigmoidMultiplyFunctor()); } @@ -239,10 +244,12 @@ void ComputeGatingLinearBackward(const framework::ExecutionContext &ctx, const phi::DenseTensor *fmha_out, const phi::DenseTensor *gate_out_grad, phi::DenseTensor *query_grad, - phi::DenseTensor *fmha_out_grad) { + phi::DenseTensor *fmha_out_grad, + bool use_fused_matmul_bias) { const auto *gate_weight = ctx.Input("GateWeight"); const auto *gate_bias = ctx.Input("GateBias"); auto &dev_ctx = ctx.template device_context(); + // Re-compute gate_bias_out phi::DenseTensor gate_bias_out; gate_bias_out.Resize(config.gate_out_dims); @@ -251,10 +258,14 @@ void ComputeGatingLinearBackward(const framework::ExecutionContext &ctx, int m = config.batch_size * config.seq_len_m * config.seq_len_r; int n = config.num_heads * config.head_dim; int k = config.q_dim; - auto gate_attn_compute = + auto gate_linear = AttnMatMul(ctx.cuda_device_context(), false, false, m, n, k, true); - gate_attn_compute.ComputeForward( - gate_weight, query, gate_bias, &gate_bias_out, &gate_bias_out); + gate_linear.ComputeForward(gate_weight, + query, + gate_bias, + &gate_bias_out, + &gate_bias_out, + use_fused_matmul_bias); // Gradient of sigmoid(gate_bias_out) * fmha_out // Compute inplace and save gate_bias_out_grad to gate_bias_out. @@ -272,19 +283,22 @@ void ComputeGatingLinearBackward(const framework::ExecutionContext &ctx, dev_ctx.Alloc(gate_weight_grad, gate_weight_grad->numel() * sizeof(T)); dev_ctx.Alloc(gate_bias_grad, gate_bias_grad->numel() * sizeof(T)); - gate_attn_compute.ComputeBackward(query, - gate_weight, - &gate_bias_out, - query_grad, - gate_weight_grad, - gate_bias_grad); + gate_linear.ComputeBackward(query, + gate_weight, + &gate_bias_out, + query_grad, + gate_weight_grad, + gate_bias_grad, + false, + use_fused_matmul_bias); } template void ComputeOutputLinearForward(const framework::ExecutionContext &ctx, const GateAttentionConfig &config, const phi::DenseTensor *fmha_or_gate_out, - phi::DenseTensor *out) { + phi::DenseTensor *out, + bool use_fused_matmul_bias) { const auto *out_linear_weight = ctx.Input("OutLinearWeight"); const auto *out_linear_bias = ctx.Input("OutLinearBias"); @@ -293,17 +307,22 @@ void ComputeOutputLinearForward(const framework::ExecutionContext &ctx, int m = config.batch_size * config.seq_len_m * config.seq_len_r; int n = config.q_dim; int k = config.num_heads * config.head_dim; - auto out_linear_compute = + auto out_linear = AttnMatMul(ctx.cuda_device_context(), false, false, m, n, k, true); - out_linear_compute.ComputeForward( - out_linear_weight, fmha_or_gate_out, out_linear_bias, out, out); + out_linear.ComputeForward(out_linear_weight, + fmha_or_gate_out, + out_linear_bias, + out, + out, + use_fused_matmul_bias); } template void ComputeOutputLinearBackward(const framework::ExecutionContext &ctx, const GateAttentionGradConfig &config, const phi::DenseTensor *input, - phi::DenseTensor *input_grad) { + phi::DenseTensor *input_grad, + bool use_fused_matmul_bias) { auto &dev_ctx = ctx.template device_context(); const auto *out_grad = ctx.Input(framework::GradVarName("Out")); @@ -323,14 +342,16 @@ void ComputeOutputLinearBackward(const framework::ExecutionContext &ctx, int m = config.batch_size * config.seq_len_m * config.seq_len_r; int n = config.q_dim; int k = config.num_heads * config.head_dim; - auto out_linear_compute = + auto out_linear = AttnMatMul(ctx.cuda_device_context(), false, false, m, n, k, true); - out_linear_compute.ComputeBackward(input, - out_linear_weight, - out_grad, - input_grad, - out_linear_weight_grad, - out_linear_bias_grad); + out_linear.ComputeBackward(input, + out_linear_weight, + out_grad, + input_grad, + out_linear_weight_grad, + out_linear_bias_grad, + false, + use_fused_matmul_bias); } template @@ -358,6 +379,7 @@ class FusedGateAttentionOpKernel : public framework::OpKernel { const bool merge_qkv = ctx.Attr("merge_qkv"); const bool has_gating = ctx.Attr("has_gating"); + bool use_fused_matmul_bias = true; auto &dev_ctx = ctx.template device_context(); AllocWithDebugInfo(dev_ctx, "softmax_out", softmax_out); AllocWithDebugInfo(dev_ctx, "fmha_out", fmha_out); @@ -413,12 +435,14 @@ class FusedGateAttentionOpKernel : public framework::OpKernel { // 3. Gating Linear if (has_gating) { - ComputeGatingLinearForward(ctx, config, query, fmha_out, gate_out); + ComputeGatingLinearForward( + ctx, config, query, fmha_out, gate_out, use_fused_matmul_bias); } // 4. Output Linear phi::DenseTensor *fmha_or_gate_out = has_gating ? gate_out : fmha_out; - ComputeOutputLinearForward(ctx, config, fmha_or_gate_out, out); + ComputeOutputLinearForward( + ctx, config, fmha_or_gate_out, out, use_fused_matmul_bias); } }; @@ -454,6 +478,7 @@ class FusedGateAttentionGradKernel : public framework::OpKernel { bool has_gating = ctx.Attr("has_gating"); bool merge_qkv = ctx.Attr("merge_qkv"); + bool use_fused_matmul_bias = true; auto &dev_ctx = ctx.template device_context(); AllocWithDebugInfo(dev_ctx, "query_grad", query_grad); @@ -468,7 +493,8 @@ class FusedGateAttentionGradKernel : public framework::OpKernel { phi::DenseTensor gate_out_grad; gate_out_grad.Resize(config.gate_out_dims); AllocWithDebugInfo(dev_ctx, "gate_out_grad", &gate_out_grad); - ComputeOutputLinearBackward(ctx, config, gate_out, &gate_out_grad); + ComputeOutputLinearBackward( + ctx, config, gate_out, &gate_out_grad, use_fused_matmul_bias); // 2. Gradient of Gating Linear // Forward: gate_out = Sigmoid(Linear(fmha_out)) * fmha_out @@ -478,10 +504,12 @@ class FusedGateAttentionGradKernel : public framework::OpKernel { fmha_out, &gate_out_grad, query_grad, - &fmha_out_grad); + &fmha_out_grad, + use_fused_matmul_bias); } else { // 1. Gradient of Output Linear: out = Linear(fmha_grad) - ComputeOutputLinearBackward(ctx, config, fmha_out, &fmha_out_grad); + ComputeOutputLinearBackward( + ctx, config, fmha_out, &fmha_out_grad, use_fused_matmul_bias); } // 3. Gradient of FMHA diff --git a/paddle/fluid/operators/fused/fused_gemm_epilogue_op.cc b/paddle/fluid/operators/fused/fused_gemm_epilogue_op.cc index 187eb4fc07ea2f3dbc422ab8c618190a4c75bfec..fc55b68fafeb6d68752ce6628fe2b3454e56719c 100644 --- a/paddle/fluid/operators/fused/fused_gemm_epilogue_op.cc +++ b/paddle/fluid/operators/fused/fused_gemm_epilogue_op.cc @@ -14,7 +14,6 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/fused/fused_gemm_epilogue_op.h" - #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_version_registry.h" diff --git a/paddle/fluid/operators/fused/fused_gemm_epilogue_op.cu b/paddle/fluid/operators/fused/fused_gemm_epilogue_op.cu index 0964a00cb23350cba9f457bbc8658d05892f386d..a6a136391de2514cadaf7ce9b06788e97929d3b9 100644 --- a/paddle/fluid/operators/fused/fused_gemm_epilogue_op.cu +++ b/paddle/fluid/operators/fused/fused_gemm_epilogue_op.cu @@ -16,14 +16,14 @@ limitations under the License. */ #include "paddle/fluid/operators/fused/fused_gemm_epilogue_op.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_version_registry.h" -#include "paddle/fluid/framework/scope_guard.h" #include "paddle/fluid/platform/bfloat16.h" -#include "paddle/fluid/platform/dynload/cublasLt.h" #include "paddle/fluid/platform/float16.h" namespace paddle { namespace operators { +#if CUDA_VERSION >= 11060 + template class FusedGemmEpilogueKernel : public framework::OpKernel { public: @@ -42,294 +42,36 @@ class FusedGemmEpilogueKernel : public framework::OpKernel { bool trans_y = ctx.Attr("trans_y"); std::string activation = ctx.Attr("activation"); - VLOG(10) << "trans_x = " << trans_x << " , trans_y = " << trans_y - << " , activation = " << activation; - bool enable_auxiliary = reserve_space == nullptr ? false : true; - dev_ctx.Alloc(out, out->numel() * sizeof(T)); - auto* out_data = out->data(); + // (M * K) * (K * N) auto x_mat_dims = phi::flatten_to_2d(x->dims(), trans_x ? 1 : x->dims().size() - 1); - // (M * K) * (K * N) int64_t M = trans_x ? x_mat_dims[1] : x_mat_dims[0]; int64_t K = trans_y ? y->dims()[1] : y->dims()[0]; int64_t N = trans_y ? y->dims()[0] : y->dims()[1]; - cudaDataType_t mat_type = CUDA_R_32F; - cudaDataType_t scale_type = CUDA_R_32F; - cublasComputeType_t compute_type = CUBLAS_COMPUTE_32F; - if (std::is_same::value) { - mat_type = CUDA_R_16F; - } - if (std::is_same::value) { - mat_type = CUDA_R_16BF; - } - if (std::is_same::value) { - mat_type = CUDA_R_64F; - scale_type = CUDA_R_64F; - compute_type = CUBLAS_COMPUTE_64F; - } - - cublasLtMatmulDesc_t operation_desc = NULL; - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatmulDescCreate( - &operation_desc, compute_type, scale_type)); - cublasOperation_t transx = trans_x ? CUBLAS_OP_T : CUBLAS_OP_N; - cublasOperation_t transy = trans_y ? CUBLAS_OP_T : CUBLAS_OP_N; - PADDLE_ENFORCE_GPU_SUCCESS( - platform::dynload::cublasLtMatmulDescSetAttribute( - operation_desc, - CUBLASLT_MATMUL_DESC_TRANSB, - &transx, - sizeof(transx))); - PADDLE_ENFORCE_GPU_SUCCESS( - platform::dynload::cublasLtMatmulDescSetAttribute( - operation_desc, - CUBLASLT_MATMUL_DESC_TRANSA, - &transy, - sizeof(transy))); - - cublasLtEpilogue_t epiloque_func = - get_epilogue_type_(activation, enable_auxiliary); - PADDLE_ENFORCE_GPU_SUCCESS( - platform::dynload::cublasLtMatmulDescSetAttribute( - operation_desc, - CUBLASLT_MATMUL_DESC_EPILOGUE, - &epiloque_func, - sizeof(epiloque_func))); - const T* bias_data = bias->data(); - PADDLE_ENFORCE_GPU_SUCCESS( - platform::dynload::cublasLtMatmulDescSetAttribute( - operation_desc, - CUBLASLT_MATMUL_DESC_BIAS_POINTER, - &bias_data, - sizeof(bias_data))); - - if (enable_auxiliary && activation != "none") { - // Note (Ming Huang): The initialization of ReseveSpace is happened in the - // dev_ctx.Alloc. Therefore, we set real date type up here. - if (activation == "relu") { - paddle::experimental::DataType rs_type = - paddle::experimental::DataType::BOOL; - size_t reserve_space_size = - phi::product(reserve_space->dims()) * SizeOf(rs_type); - dev_ctx.Alloc(reserve_space, rs_type, reserve_space_size); - } else { - size_t reserve_space_size = - phi::product(reserve_space->dims()) * sizeof(T); - dev_ctx.Alloc(reserve_space, reserve_space_size); - } - - void* aux_data = reserve_space->data(); - - PADDLE_ENFORCE_GPU_SUCCESS( - platform::dynload::cublasLtMatmulDescSetAttribute( - operation_desc, - CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER, - &aux_data, - sizeof(aux_data))); - int64_t aux_ld = N; - PADDLE_ENFORCE_GPU_SUCCESS( - platform::dynload::cublasLtMatmulDescSetAttribute( - operation_desc, - CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_LD, - &aux_ld, - sizeof(aux_ld))); - } - - cublasLtMatrixLayout_t x_desc = NULL, y_desc = NULL, out_desc = NULL; - if (trans_x) - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatrixLayoutCreate( - &x_desc, mat_type, M, K, M)); - else - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatrixLayoutCreate( - &x_desc, mat_type, K, M, K)); - if (trans_y) - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatrixLayoutCreate( - &y_desc, mat_type, K, N, K)); - else - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatrixLayoutCreate( - &y_desc, mat_type, N, K, N)); - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatrixLayoutCreate( - &out_desc, mat_type, N, M, N)); - - cublasLtHandle_t lt_handle = dev_ctx.cublaslt_handle(); - // NOTE(zengjinle): I do not know whether the 4MB workspace size is - // "enough". I just followed the settings from the NVIDIA MLPerf BERT code. - size_t workspace_size = static_cast(4) * 1024 * 1024; - cudaStream_t stream = dev_ctx.stream(); - memory::allocation::AllocationPtr workspace = memory::Alloc( - dev_ctx.GetPlace(), - workspace_size, - phi::Stream(reinterpret_cast(dev_ctx.stream()))); - - double alpha64 = 1.0, beta64 = 0.0; - float alpha32 = 1.0f, beta32 = 0.0f; - void *alpha = nullptr, *beta = nullptr; - if (std::is_same::value) { - alpha = &alpha64; - beta = &beta64; - } else { - alpha = &alpha32; - beta = &beta32; - } - - const auto* y_data = y->data(); - const auto* x_data = x->data(); - - auto algo = GemmEpilogueAlgoCache::Instance().GetGemmAlgo(lt_handle, - operation_desc, - y_desc, - x_desc, - out_desc, - alpha, - beta, - y_data, - x_data, - out_data, - stream, - workspace->ptr(), - workspace_size); - PADDLE_ENFORCE_GPU_SUCCESS( - platform::dynload::cublasLtMatmul(lt_handle, - operation_desc, - alpha, - y_data, - y_desc, - x_data, - x_desc, - beta, - out_data, - out_desc, - out_data, - out_desc, - algo, - workspace->ptr(), - workspace_size, - stream)); - - PADDLE_ENFORCE_GPU_SUCCESS( - platform::dynload::cublasLtMatmulDescDestroy(operation_desc)); - PADDLE_ENFORCE_GPU_SUCCESS( - platform::dynload::cublasLtMatrixLayoutDestroy(y_desc)); - PADDLE_ENFORCE_GPU_SUCCESS( - platform::dynload::cublasLtMatrixLayoutDestroy(x_desc)); - PADDLE_ENFORCE_GPU_SUCCESS( - platform::dynload::cublasLtMatrixLayoutDestroy(out_desc)); - } - - private: - static cublasLtEpilogue_t get_epilogue_type_(const std::string& activation, - bool enable_auxiliary) { - if (activation == "relu") { - return enable_auxiliary ? CUBLASLT_EPILOGUE_RELU_AUX_BIAS - : CUBLASLT_EPILOGUE_RELU_BIAS; - } else if (activation == "gelu") { - return enable_auxiliary ? CUBLASLT_EPILOGUE_GELU_AUX_BIAS - : CUBLASLT_EPILOGUE_GELU_BIAS; - } else if (activation == "none") { - return CUBLASLT_EPILOGUE_BIAS; - } else { - PADDLE_ENFORCE_EQ( - true, - false, - platform::errors::InvalidArgument( - "The activation attribute of fused_gemm_epilogue op should be" - " one of {\"none\", \"relu\", \"gelu\"}. But received %s." - "But received activation=%s.", - activation)); - } + ComputeFusedGemmEpilogueForward(dev_ctx, + x, + y, + bias, + M, + N, + K, + trans_x, + trans_y, + activation, + out, + reserve_space); } }; -enum FusedGEMMGradInType { kDX = 0, kDY = 1, kDZ = 2 }; - -template -struct FusedGEMMGradTrait; - -template <> -struct FusedGEMMGradTrait { - static constexpr auto kXGradA = FusedGEMMGradInType::kDZ; - static constexpr auto kXGradB = FusedGEMMGradInType::kDY; - static constexpr auto kXGradATrans = false; - static constexpr auto kXGradBTrans = true; - - static constexpr auto kYGradA = FusedGEMMGradInType::kDX; - static constexpr auto kYGradB = FusedGEMMGradInType::kDZ; - static constexpr auto kYGradATrans = true; - static constexpr auto kYGradBTrans = false; -}; - -template <> -struct FusedGEMMGradTrait { - static constexpr auto kXGradA = FusedGEMMGradInType::kDY; - static constexpr auto kXGradB = FusedGEMMGradInType::kDZ; - static constexpr auto kXGradATrans = false; - static constexpr auto kXGradBTrans = true; - - static constexpr auto kYGradA = FusedGEMMGradInType::kDX; - static constexpr auto kYGradB = FusedGEMMGradInType::kDZ; - static constexpr auto kYGradATrans = false; - static constexpr auto kYGradBTrans = false; -}; - -template <> -struct FusedGEMMGradTrait { - static constexpr auto kXGradA = FusedGEMMGradInType::kDZ; - static constexpr auto kXGradB = FusedGEMMGradInType::kDY; - static constexpr auto kXGradATrans = false; - static constexpr auto kXGradBTrans = false; - - static constexpr auto kYGradA = FusedGEMMGradInType::kDZ; - static constexpr auto kYGradB = FusedGEMMGradInType::kDX; - static constexpr auto kYGradATrans = true; - static constexpr auto kYGradBTrans = false; -}; - -template <> -struct FusedGEMMGradTrait { - static constexpr auto kXGradA = FusedGEMMGradInType::kDY; - static constexpr auto kXGradB = FusedGEMMGradInType::kDZ; - static constexpr auto kXGradATrans = true; - static constexpr auto kXGradBTrans = true; - - static constexpr auto kYGradA = FusedGEMMGradInType::kDZ; - static constexpr auto kYGradB = FusedGEMMGradInType::kDX; - static constexpr auto kYGradATrans = true; - static constexpr auto kYGradBTrans = true; -}; - -static constexpr auto BoolToCuBlasEnum(bool transpose) { - return transpose ? CUBLAS_OP_T : CUBLAS_OP_N; -} - template class FusedGemmEpilogueGradKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { - bool transpose_x = ctx.Attr("trans_x"); - bool transpose_y = ctx.Attr("trans_y"); - - if (transpose_x) { - if (transpose_y) { - ComputeImpl(ctx); - } else { - ComputeImpl(ctx); - } - } else { - if (transpose_y) { - ComputeImpl(ctx); - } else { - ComputeImpl(ctx); - } - } - } - - private: - template - static void ComputeImpl(const framework::ExecutionContext& ctx) { - using Trait = FusedGEMMGradTrait; auto& dev_ctx = ctx.template device_context(); + const phi::DenseTensor* dout = ctx.Input("DOut"); const phi::DenseTensor* x = ctx.Input("X"); const phi::DenseTensor* y = ctx.Input("Y"); @@ -342,352 +84,33 @@ class FusedGemmEpilogueGradKernel : public framework::OpKernel { std::string activation_grad = ctx.Attr("activation_grad"); - VLOG(10) << "trans_x = " << TransX << " , trans_y = " << TransY - << " , activation_grad = " << activation_grad; - - auto x_mat_dims = - phi::flatten_to_2d(x->dims(), TransX ? 1 : x->dims().size() - 1); + bool trans_x = ctx.Attr("trans_x"); + bool trans_y = ctx.Attr("trans_y"); // (M * K) * (K * N) - int64_t M = TransX ? x_mat_dims[1] : x_mat_dims[0]; - int64_t K = TransY ? y->dims()[1] : y->dims()[0]; - int64_t N = TransY ? y->dims()[0] : y->dims()[1]; - - VLOG(10) << "M = " << M << " , K = " << K << " , N = " << N; - - cudaDataType_t mat_type = CUDA_R_32F; - cudaDataType_t scale_type = CUDA_R_32F; - cublasComputeType_t compute_type = CUBLAS_COMPUTE_32F; - if (std::is_same::value) { - mat_type = CUDA_R_16F; - } - if (std::is_same::value) { - mat_type = CUDA_R_16BF; - } - if (std::is_same::value) { - mat_type = CUDA_R_64F; - scale_type = CUDA_R_64F; - compute_type = CUBLAS_COMPUTE_64F; - } - - cublasLtHandle_t lt_handle = dev_ctx.cublaslt_handle(); - // NOTE(zengjinle): I do not know whether the 4MB workspace size is - // "enough". I just followed the settings from the NVIDIA MLPerf BERT code. - size_t workspace_size = static_cast(4) * 1024 * 1024; - const cublasLtMatmulAlgo_t* algo = nullptr; - cudaStream_t stream = dev_ctx.stream(); - - double alpha64 = 1.0, beta64 = 0.0; - float alpha32 = 1.0f, beta32 = 0.0f; - void *alpha = nullptr, *beta = nullptr; - if (std::is_same::value) { - alpha = &alpha64; - beta = &beta64; - } else { - alpha = &alpha32; - beta = &beta32; - } - - cublasLtMatrixLayout_t dout_desc = nullptr, dout_trans_desc = nullptr; - cublasLtMatrixLayout_t x_desc = nullptr, x_trans_desc = nullptr; - cublasLtMatrixLayout_t y_desc = nullptr, y_trans_desc = nullptr; - cublasLtMatrixLayout_t dx_desc = nullptr, dy_desc = nullptr; - cublasLtMatmulDesc_t dx_operation_desc = nullptr, - dy_operation_desc = nullptr; - - DEFINE_PADDLE_SCOPE_GUARD([&] { - auto descs = {dout_desc, - dout_trans_desc, - x_desc, - x_trans_desc, - y_desc, - y_trans_desc, - dx_desc, - dy_desc}; - for (auto desc : descs) { - if (desc) { - PADDLE_ENFORCE_GPU_SUCCESS( - platform::dynload::cublasLtMatrixLayoutDestroy(desc)); - } - } - - if (dx_operation_desc) { - PADDLE_ENFORCE_GPU_SUCCESS( - platform::dynload::cublasLtMatmulDescDestroy(dx_operation_desc)); - } - - if (dy_operation_desc) { - PADDLE_ENFORCE_GPU_SUCCESS( - platform::dynload::cublasLtMatmulDescDestroy(dy_operation_desc)); - } - }); - - auto x_row = TransX ? K : M; - auto x_col = TransX ? M : K; - auto y_row = TransY ? N : K; - auto y_col = TransY ? K : N; - auto z_row = TransX ? N : M; - auto z_col = TransX ? M : N; - - // dx = func(dout, y) - if (dx) { - constexpr auto kXGradAIsDZ = (Trait::kXGradA == FusedGEMMGradInType::kDZ); - cublasLtMatrixLayout_t *dx_dout_desc, *dx_y_desc; - - if (TransX) { - dx_dout_desc = &dout_trans_desc; - PADDLE_ENFORCE_GPU_SUCCESS( - platform::dynload::cublasLtMatrixLayoutCreate( - dx_dout_desc, mat_type, z_row, z_col, z_row)); - } else { - dx_dout_desc = &dout_desc; - PADDLE_ENFORCE_GPU_SUCCESS( - platform::dynload::cublasLtMatrixLayoutCreate( - dx_dout_desc, mat_type, z_col, z_row, z_col)); - } - - dx_y_desc = &y_trans_desc; - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatrixLayoutCreate( - dx_y_desc, mat_type, y_col, y_row, y_col)); - - auto& a_desc = kXGradAIsDZ ? (*dx_dout_desc) : (*dx_y_desc); - auto& b_desc = kXGradAIsDZ ? (*dx_y_desc) : (*dx_dout_desc); - auto a_trans = BoolToCuBlasEnum(Trait::kXGradATrans); - auto b_trans = BoolToCuBlasEnum(Trait::kXGradBTrans); - - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatrixLayoutCreate( - &dx_desc, mat_type, x_col, x_row, x_col)); - - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatmulDescCreate( - &dx_operation_desc, compute_type, scale_type)); - PADDLE_ENFORCE_GPU_SUCCESS( - platform::dynload::cublasLtMatmulDescSetAttribute( - dx_operation_desc, - CUBLASLT_MATMUL_DESC_TRANSB, - &a_trans, - sizeof(a_trans))); - PADDLE_ENFORCE_GPU_SUCCESS( - platform::dynload::cublasLtMatmulDescSetAttribute( - dx_operation_desc, - CUBLASLT_MATMUL_DESC_TRANSA, - &b_trans, - sizeof(b_trans))); - - cublasLtEpilogue_t epiloque_func_for_dx = - get_epilogue_type_(activation_grad); - PADDLE_ENFORCE_GPU_SUCCESS( - platform::dynload::cublasLtMatmulDescSetAttribute( - dx_operation_desc, - CUBLASLT_MATMUL_DESC_EPILOGUE, - &epiloque_func_for_dx, - sizeof(epiloque_func_for_dx))); - - if (activation_grad != "none") { - auto* aux_data = reserve_space->data(); - PADDLE_ENFORCE_GPU_SUCCESS( - platform::dynload::cublasLtMatmulDescSetAttribute( - dx_operation_desc, - CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER, - &aux_data, - sizeof(aux_data))); - int64_t aux_ld = TransX ? M : K; - PADDLE_ENFORCE_GPU_SUCCESS( - platform::dynload::cublasLtMatmulDescSetAttribute( - dx_operation_desc, - CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_LD, - &aux_ld, - sizeof(aux_ld))); - } - - auto dx_workspace = memory::Alloc( - dev_ctx.GetPlace(), - workspace_size, - phi::Stream(reinterpret_cast(dev_ctx.stream()))); - - auto* dx_data = dev_ctx.Alloc(dx, dx->numel() * sizeof(T)); - const auto* y_data = y->data(); - const auto* dout_data = dout->data(); - const auto* a_data = kXGradAIsDZ ? dout_data : y_data; - const auto* b_data = kXGradAIsDZ ? y_data : dout_data; - - auto algo = - GemmEpilogueAlgoCache::Instance().GetGemmAlgo(lt_handle, - dx_operation_desc, - b_desc, - a_desc, - dx_desc, - alpha, - beta, - b_data, - a_data, - dx_data, - stream, - dx_workspace->ptr(), - workspace_size); - - PADDLE_ENFORCE_GPU_SUCCESS( - platform::dynload::cublasLtMatmul(lt_handle, - dx_operation_desc, - alpha, - b_data, - b_desc, - a_data, - a_desc, - beta, - dx_data, - dx_desc, - dx_data, - dx_desc, - algo, - dx_workspace->ptr(), - workspace_size, - stream)); - } - - // dy = func(dout, x) - if (dy) { - constexpr auto kYGradAIsDZ = (Trait::kYGradA == FusedGEMMGradInType::kDZ); - - cublasLtMatrixLayout_t *dy_dout_desc = nullptr, *dy_x_desc = nullptr; - if (TransX) { - dy_dout_desc = &dout_trans_desc; - if (dout_trans_desc == nullptr) { - PADDLE_ENFORCE_GPU_SUCCESS( - platform::dynload::cublasLtMatrixLayoutCreate( - dy_dout_desc, mat_type, z_row, z_col, z_row)); - } - } else { - dy_dout_desc = &dout_desc; - if (dout_desc == nullptr) { - PADDLE_ENFORCE_GPU_SUCCESS( - platform::dynload::cublasLtMatrixLayoutCreate( - dy_dout_desc, mat_type, z_col, z_row, z_col)); - } - } - - dy_x_desc = &x_trans_desc; - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatrixLayoutCreate( - dy_x_desc, mat_type, x_col, x_row, x_col)); - - auto& a_desc = kYGradAIsDZ ? (*dy_dout_desc) : (*dy_x_desc); - auto& b_desc = kYGradAIsDZ ? (*dy_x_desc) : (*dy_dout_desc); - auto a_trans = BoolToCuBlasEnum(Trait::kYGradATrans); - auto b_trans = BoolToCuBlasEnum(Trait::kYGradBTrans); - - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatrixLayoutCreate( - &dy_desc, mat_type, y_col, y_row, y_col)); - - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatmulDescCreate( - &dy_operation_desc, compute_type, scale_type)); - - PADDLE_ENFORCE_GPU_SUCCESS( - platform::dynload::cublasLtMatmulDescSetAttribute( - dy_operation_desc, - CUBLASLT_MATMUL_DESC_TRANSB, - &a_trans, - sizeof(a_trans))); - PADDLE_ENFORCE_GPU_SUCCESS( - platform::dynload::cublasLtMatmulDescSetAttribute( - dy_operation_desc, - CUBLASLT_MATMUL_DESC_TRANSA, - &b_trans, - sizeof(b_trans))); - - cublasLtEpilogue_t epiloque_func_for_dy; - if (dbias == nullptr) { - epiloque_func_for_dy = CUBLASLT_EPILOGUE_DEFAULT; - } else { - if (TransY) { - epiloque_func_for_dy = CUBLASLT_EPILOGUE_BGRADB; - } else { - epiloque_func_for_dy = CUBLASLT_EPILOGUE_BGRADA; - } - } - - PADDLE_ENFORCE_GPU_SUCCESS( - platform::dynload::cublasLtMatmulDescSetAttribute( - dy_operation_desc, - CUBLASLT_MATMUL_DESC_EPILOGUE, - &epiloque_func_for_dy, - sizeof(epiloque_func_for_dy))); - - if (dbias) { - auto* dbias_data = dev_ctx.Alloc(dbias, dbias->numel() * sizeof(T)); - PADDLE_ENFORCE_GPU_SUCCESS( - platform::dynload::cublasLtMatmulDescSetAttribute( - dy_operation_desc, - CUBLASLT_MATMUL_DESC_BIAS_POINTER, - &dbias_data, - sizeof(dbias_data))); - } - - auto dy_workspace = memory::Alloc( - dev_ctx.GetPlace(), - workspace_size, - phi::Stream(reinterpret_cast(dev_ctx.stream()))); - auto* dy_data = dev_ctx.Alloc(dy, dy->numel() * sizeof(T)); - const auto* dout_data = dout->data(); - const auto* x_data = x->data(); - const auto* a_data = kYGradAIsDZ ? dout_data : x_data; - const auto* b_data = kYGradAIsDZ ? x_data : dout_data; - - auto algo = - GemmEpilogueAlgoCache::Instance().GetGemmAlgo(lt_handle, - dy_operation_desc, - b_desc, - a_desc, - dy_desc, - alpha, - beta, - b_data, - a_data, - dy_data, - stream, - dy_workspace->ptr(), - workspace_size); - - PADDLE_ENFORCE_GPU_SUCCESS( - platform::dynload::cublasLtMatmul(lt_handle, - dy_operation_desc, - alpha, - b_data, - b_desc, - a_data, - a_desc, - beta, - dy_data, - dy_desc, - dy_data, - dy_desc, - algo, - dy_workspace->ptr(), - workspace_size, - stream)); - } - } + auto x_mat_dims = + phi::flatten_to_2d(x->dims(), trans_x ? 1 : x->dims().size() - 1); + int64_t M = trans_x ? x_mat_dims[1] : x_mat_dims[0]; + int64_t K = trans_y ? y->dims()[1] : y->dims()[0]; + int64_t N = trans_y ? y->dims()[0] : y->dims()[1]; - private: - static cublasLtEpilogue_t get_epilogue_type_( - const std::string& activation_grad) { - if (activation_grad == "relu_grad") { - return CUBLASLT_EPILOGUE_DRELU; - } else if (activation_grad == "gelu_grad") { - return CUBLASLT_EPILOGUE_DGELU; - } else if (activation_grad == "none") { - return CUBLASLT_EPILOGUE_DEFAULT; - } else { - PADDLE_ENFORCE_EQ( - true, - false, - platform::errors::InvalidArgument( - "The activation_grad attribute of fused_gemm_epilogue op should " - "be" - " one of {\"none\", \"relu\", \"gelu\"}. But received %s." - "But received activation_grad=%s.", - activation_grad)); - } + ComputeFusedGemmEpilogueBackward(dev_ctx, + dout, + x, + y, + reserve_space, + M, + N, + K, + trans_x, + trans_y, + activation_grad, + dx, + dy, + dbias); } }; +#endif } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/fused/fused_gemm_epilogue_op.h b/paddle/fluid/operators/fused/fused_gemm_epilogue_op.h index 0f4fa9f88954db17400f8f59bed646fac2c3c30b..059cf66fac6445553715e24ca0332e1b233cb903 100644 --- a/paddle/fluid/operators/fused/fused_gemm_epilogue_op.h +++ b/paddle/fluid/operators/fused/fused_gemm_epilogue_op.h @@ -15,21 +15,26 @@ limitations under the License. */ #pragma once +#include +#include +#include + #ifdef PADDLE_WITH_CUDA -#include -#include "cuda.h" // NOLINT +#include // NOLINT +#include "cuda.h" // NOLINT #if CUDA_VERSION >= 11060 -#include -#include -#include - #include "gflags/gflags.h" +#include "paddle/fluid/framework/scope_guard.h" #include "paddle/fluid/platform/dynload/cublasLt.h" #include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/float16.h" +#include "paddle/phi/backends/all_context.h" +#include "paddle/phi/backends/gpu/cuda/cuda_helper.h" +#include "paddle/phi/common/amp_type_traits.h" +#include "paddle/phi/core/dense_tensor.h" #include "paddle/utils/optional.h" DECLARE_int64(cublaslt_exhaustive_search_times); @@ -39,27 +44,27 @@ namespace operators { class GemmEpilogueAlgoCache { public: - static GemmEpilogueAlgoCache &Instance() { + static GemmEpilogueAlgoCache& Instance() { static GemmEpilogueAlgoCache instance( FLAGS_cublaslt_exhaustive_search_times); return instance; } - GemmEpilogueAlgoCache(GemmEpilogueAlgoCache const &) = delete; - void operator=(GemmEpilogueAlgoCache const &) = delete; + GemmEpilogueAlgoCache(GemmEpilogueAlgoCache const&) = delete; + void operator=(GemmEpilogueAlgoCache const&) = delete; - cublasLtMatmulAlgo_t *GetGemmAlgo(cublasLtHandle_t lt_handle, + cublasLtMatmulAlgo_t* GetGemmAlgo(cublasLtHandle_t lt_handle, cublasLtMatmulDesc_t op_desc, cublasLtMatrixLayout_t a_desc, cublasLtMatrixLayout_t b_desc, cublasLtMatrixLayout_t c_desc, - const void *alpha, - const void *beta, - const void *a, - const void *b, - void *c, + const void* alpha, + const void* beta, + const void* a, + const void* b, + void* c, cudaStream_t stream, - void *workspace, + void* workspace, size_t workspace_size) { if (search_times_ <= 0) return nullptr; @@ -207,7 +212,7 @@ class GemmEpilogueAlgoCache { << ") not found in GemmEpilogueAlgoCache"; std::lock_guard lock(cache_mutex_); - auto &algo_in_map = map_[seed]; + auto& algo_in_map = map_[seed]; algo_in_map = ret; return &algo_in_map; } @@ -223,8 +228,8 @@ class GemmEpilogueAlgoCache { std::mutex cache_mutex_; void HashMatmulDesc_(cublasLtMatmulDesc_t desc, - int64_t *seed, - const std::hash &hash_fn) { + int64_t* seed, + const std::hash& hash_fn) { size_t size_to_write; int trans_a, trans_b; uint32_t epilogue; @@ -258,8 +263,8 @@ class GemmEpilogueAlgoCache { } void HashMatrixLayoutDesc_(cublasLtMatrixLayout_t desc, - int64_t *seed, - const std::hash &hash_fn) { + int64_t* seed, + const std::hash& hash_fn) { size_t size_to_write; uint32_t dtype; int32_t batch; @@ -317,15 +322,665 @@ class GemmEpilogueAlgoCache { HashValue_(seed, hash_fn, static_cast(batch_offset)); } - void HashValue_(int64_t *seed, - const std::hash &hash_fn, + void HashValue_(int64_t* seed, + const std::hash& hash_fn, int64_t value) { *seed ^= hash_fn(value) + 0x9e3779b9 + (*seed << 6) + (*seed >> 2); } }; +static cublasLtEpilogue_t GetEpilogueType(const std::string& activation, + bool enable_auxiliary) { + if (activation == "relu") { + return enable_auxiliary ? CUBLASLT_EPILOGUE_RELU_AUX_BIAS + : CUBLASLT_EPILOGUE_RELU_BIAS; + } else if (activation == "gelu") { + return enable_auxiliary ? CUBLASLT_EPILOGUE_GELU_AUX_BIAS + : CUBLASLT_EPILOGUE_GELU_BIAS; + } else if (activation == "none") { + return CUBLASLT_EPILOGUE_BIAS; + } else { + PADDLE_THROW(platform::errors::InvalidArgument( + "The activation attribute of fused_gemm_epilogue op should be" + " one of {\"none\", \"relu\", \"gelu\"}. But received %s." + "But received activation=%s.", + activation)); + } +} + +template +void ComputeFusedGemmEpilogueForward(const phi::GPUContext& dev_ctx, + const phi::DenseTensor* x, + const phi::DenseTensor* y, + const phi::DenseTensor* bias, + int64_t M, + int64_t N, + int64_t K, + bool trans_x, + bool trans_y, + const std::string& activation, + phi::DenseTensor* out, + phi::DenseTensor* reserve_space) { + using MT = typename phi::dtype::MPTypeTrait::Type; + + VLOG(6) << "x.shape={" << x->dims() << "}, y.shape={" << y->dims() + << "}, out.shape={" << out->dims() << "}, M=" << M << ", N=" << N + << ", K=" << K << ", trans_x=" << trans_x << ", trans_y=" << trans_y + << ", activation=" << activation + << ", reserve_space=" << reserve_space; + + bool enable_auxiliary = reserve_space == nullptr ? false : true; + auto* out_data = out->data(); + + cudaDataType_t mat_type = phi::backends::gpu::ToCudaDataType(); + cudaDataType_t scale_type = phi::backends::gpu::ToCudaDataType(); + cublasComputeType_t compute_type = CUBLAS_COMPUTE_32F; + if (std::is_same::value) { + compute_type = CUBLAS_COMPUTE_64F; + } + + cublasLtMatmulDesc_t operation_desc = NULL; + PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatmulDescCreate( + &operation_desc, compute_type, scale_type)); + cublasOperation_t transx = trans_x ? CUBLAS_OP_T : CUBLAS_OP_N; + cublasOperation_t transy = trans_y ? CUBLAS_OP_T : CUBLAS_OP_N; + PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatmulDescSetAttribute( + operation_desc, CUBLASLT_MATMUL_DESC_TRANSB, &transx, sizeof(transx))); + PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatmulDescSetAttribute( + operation_desc, CUBLASLT_MATMUL_DESC_TRANSA, &transy, sizeof(transy))); + + cublasLtEpilogue_t epiloque_func = + GetEpilogueType(activation, enable_auxiliary); + PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatmulDescSetAttribute( + operation_desc, + CUBLASLT_MATMUL_DESC_EPILOGUE, + &epiloque_func, + sizeof(epiloque_func))); + const T* bias_data = bias->data(); + PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatmulDescSetAttribute( + operation_desc, + CUBLASLT_MATMUL_DESC_BIAS_POINTER, + &bias_data, + sizeof(bias_data))); + + if (enable_auxiliary && activation != "none") { + // Note (Ming Huang): The initialization of ReseveSpace is happened in the + // dev_ctx.Alloc. Therefore, we set real date type up here. + if (activation == "relu") { + paddle::experimental::DataType rs_type = + paddle::experimental::DataType::BOOL; + size_t reserve_space_size = + phi::product(reserve_space->dims()) * SizeOf(rs_type); + dev_ctx.Alloc(reserve_space, rs_type, reserve_space_size); + } else { + size_t reserve_space_size = + phi::product(reserve_space->dims()) * sizeof(T); + dev_ctx.Alloc(reserve_space, reserve_space_size); + } + + void* aux_data = reserve_space->data(); + + PADDLE_ENFORCE_GPU_SUCCESS( + platform::dynload::cublasLtMatmulDescSetAttribute( + operation_desc, + CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER, + &aux_data, + sizeof(aux_data))); + int64_t aux_ld = N; + PADDLE_ENFORCE_GPU_SUCCESS( + platform::dynload::cublasLtMatmulDescSetAttribute( + operation_desc, + CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_LD, + &aux_ld, + sizeof(aux_ld))); + } + + cublasLtMatrixLayout_t x_desc = NULL, y_desc = NULL, out_desc = NULL; + if (trans_x) { + PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatrixLayoutCreate( + &x_desc, mat_type, M, K, M)); + } else { + PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatrixLayoutCreate( + &x_desc, mat_type, K, M, K)); + } + if (trans_y) { + PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatrixLayoutCreate( + &y_desc, mat_type, K, N, K)); + } else { + PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatrixLayoutCreate( + &y_desc, mat_type, N, K, N)); + } + PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatrixLayoutCreate( + &out_desc, mat_type, N, M, N)); + + cublasLtHandle_t lt_handle = dev_ctx.cublaslt_handle(); + // NOTE(zengjinle): I do not know whether the 4MB workspace size is + // "enough". I just followed the settings from the NVIDIA MLPerf BERT code. + size_t workspace_size = static_cast(4) * 1024 * 1024; + cudaStream_t stream = dev_ctx.stream(); + memory::allocation::AllocationPtr workspace = memory::Alloc( + dev_ctx.GetPlace(), + workspace_size, + phi::Stream(reinterpret_cast(dev_ctx.stream()))); + + MT alpha = static_cast(1); + MT beta = static_cast(0); + + const auto* y_data = y->data(); + const auto* x_data = x->data(); + + auto algo = GemmEpilogueAlgoCache::Instance().GetGemmAlgo(lt_handle, + operation_desc, + y_desc, + x_desc, + out_desc, + &alpha, + &beta, + y_data, + x_data, + out_data, + stream, + workspace->ptr(), + workspace_size); + PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatmul(lt_handle, + operation_desc, + &alpha, + y_data, + y_desc, + x_data, + x_desc, + &beta, + out_data, + out_desc, + out_data, + out_desc, + algo, + workspace->ptr(), + workspace_size, + stream)); + + PADDLE_ENFORCE_GPU_SUCCESS( + platform::dynload::cublasLtMatmulDescDestroy(operation_desc)); + PADDLE_ENFORCE_GPU_SUCCESS( + platform::dynload::cublasLtMatrixLayoutDestroy(y_desc)); + PADDLE_ENFORCE_GPU_SUCCESS( + platform::dynload::cublasLtMatrixLayoutDestroy(x_desc)); + PADDLE_ENFORCE_GPU_SUCCESS( + platform::dynload::cublasLtMatrixLayoutDestroy(out_desc)); +} + +enum FusedGEMMGradInType { kDX = 0, kDY = 1, kDZ = 2 }; + +template +struct FusedGEMMGradTrait; + +template <> +struct FusedGEMMGradTrait { + static constexpr auto kXGradA = FusedGEMMGradInType::kDZ; + static constexpr auto kXGradB = FusedGEMMGradInType::kDY; + static constexpr auto kXGradATrans = false; + static constexpr auto kXGradBTrans = true; + + static constexpr auto kYGradA = FusedGEMMGradInType::kDX; + static constexpr auto kYGradB = FusedGEMMGradInType::kDZ; + static constexpr auto kYGradATrans = true; + static constexpr auto kYGradBTrans = false; +}; + +template <> +struct FusedGEMMGradTrait { + static constexpr auto kXGradA = FusedGEMMGradInType::kDY; + static constexpr auto kXGradB = FusedGEMMGradInType::kDZ; + static constexpr auto kXGradATrans = false; + static constexpr auto kXGradBTrans = true; + + static constexpr auto kYGradA = FusedGEMMGradInType::kDX; + static constexpr auto kYGradB = FusedGEMMGradInType::kDZ; + static constexpr auto kYGradATrans = false; + static constexpr auto kYGradBTrans = false; +}; + +template <> +struct FusedGEMMGradTrait { + static constexpr auto kXGradA = FusedGEMMGradInType::kDZ; + static constexpr auto kXGradB = FusedGEMMGradInType::kDY; + static constexpr auto kXGradATrans = false; + static constexpr auto kXGradBTrans = false; + + static constexpr auto kYGradA = FusedGEMMGradInType::kDZ; + static constexpr auto kYGradB = FusedGEMMGradInType::kDX; + static constexpr auto kYGradATrans = true; + static constexpr auto kYGradBTrans = false; +}; + +template <> +struct FusedGEMMGradTrait { + static constexpr auto kXGradA = FusedGEMMGradInType::kDY; + static constexpr auto kXGradB = FusedGEMMGradInType::kDZ; + static constexpr auto kXGradATrans = true; + static constexpr auto kXGradBTrans = true; + + static constexpr auto kYGradA = FusedGEMMGradInType::kDZ; + static constexpr auto kYGradB = FusedGEMMGradInType::kDX; + static constexpr auto kYGradATrans = true; + static constexpr auto kYGradBTrans = true; +}; + +static constexpr auto BoolToCuBlasEnum(bool transpose) { + return transpose ? CUBLAS_OP_T : CUBLAS_OP_N; +} + +static cublasLtEpilogue_t GetEpilogueGradType( + const std::string& activation_grad) { + if (activation_grad == "relu_grad") { + return CUBLASLT_EPILOGUE_DRELU; + } else if (activation_grad == "gelu_grad") { + return CUBLASLT_EPILOGUE_DGELU; + } else if (activation_grad == "none") { + return CUBLASLT_EPILOGUE_DEFAULT; + } else { + PADDLE_THROW(platform::errors::InvalidArgument( + "The activation_grad attribute of fused_gemm_epilogue op should " + "be one of {\"none\", \"relu\", \"gelu\"}. But received %s." + "But received activation_grad=%s.", + activation_grad)); + } +} + +template +void ComputeFusedGemmEpilogueBackwardImpl(const phi::GPUContext& dev_ctx, + const phi::DenseTensor* dout, + const phi::DenseTensor* x, + const phi::DenseTensor* y, + const phi::DenseTensor* reserve_space, + int64_t M, + int64_t N, + int64_t K, + const std::string activation_grad, + phi::DenseTensor* dx, + phi::DenseTensor* dy, + phi::DenseTensor* dbias, + bool use_addto) { + using MT = typename phi::dtype::MPTypeTrait::Type; + using Trait = FusedGEMMGradTrait; + + cudaDataType_t mat_type = phi::backends::gpu::ToCudaDataType(); + cudaDataType_t scale_type = phi::backends::gpu::ToCudaDataType(); + cublasComputeType_t compute_type = CUBLAS_COMPUTE_32F; + if (std::is_same::value) { + compute_type = CUBLAS_COMPUTE_64F; + } + + cublasLtHandle_t lt_handle = dev_ctx.cublaslt_handle(); + // NOTE(zengjinle): I do not know whether the 4MB workspace size is + // "enough". I just followed the settings from the NVIDIA MLPerf BERT code. + size_t workspace_size = static_cast(4) * 1024 * 1024; + const cublasLtMatmulAlgo_t* algo = nullptr; + cudaStream_t stream = dev_ctx.stream(); + + MT alpha = static_cast(1.0); + MT beta_dx = use_addto ? static_cast(1.0) : static_cast(0.0); + MT beta_dy = static_cast(0.0); + + cublasLtMatrixLayout_t dout_desc = nullptr, dout_trans_desc = nullptr; + cublasLtMatrixLayout_t x_desc = nullptr, x_trans_desc = nullptr; + cublasLtMatrixLayout_t y_desc = nullptr, y_trans_desc = nullptr; + cublasLtMatrixLayout_t dx_desc = nullptr, dy_desc = nullptr; + cublasLtMatmulDesc_t dx_operation_desc = nullptr, dy_operation_desc = nullptr; + + DEFINE_PADDLE_SCOPE_GUARD([&] { + auto descs = {dout_desc, + dout_trans_desc, + x_desc, + x_trans_desc, + y_desc, + y_trans_desc, + dx_desc, + dy_desc}; + for (auto desc : descs) { + if (desc) { + PADDLE_ENFORCE_GPU_SUCCESS( + platform::dynload::cublasLtMatrixLayoutDestroy(desc)); + } + } + + if (dx_operation_desc) { + PADDLE_ENFORCE_GPU_SUCCESS( + platform::dynload::cublasLtMatmulDescDestroy(dx_operation_desc)); + } + + if (dy_operation_desc) { + PADDLE_ENFORCE_GPU_SUCCESS( + platform::dynload::cublasLtMatmulDescDestroy(dy_operation_desc)); + } + }); + + auto x_row = TransX ? K : M; + auto x_col = TransX ? M : K; + auto y_row = TransY ? N : K; + auto y_col = TransY ? K : N; + auto z_row = TransX ? N : M; + auto z_col = TransX ? M : N; + + // dx = func(dout, y) + if (dx) { + constexpr auto kXGradAIsDZ = (Trait::kXGradA == FusedGEMMGradInType::kDZ); + cublasLtMatrixLayout_t *dx_dout_desc, *dx_y_desc; + + if (TransX) { + dx_dout_desc = &dout_trans_desc; + PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatrixLayoutCreate( + dx_dout_desc, mat_type, z_row, z_col, z_row)); + } else { + dx_dout_desc = &dout_desc; + PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatrixLayoutCreate( + dx_dout_desc, mat_type, z_col, z_row, z_col)); + } + + dx_y_desc = &y_trans_desc; + PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatrixLayoutCreate( + dx_y_desc, mat_type, y_col, y_row, y_col)); + + auto& a_desc = kXGradAIsDZ ? (*dx_dout_desc) : (*dx_y_desc); + auto& b_desc = kXGradAIsDZ ? (*dx_y_desc) : (*dx_dout_desc); + auto a_trans = BoolToCuBlasEnum(Trait::kXGradATrans); + auto b_trans = BoolToCuBlasEnum(Trait::kXGradBTrans); + + PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatrixLayoutCreate( + &dx_desc, mat_type, x_col, x_row, x_col)); + + PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatmulDescCreate( + &dx_operation_desc, compute_type, scale_type)); + PADDLE_ENFORCE_GPU_SUCCESS( + platform::dynload::cublasLtMatmulDescSetAttribute( + dx_operation_desc, + CUBLASLT_MATMUL_DESC_TRANSB, + &a_trans, + sizeof(a_trans))); + PADDLE_ENFORCE_GPU_SUCCESS( + platform::dynload::cublasLtMatmulDescSetAttribute( + dx_operation_desc, + CUBLASLT_MATMUL_DESC_TRANSA, + &b_trans, + sizeof(b_trans))); + + cublasLtEpilogue_t epiloque_func_for_dx = + GetEpilogueGradType(activation_grad); + PADDLE_ENFORCE_GPU_SUCCESS( + platform::dynload::cublasLtMatmulDescSetAttribute( + dx_operation_desc, + CUBLASLT_MATMUL_DESC_EPILOGUE, + &epiloque_func_for_dx, + sizeof(epiloque_func_for_dx))); + + if (activation_grad != "none") { + auto* aux_data = reserve_space->data(); + PADDLE_ENFORCE_GPU_SUCCESS( + platform::dynload::cublasLtMatmulDescSetAttribute( + dx_operation_desc, + CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER, + &aux_data, + sizeof(aux_data))); + int64_t aux_ld = TransX ? M : K; + PADDLE_ENFORCE_GPU_SUCCESS( + platform::dynload::cublasLtMatmulDescSetAttribute( + dx_operation_desc, + CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_LD, + &aux_ld, + sizeof(aux_ld))); + } + + auto dx_workspace = memory::Alloc( + dev_ctx.GetPlace(), + workspace_size, + phi::Stream(reinterpret_cast(dev_ctx.stream()))); + + auto* dx_data = dev_ctx.Alloc(dx, dx->numel() * sizeof(T)); + const auto* y_data = y->data(); + const auto* dout_data = dout->data(); + const auto* a_data = kXGradAIsDZ ? dout_data : y_data; + const auto* b_data = kXGradAIsDZ ? y_data : dout_data; + + auto algo = + GemmEpilogueAlgoCache::Instance().GetGemmAlgo(lt_handle, + dx_operation_desc, + b_desc, + a_desc, + dx_desc, + &alpha, + &beta_dx, + b_data, + a_data, + dx_data, + stream, + dx_workspace->ptr(), + workspace_size); + + PADDLE_ENFORCE_GPU_SUCCESS( + platform::dynload::cublasLtMatmul(lt_handle, + dx_operation_desc, + &alpha, + b_data, + b_desc, + a_data, + a_desc, + &beta_dx, + dx_data, + dx_desc, + dx_data, + dx_desc, + algo, + dx_workspace->ptr(), + workspace_size, + stream)); + } + + // dy = func(dout, x) + if (dy) { + constexpr auto kYGradAIsDZ = (Trait::kYGradA == FusedGEMMGradInType::kDZ); + + cublasLtMatrixLayout_t *dy_dout_desc = nullptr, *dy_x_desc = nullptr; + if (TransX) { + dy_dout_desc = &dout_trans_desc; + if (dout_trans_desc == nullptr) { + PADDLE_ENFORCE_GPU_SUCCESS( + platform::dynload::cublasLtMatrixLayoutCreate( + dy_dout_desc, mat_type, z_row, z_col, z_row)); + } + } else { + dy_dout_desc = &dout_desc; + if (dout_desc == nullptr) { + PADDLE_ENFORCE_GPU_SUCCESS( + platform::dynload::cublasLtMatrixLayoutCreate( + dy_dout_desc, mat_type, z_col, z_row, z_col)); + } + } + + dy_x_desc = &x_trans_desc; + PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatrixLayoutCreate( + dy_x_desc, mat_type, x_col, x_row, x_col)); + + auto& a_desc = kYGradAIsDZ ? (*dy_dout_desc) : (*dy_x_desc); + auto& b_desc = kYGradAIsDZ ? (*dy_x_desc) : (*dy_dout_desc); + auto a_trans = BoolToCuBlasEnum(Trait::kYGradATrans); + auto b_trans = BoolToCuBlasEnum(Trait::kYGradBTrans); + + PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatrixLayoutCreate( + &dy_desc, mat_type, y_col, y_row, y_col)); + + PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatmulDescCreate( + &dy_operation_desc, compute_type, scale_type)); + + PADDLE_ENFORCE_GPU_SUCCESS( + platform::dynload::cublasLtMatmulDescSetAttribute( + dy_operation_desc, + CUBLASLT_MATMUL_DESC_TRANSB, + &a_trans, + sizeof(a_trans))); + PADDLE_ENFORCE_GPU_SUCCESS( + platform::dynload::cublasLtMatmulDescSetAttribute( + dy_operation_desc, + CUBLASLT_MATMUL_DESC_TRANSA, + &b_trans, + sizeof(b_trans))); + + cublasLtEpilogue_t epiloque_func_for_dy; + if (dbias == nullptr) { + epiloque_func_for_dy = CUBLASLT_EPILOGUE_DEFAULT; + } else { + if (TransY) { + epiloque_func_for_dy = CUBLASLT_EPILOGUE_BGRADB; + } else { + epiloque_func_for_dy = CUBLASLT_EPILOGUE_BGRADA; + } + } + + PADDLE_ENFORCE_GPU_SUCCESS( + platform::dynload::cublasLtMatmulDescSetAttribute( + dy_operation_desc, + CUBLASLT_MATMUL_DESC_EPILOGUE, + &epiloque_func_for_dy, + sizeof(epiloque_func_for_dy))); + + if (dbias) { + auto* dbias_data = dev_ctx.Alloc(dbias, dbias->numel() * sizeof(T)); + PADDLE_ENFORCE_GPU_SUCCESS( + platform::dynload::cublasLtMatmulDescSetAttribute( + dy_operation_desc, + CUBLASLT_MATMUL_DESC_BIAS_POINTER, + &dbias_data, + sizeof(dbias_data))); + } + + auto dy_workspace = memory::Alloc( + dev_ctx.GetPlace(), + workspace_size, + phi::Stream(reinterpret_cast(dev_ctx.stream()))); + auto* dy_data = dev_ctx.Alloc(dy, dy->numel() * sizeof(T)); + const auto* dout_data = dout->data(); + const auto* x_data = x->data(); + const auto* a_data = kYGradAIsDZ ? dout_data : x_data; + const auto* b_data = kYGradAIsDZ ? x_data : dout_data; + + auto algo = + GemmEpilogueAlgoCache::Instance().GetGemmAlgo(lt_handle, + dy_operation_desc, + b_desc, + a_desc, + dy_desc, + &alpha, + &beta_dy, + b_data, + a_data, + dy_data, + stream, + dy_workspace->ptr(), + workspace_size); + + PADDLE_ENFORCE_GPU_SUCCESS( + platform::dynload::cublasLtMatmul(lt_handle, + dy_operation_desc, + &alpha, + b_data, + b_desc, + a_data, + a_desc, + &beta_dy, + dy_data, + dy_desc, + dy_data, + dy_desc, + algo, + dy_workspace->ptr(), + workspace_size, + stream)); + } +} + +template +void ComputeFusedGemmEpilogueBackward(const phi::GPUContext& dev_ctx, + const phi::DenseTensor* dout, + const phi::DenseTensor* x, + const phi::DenseTensor* y, + const phi::DenseTensor* reserve_space, + int64_t M, + int64_t N, + int64_t K, + bool trans_x, + bool trans_y, + const std::string& activation_grad, + phi::DenseTensor* dx, + phi::DenseTensor* dy, + phi::DenseTensor* dbias, + bool use_addto = false) { + VLOG(10) << "M=" << M << ", K=" << K << ", N=" << N << ", trans_x=" << trans_x + << ", trans_y=" << trans_y + << ", activation_grad=" << activation_grad; + + if (trans_x) { + if (trans_y) { + ComputeFusedGemmEpilogueBackwardImpl(dev_ctx, + dout, + x, + y, + reserve_space, + M, + N, + K, + activation_grad, + dx, + dy, + dbias, + use_addto); + } else { + ComputeFusedGemmEpilogueBackwardImpl(dev_ctx, + dout, + x, + y, + reserve_space, + M, + N, + K, + activation_grad, + dx, + dy, + dbias, + use_addto); + } + } else { + if (trans_y) { + ComputeFusedGemmEpilogueBackwardImpl(dev_ctx, + dout, + x, + y, + reserve_space, + M, + N, + K, + activation_grad, + dx, + dy, + dbias, + use_addto); + } else { + ComputeFusedGemmEpilogueBackwardImpl(dev_ctx, + dout, + x, + y, + reserve_space, + M, + N, + K, + activation_grad, + dx, + dy, + dbias, + use_addto); + } + } +} + } // namespace operators } // namespace paddle - #endif #endif diff --git a/paddle/fluid/operators/fused/fused_multi_transformer_op.cu b/paddle/fluid/operators/fused/fused_multi_transformer_op.cu index 11e1eedae86f96bb7ec6cc08859f242e82817e35..13b06fcac70bb3f8b89fcf5e07fd5b8b2c4a0778 100644 --- a/paddle/fluid/operators/fused/fused_multi_transformer_op.cu +++ b/paddle/fluid/operators/fused/fused_multi_transformer_op.cu @@ -1,8 +1,11 @@ /* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 + Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. diff --git a/paddle/fluid/operators/fused/fused_multi_transformer_op.cu.h b/paddle/fluid/operators/fused/fused_multi_transformer_op.cu.h index 3a77f763983e83d88437f7964406aec6b54fe389..5d0dab032e012c22d6911dada1770350ce065bc9 100644 --- a/paddle/fluid/operators/fused/fused_multi_transformer_op.cu.h +++ b/paddle/fluid/operators/fused/fused_multi_transformer_op.cu.h @@ -1,14 +1,18 @@ /* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. * Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved. + Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 + Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ + // This file has been adapted from FasterTransformer file: // https://github.com/NVIDIA/FasterTransformer/blob/v4.0/fastertransformer/cuda/masked_multihead_attention.cu // We add License in the head. diff --git a/paddle/phi/backends/gpu/cuda/cuda_helper.h b/paddle/phi/backends/gpu/cuda/cuda_helper.h index 7463edc5d9ff60dbe9f8a255458af122dec75f33..8d5ffec14e8a6626ae3adc57656137c415f5bcd0 100644 --- a/paddle/phi/backends/gpu/cuda/cuda_helper.h +++ b/paddle/phi/backends/gpu/cuda/cuda_helper.h @@ -14,6 +14,12 @@ #pragma once +#ifdef PADDLE_WITH_CUDA +#include // NOLINT + +#include "paddle/phi/common/bfloat16.h" +#include "paddle/phi/common/float16.h" + namespace phi { namespace backends { namespace gpu { @@ -69,6 +75,22 @@ namespace gpu { for (index_type i = __index__; __index__ < (num); \ __index__ += __stride__, i = __index__) +template +cudaDataType_t ToCudaDataType() { + if (std::is_same::value) { + return CUDA_R_32F; + } else if (std::is_same::value) { + return CUDA_R_64F; + } else if (std::is_same::value) { + return CUDA_R_16F; +#if CUDA_VERSION >= 11000 + } else if (std::is_same::value) { + return CUDA_R_16BF; +#endif + } +} + } // namespace gpu } // namespace backends } // namespace phi +#endif