diff --git a/paddle/phi/kernels/impl/lerp_grad_kernel_impl.h b/paddle/phi/kernels/impl/lerp_grad_kernel_impl.h index 0fc491c8c21d3e609067629e3c4b2278e8728b8f..b47acbda0da2d213930476467a45f00fa1e544ba 100644 --- a/paddle/phi/kernels/impl/lerp_grad_kernel_impl.h +++ b/paddle/phi/kernels/impl/lerp_grad_kernel_impl.h @@ -34,14 +34,22 @@ static void LerpGradFunction(const Context& ctx, auto* dy = y_grad; auto dout_dims = dout.dims(); - auto dx_dims = phi::funcs::ExtendDims2Rank(dx->dims(), D); - auto dy_dims = phi::funcs::ExtendDims2Rank(dy->dims(), D); + DDim dx_dims; + DDim dy_dims; + auto w_dims = phi::funcs::ExtendDims2Rank(w.dims(), D); Eigen::DSizes dx_bcast_dims; Eigen::DSizes dy_bcast_dims; Eigen::DSizes w_bcast_dims; - phi::funcs::GetBroadcastDims(dx_dims, dout_dims, &dx_bcast_dims); - phi::funcs::GetBroadcastDims(dy_dims, dout_dims, &dy_bcast_dims); + + if (dx) { + dx_dims = phi::funcs::ExtendDims2Rank(dx->dims(), D); + phi::funcs::GetBroadcastDims(dx_dims, dout_dims, &dx_bcast_dims); + } + if (dy) { + dy_dims = phi::funcs::ExtendDims2Rank(dy->dims(), D); + phi::funcs::GetBroadcastDims(dy_dims, dout_dims, &dy_bcast_dims); + } phi::funcs::GetBroadcastDims(w_dims, dout_dims, &w_bcast_dims); auto eigen_w = phi::EigenTensor::From(w, w_dims); @@ -50,11 +58,16 @@ static void LerpGradFunction(const Context& ctx, Eigen::DSizes dx_reshape_dims; Eigen::DSizes dy_reshape_dims; Eigen::DSizes reduce_dims; + for (int i = 0; i < dout_dims.size(); ++i) { - dx_reshape_dims[2 * i] = dx_bcast_dims[i]; - dx_reshape_dims[2 * i + 1] = dx_dims[i]; - dy_reshape_dims[2 * i] = dy_bcast_dims[i]; - dy_reshape_dims[2 * i + 1] = dy_dims[i]; + if (dx) { + dx_reshape_dims[2 * i] = dx_bcast_dims[i]; + dx_reshape_dims[2 * i + 1] = dx_dims[i]; + } + if (dy) { + dy_reshape_dims[2 * i] = dy_bcast_dims[i]; + dy_reshape_dims[2 * i + 1] = dy_dims[i]; + } reduce_dims[i] = 2 * i; }