diff --git a/paddle/fluid/framework/operator.cc b/paddle/fluid/framework/operator.cc index 7756fe4f8f2289c6523f3c5c788c204b4ad7c6be..fa6de326bc111c7872e707d1000554e41bcc8775 100644 --- a/paddle/fluid/framework/operator.cc +++ b/paddle/fluid/framework/operator.cc @@ -884,8 +884,6 @@ void OperatorWithKernel::RunImpl(const Scope& scope, // result of HasAttr. if (!enable_cache_runtime_context && HasAttr(kEnableCacheRuntimeContext)) enable_cache_runtime_context = true; - if (!enable_cache_expected_kernel && HasAttr(kEnableCacheExpectedKernel)) - enable_cache_expected_kernel = true; if (!all_kernels_must_compute_runtime_shape && HasAttr(kAllKernelsMustComputeRuntimeShape)) all_kernels_must_compute_runtime_shape = true; @@ -894,9 +892,12 @@ void OperatorWithKernel::RunImpl(const Scope& scope, RunImpl(scope, place, &ctx); } else { const Scope* cur_scope = &scope; - if (!runtime_ctx_ || pre_scope_ != cur_scope) { - runtime_ctx_.reset(new RuntimeContext(Inputs(), Outputs(), scope)); - pre_scope_ = cur_scope; + if (runtime_ctx_.get() == nullptr || pre_scope_ != cur_scope) { + std::lock_guard lock(cache_update_mutex_); + if (runtime_ctx_.get() == nullptr || pre_scope_ != cur_scope) { + runtime_ctx_.reset(new RuntimeContext(Inputs(), Outputs(), scope)); + pre_scope_ = cur_scope; + } } RunImpl(scope, place, runtime_ctx_.get()); } @@ -908,7 +909,7 @@ void OperatorWithKernel::RunImpl(const Scope& scope, platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); auto* dev_ctx = pool.Get(place); - if (!enable_cache_expected_kernel || !kernel_type_) { + if (kernel_type_.get() == nullptr || kernel_func_.get() == nullptr) { ChooseKernel(*runtime_ctx, scope, place); } @@ -996,8 +997,11 @@ void OperatorWithKernel::ChooseKernel(const RuntimeContext& ctx, KernelTypeToString(expected_kernel_key)); } - kernel_type_.reset(new OpKernelType(expected_kernel_key)); - kernel_func_.reset(new OpKernelFunc(kernel_iter->second)); + std::lock_guard lock(cache_update_mutex_); + if (kernel_type_.get() == nullptr || kernel_func_.get() == nullptr) { + kernel_type_.reset(new OpKernelType(expected_kernel_key)); + kernel_func_.reset(new OpKernelFunc(kernel_iter->second)); + } } void OperatorWithKernel::TransferInplaceVarsBack( diff --git a/paddle/fluid/framework/operator.h b/paddle/fluid/framework/operator.h index 3332530b3c0be1e7e91a741d4ff3731b22fdb4da..9090b7e1d13bdd65225bdde60549e5259b51116f 100644 --- a/paddle/fluid/framework/operator.h +++ b/paddle/fluid/framework/operator.h @@ -17,6 +17,7 @@ limitations under the License. */ #include #include #include +#include // NOLINT #include #include #include @@ -508,8 +509,8 @@ class OperatorWithKernel : public OperatorBase { mutable std::unique_ptr runtime_ctx_; mutable const Scope* pre_scope_ = nullptr; mutable bool enable_cache_runtime_context = false; - mutable bool enable_cache_expected_kernel = false; mutable bool all_kernels_must_compute_runtime_shape = false; + mutable std::mutex cache_update_mutex_; }; extern bool OpSupportGPU(const std::string& op_type);