提交 6468cdeb 编写于 作者: L Liu Yiqun

Optimize the InferShape of reshape and reshape2.

上级 b7f6ed3d
...@@ -102,6 +102,15 @@ class DDimLite { ...@@ -102,6 +102,15 @@ class DDimLite {
DDimLite Slice(int start, int end) const; DDimLite Slice(int start, int end) const;
bool CheckPositive() const {
for (size_t i = 0; i < size(); ++i) {
if (data_[i] <= 0) {
return false;
}
}
return true;
}
DDimLite Flatten2D(int col) const { DDimLite Flatten2D(int col) const {
return DDimLite(std::vector<value_type>( return DDimLite(std::vector<value_type>(
{Slice(0, col).production(), Slice(col, size()).production()})); {Slice(0, col).production(), Slice(col, size()).production()}));
......
...@@ -27,14 +27,15 @@ bool ReshapeOp::CheckShape() const { ...@@ -27,14 +27,15 @@ bool ReshapeOp::CheckShape() const {
} }
bool ReshapeOp::InferShape() const { bool ReshapeOp::InferShape() const {
auto shape_tensor_vct = param_.shape_tensor_vct; auto &shape_tensor_vct = param_.shape_tensor_vct;
auto *shape_tensor = param_.shape_tensor; auto *shape_tensor = param_.shape_tensor;
auto shape_vct = param_.shape_vct; auto &shape_vct = param_.shape_vct;
std::vector<int> final_shape; std::vector<int> final_shape;
if (shape_tensor_vct.size() > 0) { if (shape_tensor_vct.size() > 0) {
final_shape.resize(shape_tensor_vct.size());
for (int i = 0; i < shape_tensor_vct.size(); i++) { for (int i = 0; i < shape_tensor_vct.size(); i++) {
final_shape.push_back(shape_tensor_vct[i]->data<int>()[0]); final_shape[i] = shape_tensor_vct[i]->data<int>()[0];
} }
} else if (shape_tensor != nullptr) { } else if (shape_tensor != nullptr) {
auto *shape_tensor_data = shape_tensor->data<int>(); auto *shape_tensor_data = shape_tensor->data<int>();
...@@ -46,7 +47,7 @@ bool ReshapeOp::InferShape() const { ...@@ -46,7 +47,7 @@ bool ReshapeOp::InferShape() const {
LOG(FATAL) << "input shape error"; LOG(FATAL) << "input shape error";
} }
auto x_dims = param_.x->dims(); auto &x_dims = param_.x->dims();
auto output_dims = ValidateShape(final_shape, x_dims); auto output_dims = ValidateShape(final_shape, x_dims);
param_.output->Resize(output_dims); param_.output->Resize(output_dims);
auto out_lod = param_.output->mutable_lod(); auto out_lod = param_.output->mutable_lod();
...@@ -98,8 +99,9 @@ bool Reshape2Op::CheckShape() const { ...@@ -98,8 +99,9 @@ bool Reshape2Op::CheckShape() const {
bool Reshape2Op::InferShape() const { bool Reshape2Op::InferShape() const {
ReshapeOp::InferShape(); ReshapeOp::InferShape();
auto x_dims = param_.x->dims(); auto &x_dims = param_.x->dims();
std::vector<DDim::value_type> xshape_dims(x_dims.size() + 1, 0); DDim xshape_dims(x_dims.size() + 1);
xshape_dims[0] = 0;
for (size_t i = 0; i < x_dims.size(); i++) { for (size_t i = 0; i < x_dims.size(); i++) {
xshape_dims[i + 1] = x_dims[i]; xshape_dims[i + 1] = x_dims[i];
} }
...@@ -117,19 +119,15 @@ bool Reshape2Op::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) { ...@@ -117,19 +119,15 @@ bool Reshape2Op::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) {
} }
DDim ValidateShape(const std::vector<int> &shape, const DDim &input_dims) { DDim ValidateShape(const std::vector<int> &shape, const DDim &input_dims) {
const lite::DDim::value_type input_size = input_dims.production(); const DDim::value_type input_size = input_dims.production();
auto input_shape = input_dims.Vectorize();
bool all_positive = std::all_of( // Only one dimension can be set to -1, whose size will be automatically
input_shape.cbegin(), input_shape.cend(), [](lite::DDim::value_type i) {
return i > 0;
});
// only one dimension can be set to -1, whose size will be automatically
// infered. // infered.
const int unk_dim_val = -1; const int unk_dim_val = -1;
const int copy_dim_val = 0; const int copy_dim_val = 0;
std::vector<lite::DDim::value_type> output_shape(shape.size(), 0); DDim output_dims(shape.size());
lite::DDim::value_type capacity = 1; DDim::value_type capacity = 1;
int unk_dim_idx = -1; int unk_dim_idx = -1;
for (size_t i = 0; i < shape.size(); ++i) { for (size_t i = 0; i < shape.size(); ++i) {
if (shape[i] == unk_dim_val) { if (shape[i] == unk_dim_val) {
...@@ -137,7 +135,7 @@ DDim ValidateShape(const std::vector<int> &shape, const DDim &input_dims) { ...@@ -137,7 +135,7 @@ DDim ValidateShape(const std::vector<int> &shape, const DDim &input_dims) {
<< "Only one input dimension of Attr(shape) can be unknown."; << "Only one input dimension of Attr(shape) can be unknown.";
unk_dim_idx = i; unk_dim_idx = i;
} else if (shape[i] == copy_dim_val) { } else if (shape[i] == copy_dim_val) {
CHECK_LT(static_cast<int>(i), input_shape.size()) CHECK_LT(static_cast<int>(i), input_dims.size())
<< "The index of dimension to copy from input shape must be less " << "The index of dimension to copy from input shape must be less "
"than the size of input shape."; "than the size of input shape.";
} else { } else {
...@@ -145,28 +143,28 @@ DDim ValidateShape(const std::vector<int> &shape, const DDim &input_dims) { ...@@ -145,28 +143,28 @@ DDim ValidateShape(const std::vector<int> &shape, const DDim &input_dims) {
"be negtive except one unknown dimension."; "be negtive except one unknown dimension.";
} }
capacity *= (shape[i] ? static_cast<lite::DDim::value_type>(shape[i]) DDim::value_type output_dim_i =
: input_shape[i]); shape[i] ? static_cast<DDim::value_type>(shape[i]) : input_dims[i];
output_shape[i] = (shape[i] ? static_cast<lite::DDim::value_type>(shape[i]) output_dims[i] = output_dim_i;
: input_shape[i]); capacity *= output_dim_i;
} }
if (unk_dim_idx != -1) { if (unk_dim_idx != -1) {
if (all_positive) { if (input_dims.CheckPositive()) {
// input_size < 0 and is un-determinate in compile time, skip the check, // input_size < 0 and is un-determinate in compile time, skip the check,
// for example, input_dims = [-1, 8, 1, 1], shape = [-1, 3, 8], // for example, input_dims = [-1, 8, 1, 1], shape = [-1, 3, 8],
// capacity = -24, input_size = -8, output_shape[0] = 0 // capacity = -24, input_size = -8, output_dims[0] = 0
// the following check will fail. // the following check will fail.
output_shape[unk_dim_idx] = -input_size / capacity; output_dims[unk_dim_idx] = -input_size / capacity;
CHECK_EQ(output_shape[unk_dim_idx] * capacity, -input_size) CHECK_EQ(output_dims[unk_dim_idx] * capacity, -input_size)
<< "Invalid shape is given."; << "Invalid shape is given.";
} else { } else {
output_shape[unk_dim_idx] = -1; output_dims[unk_dim_idx] = -1;
} }
} else { } else {
CHECK_EQ(capacity, input_size) << "Invalid shape is given."; CHECK_EQ(capacity, input_size) << "Invalid shape is given.";
} }
return lite::DDim(output_shape); return output_dims;
} }
} // namespace operators } // namespace operators
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册