未验证 提交 bd813d35 编写于 作者: W Weilong Wu 提交者: GitHub

[Eager] fix lerp grad kernel logic (#44705)

上级 e9b92018
...@@ -34,14 +34,22 @@ static void LerpGradFunction(const Context& ctx, ...@@ -34,14 +34,22 @@ static void LerpGradFunction(const Context& ctx,
auto* dy = y_grad; auto* dy = y_grad;
auto dout_dims = dout.dims(); auto dout_dims = dout.dims();
auto dx_dims = phi::funcs::ExtendDims2Rank(dx->dims(), D); DDim dx_dims;
auto dy_dims = phi::funcs::ExtendDims2Rank(dy->dims(), D); DDim dy_dims;
auto w_dims = phi::funcs::ExtendDims2Rank(w.dims(), D); auto w_dims = phi::funcs::ExtendDims2Rank(w.dims(), D);
Eigen::DSizes<int, D> dx_bcast_dims; Eigen::DSizes<int, D> dx_bcast_dims;
Eigen::DSizes<int, D> dy_bcast_dims; Eigen::DSizes<int, D> dy_bcast_dims;
Eigen::DSizes<int, D> w_bcast_dims; Eigen::DSizes<int, D> w_bcast_dims;
phi::funcs::GetBroadcastDims<D>(dx_dims, dout_dims, &dx_bcast_dims);
phi::funcs::GetBroadcastDims<D>(dy_dims, dout_dims, &dy_bcast_dims); if (dx) {
dx_dims = phi::funcs::ExtendDims2Rank(dx->dims(), D);
phi::funcs::GetBroadcastDims<D>(dx_dims, dout_dims, &dx_bcast_dims);
}
if (dy) {
dy_dims = phi::funcs::ExtendDims2Rank(dy->dims(), D);
phi::funcs::GetBroadcastDims<D>(dy_dims, dout_dims, &dy_bcast_dims);
}
phi::funcs::GetBroadcastDims<D>(w_dims, dout_dims, &w_bcast_dims); phi::funcs::GetBroadcastDims<D>(w_dims, dout_dims, &w_bcast_dims);
auto eigen_w = phi::EigenTensor<T, D>::From(w, w_dims); auto eigen_w = phi::EigenTensor<T, D>::From(w, w_dims);
...@@ -50,11 +58,16 @@ static void LerpGradFunction(const Context& ctx, ...@@ -50,11 +58,16 @@ static void LerpGradFunction(const Context& ctx,
Eigen::DSizes<int, D * 2> dx_reshape_dims; Eigen::DSizes<int, D * 2> dx_reshape_dims;
Eigen::DSizes<int, D * 2> dy_reshape_dims; Eigen::DSizes<int, D * 2> dy_reshape_dims;
Eigen::DSizes<int, D> reduce_dims; Eigen::DSizes<int, D> reduce_dims;
for (int i = 0; i < dout_dims.size(); ++i) { for (int i = 0; i < dout_dims.size(); ++i) {
dx_reshape_dims[2 * i] = dx_bcast_dims[i]; if (dx) {
dx_reshape_dims[2 * i + 1] = dx_dims[i]; dx_reshape_dims[2 * i] = dx_bcast_dims[i];
dy_reshape_dims[2 * i] = dy_bcast_dims[i]; dx_reshape_dims[2 * i + 1] = dx_dims[i];
dy_reshape_dims[2 * i + 1] = dy_dims[i]; }
if (dy) {
dy_reshape_dims[2 * i] = dy_bcast_dims[i];
dy_reshape_dims[2 * i + 1] = dy_dims[i];
}
reduce_dims[i] = 2 * i; reduce_dims[i] = 2 * i;
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册