From 800f5ce6042fb55fafba206a59854bd7c81e324f Mon Sep 17 00:00:00 2001 From: Yiqun Liu Date: Tue, 11 Feb 2020 15:59:44 +0800 Subject: [PATCH] Optimize the InferShape of several operators. (#2839) * Optimize the InferShape of several operators. test=develop * Remove the new function, resize and CheckPositive in DDim. test=develop * Fix a bug in fc_op's InferShape. test=develop --- lite/core/tensor.cc | 22 +++++----- lite/core/tensor.h | 8 +++- lite/kernels/x86/fc_compute.h | 34 +++------------ lite/kernels/x86/reshape_compute.h | 3 +- lite/operators/concat_op.cc | 18 ++++---- lite/operators/fc_op.cc | 16 ++++--- lite/operators/gru_op.cc | 18 ++++---- lite/operators/lookup_table_op.cc | 23 ++++------ lite/operators/reduce_ops.cc | 32 ++++++++------ lite/operators/reshape_op.cc | 62 +++++++++++++++------------ lite/operators/reshape_op.h | 3 +- lite/tests/kernels/fc_compute_test.cc | 1 - 12 files changed, 115 insertions(+), 125 deletions(-) diff --git a/lite/core/tensor.cc b/lite/core/tensor.cc index ecfdcf3d11..38a6be6767 100644 --- a/lite/core/tensor.cc +++ b/lite/core/tensor.cc @@ -25,21 +25,17 @@ using value_type = int64_t; value_type DDimLite::production() const { value_type res = 1; - for (size_t i = 0; i < this->size(); i++) { - res *= (*this)[i]; + for (size_t i = 0; i < data_.size(); i++) { + res *= data_[i]; } return res; } value_type DDimLite::count(int start, int end) const { - if (start < 0) { - start = 0; - } - if (end > size()) { - end = size(); - } + start = std::max(start, 0); + end = std::min(end, static_cast(data_.size())); if (end < start) { - end = start; + return 0; } value_type sum = 1; for (auto i = start; i < end; ++i) { @@ -49,11 +45,13 @@ value_type DDimLite::count(int start, int end) const { } DDimLite DDimLite::Slice(int start, int end) const { - std::vector vec; + start = std::max(start, 0); + end = std::min(end, static_cast(data_.size())); + std::vector new_dim(end - start); for (int i = start; i < end; i++) { - vec.push_back((*this)[i]); + new_dim[i - start] = data_[i]; } - return DDimLite(vec); + return DDim(new_dim); } std::string DDimLite::repr() const { diff --git a/lite/core/tensor.h b/lite/core/tensor.h index 3e334048fa..04e540002b 100644 --- a/lite/core/tensor.h +++ b/lite/core/tensor.h @@ -85,7 +85,11 @@ class DDimLite { } friend bool operator!=(const DDimLite &a, const DDimLite &b) { - return !(a == b); + if (a.size() != b.size()) return true; + for (size_t i = 0; i < a.size(); i++) { + if (a[i] != b[i]) return true; + } + return false; } private: @@ -118,7 +122,7 @@ class TensorLite { } void Resize(const DDimLite &ddim) { dims_ = ddim; } - void Resize(const std::vector &x) { dims_ = DDimLite(x); } + void Resize(const std::vector &x) { dims_.ConstructFrom(x); } const DDimLite &dims() const { return dims_; } int64_t numel() const { return dims_.production(); } diff --git a/lite/kernels/x86/fc_compute.h b/lite/kernels/x86/fc_compute.h index 971f5dfa2f..e719b8d221 100644 --- a/lite/kernels/x86/fc_compute.h +++ b/lite/kernels/x86/fc_compute.h @@ -31,20 +31,6 @@ namespace lite { namespace kernels { namespace x86 { -inline void FCOutputSize(const lite::DDim& in_dims, - const lite::DDim& w_dims, - std::vector& out_dims, // NOLINT - int in_num_col_dims, - bool padding_weights) { - auto w_dims1 = padding_weights ? w_dims[1] - 4 : w_dims[1]; - - out_dims.reserve(static_cast(in_num_col_dims + 1)); - for (int i = 0; i < in_num_col_dims; ++i) { - out_dims.push_back(in_dims[i]); - } - out_dims.push_back(w_dims1); -} - template class FCFunctor { public: @@ -84,11 +70,11 @@ class FCFunctor { // NOTE: here need to mutable_data for temporary Tensor X1 and Y1, // the overhead is unmeasured. lite::Tensor X1; - X1.Resize({M * KK}); + X1.Resize(std::vector{M * KK}); T* X1_data = X1.mutable_data(); lite::Tensor Y1; - Y1.Resize({M * (N + 4)}); + Y1.Resize(std::vector{M * NN}); Y1_data = Y1.mutable_data(); auto parallel_memcpy_x = [&](int64_t begin, int64_t end) { @@ -115,7 +101,7 @@ class FCFunctor { if (!B) { auto parallel_memcpy_y = [&](int64_t begin, int64_t end) { for (int64_t i = begin; i < end; i++) { - memcpy(Y + i * N, Y1_data + i * (N + 4), N * sizeof(T)); + memcpy(Y + i * N, Y1_data + i * NN, N * sizeof(T)); } }; lite::x86::RunParallelFor(0, M, parallel_memcpy_y); @@ -145,22 +131,14 @@ class FcCompute : public KernelLite { auto* w = param.w; auto* bias = param.bias; auto* output = param.output; - int in_num_col_dims = param.in_num_col_dims; bool with_relu = (param.activation_type == "relu") ? true : false; - auto w_dims = w->dims(); bool padding_weights = param.padding_weights; - - std::vector output_dims; - FCOutputSize( - input->dims(), w_dims, output_dims, in_num_col_dims, padding_weights); - output->Resize(output_dims); - output->set_lod(input->lod()); - - auto out_dims = output->dims(); + const auto& w_dims = w->dims(); auto w_dims0 = padding_weights ? w_dims[0] - 4 : w_dims[0]; auto w_dims1 = padding_weights ? w_dims[1] - 4 : w_dims[1]; - int M = out_dims.production() / w_dims1; + + int M = output->dims().production() / w_dims1; const T* input_data = input->data(); const T* w_data = w->data(); diff --git a/lite/kernels/x86/reshape_compute.h b/lite/kernels/x86/reshape_compute.h index 948c4ec31d..b06eb6eb67 100644 --- a/lite/kernels/x86/reshape_compute.h +++ b/lite/kernels/x86/reshape_compute.h @@ -28,8 +28,9 @@ namespace x86 { template void Compute(const lite::Tensor* in, lite::Tensor* out) { + // In CopyDataFrom, the target tensor's dims will be set to the source + // tensor's dims. auto out_dims = out->dims(); - auto in_dims = in->dims(); out->CopyDataFrom(*in); out->Resize(out_dims); } diff --git a/lite/operators/concat_op.cc b/lite/operators/concat_op.cc index 1941a88bbf..b2f7438b64 100644 --- a/lite/operators/concat_op.cc +++ b/lite/operators/concat_op.cc @@ -27,11 +27,8 @@ bool ConcatOpLite::CheckShape() const { } bool ConcatOpLite::InferShape() const { - std::vector input_dims; - for (auto p : param_.x) { - input_dims.push_back(p->dims()); - } - const size_t n = input_dims.size(); + const std::vector &inputs = param_.x; + const size_t n = inputs.size(); CHECK_GT_OR_FALSE(n, 0); int axis = 0; @@ -42,17 +39,18 @@ bool ConcatOpLite::InferShape() const { axis = axis_tensor_val[0]; } if (axis < 0) { - axis += input_dims[0].size(); + axis += inputs[0]->dims().size(); } - auto &out_dims = input_dims[0]; + auto out_dims = inputs[0]->dims(); size_t in_zero_dims_size = out_dims.size(); for (size_t i = 1; i < n; i++) { + const auto &input_dims_i = inputs[i]->dims(); for (size_t j = 0; j < in_zero_dims_size; j++) { if (j == static_cast(axis)) { - out_dims[axis] += input_dims[i][j]; + out_dims[axis] += input_dims_i[j]; } else { - CHECK_EQ_OR_FALSE(out_dims[j], input_dims[i][j]); + CHECK_EQ_OR_FALSE(out_dims[j], input_dims_i[j]); } } } @@ -60,7 +58,7 @@ bool ConcatOpLite::InferShape() const { out_dims[axis] = -1; } // Set output dims - param_.output->Resize(lite::DDim(out_dims)); + param_.output->Resize(out_dims); auto out_lod = param_.output->mutable_lod(); *out_lod = param_.x[0]->lod(); return true; diff --git a/lite/operators/fc_op.cc b/lite/operators/fc_op.cc index 702950ae18..eff9300fea 100644 --- a/lite/operators/fc_op.cc +++ b/lite/operators/fc_op.cc @@ -49,23 +49,25 @@ bool FcOpLite::CheckShape() const { } bool FcOpLite::InferShape() const { - const auto input_dims = param_.input->dims(); - const auto w_dims = param_.w->dims(); + const auto& input_dims = param_.input->dims(); + const auto& w_dims = param_.w->dims(); + int in_num_col_dims = param_.in_num_col_dims; + int64_t w_dims_1 = param_.padding_weights ? w_dims[1] - 4 : w_dims[1]; // Set output dims - std::vector output_dims(param_.in_num_col_dims + 1, 0); - for (int i = 0; i < param_.in_num_col_dims; ++i) { + std::vector output_dims(in_num_col_dims + 1); + for (int i = 0; i < in_num_col_dims; ++i) { output_dims[i] = input_dims[i]; } - output_dims.back() = w_dims[1]; - param_.output->Resize(lite::DDim(output_dims)); + output_dims[in_num_col_dims] = w_dims_1; + param_.output->Resize(output_dims); // share LoD param_.output->set_lod(param_.input->lod()); return true; } -bool FcOpLite::AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) { +bool FcOpLite::AttachImpl(const cpp::OpDesc& op_desc, lite::Scope* scope) { auto input = op_desc.Input("Input").front(); auto W = op_desc.Input("W").front(); auto out = op_desc.Output("Out").front(); diff --git a/lite/operators/gru_op.cc b/lite/operators/gru_op.cc index 3ddeb5b734..eb97d65a1a 100644 --- a/lite/operators/gru_op.cc +++ b/lite/operators/gru_op.cc @@ -28,8 +28,8 @@ bool GRUOpLite::CheckShape() const { CHECK_OR_FALSE(param_.batch_hidden) CHECK_OR_FALSE(param_.hidden) - auto input_dims = param_.input->dims(); - auto weight_dims = param_.weight->dims(); + const auto& input_dims = param_.input->dims(); + const auto& weight_dims = param_.weight->dims(); int input_size = input_dims[1]; int frame_size = weight_dims[0]; CHECK_EQ_OR_FALSE(input_size, frame_size * 3) @@ -52,21 +52,23 @@ bool GRUOpLite::CheckShape() const { } bool GRUOpLite::InferShape() const { - auto input_dims = param_.input->dims(); - auto weight_dims = param_.weight->dims(); + const auto& input_dims = param_.input->dims(); + const auto& weight_dims = param_.weight->dims(); int frame_size = weight_dims[0]; auto batch_size = input_dims[0]; param_.batch_gate->Resize(input_dims); - param_.batch_reset_hidden_prev->Resize(lite::DDim({batch_size, frame_size})); - param_.batch_hidden->Resize(lite::DDim({batch_size, frame_size})); - param_.hidden->Resize(lite::DDim({batch_size, frame_size})); + + DDim out_dims({batch_size, frame_size}); + param_.batch_reset_hidden_prev->Resize(out_dims); + param_.batch_hidden->Resize(out_dims); + param_.hidden->Resize(out_dims); *(param_.hidden->mutable_lod()) = param_.input->lod(); return true; } -bool GRUOpLite::AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) { +bool GRUOpLite::AttachImpl(const cpp::OpDesc& op_desc, lite::Scope* scope) { auto input = op_desc.Input("Input").front(); auto weight = op_desc.Input("Weight").front(); auto batch_gate = op_desc.Output("BatchGate").front(); diff --git a/lite/operators/lookup_table_op.cc b/lite/operators/lookup_table_op.cc index 3d5a71cee9..931894d925 100644 --- a/lite/operators/lookup_table_op.cc +++ b/lite/operators/lookup_table_op.cc @@ -25,8 +25,8 @@ bool LookupTableOpLite::CheckShape() const { CHECK_OR_FALSE(param_.Ids) CHECK_OR_FALSE(param_.Out) - auto table_dims = param_.W->dims(); - auto ids_dims = param_.Ids->dims(); + const auto& table_dims = param_.W->dims(); + const auto& ids_dims = param_.Ids->dims(); int ids_rank = ids_dims.size(); @@ -37,25 +37,20 @@ bool LookupTableOpLite::CheckShape() const { } bool LookupTableOpLite::InferShape() const { - auto table_dims = param_.W->dims(); - auto ids_dims = param_.Ids->dims(); + const auto& table_dims = param_.W->dims(); + const auto& ids_dims = param_.Ids->dims(); + auto out_dims = ids_dims; int ids_rank = ids_dims.size(); + out_dims[ids_rank - 1] = table_dims[1]; - auto output_dims = ids_dims.Slice(0, ids_rank - 1); - - std::vector out_dims; - for (int i = 0; i < ids_rank - 1; ++i) { - out_dims.push_back(ids_dims[i]); - } - out_dims.push_back(table_dims[1]); - param_.Out->Resize(lite::DDim{out_dims}); + param_.Out->Resize(out_dims); param_.Out->set_lod(param_.Ids->lod()); return true; } -bool LookupTableOpLite::AttachImpl(const cpp::OpDesc &op_desc, - lite::Scope *scope) { +bool LookupTableOpLite::AttachImpl(const cpp::OpDesc& op_desc, + lite::Scope* scope) { auto input = op_desc.Input("W").front(); auto ids = op_desc.Input("Ids").front(); auto out = op_desc.Output("Out").front(); diff --git a/lite/operators/reduce_ops.cc b/lite/operators/reduce_ops.cc index e986b0ca54..3f0de17471 100644 --- a/lite/operators/reduce_ops.cc +++ b/lite/operators/reduce_ops.cc @@ -29,39 +29,43 @@ bool ReduceOp::CheckShape() const { } bool ReduceOp::InferShape() const { - auto x_dims = param_.x->dims(); + const auto &x_dims = param_.x->dims(); auto x_rank = x_dims.size(); auto dims = param_.dim; for (size_t i = 0; i < dims.size(); ++i) { - if (dims[i] < 0) dims[i] = x_rank + dims[i]; + if (dims[i] < 0) { + dims[i] = x_rank + dims[i]; + } CHECK_LT(dims[i], x_rank) << "The dim should be in the range [-rank(input), rank(input)."; } - sort(dims.begin(), dims.end()); bool reduce_all = param_.reduce_all; bool keep_dim = param_.keep_dim; if (reduce_all) { if (keep_dim) - param_.output->Resize(lite::DDim(std::vector(x_rank, 1))); + param_.output->Resize(std::vector(x_rank, 1)); else - param_.output->Resize(lite::DDim(std::vector{1})); + param_.output->Resize(std::vector{1}); } else { - auto dims_vector = x_dims.Vectorize(); + size_t out_rank = keep_dim ? x_rank : x_rank - dims.size(); + std::vector out_dims(out_rank); if (keep_dim) { for (size_t i = 0; i < dims.size(); ++i) { - dims_vector[dims[i]] = 1; + out_dims[dims[i]] = 1; } } else { - const int kDelFlag = -2; - for (size_t i = 0; i < dims.size(); ++i) { - dims_vector[dims[i]] = kDelFlag; + sort(dims.begin(), dims.end()); + int dim_index = 0; + int out_index = 0; + for (size_t i = 0; i < x_rank; ++i) { + if (dims[dim_index] == static_cast(i)) { + dim_index++; + } else { + out_dims[out_index++] = x_dims[i]; + } } - dims_vector.erase( - remove(dims_vector.begin(), dims_vector.end(), kDelFlag), - dims_vector.end()); } - auto out_dims = lite::DDim(dims_vector); param_.output->Resize(out_dims); if (dims[0] != 0) { param_.output->set_lod(param_.x->lod()); diff --git a/lite/operators/reshape_op.cc b/lite/operators/reshape_op.cc index 35f3814859..655ac58bdc 100644 --- a/lite/operators/reshape_op.cc +++ b/lite/operators/reshape_op.cc @@ -27,14 +27,15 @@ bool ReshapeOp::CheckShape() const { } bool ReshapeOp::InferShape() const { - auto shape_tensor_vct = param_.shape_tensor_vct; + const auto &shape_tensor_vct = param_.shape_tensor_vct; auto *shape_tensor = param_.shape_tensor; - auto shape_vct = param_.shape_vct; - std::vector final_shape; + const auto &shape_vct = param_.shape_vct; + std::vector final_shape; if (shape_tensor_vct.size() > 0) { - for (int i = 0; i < shape_tensor_vct.size(); i++) { - final_shape.push_back(shape_tensor_vct[i]->data()[0]); + final_shape.resize(shape_tensor_vct.size()); + for (size_t i = 0; i < shape_tensor_vct.size(); i++) { + final_shape[i] = shape_tensor_vct[i]->data()[0]; } } else if (shape_tensor != nullptr) { auto *shape_tensor_data = shape_tensor->data(); @@ -46,7 +47,7 @@ bool ReshapeOp::InferShape() const { LOG(FATAL) << "input shape error"; } - auto x_dims = param_.x->dims(); + const auto &x_dims = param_.x->dims(); auto output_dims = ValidateShape(final_shape, x_dims); param_.output->Resize(output_dims); auto out_lod = param_.output->mutable_lod(); @@ -98,8 +99,9 @@ bool Reshape2Op::CheckShape() const { bool Reshape2Op::InferShape() const { ReshapeOp::InferShape(); - auto x_dims = param_.x->dims(); - std::vector xshape_dims(x_dims.size() + 1, 0); + const auto &x_dims = param_.x->dims(); + std::vector xshape_dims(x_dims.size() + 1); + xshape_dims[0] = 0; for (size_t i = 0; i < x_dims.size(); i++) { xshape_dims[i + 1] = x_dims[i]; } @@ -116,20 +118,26 @@ bool Reshape2Op::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) { return true; } -DDim ValidateShape(const std::vector &shape, const DDim &input_dims) { - const lite::DDim::value_type input_size = input_dims.production(); - auto input_shape = input_dims.Vectorize(); - bool all_positive = std::all_of( - input_shape.cbegin(), input_shape.cend(), [](lite::DDim::value_type i) { - return i > 0; - }); +static bool CheckPositive(const DDim &dims) { + for (size_t i = 0; i < dims.size(); ++i) { + if (dims[i] <= 0) { + return false; + } + } + return true; +} + +std::vector ValidateShape(const std::vector &shape, + const DDim &input_dims) { + const DDim::value_type input_size = input_dims.production(); + // only one dimension can be set to -1, whose size will be automatically // infered. const int unk_dim_val = -1; const int copy_dim_val = 0; - std::vector output_shape(shape.size(), 0); - lite::DDim::value_type capacity = 1; + std::vector output_dims(shape.size()); + DDim::value_type capacity = 1; int unk_dim_idx = -1; for (size_t i = 0; i < shape.size(); ++i) { if (shape[i] == unk_dim_val) { @@ -137,7 +145,7 @@ DDim ValidateShape(const std::vector &shape, const DDim &input_dims) { << "Only one input dimension of Attr(shape) can be unknown."; unk_dim_idx = i; } else if (shape[i] == copy_dim_val) { - CHECK_LT(static_cast(i), input_shape.size()) + CHECK_LT(static_cast(i), input_dims.size()) << "The index of dimension to copy from input shape must be less " "than the size of input shape."; } else { @@ -145,28 +153,28 @@ DDim ValidateShape(const std::vector &shape, const DDim &input_dims) { "be negtive except one unknown dimension."; } - capacity *= (shape[i] ? static_cast(shape[i]) - : input_shape[i]); - output_shape[i] = (shape[i] ? static_cast(shape[i]) - : input_shape[i]); + DDim::value_type output_dim_i = + shape[i] ? static_cast(shape[i]) : input_dims[i]; + output_dims[i] = output_dim_i; + capacity *= output_dim_i; } if (unk_dim_idx != -1) { - if (all_positive) { + if (CheckPositive(input_dims)) { // input_size < 0 and is un-determinate in compile time, skip the check, // for example, input_dims = [-1, 8, 1, 1], shape = [-1, 3, 8], // capacity = -24, input_size = -8, output_shape[0] = 0 // the following check will fail. - output_shape[unk_dim_idx] = -input_size / capacity; - CHECK_EQ(output_shape[unk_dim_idx] * capacity, -input_size) + output_dims[unk_dim_idx] = -input_size / capacity; + CHECK_EQ(output_dims[unk_dim_idx] * capacity, -input_size) << "Invalid shape is given."; } else { - output_shape[unk_dim_idx] = -1; + output_dims[unk_dim_idx] = -1; } } else { CHECK_EQ(capacity, input_size) << "Invalid shape is given."; } - return lite::DDim(output_shape); + return output_dims; } } // namespace operators diff --git a/lite/operators/reshape_op.h b/lite/operators/reshape_op.h index bd31f7f73f..1df49fb5f4 100644 --- a/lite/operators/reshape_op.h +++ b/lite/operators/reshape_op.h @@ -56,7 +56,8 @@ class Reshape2Op : public ReshapeOp { std::string DebugString() const override { return "reshape2"; } }; -DDim ValidateShape(const std::vector &shape, const DDim &input_dims); +std::vector ValidateShape(const std::vector &shape, + const DDim &input_dims); } // namespace operators } // namespace lite diff --git a/lite/tests/kernels/fc_compute_test.cc b/lite/tests/kernels/fc_compute_test.cc index bd6d86b599..1d5adaa6cc 100644 --- a/lite/tests/kernels/fc_compute_test.cc +++ b/lite/tests/kernels/fc_compute_test.cc @@ -47,7 +47,6 @@ void Relu(float* out, int num, int channel) { DDim ComputeOutDim(const DDim& dim_in, const DDim& wdim, int in_num_col_dim) { std::vector out_dim; out_dim.resize(in_num_col_dim + 1); - auto in_mat_dims = dim_in.Flatten2D(in_num_col_dim); for (int i = 0; i < in_num_col_dim; ++i) { out_dim[i] = dim_in[i]; } -- GitLab