From 6d436f6edcae84e56c61de044d70316ab45f77e3 Mon Sep 17 00:00:00 2001 From: sneaxiy <32832641+sneaxiy@users.noreply.github.com> Date: Tue, 28 Jun 2022 10:28:51 +0800 Subject: [PATCH] fix cublasLt workspace size (#43877) --- paddle/fluid/operators/fused/fused_gemm_epilogue_op.cu | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/paddle/fluid/operators/fused/fused_gemm_epilogue_op.cu b/paddle/fluid/operators/fused/fused_gemm_epilogue_op.cu index 42cc6bdfb5..3ebb9f9e64 100644 --- a/paddle/fluid/operators/fused/fused_gemm_epilogue_op.cu +++ b/paddle/fluid/operators/fused/fused_gemm_epilogue_op.cu @@ -146,7 +146,9 @@ class FusedGemmEpilogueKernel : public framework::OpKernel { &out_desc, mat_type, N, M, N)); cublasLtHandle_t lt_handle = dev_ctx.cublaslt_handle(); - size_t workspace_size = static_cast(4) * 1024 * 1024 * 1024; + // NOTE(zengjinle): I do not know whether the 4MB workspace size is + // "enough". I just followed the settings from the NVIDIA MLPerf BERT code. + size_t workspace_size = static_cast(4) * 1024 * 1024; cudaStream_t stream = dev_ctx.stream(); memory::allocation::AllocationPtr workspace = memory::Alloc(dev_ctx, workspace_size); @@ -356,7 +358,9 @@ class FusedGemmEpilogueGradKernel : public framework::OpKernel { } cublasLtHandle_t lt_handle = dev_ctx.cublaslt_handle(); - size_t workspace_size = static_cast(4) * 1024 * 1024 * 1024; + // NOTE(zengjinle): I do not know whether the 4MB workspace size is + // "enough". I just followed the settings from the NVIDIA MLPerf BERT code. + size_t workspace_size = static_cast(4) * 1024 * 1024; const cublasLtMatmulAlgo_t* algo = nullptr; cudaStream_t stream = dev_ctx.stream(); -- GitLab