From ddc575654f3a6d727bc833dd83ef0486583e1497 Mon Sep 17 00:00:00 2001 From: huzhiqiang <912790387@qq.com> Date: Tue, 12 May 2020 10:47:00 +0800 Subject: [PATCH] [Framwork][InferShape]Improve InferShape period (#3601) --- lite/core/op_lite.cc | 33 +++++++++++++++++---------------- lite/core/op_lite.h | 6 +++++- 2 files changed, 22 insertions(+), 17 deletions(-) diff --git a/lite/core/op_lite.cc b/lite/core/op_lite.cc index 7428a16d0a..537636065d 100644 --- a/lite/core/op_lite.cc +++ b/lite/core/op_lite.cc @@ -36,25 +36,21 @@ bool OpLite::InferShapeWithCache() { // 1. Get vector of current input tensors auto *current_inputs = op_param_->input_tensor_ptrs(); // 2. Get hash value of current inputs shape and lod - size_t new_hash = 0; - for (auto iter = current_inputs->begin(); iter != current_inputs->end(); - iter++) { - // combined dims value into new_hash value. - auto &element_dims = (*iter)->dims(); - for (size_t i = 0; i < element_dims.size(); i++) { - lite::CombineHash(static_cast(element_dims[i]), &new_hash); - } - // combine lod value into new_hash valud. - auto &emement_lods = (*iter)->lod(); - for (auto lod_iter = emement_lods.begin(); lod_iter != emement_lods.end(); - lod_iter++) { - for (size_t i = 0; i < lod_iter->size(); i++) { - lite::CombineHash(static_cast(lod_iter->at(i)), &new_hash); + bool use_cache = true; + if (last_input_shapes.size() == current_inputs->size()) { + for (int i = 0; i < current_inputs->size(); i++) { + if (last_input_shapes[i] != current_inputs->at(i)->dims() || + last_input_lods[i] != current_inputs->at(i)->lod()) { + use_cache = false; + break; } } + } else { + use_cache = false; } + // 3. infer shapes of output tensors - if (new_hash == io_shape_lod_hash_ && new_hash != 0) { + if (use_cache) { // if current hash value is consistent with io_shape_lod_hash_, // previous outputs shape and lod are reused. auto *current_outputs = op_param_->output_tensor_ptrs(); @@ -64,7 +60,6 @@ bool OpLite::InferShapeWithCache() { } } else { // otherwise, current hash value is changed, InferShapeImpl will apply. - io_shape_lod_hash_ = new_hash; this->InferShapeImpl(); auto *current_outputs = op_param_->output_tensor_ptrs(); last_output_shapes.clear(); @@ -73,6 +68,12 @@ bool OpLite::InferShapeWithCache() { last_output_shapes.push_back(current_outputs->at(i)->dims()); last_output_lods.push_back(current_outputs->at(i)->lod()); } + last_input_shapes.clear(); + last_input_lods.clear(); + for (size_t i = 0; i < current_inputs->size(); i++) { + last_input_shapes.push_back(current_inputs->at(i)->dims()); + last_input_lods.push_back(current_inputs->at(i)->lod()); + } } return true; } diff --git a/lite/core/op_lite.h b/lite/core/op_lite.h index 656f992b17..7fb74a3ca3 100644 --- a/lite/core/op_lite.h +++ b/lite/core/op_lite.h @@ -172,9 +172,13 @@ class OpLite : public Registry { std::vector valid_places_; Place kernel_place_{TARGET(kHost), PRECISION(kFloat)}; std::unique_ptr op_info_; + // todo: it's prefered to combine last_input_shapes and + // last_input_lods into a single hash value to decrease + // memory usage. + std::vector last_input_shapes{}; + std::vector>> last_input_lods{}; std::vector last_output_shapes{}; std::vector>> last_output_lods{}; - size_t io_shape_lod_hash_{}; mutable operators::ParamBase *op_param_{nullptr}; private: -- GitLab