未验证 提交 fb4215b2 编写于 作者: J JingZhuangzhuang 提交者: GitHub

fix batch_norm op kernel (#40171)

上级 fcae3430
...@@ -460,10 +460,14 @@ void BatchNormKernel(const Context &ctx, ...@@ -460,10 +460,14 @@ void BatchNormKernel(const Context &ctx,
void *reserve_space_ptr = nullptr; void *reserve_space_ptr = nullptr;
void *workspace_ptr = nullptr; void *workspace_ptr = nullptr;
DenseTensor workspace_tensor; DenseTensor workspace_tensor;
DenseTensor reserve_space_tensor;
// Create reserve space and workspace for batch norm. // Create reserve space and workspace for batch norm.
// Create tensor for each batchnorm op, it will be used in the // Create tensor for each batchnorm op, it will be used in the
// backward. Thus this tensor shouldn't be temp. // backward. Thus this tensor shouldn't be temp.
// auto *reserve_space = ctx.Output<Tensor>("ReserveSpace"); // auto *reserve_space = ctx.Output<Tensor>("ReserveSpace");
if (reserve_space == nullptr) {
reserve_space = &reserve_space_tensor;
}
PADDLE_ENFORCE_NOT_NULL( PADDLE_ENFORCE_NOT_NULL(
reserve_space, reserve_space,
phi::errors::NotFound( phi::errors::NotFound(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册