未验证 提交 daf5aa9b 编写于 作者: W whs 提交者: GitHub

Fix round in grid sample op (#27657)

上级 3ccee082
......@@ -238,9 +238,8 @@ __global__ void grid_sample_cuda_kernel(const int nthreads, int n, int out_c,
}
}
} else if (mode == Mode::nearest) {
int ix_nearest = static_cast<int>(round(ix));
int iy_nearest = static_cast<int>(round(iy));
int ix_nearest = static_cast<int>(std::nearbyint(ix));
int iy_nearest = static_cast<int>(std::nearbyint(iy));
auto inp_offset_NC = n * inp_sN;
auto out_ptr_NCHW = output + n * out_sN + h * out_sH + w * out_sW;
for (int c = 0; c < out_c;
......@@ -403,8 +402,8 @@ __global__ void grid_sampler_cuda_backward_kernel(
gGrid_ptr_NHW[1] = giy_mult * giy;
}
} else if (mode == Mode::nearest) {
int ix_nearest = static_cast<int>(::round(ix));
int iy_nearest = static_cast<int>(::round(iy));
int ix_nearest = static_cast<int>(std::nearbyint(ix));
int iy_nearest = static_cast<int>(std::nearbyint(iy));
int gOut_offset = n * gOut_sN + h * gOut_sH + w * gOut_sW;
T* gInp_ptr_NC = grad_input + n * inp_sN;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册