diff --git a/paddle/fluid/framework/operator.cc b/paddle/fluid/framework/operator.cc index 6905e6d80d461ea1f345dd2a30c388c3b321973f..f4440e44124df33f7fea1e0682e5f4974b8629c4 100644 --- a/paddle/fluid/framework/operator.cc +++ b/paddle/fluid/framework/operator.cc @@ -924,11 +924,12 @@ void OperatorWithKernel::RunImpl(const Scope& scope, if (!all_kernels_must_compute_runtime_shape_ && HasAttr(kAllKernelsMustComputeRuntimeShape)) all_kernels_must_compute_runtime_shape_ = true; + const Scope* cur_scope = &scope; if (!enable_cache_runtime_context_) { RuntimeContext ctx(Inputs(), Outputs(), scope); RunImpl(scope, place, &ctx); + pre_scope_ = cur_scope; } else { - const Scope* cur_scope = &scope; if (runtime_ctx_.get() == nullptr || pre_scope_ != cur_scope) { std::lock_guard lock(cache_update_mutex_); if (runtime_ctx_.get() == nullptr || pre_scope_ != cur_scope) { @@ -958,8 +959,10 @@ void OperatorWithKernel::RunImpl(const Scope& scope, { platform::RecordEvent record_event("prepare_data", platform::EventRole::kInnerOp); - transfer_scope = PrepareData(scope, *kernel_type_, &transfered_inplace_vars, - runtime_ctx); + if (need_prepare_data_) { + transfer_scope = PrepareData(scope, *kernel_type_, + &transfered_inplace_vars, runtime_ctx); + } } // exec scope is the scope that kernel actually executed on. const Scope& exec_scope = @@ -1252,6 +1255,15 @@ Scope* OperatorWithKernel::PrepareData( SetTensorToVariable(*var, out, trans_var); } } + // If pre_scope = &scope, it means that scope is cached and the op is not in + // while block. If new_scope = nullptr, it means that for each input of this + // Op, there is no need to do PrepareData. So PrepareData could be skipped at + // the rest iterations to save the elapsed time. + // We do not support skipping PrepareData in while block, because the Op's + // input may be changed by subsequent Ops, which may cause an error. + if (pre_scope_ == &scope && new_scope == nullptr) { + need_prepare_data_ = false; + } return new_scope; } diff --git a/paddle/fluid/framework/operator.h b/paddle/fluid/framework/operator.h index 7fdc3b033eea6b98f47825cd78ffd92a0486e454..b58ad71b8da424f382ed92c7f29b1b6a959c1328 100644 --- a/paddle/fluid/framework/operator.h +++ b/paddle/fluid/framework/operator.h @@ -185,6 +185,7 @@ class OperatorBase { virtual std::vector OutputVars(bool has_intermediate) const; void SetIsCalledByExecutor(bool x) { run_by_executor_ = x; } + virtual void RuntimeInferShape(const Scope& scope, const platform::Place& place, const RuntimeContext& ctx) const {} @@ -541,6 +542,7 @@ class OperatorWithKernel : public OperatorBase { mutable std::unique_ptr kernel_func_; mutable std::unique_ptr runtime_ctx_; mutable const Scope* pre_scope_ = nullptr; + mutable bool need_prepare_data_ = true; mutable bool enable_cache_runtime_context_ = false; mutable bool all_kernels_must_compute_runtime_shape_ = false; mutable std::mutex cache_update_mutex_;