From c0f993f63b660b5316bf2e6dc4624adda918e829 Mon Sep 17 00:00:00 2001 From: zyfncg Date: Thu, 27 Jan 2022 13:26:57 +0800 Subject: [PATCH] Fix slice error in jit.to_static mode (#39251) * fix slice bug * fix syntax error --- paddle/fluid/operators/slice_op.h | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/paddle/fluid/operators/slice_op.h b/paddle/fluid/operators/slice_op.h index 15d52880ed9..d9ef45343d8 100644 --- a/paddle/fluid/operators/slice_op.h +++ b/paddle/fluid/operators/slice_op.h @@ -43,6 +43,10 @@ inline void DealTensorArray(const framework::ExecutionContext& ctx, end = std::max(end, static_cast(0)); end = std::min(end, in_size); + if (starts[0] == -1 && end == 0) { + end = start + 1; + } + PADDLE_ENFORCE_GT(end, start, platform::errors::InvalidArgument( "Attr(ends) should be greater than attr(starts) in " @@ -330,7 +334,7 @@ class SliceGradKernel : public framework::OpKernel { auto decrease_axis = ctx.Attr>("decrease_axis"); auto decrease_size = decrease_axis.size(); if (decrease_size > 0) { - if (decrease_size == (size_t)in_dims.size()) { + if (decrease_size == static_cast(in_dims.size())) { // all dims decrease std::vector origin_out_shape(decrease_size, 1); out_dims = framework::make_ddim(std::vector(decrease_size, 1)); -- GitLab