提交 c1f8ade2 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!4055 [LITE] fix bug: arm cpu fp32 op stride_slice infershape

Merge pull request !4055 from yangruoqi713/stride_slice
......@@ -726,29 +726,28 @@ class StridedSlice : public Primitive {
explicit StridedSlice(schema::Primitive *primitive) : Primitive(primitive) {}
const schema::StridedSlice *GetAttribute() const { return this->primitive->value_as_StridedSlice(); }
int InferShape(std::vector<tensor::Tensor *> inputs, std::vector<tensor::Tensor *> outputs) override;
int NDims() { return this->updated_ndim_; }
int NDims() { return this->ndim_; }
void ApplyNewAxisMask();
std::vector<int> ApplyShrinkMask(std::vector<int> out_shape);
void ApplyBeginMask();
void ApplyEndMask();
void ApplyEllipsisMask();
std::vector<int> UpdatedInShape() { return this->updated_in_shape_; }
std::vector<int> UpdatedBegins() { return this->updated_begins_; }
std::vector<int> UpdatedEnds() { return this->updated_ends_; }
std::vector<int> UpdatedStrides() { return this->updated_strides_; }
std::vector<int> GetInShape() { return this->in_shape_; }
std::vector<int> GetBegins() { return this->begins_; }
std::vector<int> GetEnds() { return this->ends_; }
std::vector<int> GetStrides() { return this->strides_; }
protected:
int updated_ndim_;
int ori_ndim_;
std::vector<int> updated_in_shape_;
std::vector<int> updated_begins_;
std::vector<int> updated_ends_;
std::vector<int> updated_strides_;
std::vector<int> begins_mask_;
std::vector<int> ends_mask_;
std::vector<int> ellipsis_mask_;
std::vector<int> new_axis_mask_;
std::vector<int> shrink_axis_mask_;
int ndim_;
std::vector<int> in_shape_;
std::vector<int> begins_;
std::vector<int> ends_;
std::vector<int> strides_;
std::vector<bool> begins_mask_;
std::vector<bool> ends_mask_;
std::vector<bool> ellipsis_mask_;
std::vector<bool> new_axis_mask_;
std::vector<bool> shrink_axis_mask_;
};
class PriorBox : public Primitive {
......
......@@ -30,14 +30,20 @@ constexpr int kStridedSliceInputNum = 1;
void StridedSlice::ApplyNewAxisMask() {
for (int i = 0; i < new_axis_mask_.size(); i++) {
if (new_axis_mask_.at(i)) {
updated_ndim_ += 1;
updated_in_shape_.insert(updated_in_shape_.begin() + i, 1);
updated_begins_.at(i) = 0;
updated_ends_.at(i) = 1;
updated_strides_.at(i) = 1;
updated_begins_.emplace_back(0);
updated_ends_.emplace_back(updated_in_shape_.at(updated_ndim_ - 1));
updated_strides_.emplace_back(1);
ndim_ += 1;
in_shape_.insert(in_shape_.begin() + i, 1);
begins_.at(i) = 0;
ends_.at(i) = 1;
strides_.at(i) = 1;
begins_.emplace_back(0);
ends_.emplace_back(in_shape_.at(ndim_ - 1));
strides_.emplace_back(1);
begins_mask_.at(i) = false;
ends_mask_.at(i) = false;
ellipsis_mask_.at(i) = false;
shrink_axis_mask_.at(i) = false;
}
}
}
......@@ -47,8 +53,8 @@ std::vector<int> StridedSlice::ApplyShrinkMask(std::vector<int> out_shape) {
out_shape.clear();
for (int i = 0; i < shrink_axis_mask_.size(); i++) {
if (shrink_axis_mask_.at(i)) {
updated_ends_.at(i) = updated_begins_.at(i) + 1;
updated_strides_.at(i) = 1;
ends_.at(i) = begins_.at(i) + 1;
strides_.at(i) = 1;
} else {
out_shape.emplace_back(old_out_shape.at(i));
}
......@@ -63,22 +69,26 @@ std::vector<int> StridedSlice::ApplyShrinkMask(std::vector<int> out_shape) {
void StridedSlice::ApplyEllipsisMask() {
for (int i = 0; i < ellipsis_mask_.size(); i++) {
if (ellipsis_mask_.at(i)) {
updated_begins_.at(i) = 0;
updated_ends_.at(i) = updated_in_shape_.at(i);
begins_.at(i) = 0;
ends_.at(i) = in_shape_.at(i);
break;
}
}
}
void StridedSlice::ApplyBeginMask() {
for (int i = 0; i < ori_ndim_; i++) {
updated_begins_.at(i) = 0;
for (int i = 0; i < ndim_; i++) {
if (begins_mask_.at(i)) {
begins_.at(i) = 0;
}
}
}
void StridedSlice::ApplyEndMask() {
for (int i = 0; i < ori_ndim_; i++) {
updated_ends_.at(i) = 0;
for (int i = 0; i < ndim_; i++) {
if (ends_.at(i)) {
ends_.at(i) = in_shape_.at(i);
}
}
}
......@@ -88,7 +98,7 @@ int StridedSlice::InferShape(std::vector<tensor::Tensor *> inputs, std::vector<t
MS_LOG(ERROR) << "Invalid output size:" << outputs.size();
return RET_PARAM_INVALID;
}
if (inputs.size() < kStridedSliceInputNum) {
if (inputs.size() != kStridedSliceInputNum) {
MS_LOG(ERROR) << "Invalid input size " << inputs.size();
return RET_PARAM_INVALID;
}
......@@ -97,28 +107,28 @@ int StridedSlice::InferShape(std::vector<tensor::Tensor *> inputs, std::vector<t
auto input_shape = input->shape();
std::vector<int> output_shape;
auto strided_slice_prim = this->primitive->value_as_StridedSlice();
updated_ndim_ = static_cast<int>(strided_slice_prim->begin()->size());
ori_ndim_ = updated_ndim_;
MS_ASSERT(updated_ndim_ == static_cast<int>(strided_slice_prim->end()->size()));
MS_ASSERT(updated_ndim_ == static_cast<int>(strided_slice_prim->stride()->size()));
MS_ASSERT(updated_ndim_ == static_cast<int>(input_shape.size()));
for (int i = 0; i < updated_ndim_; i++) {
updated_in_shape_.emplace_back(input_shape.at(i));
updated_begins_.emplace_back((*(strided_slice_prim->begin()))[i]);
updated_ends_.emplace_back((*(strided_slice_prim->end()))[i]);
updated_strides_.emplace_back((*(strided_slice_prim->stride()))[i]);
ndim_ = static_cast<int>(strided_slice_prim->begin()->size());
MS_ASSERT(ndim_ == static_cast<int>(strided_slice_prim->end()->size()));
MS_ASSERT(ndim_ == static_cast<int>(strided_slice_prim->stride()->size()));
MS_ASSERT(ndim_ == static_cast<int>(input_shape.size()));
for (int i = 0; i < ndim_; i++) {
in_shape_.emplace_back(input_shape.at(i));
begins_.emplace_back((*(strided_slice_prim->begin()))[i]);
ends_.emplace_back((*(strided_slice_prim->end()))[i]);
strides_.emplace_back((*(strided_slice_prim->stride()))[i]);
}
// set all mask to original input shape
begins_mask_.resize(updated_ndim_);
ends_mask_.resize(updated_ndim_);
ellipsis_mask_.resize(updated_ndim_);
new_axis_mask_.resize(updated_ndim_);
shrink_axis_mask_.resize(updated_ndim_);
begins_mask_.resize(ndim_);
ends_mask_.resize(ndim_);
ellipsis_mask_.resize(ndim_);
new_axis_mask_.resize(ndim_);
shrink_axis_mask_.resize(ndim_);
// convert bit to vector
for (int i = 0; i < updated_ndim_; i++) {
for (int i = 0; i < ndim_; i++) {
begins_mask_.at(i) = static_cast<uint32_t>(strided_slice_prim->beginMask()) & (1 << i);
ends_mask_.at(i) = static_cast<uint32_t>(strided_slice_prim->endMask()) & (1 << i);
ellipsis_mask_.at(i) = static_cast<uint32_t>(strided_slice_prim->ellipsisMask()) & (1 << i);
......@@ -127,29 +137,17 @@ int StridedSlice::InferShape(std::vector<tensor::Tensor *> inputs, std::vector<t
}
ApplyNewAxisMask();
ApplyNewAxisMask();
ApplyBeginMask();
ApplyEndMask();
ApplyEllipsisMask();
output_shape.resize(updated_in_shape_.size());
for (int i = 0; i < updated_in_shape_.size(); i++) {
if (i < ori_ndim_ && new_axis_mask_.at(i)) {
output_shape.clear();
output_shape.resize(in_shape_.size());
for (int i = 0; i < in_shape_.size(); i++) {
if (i < ndim_ && new_axis_mask_.at(i)) {
output_shape.at(i) = 1;
} else {
// begins and ends out of range handling
if (updated_begins_.at(i) >= updated_in_shape_.at(i) || updated_begins_.at(i) < -updated_in_shape_.at(i) ||
updated_ends_.at(i) < -updated_in_shape_.at(i) || updated_ends_.at(i) > updated_in_shape_.at(i)) {
return RET_PARAM_INVALID;
}
updated_begins_.at(i) = updated_begins_.at(i) % updated_in_shape_.at(i);
updated_ends_.at(i) = updated_ends_.at(i) % updated_in_shape_.at(i);
if ((updated_ends_.at(i) <= updated_begins_.at(i) && updated_strides_.at(i) > 0) ||
(updated_ends_.at(i) >= updated_begins_.at(i) && updated_strides_.at(i) < 0)) {
output_shape.at(i) = 0;
} else {
output_shape.at(i) = 1 + (updated_ends_.at(i) - updated_begins_.at(i) - 1) / updated_strides_.at(i);
}
output_shape.at(i) = (ends_.at(i) - begins_.at(i)) / strides_.at(i);
}
}
......
......@@ -1094,13 +1094,13 @@ OpParameter *PopulateStridedSliceParameter(const lite::Primitive *primitive) {
strided_slice_param->op_parameter_.type_ = primitive->Type();
auto n_dims = ((lite::StridedSlice *)primitive)->NDims();
strided_slice_param->num_axes_ = n_dims;
auto begin = ((lite::StridedSlice *)primitive)->UpdatedBegins();
auto begin = ((lite::StridedSlice *)primitive)->GetBegins();
(void)memcpy(strided_slice_param->begins_, (begin.data()), begin.size() * sizeof(int));
auto end = ((lite::StridedSlice *)primitive)->UpdatedEnds();
auto end = ((lite::StridedSlice *)primitive)->GetEnds();
(void)memcpy(strided_slice_param->ends_, (end.data()), end.size() * sizeof(int));
auto stride = ((lite::StridedSlice *)primitive)->UpdatedStrides();
auto stride = ((lite::StridedSlice *)primitive)->GetStrides();
(void)memcpy(strided_slice_param->strides_, (stride.data()), stride.size() * sizeof(int));
auto in_shape = ((lite::StridedSlice *)primitive)->UpdatedInShape();
auto in_shape = ((lite::StridedSlice *)primitive)->GetInShape();
(void)memcpy(strided_slice_param->in_shape_, (in_shape.data()), in_shape.size() * sizeof(int));
return reinterpret_cast<OpParameter *>(strided_slice_param);
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册