提交 c1f8ade2 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!4055 [LITE] fix bug: arm cpu fp32 op stride_slice infershape

Merge pull request !4055 from yangruoqi713/stride_slice
...@@ -726,29 +726,28 @@ class StridedSlice : public Primitive { ...@@ -726,29 +726,28 @@ class StridedSlice : public Primitive {
explicit StridedSlice(schema::Primitive *primitive) : Primitive(primitive) {} explicit StridedSlice(schema::Primitive *primitive) : Primitive(primitive) {}
const schema::StridedSlice *GetAttribute() const { return this->primitive->value_as_StridedSlice(); } const schema::StridedSlice *GetAttribute() const { return this->primitive->value_as_StridedSlice(); }
int InferShape(std::vector<tensor::Tensor *> inputs, std::vector<tensor::Tensor *> outputs) override; int InferShape(std::vector<tensor::Tensor *> inputs, std::vector<tensor::Tensor *> outputs) override;
int NDims() { return this->updated_ndim_; } int NDims() { return this->ndim_; }
void ApplyNewAxisMask(); void ApplyNewAxisMask();
std::vector<int> ApplyShrinkMask(std::vector<int> out_shape); std::vector<int> ApplyShrinkMask(std::vector<int> out_shape);
void ApplyBeginMask(); void ApplyBeginMask();
void ApplyEndMask(); void ApplyEndMask();
void ApplyEllipsisMask(); void ApplyEllipsisMask();
std::vector<int> UpdatedInShape() { return this->updated_in_shape_; } std::vector<int> GetInShape() { return this->in_shape_; }
std::vector<int> UpdatedBegins() { return this->updated_begins_; } std::vector<int> GetBegins() { return this->begins_; }
std::vector<int> UpdatedEnds() { return this->updated_ends_; } std::vector<int> GetEnds() { return this->ends_; }
std::vector<int> UpdatedStrides() { return this->updated_strides_; } std::vector<int> GetStrides() { return this->strides_; }
protected: protected:
int updated_ndim_; int ndim_;
int ori_ndim_; std::vector<int> in_shape_;
std::vector<int> updated_in_shape_; std::vector<int> begins_;
std::vector<int> updated_begins_; std::vector<int> ends_;
std::vector<int> updated_ends_; std::vector<int> strides_;
std::vector<int> updated_strides_; std::vector<bool> begins_mask_;
std::vector<int> begins_mask_; std::vector<bool> ends_mask_;
std::vector<int> ends_mask_; std::vector<bool> ellipsis_mask_;
std::vector<int> ellipsis_mask_; std::vector<bool> new_axis_mask_;
std::vector<int> new_axis_mask_; std::vector<bool> shrink_axis_mask_;
std::vector<int> shrink_axis_mask_;
}; };
class PriorBox : public Primitive { class PriorBox : public Primitive {
......
...@@ -30,14 +30,20 @@ constexpr int kStridedSliceInputNum = 1; ...@@ -30,14 +30,20 @@ constexpr int kStridedSliceInputNum = 1;
void StridedSlice::ApplyNewAxisMask() { void StridedSlice::ApplyNewAxisMask() {
for (int i = 0; i < new_axis_mask_.size(); i++) { for (int i = 0; i < new_axis_mask_.size(); i++) {
if (new_axis_mask_.at(i)) { if (new_axis_mask_.at(i)) {
updated_ndim_ += 1; ndim_ += 1;
updated_in_shape_.insert(updated_in_shape_.begin() + i, 1); in_shape_.insert(in_shape_.begin() + i, 1);
updated_begins_.at(i) = 0; begins_.at(i) = 0;
updated_ends_.at(i) = 1; ends_.at(i) = 1;
updated_strides_.at(i) = 1; strides_.at(i) = 1;
updated_begins_.emplace_back(0);
updated_ends_.emplace_back(updated_in_shape_.at(updated_ndim_ - 1)); begins_.emplace_back(0);
updated_strides_.emplace_back(1); ends_.emplace_back(in_shape_.at(ndim_ - 1));
strides_.emplace_back(1);
begins_mask_.at(i) = false;
ends_mask_.at(i) = false;
ellipsis_mask_.at(i) = false;
shrink_axis_mask_.at(i) = false;
} }
} }
} }
...@@ -47,8 +53,8 @@ std::vector<int> StridedSlice::ApplyShrinkMask(std::vector<int> out_shape) { ...@@ -47,8 +53,8 @@ std::vector<int> StridedSlice::ApplyShrinkMask(std::vector<int> out_shape) {
out_shape.clear(); out_shape.clear();
for (int i = 0; i < shrink_axis_mask_.size(); i++) { for (int i = 0; i < shrink_axis_mask_.size(); i++) {
if (shrink_axis_mask_.at(i)) { if (shrink_axis_mask_.at(i)) {
updated_ends_.at(i) = updated_begins_.at(i) + 1; ends_.at(i) = begins_.at(i) + 1;
updated_strides_.at(i) = 1; strides_.at(i) = 1;
} else { } else {
out_shape.emplace_back(old_out_shape.at(i)); out_shape.emplace_back(old_out_shape.at(i));
} }
...@@ -63,22 +69,26 @@ std::vector<int> StridedSlice::ApplyShrinkMask(std::vector<int> out_shape) { ...@@ -63,22 +69,26 @@ std::vector<int> StridedSlice::ApplyShrinkMask(std::vector<int> out_shape) {
void StridedSlice::ApplyEllipsisMask() { void StridedSlice::ApplyEllipsisMask() {
for (int i = 0; i < ellipsis_mask_.size(); i++) { for (int i = 0; i < ellipsis_mask_.size(); i++) {
if (ellipsis_mask_.at(i)) { if (ellipsis_mask_.at(i)) {
updated_begins_.at(i) = 0; begins_.at(i) = 0;
updated_ends_.at(i) = updated_in_shape_.at(i); ends_.at(i) = in_shape_.at(i);
break; break;
} }
} }
} }
void StridedSlice::ApplyBeginMask() { void StridedSlice::ApplyBeginMask() {
for (int i = 0; i < ori_ndim_; i++) { for (int i = 0; i < ndim_; i++) {
updated_begins_.at(i) = 0; if (begins_mask_.at(i)) {
begins_.at(i) = 0;
}
} }
} }
void StridedSlice::ApplyEndMask() { void StridedSlice::ApplyEndMask() {
for (int i = 0; i < ori_ndim_; i++) { for (int i = 0; i < ndim_; i++) {
updated_ends_.at(i) = 0; if (ends_.at(i)) {
ends_.at(i) = in_shape_.at(i);
}
} }
} }
...@@ -88,7 +98,7 @@ int StridedSlice::InferShape(std::vector<tensor::Tensor *> inputs, std::vector<t ...@@ -88,7 +98,7 @@ int StridedSlice::InferShape(std::vector<tensor::Tensor *> inputs, std::vector<t
MS_LOG(ERROR) << "Invalid output size:" << outputs.size(); MS_LOG(ERROR) << "Invalid output size:" << outputs.size();
return RET_PARAM_INVALID; return RET_PARAM_INVALID;
} }
if (inputs.size() < kStridedSliceInputNum) { if (inputs.size() != kStridedSliceInputNum) {
MS_LOG(ERROR) << "Invalid input size " << inputs.size(); MS_LOG(ERROR) << "Invalid input size " << inputs.size();
return RET_PARAM_INVALID; return RET_PARAM_INVALID;
} }
...@@ -97,28 +107,28 @@ int StridedSlice::InferShape(std::vector<tensor::Tensor *> inputs, std::vector<t ...@@ -97,28 +107,28 @@ int StridedSlice::InferShape(std::vector<tensor::Tensor *> inputs, std::vector<t
auto input_shape = input->shape(); auto input_shape = input->shape();
std::vector<int> output_shape; std::vector<int> output_shape;
auto strided_slice_prim = this->primitive->value_as_StridedSlice(); auto strided_slice_prim = this->primitive->value_as_StridedSlice();
updated_ndim_ = static_cast<int>(strided_slice_prim->begin()->size()); ndim_ = static_cast<int>(strided_slice_prim->begin()->size());
ori_ndim_ = updated_ndim_;
MS_ASSERT(updated_ndim_ == static_cast<int>(strided_slice_prim->end()->size())); MS_ASSERT(ndim_ == static_cast<int>(strided_slice_prim->end()->size()));
MS_ASSERT(updated_ndim_ == static_cast<int>(strided_slice_prim->stride()->size())); MS_ASSERT(ndim_ == static_cast<int>(strided_slice_prim->stride()->size()));
MS_ASSERT(updated_ndim_ == static_cast<int>(input_shape.size())); MS_ASSERT(ndim_ == static_cast<int>(input_shape.size()));
for (int i = 0; i < updated_ndim_; i++) { for (int i = 0; i < ndim_; i++) {
updated_in_shape_.emplace_back(input_shape.at(i)); in_shape_.emplace_back(input_shape.at(i));
updated_begins_.emplace_back((*(strided_slice_prim->begin()))[i]); begins_.emplace_back((*(strided_slice_prim->begin()))[i]);
updated_ends_.emplace_back((*(strided_slice_prim->end()))[i]); ends_.emplace_back((*(strided_slice_prim->end()))[i]);
updated_strides_.emplace_back((*(strided_slice_prim->stride()))[i]); strides_.emplace_back((*(strided_slice_prim->stride()))[i]);
} }
// set all mask to original input shape // set all mask to original input shape
begins_mask_.resize(updated_ndim_); begins_mask_.resize(ndim_);
ends_mask_.resize(updated_ndim_); ends_mask_.resize(ndim_);
ellipsis_mask_.resize(updated_ndim_); ellipsis_mask_.resize(ndim_);
new_axis_mask_.resize(updated_ndim_); new_axis_mask_.resize(ndim_);
shrink_axis_mask_.resize(updated_ndim_); shrink_axis_mask_.resize(ndim_);
// convert bit to vector // convert bit to vector
for (int i = 0; i < updated_ndim_; i++) { for (int i = 0; i < ndim_; i++) {
begins_mask_.at(i) = static_cast<uint32_t>(strided_slice_prim->beginMask()) & (1 << i); begins_mask_.at(i) = static_cast<uint32_t>(strided_slice_prim->beginMask()) & (1 << i);
ends_mask_.at(i) = static_cast<uint32_t>(strided_slice_prim->endMask()) & (1 << i); ends_mask_.at(i) = static_cast<uint32_t>(strided_slice_prim->endMask()) & (1 << i);
ellipsis_mask_.at(i) = static_cast<uint32_t>(strided_slice_prim->ellipsisMask()) & (1 << i); ellipsis_mask_.at(i) = static_cast<uint32_t>(strided_slice_prim->ellipsisMask()) & (1 << i);
...@@ -127,29 +137,17 @@ int StridedSlice::InferShape(std::vector<tensor::Tensor *> inputs, std::vector<t ...@@ -127,29 +137,17 @@ int StridedSlice::InferShape(std::vector<tensor::Tensor *> inputs, std::vector<t
} }
ApplyNewAxisMask(); ApplyNewAxisMask();
ApplyNewAxisMask(); ApplyBeginMask();
ApplyEndMask(); ApplyEndMask();
ApplyEllipsisMask(); ApplyEllipsisMask();
output_shape.resize(updated_in_shape_.size()); output_shape.clear();
for (int i = 0; i < updated_in_shape_.size(); i++) { output_shape.resize(in_shape_.size());
if (i < ori_ndim_ && new_axis_mask_.at(i)) { for (int i = 0; i < in_shape_.size(); i++) {
if (i < ndim_ && new_axis_mask_.at(i)) {
output_shape.at(i) = 1; output_shape.at(i) = 1;
} else { } else {
// begins and ends out of range handling output_shape.at(i) = (ends_.at(i) - begins_.at(i)) / strides_.at(i);
if (updated_begins_.at(i) >= updated_in_shape_.at(i) || updated_begins_.at(i) < -updated_in_shape_.at(i) ||
updated_ends_.at(i) < -updated_in_shape_.at(i) || updated_ends_.at(i) > updated_in_shape_.at(i)) {
return RET_PARAM_INVALID;
}
updated_begins_.at(i) = updated_begins_.at(i) % updated_in_shape_.at(i);
updated_ends_.at(i) = updated_ends_.at(i) % updated_in_shape_.at(i);
if ((updated_ends_.at(i) <= updated_begins_.at(i) && updated_strides_.at(i) > 0) ||
(updated_ends_.at(i) >= updated_begins_.at(i) && updated_strides_.at(i) < 0)) {
output_shape.at(i) = 0;
} else {
output_shape.at(i) = 1 + (updated_ends_.at(i) - updated_begins_.at(i) - 1) / updated_strides_.at(i);
}
} }
} }
......
...@@ -1094,13 +1094,13 @@ OpParameter *PopulateStridedSliceParameter(const lite::Primitive *primitive) { ...@@ -1094,13 +1094,13 @@ OpParameter *PopulateStridedSliceParameter(const lite::Primitive *primitive) {
strided_slice_param->op_parameter_.type_ = primitive->Type(); strided_slice_param->op_parameter_.type_ = primitive->Type();
auto n_dims = ((lite::StridedSlice *)primitive)->NDims(); auto n_dims = ((lite::StridedSlice *)primitive)->NDims();
strided_slice_param->num_axes_ = n_dims; strided_slice_param->num_axes_ = n_dims;
auto begin = ((lite::StridedSlice *)primitive)->UpdatedBegins(); auto begin = ((lite::StridedSlice *)primitive)->GetBegins();
(void)memcpy(strided_slice_param->begins_, (begin.data()), begin.size() * sizeof(int)); (void)memcpy(strided_slice_param->begins_, (begin.data()), begin.size() * sizeof(int));
auto end = ((lite::StridedSlice *)primitive)->UpdatedEnds(); auto end = ((lite::StridedSlice *)primitive)->GetEnds();
(void)memcpy(strided_slice_param->ends_, (end.data()), end.size() * sizeof(int)); (void)memcpy(strided_slice_param->ends_, (end.data()), end.size() * sizeof(int));
auto stride = ((lite::StridedSlice *)primitive)->UpdatedStrides(); auto stride = ((lite::StridedSlice *)primitive)->GetStrides();
(void)memcpy(strided_slice_param->strides_, (stride.data()), stride.size() * sizeof(int)); (void)memcpy(strided_slice_param->strides_, (stride.data()), stride.size() * sizeof(int));
auto in_shape = ((lite::StridedSlice *)primitive)->UpdatedInShape(); auto in_shape = ((lite::StridedSlice *)primitive)->GetInShape();
(void)memcpy(strided_slice_param->in_shape_, (in_shape.data()), in_shape.size() * sizeof(int)); (void)memcpy(strided_slice_param->in_shape_, (in_shape.data()), in_shape.size() * sizeof(int));
return reinterpret_cast<OpParameter *>(strided_slice_param); return reinterpret_cast<OpParameter *>(strided_slice_param);
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册