diff --git a/lite/core/tensor.h b/lite/core/tensor.h index 5852922c01b06e253c7ca726eadd3132a53a3954..144fc540c680c217a655eeb184cb5742799ba922 100644 --- a/lite/core/tensor.h +++ b/lite/core/tensor.h @@ -102,6 +102,15 @@ class DDimLite { DDimLite Slice(int start, int end) const; + bool CheckPositive() const { + for (size_t i = 0; i < size(); ++i) { + if (data_[i] <= 0) { + return false; + } + } + return true; + } + DDimLite Flatten2D(int col) const { return DDimLite(std::vector( {Slice(0, col).production(), Slice(col, size()).production()})); diff --git a/lite/operators/reshape_op.cc b/lite/operators/reshape_op.cc index 35f38148591690b4a1e563327112807ad12e6e2b..b8ade2733c196e5ecfef161c21f011d6a893586e 100644 --- a/lite/operators/reshape_op.cc +++ b/lite/operators/reshape_op.cc @@ -27,14 +27,15 @@ bool ReshapeOp::CheckShape() const { } bool ReshapeOp::InferShape() const { - auto shape_tensor_vct = param_.shape_tensor_vct; + auto &shape_tensor_vct = param_.shape_tensor_vct; auto *shape_tensor = param_.shape_tensor; - auto shape_vct = param_.shape_vct; + auto &shape_vct = param_.shape_vct; std::vector final_shape; if (shape_tensor_vct.size() > 0) { + final_shape.resize(shape_tensor_vct.size()); for (int i = 0; i < shape_tensor_vct.size(); i++) { - final_shape.push_back(shape_tensor_vct[i]->data()[0]); + final_shape[i] = shape_tensor_vct[i]->data()[0]; } } else if (shape_tensor != nullptr) { auto *shape_tensor_data = shape_tensor->data(); @@ -46,7 +47,7 @@ bool ReshapeOp::InferShape() const { LOG(FATAL) << "input shape error"; } - auto x_dims = param_.x->dims(); + auto &x_dims = param_.x->dims(); auto output_dims = ValidateShape(final_shape, x_dims); param_.output->Resize(output_dims); auto out_lod = param_.output->mutable_lod(); @@ -98,8 +99,9 @@ bool Reshape2Op::CheckShape() const { bool Reshape2Op::InferShape() const { ReshapeOp::InferShape(); - auto x_dims = param_.x->dims(); - std::vector xshape_dims(x_dims.size() + 1, 0); + auto &x_dims = param_.x->dims(); + DDim xshape_dims(x_dims.size() + 1); + xshape_dims[0] = 0; for (size_t i = 0; i < x_dims.size(); i++) { xshape_dims[i + 1] = x_dims[i]; } @@ -117,19 +119,15 @@ bool Reshape2Op::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) { } DDim ValidateShape(const std::vector &shape, const DDim &input_dims) { - const lite::DDim::value_type input_size = input_dims.production(); - auto input_shape = input_dims.Vectorize(); - bool all_positive = std::all_of( - input_shape.cbegin(), input_shape.cend(), [](lite::DDim::value_type i) { - return i > 0; - }); - // only one dimension can be set to -1, whose size will be automatically + const DDim::value_type input_size = input_dims.production(); + + // Only one dimension can be set to -1, whose size will be automatically // infered. const int unk_dim_val = -1; const int copy_dim_val = 0; - std::vector output_shape(shape.size(), 0); - lite::DDim::value_type capacity = 1; + DDim output_dims(shape.size()); + DDim::value_type capacity = 1; int unk_dim_idx = -1; for (size_t i = 0; i < shape.size(); ++i) { if (shape[i] == unk_dim_val) { @@ -137,7 +135,7 @@ DDim ValidateShape(const std::vector &shape, const DDim &input_dims) { << "Only one input dimension of Attr(shape) can be unknown."; unk_dim_idx = i; } else if (shape[i] == copy_dim_val) { - CHECK_LT(static_cast(i), input_shape.size()) + CHECK_LT(static_cast(i), input_dims.size()) << "The index of dimension to copy from input shape must be less " "than the size of input shape."; } else { @@ -145,28 +143,28 @@ DDim ValidateShape(const std::vector &shape, const DDim &input_dims) { "be negtive except one unknown dimension."; } - capacity *= (shape[i] ? static_cast(shape[i]) - : input_shape[i]); - output_shape[i] = (shape[i] ? static_cast(shape[i]) - : input_shape[i]); + DDim::value_type output_dim_i = + shape[i] ? static_cast(shape[i]) : input_dims[i]; + output_dims[i] = output_dim_i; + capacity *= output_dim_i; } if (unk_dim_idx != -1) { - if (all_positive) { + if (input_dims.CheckPositive()) { // input_size < 0 and is un-determinate in compile time, skip the check, // for example, input_dims = [-1, 8, 1, 1], shape = [-1, 3, 8], - // capacity = -24, input_size = -8, output_shape[0] = 0 + // capacity = -24, input_size = -8, output_dims[0] = 0 // the following check will fail. - output_shape[unk_dim_idx] = -input_size / capacity; - CHECK_EQ(output_shape[unk_dim_idx] * capacity, -input_size) + output_dims[unk_dim_idx] = -input_size / capacity; + CHECK_EQ(output_dims[unk_dim_idx] * capacity, -input_size) << "Invalid shape is given."; } else { - output_shape[unk_dim_idx] = -1; + output_dims[unk_dim_idx] = -1; } } else { CHECK_EQ(capacity, input_size) << "Invalid shape is given."; } - return lite::DDim(output_shape); + return output_dims; } } // namespace operators