提交 ddc57565 编写于 作者: H huzhiqiang 提交者: GitHub

[Framwork][InferShape]Improve InferShape period (#3601)

上级 72f97cae
......@@ -36,25 +36,21 @@ bool OpLite::InferShapeWithCache() {
// 1. Get vector of current input tensors
auto *current_inputs = op_param_->input_tensor_ptrs();
// 2. Get hash value of current inputs shape and lod
size_t new_hash = 0;
for (auto iter = current_inputs->begin(); iter != current_inputs->end();
iter++) {
// combined dims value into new_hash value.
auto &element_dims = (*iter)->dims();
for (size_t i = 0; i < element_dims.size(); i++) {
lite::CombineHash(static_cast<int64_t>(element_dims[i]), &new_hash);
}
// combine lod value into new_hash valud.
auto &emement_lods = (*iter)->lod();
for (auto lod_iter = emement_lods.begin(); lod_iter != emement_lods.end();
lod_iter++) {
for (size_t i = 0; i < lod_iter->size(); i++) {
lite::CombineHash(static_cast<int64_t>(lod_iter->at(i)), &new_hash);
bool use_cache = true;
if (last_input_shapes.size() == current_inputs->size()) {
for (int i = 0; i < current_inputs->size(); i++) {
if (last_input_shapes[i] != current_inputs->at(i)->dims() ||
last_input_lods[i] != current_inputs->at(i)->lod()) {
use_cache = false;
break;
}
}
} else {
use_cache = false;
}
// 3. infer shapes of output tensors
if (new_hash == io_shape_lod_hash_ && new_hash != 0) {
if (use_cache) {
// if current hash value is consistent with io_shape_lod_hash_,
// previous outputs shape and lod are reused.
auto *current_outputs = op_param_->output_tensor_ptrs();
......@@ -64,7 +60,6 @@ bool OpLite::InferShapeWithCache() {
}
} else {
// otherwise, current hash value is changed, InferShapeImpl will apply.
io_shape_lod_hash_ = new_hash;
this->InferShapeImpl();
auto *current_outputs = op_param_->output_tensor_ptrs();
last_output_shapes.clear();
......@@ -73,6 +68,12 @@ bool OpLite::InferShapeWithCache() {
last_output_shapes.push_back(current_outputs->at(i)->dims());
last_output_lods.push_back(current_outputs->at(i)->lod());
}
last_input_shapes.clear();
last_input_lods.clear();
for (size_t i = 0; i < current_inputs->size(); i++) {
last_input_shapes.push_back(current_inputs->at(i)->dims());
last_input_lods.push_back(current_inputs->at(i)->lod());
}
}
return true;
}
......
......@@ -172,9 +172,13 @@ class OpLite : public Registry {
std::vector<Place> valid_places_;
Place kernel_place_{TARGET(kHost), PRECISION(kFloat)};
std::unique_ptr<OpInfo> op_info_;
// todo: it's prefered to combine last_input_shapes and
// last_input_lods into a single hash value to decrease
// memory usage.
std::vector<DDimLite> last_input_shapes{};
std::vector<std::vector<std::vector<uint64_t>>> last_input_lods{};
std::vector<DDimLite> last_output_shapes{};
std::vector<std::vector<std::vector<uint64_t>>> last_output_lods{};
size_t io_shape_lod_hash_{};
mutable operators::ParamBase *op_param_{nullptr};
private:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册