未验证 提交 d338b2f8 编写于 作者: W wangzhen38 提交者: GitHub

[bug fix] group norm backward (#54341)

上级 f55eb06f
...@@ -291,7 +291,9 @@ void GroupNormGradKernel(const Context& dev_ctx, ...@@ -291,7 +291,9 @@ void GroupNormGradKernel(const Context& dev_ctx,
const int W = (data_layout == DataLayout::kNCHW ? x_dims[x_dims.size() - 1] const int W = (data_layout == DataLayout::kNCHW ? x_dims[x_dims.size() - 1]
: x_dims[x_dims.size() - 2]); : x_dims[x_dims.size() - 2]);
dev_ctx.template Alloc<T>(d_x); if (d_x) {
dev_ctx.template Alloc<T>(d_x);
}
phi::funcs::SetConstant<GPUContext, T> set_zero; phi::funcs::SetConstant<GPUContext, T> set_zero;
phi::funcs::SetConstant<GPUContext, AccT> set_zero_AccT; phi::funcs::SetConstant<GPUContext, AccT> set_zero_AccT;
DenseTensor ds, db; DenseTensor ds, db;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册