From eff3ee5e7459f8d2f5cb799a100063ce7cc99701 Mon Sep 17 00:00:00 2001 From: whs Date: Mon, 25 Oct 2021 10:25:26 +0800 Subject: [PATCH] Fix grid sampler while input size is [1] (#36183) --- paddle/fluid/operators/grid_sampler_op.h | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/paddle/fluid/operators/grid_sampler_op.h b/paddle/fluid/operators/grid_sampler_op.h index b1857b49eed..da386052c7d 100644 --- a/paddle/fluid/operators/grid_sampler_op.h +++ b/paddle/fluid/operators/grid_sampler_op.h @@ -82,6 +82,9 @@ static inline void clip(const platform::CPUDeviceContext& ctx, auto grid_abs = grid_slice_t.abs(); auto extra = grid_abs - (grid_abs / double_range).floor() * double_range; grid_slice_t.device(place) = extra.cwiseMin(double_range - extra); + if (max_val == 0) { + grid_slice_t.device(place) = grid_slice_t.constant(static_cast(0)); + } } else { auto double_range = static_cast((max_val + 1) * 2); auto grid_abs = (grid_slice_t + static_cast(0.5)).abs(); @@ -128,6 +131,9 @@ static inline void clipWithMask(const platform::CPUDeviceContext& ctx, grid_scale_t * ((is_neg == one_more_flip).template cast() - (is_neg != one_more_flip).template cast()); grid_slice_t.device(place) = extra.cwiseMin(double_range - extra); + if (max_val == 0) { + grid_slice_t.device(place) = grid_slice_t.constant(static_cast(0)); + } } else { auto double_range = static_cast((max_val + 1) * 2); auto grid_abs = (grid_slice_t + static_cast(0.5)).abs(); -- GitLab