diff --git a/mindspore/lite/src/ops/ops.h b/mindspore/lite/src/ops/ops.h index d69b00ed8da911dcf071759878cf8dc750eaf77a..e1dd28980482f091a6f486496cff447ad35150ef 100644 --- a/mindspore/lite/src/ops/ops.h +++ b/mindspore/lite/src/ops/ops.h @@ -726,29 +726,28 @@ class StridedSlice : public Primitive { explicit StridedSlice(schema::Primitive *primitive) : Primitive(primitive) {} const schema::StridedSlice *GetAttribute() const { return this->primitive->value_as_StridedSlice(); } int InferShape(std::vector inputs, std::vector outputs) override; - int NDims() { return this->updated_ndim_; } + int NDims() { return this->ndim_; } void ApplyNewAxisMask(); std::vector ApplyShrinkMask(std::vector out_shape); void ApplyBeginMask(); void ApplyEndMask(); void ApplyEllipsisMask(); - std::vector UpdatedInShape() { return this->updated_in_shape_; } - std::vector UpdatedBegins() { return this->updated_begins_; } - std::vector UpdatedEnds() { return this->updated_ends_; } - std::vector UpdatedStrides() { return this->updated_strides_; } + std::vector GetInShape() { return this->in_shape_; } + std::vector GetBegins() { return this->begins_; } + std::vector GetEnds() { return this->ends_; } + std::vector GetStrides() { return this->strides_; } protected: - int updated_ndim_; - int ori_ndim_; - std::vector updated_in_shape_; - std::vector updated_begins_; - std::vector updated_ends_; - std::vector updated_strides_; - std::vector begins_mask_; - std::vector ends_mask_; - std::vector ellipsis_mask_; - std::vector new_axis_mask_; - std::vector shrink_axis_mask_; + int ndim_; + std::vector in_shape_; + std::vector begins_; + std::vector ends_; + std::vector strides_; + std::vector begins_mask_; + std::vector ends_mask_; + std::vector ellipsis_mask_; + std::vector new_axis_mask_; + std::vector shrink_axis_mask_; }; class PriorBox : public Primitive { diff --git a/mindspore/lite/src/ops/strided_slice.cc b/mindspore/lite/src/ops/strided_slice.cc index 5915de51c798b6f63e8ab0a07aaaf5787120aa74..63f4e4d0b111e833d0a5f34c4737beae9fc72c05 100644 --- a/mindspore/lite/src/ops/strided_slice.cc +++ b/mindspore/lite/src/ops/strided_slice.cc @@ -30,14 +30,20 @@ constexpr int kStridedSliceInputNum = 1; void StridedSlice::ApplyNewAxisMask() { for (int i = 0; i < new_axis_mask_.size(); i++) { if (new_axis_mask_.at(i)) { - updated_ndim_ += 1; - updated_in_shape_.insert(updated_in_shape_.begin() + i, 1); - updated_begins_.at(i) = 0; - updated_ends_.at(i) = 1; - updated_strides_.at(i) = 1; - updated_begins_.emplace_back(0); - updated_ends_.emplace_back(updated_in_shape_.at(updated_ndim_ - 1)); - updated_strides_.emplace_back(1); + ndim_ += 1; + in_shape_.insert(in_shape_.begin() + i, 1); + begins_.at(i) = 0; + ends_.at(i) = 1; + strides_.at(i) = 1; + + begins_.emplace_back(0); + ends_.emplace_back(in_shape_.at(ndim_ - 1)); + strides_.emplace_back(1); + + begins_mask_.at(i) = false; + ends_mask_.at(i) = false; + ellipsis_mask_.at(i) = false; + shrink_axis_mask_.at(i) = false; } } } @@ -47,8 +53,8 @@ std::vector StridedSlice::ApplyShrinkMask(std::vector out_shape) { out_shape.clear(); for (int i = 0; i < shrink_axis_mask_.size(); i++) { if (shrink_axis_mask_.at(i)) { - updated_ends_.at(i) = updated_begins_.at(i) + 1; - updated_strides_.at(i) = 1; + ends_.at(i) = begins_.at(i) + 1; + strides_.at(i) = 1; } else { out_shape.emplace_back(old_out_shape.at(i)); } @@ -63,22 +69,26 @@ std::vector StridedSlice::ApplyShrinkMask(std::vector out_shape) { void StridedSlice::ApplyEllipsisMask() { for (int i = 0; i < ellipsis_mask_.size(); i++) { if (ellipsis_mask_.at(i)) { - updated_begins_.at(i) = 0; - updated_ends_.at(i) = updated_in_shape_.at(i); + begins_.at(i) = 0; + ends_.at(i) = in_shape_.at(i); break; } } } void StridedSlice::ApplyBeginMask() { - for (int i = 0; i < ori_ndim_; i++) { - updated_begins_.at(i) = 0; + for (int i = 0; i < ndim_; i++) { + if (begins_mask_.at(i)) { + begins_.at(i) = 0; + } } } void StridedSlice::ApplyEndMask() { - for (int i = 0; i < ori_ndim_; i++) { - updated_ends_.at(i) = 0; + for (int i = 0; i < ndim_; i++) { + if (ends_.at(i)) { + ends_.at(i) = in_shape_.at(i); + } } } @@ -88,7 +98,7 @@ int StridedSlice::InferShape(std::vector inputs, std::vector inputs, std::vectorshape(); std::vector output_shape; auto strided_slice_prim = this->primitive->value_as_StridedSlice(); - updated_ndim_ = static_cast(strided_slice_prim->begin()->size()); - ori_ndim_ = updated_ndim_; - MS_ASSERT(updated_ndim_ == static_cast(strided_slice_prim->end()->size())); - MS_ASSERT(updated_ndim_ == static_cast(strided_slice_prim->stride()->size())); - MS_ASSERT(updated_ndim_ == static_cast(input_shape.size())); - - for (int i = 0; i < updated_ndim_; i++) { - updated_in_shape_.emplace_back(input_shape.at(i)); - updated_begins_.emplace_back((*(strided_slice_prim->begin()))[i]); - updated_ends_.emplace_back((*(strided_slice_prim->end()))[i]); - updated_strides_.emplace_back((*(strided_slice_prim->stride()))[i]); + ndim_ = static_cast(strided_slice_prim->begin()->size()); + + MS_ASSERT(ndim_ == static_cast(strided_slice_prim->end()->size())); + MS_ASSERT(ndim_ == static_cast(strided_slice_prim->stride()->size())); + MS_ASSERT(ndim_ == static_cast(input_shape.size())); + + for (int i = 0; i < ndim_; i++) { + in_shape_.emplace_back(input_shape.at(i)); + begins_.emplace_back((*(strided_slice_prim->begin()))[i]); + ends_.emplace_back((*(strided_slice_prim->end()))[i]); + strides_.emplace_back((*(strided_slice_prim->stride()))[i]); } // set all mask to original input shape - begins_mask_.resize(updated_ndim_); - ends_mask_.resize(updated_ndim_); - ellipsis_mask_.resize(updated_ndim_); - new_axis_mask_.resize(updated_ndim_); - shrink_axis_mask_.resize(updated_ndim_); + begins_mask_.resize(ndim_); + ends_mask_.resize(ndim_); + ellipsis_mask_.resize(ndim_); + new_axis_mask_.resize(ndim_); + shrink_axis_mask_.resize(ndim_); // convert bit to vector - for (int i = 0; i < updated_ndim_; i++) { + for (int i = 0; i < ndim_; i++) { begins_mask_.at(i) = static_cast(strided_slice_prim->beginMask()) & (1 << i); ends_mask_.at(i) = static_cast(strided_slice_prim->endMask()) & (1 << i); ellipsis_mask_.at(i) = static_cast(strided_slice_prim->ellipsisMask()) & (1 << i); @@ -127,29 +137,17 @@ int StridedSlice::InferShape(std::vector inputs, std::vector= updated_in_shape_.at(i) || updated_begins_.at(i) < -updated_in_shape_.at(i) || - updated_ends_.at(i) < -updated_in_shape_.at(i) || updated_ends_.at(i) > updated_in_shape_.at(i)) { - return RET_PARAM_INVALID; - } - updated_begins_.at(i) = updated_begins_.at(i) % updated_in_shape_.at(i); - updated_ends_.at(i) = updated_ends_.at(i) % updated_in_shape_.at(i); - - if ((updated_ends_.at(i) <= updated_begins_.at(i) && updated_strides_.at(i) > 0) || - (updated_ends_.at(i) >= updated_begins_.at(i) && updated_strides_.at(i) < 0)) { - output_shape.at(i) = 0; - } else { - output_shape.at(i) = 1 + (updated_ends_.at(i) - updated_begins_.at(i) - 1) / updated_strides_.at(i); - } + output_shape.at(i) = (ends_.at(i) - begins_.at(i)) / strides_.at(i); } } diff --git a/mindspore/lite/src/populate_parameter.cc b/mindspore/lite/src/populate_parameter.cc index f4367768c08bf381200c916ec09fa487db768a67..162a9d3275746c53f378063f2c923162f724f752 100644 --- a/mindspore/lite/src/populate_parameter.cc +++ b/mindspore/lite/src/populate_parameter.cc @@ -1094,13 +1094,13 @@ OpParameter *PopulateStridedSliceParameter(const lite::Primitive *primitive) { strided_slice_param->op_parameter_.type_ = primitive->Type(); auto n_dims = ((lite::StridedSlice *)primitive)->NDims(); strided_slice_param->num_axes_ = n_dims; - auto begin = ((lite::StridedSlice *)primitive)->UpdatedBegins(); + auto begin = ((lite::StridedSlice *)primitive)->GetBegins(); (void)memcpy(strided_slice_param->begins_, (begin.data()), begin.size() * sizeof(int)); - auto end = ((lite::StridedSlice *)primitive)->UpdatedEnds(); + auto end = ((lite::StridedSlice *)primitive)->GetEnds(); (void)memcpy(strided_slice_param->ends_, (end.data()), end.size() * sizeof(int)); - auto stride = ((lite::StridedSlice *)primitive)->UpdatedStrides(); + auto stride = ((lite::StridedSlice *)primitive)->GetStrides(); (void)memcpy(strided_slice_param->strides_, (stride.data()), stride.size() * sizeof(int)); - auto in_shape = ((lite::StridedSlice *)primitive)->UpdatedInShape(); + auto in_shape = ((lite::StridedSlice *)primitive)->GetInShape(); (void)memcpy(strided_slice_param->in_shape_, (in_shape.data()), in_shape.size() * sizeof(int)); return reinterpret_cast(strided_slice_param); }