未验证 提交 eda8df71 编写于 作者: R RuohengMa 提交者: GitHub

[XPU] substitute new api kernel for combinatorial adaptive avg_pool2d_grad kernel (#53528)

上级 da963eab
...@@ -112,30 +112,6 @@ void Pool2dGradKernel(const Context& ctx, ...@@ -112,30 +112,6 @@ void Pool2dGradKernel(const Context& ctx,
true); true);
} else if (pooling_type == "avg") { } else if (pooling_type == "avg") {
// When output dim is 1 * 1 (1 * 1 * 1 in pool_3d), use scale
// and broadcast kernels to get same output, but better performance.
if (out_h == 1 && out_w == 1 && std::is_same<T, float>::value) {
xpu::ctx_guard RAII_GUARD(ctx.x_context());
float scale = 1.0 / (in_h * in_w);
float* scaled_dy = RAII_GUARD.alloc_l3_or_gm<float>(n * c);
r = xpu::scale(ctx.x_context(),
dout.data<float>(),
scaled_dy,
n * c,
true,
scale,
0.0);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "scale");
r = xpu::broadcast(ctx.x_context(),
scaled_dy,
dx->data<float>(),
{n, c, 1, 1},
{n, c, in_h, in_w});
PADDLE_ENFORCE_XDNN_SUCCESS(r, "broadcast");
return;
}
r = xpu::adaptive_avg_pool2d_grad<XPUType>( r = xpu::adaptive_avg_pool2d_grad<XPUType>(
ctx.x_context(), ctx.x_context(),
reinterpret_cast<const XPUType*>(dout.data<T>()), reinterpret_cast<const XPUType*>(dout.data<T>()),
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册