提交 6468cdeb 编写于 作者: L Liu Yiqun

Optimize the InferShape of reshape and reshape2.

上级 b7f6ed3d
......@@ -102,6 +102,15 @@ class DDimLite {
DDimLite Slice(int start, int end) const;
bool CheckPositive() const {
for (size_t i = 0; i < size(); ++i) {
if (data_[i] <= 0) {
return false;
}
}
return true;
}
DDimLite Flatten2D(int col) const {
return DDimLite(std::vector<value_type>(
{Slice(0, col).production(), Slice(col, size()).production()}));
......
......@@ -27,14 +27,15 @@ bool ReshapeOp::CheckShape() const {
}
bool ReshapeOp::InferShape() const {
auto shape_tensor_vct = param_.shape_tensor_vct;
auto &shape_tensor_vct = param_.shape_tensor_vct;
auto *shape_tensor = param_.shape_tensor;
auto shape_vct = param_.shape_vct;
auto &shape_vct = param_.shape_vct;
std::vector<int> final_shape;
if (shape_tensor_vct.size() > 0) {
final_shape.resize(shape_tensor_vct.size());
for (int i = 0; i < shape_tensor_vct.size(); i++) {
final_shape.push_back(shape_tensor_vct[i]->data<int>()[0]);
final_shape[i] = shape_tensor_vct[i]->data<int>()[0];
}
} else if (shape_tensor != nullptr) {
auto *shape_tensor_data = shape_tensor->data<int>();
......@@ -46,7 +47,7 @@ bool ReshapeOp::InferShape() const {
LOG(FATAL) << "input shape error";
}
auto x_dims = param_.x->dims();
auto &x_dims = param_.x->dims();
auto output_dims = ValidateShape(final_shape, x_dims);
param_.output->Resize(output_dims);
auto out_lod = param_.output->mutable_lod();
......@@ -98,8 +99,9 @@ bool Reshape2Op::CheckShape() const {
bool Reshape2Op::InferShape() const {
ReshapeOp::InferShape();
auto x_dims = param_.x->dims();
std::vector<DDim::value_type> xshape_dims(x_dims.size() + 1, 0);
auto &x_dims = param_.x->dims();
DDim xshape_dims(x_dims.size() + 1);
xshape_dims[0] = 0;
for (size_t i = 0; i < x_dims.size(); i++) {
xshape_dims[i + 1] = x_dims[i];
}
......@@ -117,19 +119,15 @@ bool Reshape2Op::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) {
}
DDim ValidateShape(const std::vector<int> &shape, const DDim &input_dims) {
const lite::DDim::value_type input_size = input_dims.production();
auto input_shape = input_dims.Vectorize();
bool all_positive = std::all_of(
input_shape.cbegin(), input_shape.cend(), [](lite::DDim::value_type i) {
return i > 0;
});
// only one dimension can be set to -1, whose size will be automatically
const DDim::value_type input_size = input_dims.production();
// Only one dimension can be set to -1, whose size will be automatically
// infered.
const int unk_dim_val = -1;
const int copy_dim_val = 0;
std::vector<lite::DDim::value_type> output_shape(shape.size(), 0);
lite::DDim::value_type capacity = 1;
DDim output_dims(shape.size());
DDim::value_type capacity = 1;
int unk_dim_idx = -1;
for (size_t i = 0; i < shape.size(); ++i) {
if (shape[i] == unk_dim_val) {
......@@ -137,7 +135,7 @@ DDim ValidateShape(const std::vector<int> &shape, const DDim &input_dims) {
<< "Only one input dimension of Attr(shape) can be unknown.";
unk_dim_idx = i;
} else if (shape[i] == copy_dim_val) {
CHECK_LT(static_cast<int>(i), input_shape.size())
CHECK_LT(static_cast<int>(i), input_dims.size())
<< "The index of dimension to copy from input shape must be less "
"than the size of input shape.";
} else {
......@@ -145,28 +143,28 @@ DDim ValidateShape(const std::vector<int> &shape, const DDim &input_dims) {
"be negtive except one unknown dimension.";
}
capacity *= (shape[i] ? static_cast<lite::DDim::value_type>(shape[i])
: input_shape[i]);
output_shape[i] = (shape[i] ? static_cast<lite::DDim::value_type>(shape[i])
: input_shape[i]);
DDim::value_type output_dim_i =
shape[i] ? static_cast<DDim::value_type>(shape[i]) : input_dims[i];
output_dims[i] = output_dim_i;
capacity *= output_dim_i;
}
if (unk_dim_idx != -1) {
if (all_positive) {
if (input_dims.CheckPositive()) {
// input_size < 0 and is un-determinate in compile time, skip the check,
// for example, input_dims = [-1, 8, 1, 1], shape = [-1, 3, 8],
// capacity = -24, input_size = -8, output_shape[0] = 0
// capacity = -24, input_size = -8, output_dims[0] = 0
// the following check will fail.
output_shape[unk_dim_idx] = -input_size / capacity;
CHECK_EQ(output_shape[unk_dim_idx] * capacity, -input_size)
output_dims[unk_dim_idx] = -input_size / capacity;
CHECK_EQ(output_dims[unk_dim_idx] * capacity, -input_size)
<< "Invalid shape is given.";
} else {
output_shape[unk_dim_idx] = -1;
output_dims[unk_dim_idx] = -1;
}
} else {
CHECK_EQ(capacity, input_size) << "Invalid shape is given.";
}
return lite::DDim(output_shape);
return output_dims;
}
} // namespace operators
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册