From fb4215b2d1765e305f687d2d1ca5f19c90f7eeb1 Mon Sep 17 00:00:00 2001 From: JingZhuangzhuang <75348594+JZZ-NOTE@users.noreply.github.com> Date: Wed, 9 Mar 2022 10:21:50 +0800 Subject: [PATCH] fix batch_norm op kernel (#40171) --- paddle/phi/kernels/gpu/batch_norm_kernel.cu | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/paddle/phi/kernels/gpu/batch_norm_kernel.cu b/paddle/phi/kernels/gpu/batch_norm_kernel.cu index 6ad12245d2a..49b550f51e6 100644 --- a/paddle/phi/kernels/gpu/batch_norm_kernel.cu +++ b/paddle/phi/kernels/gpu/batch_norm_kernel.cu @@ -460,10 +460,14 @@ void BatchNormKernel(const Context &ctx, void *reserve_space_ptr = nullptr; void *workspace_ptr = nullptr; DenseTensor workspace_tensor; + DenseTensor reserve_space_tensor; // Create reserve space and workspace for batch norm. // Create tensor for each batchnorm op, it will be used in the // backward. Thus this tensor shouldn't be temp. // auto *reserve_space = ctx.Output("ReserveSpace"); + if (reserve_space == nullptr) { + reserve_space = &reserve_space_tensor; + } PADDLE_ENFORCE_NOT_NULL( reserve_space, phi::errors::NotFound( -- GitLab