提交 4a3a45b6 编写于 作者: W Wilber 提交者: cyj1986

modify nearest_interpolate when attr align_corners=false (bug_fix) (#1969)

* modify slice op and add slice test

* modify nearest_polate when align_corners=false (bugfix)
上级 f3124b30
......@@ -459,8 +459,10 @@ void nearest_interp(const float* src,
#pragma omp parallel for collapse(2) schedule(static)
for (int h = 0; h < h_out; ++h) {
for (int w = 0; w < w_out; ++w) {
int near_x = static_cast<int>(scale_w_new * w + 0.5);
int near_y = static_cast<int>(scale_h_new * h + 0.5);
int near_x = (with_align) ? static_cast<int>(scale_w_new * w + 0.5)
: static_cast<int>(scale_w_new * w);
int near_y = (with_align) ? static_cast<int>(scale_h_new * h + 0.5)
: static_cast<int>(scale_h_new * h);
near_x = near_x < 0 ? 0 : near_x;
near_y = near_y < 0 ? 0 : near_y;
dst[h * w_out + w] = src[near_y * w_in + near_x];
......
......@@ -51,9 +51,11 @@ void resize_nearest_align(std::vector<const lite::Tensor*> inputs,
int src_index = n * src_stride_batch + c * src_stride_c;
for (int h = 0; h < hout; ++h) {
for (int w = 0; w < wout; ++w) {
dtype fw = scale_w * w + 0.5;
int fw = (with_align) ? static_cast<int>(scale_w * w + 0.5)
: static_cast<int>(scale_w * w);
fw = (fw < 0) ? 0 : fw;
dtype fh = scale_h * h + 0.5;
int fh = (with_align) ? static_cast<int>(scale_h * h + 0.5)
: static_cast<int>(scale_h * h);
fh = (fh < 0) ? 0 : fh;
int w_start = static_cast<int>(fw);
int h_start = static_cast<int>(fh);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册