diff --git a/paddle/fluid/operators/grid_sampler_op.h b/paddle/fluid/operators/grid_sampler_op.h index b1857b49eede0db931479fdf75e3a407558df1c0..da386052c7dc01f405f5030922f3801bd998ce62 100644 --- a/paddle/fluid/operators/grid_sampler_op.h +++ b/paddle/fluid/operators/grid_sampler_op.h @@ -82,6 +82,9 @@ static inline void clip(const platform::CPUDeviceContext& ctx, auto grid_abs = grid_slice_t.abs(); auto extra = grid_abs - (grid_abs / double_range).floor() * double_range; grid_slice_t.device(place) = extra.cwiseMin(double_range - extra); + if (max_val == 0) { + grid_slice_t.device(place) = grid_slice_t.constant(static_cast(0)); + } } else { auto double_range = static_cast((max_val + 1) * 2); auto grid_abs = (grid_slice_t + static_cast(0.5)).abs(); @@ -128,6 +131,9 @@ static inline void clipWithMask(const platform::CPUDeviceContext& ctx, grid_scale_t * ((is_neg == one_more_flip).template cast() - (is_neg != one_more_flip).template cast()); grid_slice_t.device(place) = extra.cwiseMin(double_range - extra); + if (max_val == 0) { + grid_slice_t.device(place) = grid_slice_t.constant(static_cast(0)); + } } else { auto double_range = static_cast((max_val + 1) * 2); auto grid_abs = (grid_slice_t + static_cast(0.5)).abs();