diff --git a/paddle/fluid/operators/strided_slice_op.h b/paddle/fluid/operators/strided_slice_op.h index 5baacc7ea1350b9f5b7ba81ff5a31c3e75c46853..9f7ade5ec44d9eb4106e28e70dcf1c6357740ca2 100644 --- a/paddle/fluid/operators/strided_slice_op.h +++ b/paddle/fluid/operators/strided_slice_op.h @@ -241,8 +241,6 @@ class StridedSliceKernel : public framework::OpKernel { reverse_axis[axis_index] = (reverse_vector[axis] == 1) ? true : false; } - framework::Tensor tmp; - auto out_dims_origin = out_dims; if (decrease_axis.size() > 0) { std::vector new_out_shape; @@ -263,21 +261,34 @@ class StridedSliceKernel : public framework::OpKernel { out_dims_origin = framework::make_ddim(new_out_shape); } - tmp.mutable_data(out_dims, context.GetPlace()); + bool need_reverse = false; + for (size_t axis = 0; axis < axes.size(); axis++) { + if (reverse_vector[axis] == 1) { + need_reverse = true; + break; + } + } + out->Resize(out_dims); out->mutable_data(context.GetPlace()); auto in_t = framework::EigenTensor::From( *in); - auto tmp_t = - framework::EigenTensor::From( - tmp); auto out_t = framework::EigenTensor::From( *out, out_dims); - tmp_t.device(place) = - in_t.stridedSlice(starts_indices, ends_indices, strides_indices); - out_t.device(place) = tmp_t.reverse(reverse_axis); + if (need_reverse) { + framework::Tensor tmp; + tmp.mutable_data(out_dims, context.GetPlace()); + auto tmp_t = framework::EigenTensor::From(tmp); + tmp_t.device(place) = + in_t.stridedSlice(starts_indices, ends_indices, strides_indices); + out_t.device(place) = tmp_t.reverse(reverse_axis); + } else { + out_t.device(place) = + in_t.stridedSlice(starts_indices, ends_indices, strides_indices); + } if (decrease_axis.size() > 0) { out->Resize(out_dims_origin); @@ -388,22 +399,33 @@ class StridedSliceGradKernel : public framework::OpKernel { reverse_axis[axis_index] = (reverse_vector[axis] == 1) ? true : false; } - framework::Tensor reverse_input; - reverse_input.mutable_data(in_dims, context.GetPlace()); - + bool need_reverse = false; + for (size_t axis = 0; axis < axes.size(); axis++) { + if (reverse_vector[axis] == 1) { + need_reverse = true; + break; + } + } auto in_t = framework::EigenTensor::From( *d_input); - auto reverse_in_t = - framework::EigenTensor::From( - reverse_input); auto out_t = framework::EigenTensor::From( *d_out, out_dims); - - reverse_in_t.device(place) = in_t.reverse(reverse_axis); - out_t.stridedSlice(starts_indices, ends_indices, strides_indices) - .device(place) = reverse_in_t; + if (need_reverse) { + framework::Tensor reverse_input; + reverse_input.mutable_data(in_dims, context.GetPlace()); + auto reverse_in_t = + framework::EigenTensor::From(reverse_input); + + reverse_in_t.device(place) = in_t.reverse(reverse_axis); + out_t.stridedSlice(starts_indices, ends_indices, strides_indices) + .device(place) = reverse_in_t; + } else { + out_t.stridedSlice(starts_indices, ends_indices, strides_indices) + .device(place) = in_t; + } } }; } // namespace operators