未验证 提交 42abcc08 编写于 作者: H hong 提交者: GitHub

fix batch norm memory issue (#41717)

* try to fix batch norm memory issue

* fix batch norm memroy alloc bug

* polish some code
上级 e7f0aa38
...@@ -570,7 +570,8 @@ void BatchNormGradRawKernel(const Context &ctx, ...@@ -570,7 +570,8 @@ void BatchNormGradRawKernel(const Context &ctx,
/*sizeInBytes=*/&workspace_size)); /*sizeInBytes=*/&workspace_size));
workspace_tensor.Resize({static_cast<int64_t>(workspace_size)}); workspace_tensor.Resize({static_cast<int64_t>(workspace_size)});
workspace_ptr = ctx.template Alloc<T>(&workspace_tensor); workspace_ptr =
static_cast<void *>(ctx.template Alloc<uint8_t>(&workspace_tensor));
PADDLE_ENFORCE_GPU_SUCCESS( PADDLE_ENFORCE_GPU_SUCCESS(
paddle::platform::dynload::cudnnBatchNormalizationBackwardEx( paddle::platform::dynload::cudnnBatchNormalizationBackwardEx(
...@@ -603,8 +604,8 @@ void BatchNormGradRawKernel(const Context &ctx, ...@@ -603,8 +604,8 @@ void BatchNormGradRawKernel(const Context &ctx,
/*activationDesc=*/nullptr, /*activationDesc=*/nullptr,
/*workspace=*/workspace_ptr, /*workspace=*/workspace_ptr,
/*workSpaceSizeInBytes=*/workspace_size, /*workSpaceSizeInBytes=*/workspace_size,
/*reserveSpace=*/const_cast<T *>( /*reserveSpace=*/const_cast<uint8_t *>(
reserve_space->template data<T>()), reserve_space->template data<uint8_t>()),
/*reserveSpaceSizeInBytes=*/reserve_space_size)); /*reserveSpaceSizeInBytes=*/reserve_space_size));
#endif // CUDNN_VERSION_MIN(7, 4, 1) #endif // CUDNN_VERSION_MIN(7, 4, 1)
if (!called) { if (!called) {
......
...@@ -498,9 +498,11 @@ void BatchNormKernel(const Context &ctx, ...@@ -498,9 +498,11 @@ void BatchNormKernel(const Context &ctx,
/*sizeInBytes=*/&reserve_space_size)); /*sizeInBytes=*/&reserve_space_size));
reserve_space->Resize({static_cast<int64_t>(reserve_space_size)}); reserve_space->Resize({static_cast<int64_t>(reserve_space_size)});
reserve_space_ptr = ctx.template Alloc<T>(reserve_space); reserve_space_ptr =
static_cast<void *>(ctx.template Alloc<uint8_t>(reserve_space));
workspace_tensor.Resize({static_cast<int64_t>(workspace_size)}); workspace_tensor.Resize({static_cast<int64_t>(workspace_size)});
workspace_ptr = ctx.template Alloc<T>(&workspace_tensor); workspace_ptr =
static_cast<void *>(ctx.template Alloc<uint8_t>(&workspace_tensor));
PADDLE_ENFORCE_GPU_SUCCESS( PADDLE_ENFORCE_GPU_SUCCESS(
paddle::platform::dynload::cudnnBatchNormalizationForwardTrainingEx( paddle::platform::dynload::cudnnBatchNormalizationForwardTrainingEx(
handle, handle,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册