From 42abcc08c6f3186d23ab1f477e45c5b58067d0b5 Mon Sep 17 00:00:00 2001 From: hong <43953930+phlrain@users.noreply.github.com> Date: Fri, 15 Apr 2022 09:56:37 +0800 Subject: [PATCH] fix batch norm memory issue (#41717) * try to fix batch norm memory issue * fix batch norm memroy alloc bug * polish some code --- paddle/phi/kernels/gpu/batch_norm_grad_kernel.cu | 7 ++++--- paddle/phi/kernels/gpu/batch_norm_kernel.cu | 6 ++++-- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/paddle/phi/kernels/gpu/batch_norm_grad_kernel.cu b/paddle/phi/kernels/gpu/batch_norm_grad_kernel.cu index 09bce3c9895..e15b4cc10d9 100644 --- a/paddle/phi/kernels/gpu/batch_norm_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/batch_norm_grad_kernel.cu @@ -570,7 +570,8 @@ void BatchNormGradRawKernel(const Context &ctx, /*sizeInBytes=*/&workspace_size)); workspace_tensor.Resize({static_cast(workspace_size)}); - workspace_ptr = ctx.template Alloc(&workspace_tensor); + workspace_ptr = + static_cast(ctx.template Alloc(&workspace_tensor)); PADDLE_ENFORCE_GPU_SUCCESS( paddle::platform::dynload::cudnnBatchNormalizationBackwardEx( @@ -603,8 +604,8 @@ void BatchNormGradRawKernel(const Context &ctx, /*activationDesc=*/nullptr, /*workspace=*/workspace_ptr, /*workSpaceSizeInBytes=*/workspace_size, - /*reserveSpace=*/const_cast( - reserve_space->template data()), + /*reserveSpace=*/const_cast( + reserve_space->template data()), /*reserveSpaceSizeInBytes=*/reserve_space_size)); #endif // CUDNN_VERSION_MIN(7, 4, 1) if (!called) { diff --git a/paddle/phi/kernels/gpu/batch_norm_kernel.cu b/paddle/phi/kernels/gpu/batch_norm_kernel.cu index 74a523f4ecf..361e62e5660 100644 --- a/paddle/phi/kernels/gpu/batch_norm_kernel.cu +++ b/paddle/phi/kernels/gpu/batch_norm_kernel.cu @@ -498,9 +498,11 @@ void BatchNormKernel(const Context &ctx, /*sizeInBytes=*/&reserve_space_size)); reserve_space->Resize({static_cast(reserve_space_size)}); - reserve_space_ptr = ctx.template Alloc(reserve_space); + reserve_space_ptr = + static_cast(ctx.template Alloc(reserve_space)); workspace_tensor.Resize({static_cast(workspace_size)}); - workspace_ptr = ctx.template Alloc(&workspace_tensor); + workspace_ptr = + static_cast(ctx.template Alloc(&workspace_tensor)); PADDLE_ENFORCE_GPU_SUCCESS( paddle::platform::dynload::cudnnBatchNormalizationForwardTrainingEx( handle, -- GitLab