From 83fa6f827c4e00bbf5450f0dd2b13202603aef06 Mon Sep 17 00:00:00 2001 From: Yanzhan Yang Date: Wed, 17 Jul 2019 16:34:42 +0800 Subject: [PATCH] only InferShape when input shape has been changed (#1748) * only InferShape when input shape has been changed * apply input_dim_has_changed_ criterion only when model has one input vector --- src/framework/executor.cpp | 18 +++++++++++++++++- src/framework/executor.h | 1 + 2 files changed, 18 insertions(+), 1 deletion(-) diff --git a/src/framework/executor.cpp b/src/framework/executor.cpp index c06abc8416..bf9ec6a87f 100644 --- a/src/framework/executor.cpp +++ b/src/framework/executor.cpp @@ -409,6 +409,11 @@ void Executor::SetInput(const Tensor &input, target.Resize(input.dims()); target.ShareDataWith(input); + if (feed_indices_.size() == 1) { + auto &dim = input.dims(); + input_dim_has_changed_ = input_dim_last_ != dim; + input_dim_last_ = static_cast(dim); + } } template @@ -425,6 +430,11 @@ void Executor::SetInput(const LoDTensor &input, target.Resize(input.dims()); target.ShareDataWith(input); target.set_lod(input.lod()); + if (feed_indices_.size() == 1) { + auto &dim = input.dims(); + input_dim_has_changed_ = input_dim_last_ != dim; + input_dim_last_ = static_cast(dim); + } } template @@ -469,7 +479,7 @@ PMStatus Executor::Predict() { profile[op_index].runBegin = (uint64_t)ts.tv_sec * 1e9 + ts.tv_nsec; #endif DLOG << "run op: " << op_handler->Type(); - if (lod_mode_) { + if (lod_mode_ && input_dim_has_changed_) { op_handler->InferShape(); } op_handler->Run(); @@ -479,6 +489,9 @@ PMStatus Executor::Predict() { ++op_index; #endif } + if (feed_indices_.size() == 1) { + input_dim_has_changed_ = false; + } #ifdef PADDLE_MOBILE_PROFILE PrintProfile(profile); @@ -793,6 +806,9 @@ void Executor::SetInput(const Tensor &input, DLOG << "SetInput ---- > ShareDataWith"; } target_tensor->ShareDataWith(input); + if (feed_indices_.size() == 1) { + input_dim_has_changed_ = input_dim_last_ != input.dims(); + } auto &dim = input.dims(); input_dim_last_ = static_cast(dim); } diff --git a/src/framework/executor.h b/src/framework/executor.h index c2d096182d..66af8d3bda 100644 --- a/src/framework/executor.h +++ b/src/framework/executor.h @@ -99,6 +99,7 @@ class Executor { // for super resoltion DDim input_dim_last_; + bool input_dim_has_changed_ = true; #ifdef PADDLE_MOBILE_PROFILE typedef typename DtypeTensorTrait::gtype ProfileTensorType; -- GitLab