未验证 提交 6d436f6e 编写于 作者: S sneaxiy 提交者: GitHub

fix cublasLt workspace size (#43877)

上级 a8113a65
...@@ -146,7 +146,9 @@ class FusedGemmEpilogueKernel : public framework::OpKernel<T> { ...@@ -146,7 +146,9 @@ class FusedGemmEpilogueKernel : public framework::OpKernel<T> {
&out_desc, mat_type, N, M, N)); &out_desc, mat_type, N, M, N));
cublasLtHandle_t lt_handle = dev_ctx.cublaslt_handle(); cublasLtHandle_t lt_handle = dev_ctx.cublaslt_handle();
size_t workspace_size = static_cast<size_t>(4) * 1024 * 1024 * 1024; // NOTE(zengjinle): I do not know whether the 4MB workspace size is
// "enough". I just followed the settings from the NVIDIA MLPerf BERT code.
size_t workspace_size = static_cast<size_t>(4) * 1024 * 1024;
cudaStream_t stream = dev_ctx.stream(); cudaStream_t stream = dev_ctx.stream();
memory::allocation::AllocationPtr workspace = memory::allocation::AllocationPtr workspace =
memory::Alloc(dev_ctx, workspace_size); memory::Alloc(dev_ctx, workspace_size);
...@@ -356,7 +358,9 @@ class FusedGemmEpilogueGradKernel : public framework::OpKernel<T> { ...@@ -356,7 +358,9 @@ class FusedGemmEpilogueGradKernel : public framework::OpKernel<T> {
} }
cublasLtHandle_t lt_handle = dev_ctx.cublaslt_handle(); cublasLtHandle_t lt_handle = dev_ctx.cublaslt_handle();
size_t workspace_size = static_cast<size_t>(4) * 1024 * 1024 * 1024; // NOTE(zengjinle): I do not know whether the 4MB workspace size is
// "enough". I just followed the settings from the NVIDIA MLPerf BERT code.
size_t workspace_size = static_cast<size_t>(4) * 1024 * 1024;
const cublasLtMatmulAlgo_t* algo = nullptr; const cublasLtMatmulAlgo_t* algo = nullptr;
cudaStream_t stream = dev_ctx.stream(); cudaStream_t stream = dev_ctx.stream();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册