未验证 提交 b42f3d49 编写于 作者: Y Yanzhan Yang 提交者: GitHub

only InferShape when input shape has been changed (#1748)

* only InferShape when input shape has been changed

* apply input_dim_has_changed_ criterion only when model has one input vector
上级 16a0bd75
...@@ -409,6 +409,11 @@ void Executor<Device, T>::SetInput(const Tensor &input, ...@@ -409,6 +409,11 @@ void Executor<Device, T>::SetInput(const Tensor &input,
target.Resize(input.dims()); target.Resize(input.dims());
target.ShareDataWith(input); target.ShareDataWith(input);
if (feed_indices_.size() == 1) {
auto &dim = input.dims();
input_dim_has_changed_ = input_dim_last_ != dim;
input_dim_last_ = static_cast<DDim>(dim);
}
} }
template <typename Device, typename T> template <typename Device, typename T>
...@@ -425,6 +430,11 @@ void Executor<Device, T>::SetInput(const LoDTensor &input, ...@@ -425,6 +430,11 @@ void Executor<Device, T>::SetInput(const LoDTensor &input,
target.Resize(input.dims()); target.Resize(input.dims());
target.ShareDataWith(input); target.ShareDataWith(input);
target.set_lod(input.lod()); target.set_lod(input.lod());
if (feed_indices_.size() == 1) {
auto &dim = input.dims();
input_dim_has_changed_ = input_dim_last_ != dim;
input_dim_last_ = static_cast<DDim>(dim);
}
} }
template <typename Device, typename T> template <typename Device, typename T>
...@@ -469,7 +479,7 @@ PMStatus Executor<Device, T>::Predict() { ...@@ -469,7 +479,7 @@ PMStatus Executor<Device, T>::Predict() {
profile[op_index].runBegin = (uint64_t)ts.tv_sec * 1e9 + ts.tv_nsec; profile[op_index].runBegin = (uint64_t)ts.tv_sec * 1e9 + ts.tv_nsec;
#endif #endif
DLOG << "run op: " << op_handler->Type(); DLOG << "run op: " << op_handler->Type();
if (lod_mode_) { if (lod_mode_ && input_dim_has_changed_) {
op_handler->InferShape(); op_handler->InferShape();
} }
op_handler->Run(); op_handler->Run();
...@@ -479,6 +489,9 @@ PMStatus Executor<Device, T>::Predict() { ...@@ -479,6 +489,9 @@ PMStatus Executor<Device, T>::Predict() {
++op_index; ++op_index;
#endif #endif
} }
if (feed_indices_.size() == 1) {
input_dim_has_changed_ = false;
}
#ifdef PADDLE_MOBILE_PROFILE #ifdef PADDLE_MOBILE_PROFILE
PrintProfile(profile); PrintProfile(profile);
...@@ -793,6 +806,9 @@ void Executor<GPU_CL, float>::SetInput(const Tensor &input, ...@@ -793,6 +806,9 @@ void Executor<GPU_CL, float>::SetInput(const Tensor &input,
DLOG << "SetInput ---- > ShareDataWith"; DLOG << "SetInput ---- > ShareDataWith";
} }
target_tensor->ShareDataWith(input); target_tensor->ShareDataWith(input);
if (feed_indices_.size() == 1) {
input_dim_has_changed_ = input_dim_last_ != input.dims();
}
auto &dim = input.dims(); auto &dim = input.dims();
input_dim_last_ = static_cast<DDim>(dim); input_dim_last_ = static_cast<DDim>(dim);
} }
......
...@@ -99,6 +99,7 @@ class Executor { ...@@ -99,6 +99,7 @@ class Executor {
// for super resoltion // for super resoltion
DDim input_dim_last_; DDim input_dim_last_;
bool input_dim_has_changed_ = true;
#ifdef PADDLE_MOBILE_PROFILE #ifdef PADDLE_MOBILE_PROFILE
typedef typename DtypeTensorTrait<Device>::gtype ProfileTensorType; typedef typename DtypeTensorTrait<Device>::gtype ProfileTensorType;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册