未验证 提交 0b388226 编写于 作者: K Kaipeng Deng 提交者: GitHub

Merge pull request #14345 from heavengate/fix_grid_sampler

fix #14344 : win compile error, EigenTenor * float unsupport. test=develop
...@@ -63,12 +63,19 @@ static void CalcGridLocations(const platform::CPUDeviceContext& ctx, ...@@ -63,12 +63,19 @@ static void CalcGridLocations(const platform::CPUDeviceContext& ctx,
Tensor ones; Tensor ones;
ones.mutable_data<T>({n, h, w}, ctx.GetPlace()); ones.mutable_data<T>({n, h, w}, ctx.GetPlace());
auto ones_t = EigenTensor<T, 3>::From(ones).setConstant(1.0); auto ones_t = EigenTensor<T, 3>::From(ones).setConstant(1.0);
Tensor half_xmax, half_ymax;
half_xmax.mutable_data<T>({n, h, w}, ctx.GetPlace());
auto half_xmax_t =
EigenTensor<T, 3>::From(half_xmax).setConstant(0.5 * x_max);
half_ymax.mutable_data<T>({n, h, w}, ctx.GetPlace());
auto half_ymax_t =
EigenTensor<T, 3>::From(half_ymax).setConstant(0.5 * y_max);
// scale grid to [0, h-1/w-1] // scale grid to [0, h-1/w-1]
auto grid_x_t = EigenTensor<T, 3>::From(grid_x); auto grid_x_t = EigenTensor<T, 3>::From(grid_x);
auto grid_y_t = EigenTensor<T, 3>::From(grid_y); auto grid_y_t = EigenTensor<T, 3>::From(grid_y);
grid_x_t.device(place) = 0.5 * ((grid_x_t + ones_t) * x_max); grid_x_t.device(place) = (grid_x_t + ones_t) * half_xmax_t;
grid_y_t.device(place) = 0.5 * ((grid_y_t + ones_t) * y_max); grid_y_t.device(place) = (grid_y_t + ones_t) * half_ymax_t;
// calculate coords of 4 corner points // calculate coords of 4 corner points
x_w->mutable_data<T>({n, h, w}, ctx.GetPlace()); x_w->mutable_data<T>({n, h, w}, ctx.GetPlace());
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册