diff --git a/paddle/fluid/framework/operator.cc b/paddle/fluid/framework/operator.cc index 0dfac96bfee868ad395366f4f8dd95e2c7796eb5..1723a9a78a0da6e3eac7f823f79fe802a916e5b3 100644 --- a/paddle/fluid/framework/operator.cc +++ b/paddle/fluid/framework/operator.cc @@ -880,7 +880,16 @@ std::vector* OperatorWithKernel::GetKernelConfig( void OperatorWithKernel::RunImpl(const Scope& scope, const platform::Place& place) const { - if (!HasAttr(kEnableCacheRuntimeContext)) { + // To reduce the elapsed time of HasAttr, we use bool variable to record the + // result of HasAttr. + if (!enable_cache_runtime_context && HasAttr(kEnableCacheRuntimeContext)) + enable_cache_runtime_context = true; + if (!enable_cache_expected_kernel && HasAttr(kEnableCacheExpectedKernel)) + enable_cache_expected_kernel = true; + if (!all_kernels_must_compute_runtime_shape && + HasAttr(kAllKernelsMustComputeRuntimeShape)) + all_kernels_must_compute_runtime_shape = true; + if (!enable_cache_runtime_context) { RuntimeContext ctx(Inputs(), Outputs(), scope); RunImpl(scope, place, &ctx); } else { @@ -899,7 +908,7 @@ void OperatorWithKernel::RunImpl(const Scope& scope, platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); auto* dev_ctx = pool.Get(place); - if (!HasAttr(kEnableCacheExpectedKernel) || !kernel_type_) { + if (!enable_cache_expected_kernel || !kernel_type_) { ChooseKernel(*runtime_ctx, scope, place); } @@ -918,7 +927,7 @@ void OperatorWithKernel::RunImpl(const Scope& scope, dev_ctx = pool.Get(kernel_type_->place_); } - if (!HasAttr(kAllKernelsMustComputeRuntimeShape)) { + if (!all_kernels_must_compute_runtime_shape) { RuntimeInferShapeContext infer_shape_ctx(*this, exec_scope, *runtime_ctx); this->InferShape(&infer_shape_ctx); } diff --git a/paddle/fluid/framework/operator.h b/paddle/fluid/framework/operator.h index 8c5649deaa8c2c0ed1e976a8453730541adbdb88..489b66099658d522fe1f1adaad763b66bdd22c91 100644 --- a/paddle/fluid/framework/operator.h +++ b/paddle/fluid/framework/operator.h @@ -506,6 +506,9 @@ class OperatorWithKernel : public OperatorBase { mutable std::unique_ptr kernel_func_; mutable std::unique_ptr runtime_ctx_; mutable const Scope* pre_scope_ = nullptr; + mutable bool enable_cache_runtime_context = false; + mutable bool enable_cache_expected_kernel = false; + mutable bool all_kernels_must_compute_runtime_shape = false; }; extern bool OpSupportGPU(const std::string& op_type);