From bd813d35195a87cb9a6a8e47a62ff3746b948a60 Mon Sep 17 00:00:00 2001 From: Weilong Wu Date: Thu, 28 Jul 2022 20:44:18 +0800 Subject: [PATCH] [Eager] fix lerp grad kernel logic (#44705) --- .../phi/kernels/impl/lerp_grad_kernel_impl.h | 29 ++++++++++++++----- 1 file changed, 21 insertions(+), 8 deletions(-) diff --git a/paddle/phi/kernels/impl/lerp_grad_kernel_impl.h b/paddle/phi/kernels/impl/lerp_grad_kernel_impl.h index 0fc491c8c21..b47acbda0da 100644 --- a/paddle/phi/kernels/impl/lerp_grad_kernel_impl.h +++ b/paddle/phi/kernels/impl/lerp_grad_kernel_impl.h @@ -34,14 +34,22 @@ static void LerpGradFunction(const Context& ctx, auto* dy = y_grad; auto dout_dims = dout.dims(); - auto dx_dims = phi::funcs::ExtendDims2Rank(dx->dims(), D); - auto dy_dims = phi::funcs::ExtendDims2Rank(dy->dims(), D); + DDim dx_dims; + DDim dy_dims; + auto w_dims = phi::funcs::ExtendDims2Rank(w.dims(), D); Eigen::DSizes dx_bcast_dims; Eigen::DSizes dy_bcast_dims; Eigen::DSizes w_bcast_dims; - phi::funcs::GetBroadcastDims(dx_dims, dout_dims, &dx_bcast_dims); - phi::funcs::GetBroadcastDims(dy_dims, dout_dims, &dy_bcast_dims); + + if (dx) { + dx_dims = phi::funcs::ExtendDims2Rank(dx->dims(), D); + phi::funcs::GetBroadcastDims(dx_dims, dout_dims, &dx_bcast_dims); + } + if (dy) { + dy_dims = phi::funcs::ExtendDims2Rank(dy->dims(), D); + phi::funcs::GetBroadcastDims(dy_dims, dout_dims, &dy_bcast_dims); + } phi::funcs::GetBroadcastDims(w_dims, dout_dims, &w_bcast_dims); auto eigen_w = phi::EigenTensor::From(w, w_dims); @@ -50,11 +58,16 @@ static void LerpGradFunction(const Context& ctx, Eigen::DSizes dx_reshape_dims; Eigen::DSizes dy_reshape_dims; Eigen::DSizes reduce_dims; + for (int i = 0; i < dout_dims.size(); ++i) { - dx_reshape_dims[2 * i] = dx_bcast_dims[i]; - dx_reshape_dims[2 * i + 1] = dx_dims[i]; - dy_reshape_dims[2 * i] = dy_bcast_dims[i]; - dy_reshape_dims[2 * i + 1] = dy_dims[i]; + if (dx) { + dx_reshape_dims[2 * i] = dx_bcast_dims[i]; + dx_reshape_dims[2 * i + 1] = dx_dims[i]; + } + if (dy) { + dy_reshape_dims[2 * i] = dy_bcast_dims[i]; + dy_reshape_dims[2 * i + 1] = dy_dims[i]; + } reduce_dims[i] = 2 * i; } -- GitLab