From 668db938508078e19cde1d39c7694f7b162148ad Mon Sep 17 00:00:00 2001 From: whs Date: Mon, 25 Oct 2021 16:21:37 +0800 Subject: [PATCH] [cherry-pick]Fix grid sampler (#36625) * Fix grid sampler * Fix code format --- paddle/fluid/operators/grid_sampler_op.h | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/paddle/fluid/operators/grid_sampler_op.h b/paddle/fluid/operators/grid_sampler_op.h index b1857b49eed..da386052c7d 100644 --- a/paddle/fluid/operators/grid_sampler_op.h +++ b/paddle/fluid/operators/grid_sampler_op.h @@ -82,6 +82,9 @@ static inline void clip(const platform::CPUDeviceContext& ctx, auto grid_abs = grid_slice_t.abs(); auto extra = grid_abs - (grid_abs / double_range).floor() * double_range; grid_slice_t.device(place) = extra.cwiseMin(double_range - extra); + if (max_val == 0) { + grid_slice_t.device(place) = grid_slice_t.constant(static_cast(0)); + } } else { auto double_range = static_cast((max_val + 1) * 2); auto grid_abs = (grid_slice_t + static_cast(0.5)).abs(); @@ -128,6 +131,9 @@ static inline void clipWithMask(const platform::CPUDeviceContext& ctx, grid_scale_t * ((is_neg == one_more_flip).template cast() - (is_neg != one_more_flip).template cast()); grid_slice_t.device(place) = extra.cwiseMin(double_range - extra); + if (max_val == 0) { + grid_slice_t.device(place) = grid_slice_t.constant(static_cast(0)); + } } else { auto double_range = static_cast((max_val + 1) * 2); auto grid_abs = (grid_slice_t + static_cast(0.5)).abs(); -- GitLab