未验证 提交 42abcc08 编写于 作者: H hong 提交者: GitHub

fix batch norm memory issue (#41717)

* try to fix batch norm memory issue

* fix batch norm memroy alloc bug

* polish some code
上级 e7f0aa38
......@@ -570,7 +570,8 @@ void BatchNormGradRawKernel(const Context &ctx,
/*sizeInBytes=*/&workspace_size));
workspace_tensor.Resize({static_cast<int64_t>(workspace_size)});
workspace_ptr = ctx.template Alloc<T>(&workspace_tensor);
workspace_ptr =
static_cast<void *>(ctx.template Alloc<uint8_t>(&workspace_tensor));
PADDLE_ENFORCE_GPU_SUCCESS(
paddle::platform::dynload::cudnnBatchNormalizationBackwardEx(
......@@ -603,8 +604,8 @@ void BatchNormGradRawKernel(const Context &ctx,
/*activationDesc=*/nullptr,
/*workspace=*/workspace_ptr,
/*workSpaceSizeInBytes=*/workspace_size,
/*reserveSpace=*/const_cast<T *>(
reserve_space->template data<T>()),
/*reserveSpace=*/const_cast<uint8_t *>(
reserve_space->template data<uint8_t>()),
/*reserveSpaceSizeInBytes=*/reserve_space_size));
#endif // CUDNN_VERSION_MIN(7, 4, 1)
if (!called) {
......
......@@ -498,9 +498,11 @@ void BatchNormKernel(const Context &ctx,
/*sizeInBytes=*/&reserve_space_size));
reserve_space->Resize({static_cast<int64_t>(reserve_space_size)});
reserve_space_ptr = ctx.template Alloc<T>(reserve_space);
reserve_space_ptr =
static_cast<void *>(ctx.template Alloc<uint8_t>(reserve_space));
workspace_tensor.Resize({static_cast<int64_t>(workspace_size)});
workspace_ptr = ctx.template Alloc<T>(&workspace_tensor);
workspace_ptr =
static_cast<void *>(ctx.template Alloc<uint8_t>(&workspace_tensor));
PADDLE_ENFORCE_GPU_SUCCESS(
paddle::platform::dynload::cudnnBatchNormalizationForwardTrainingEx(
handle,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册