diff --git a/paddle/fluid/operators/fused/fused_gemm_epilogue_op.cu b/paddle/fluid/operators/fused/fused_gemm_epilogue_op.cu index 42cc6bdfb52d83133cea605b31ba0edfbc4b61cc..3ebb9f9e640cc0c8e56b06cd36614fba327cad4e 100644 --- a/paddle/fluid/operators/fused/fused_gemm_epilogue_op.cu +++ b/paddle/fluid/operators/fused/fused_gemm_epilogue_op.cu @@ -146,7 +146,9 @@ class FusedGemmEpilogueKernel : public framework::OpKernel { &out_desc, mat_type, N, M, N)); cublasLtHandle_t lt_handle = dev_ctx.cublaslt_handle(); - size_t workspace_size = static_cast(4) * 1024 * 1024 * 1024; + // NOTE(zengjinle): I do not know whether the 4MB workspace size is + // "enough". I just followed the settings from the NVIDIA MLPerf BERT code. + size_t workspace_size = static_cast(4) * 1024 * 1024; cudaStream_t stream = dev_ctx.stream(); memory::allocation::AllocationPtr workspace = memory::Alloc(dev_ctx, workspace_size); @@ -356,7 +358,9 @@ class FusedGemmEpilogueGradKernel : public framework::OpKernel { } cublasLtHandle_t lt_handle = dev_ctx.cublaslt_handle(); - size_t workspace_size = static_cast(4) * 1024 * 1024 * 1024; + // NOTE(zengjinle): I do not know whether the 4MB workspace size is + // "enough". I just followed the settings from the NVIDIA MLPerf BERT code. + size_t workspace_size = static_cast(4) * 1024 * 1024; const cublasLtMatmulAlgo_t* algo = nullptr; cudaStream_t stream = dev_ctx.stream();