未验证 提交 728bbaa4 编写于 作者: Q Qiao Longfei 提交者: GitHub

add cache_update_mutex_ for operator test=develop (#17124)

* add cache_update_mutex_ for operator 
上级 15453d05
...@@ -884,8 +884,6 @@ void OperatorWithKernel::RunImpl(const Scope& scope, ...@@ -884,8 +884,6 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
// result of HasAttr. // result of HasAttr.
if (!enable_cache_runtime_context && HasAttr(kEnableCacheRuntimeContext)) if (!enable_cache_runtime_context && HasAttr(kEnableCacheRuntimeContext))
enable_cache_runtime_context = true; enable_cache_runtime_context = true;
if (!enable_cache_expected_kernel && HasAttr(kEnableCacheExpectedKernel))
enable_cache_expected_kernel = true;
if (!all_kernels_must_compute_runtime_shape && if (!all_kernels_must_compute_runtime_shape &&
HasAttr(kAllKernelsMustComputeRuntimeShape)) HasAttr(kAllKernelsMustComputeRuntimeShape))
all_kernels_must_compute_runtime_shape = true; all_kernels_must_compute_runtime_shape = true;
...@@ -894,10 +892,13 @@ void OperatorWithKernel::RunImpl(const Scope& scope, ...@@ -894,10 +892,13 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
RunImpl(scope, place, &ctx); RunImpl(scope, place, &ctx);
} else { } else {
const Scope* cur_scope = &scope; const Scope* cur_scope = &scope;
if (!runtime_ctx_ || pre_scope_ != cur_scope) { if (runtime_ctx_.get() == nullptr || pre_scope_ != cur_scope) {
std::lock_guard<std::mutex> lock(cache_update_mutex_);
if (runtime_ctx_.get() == nullptr || pre_scope_ != cur_scope) {
runtime_ctx_.reset(new RuntimeContext(Inputs(), Outputs(), scope)); runtime_ctx_.reset(new RuntimeContext(Inputs(), Outputs(), scope));
pre_scope_ = cur_scope; pre_scope_ = cur_scope;
} }
}
RunImpl(scope, place, runtime_ctx_.get()); RunImpl(scope, place, runtime_ctx_.get());
} }
} }
...@@ -908,7 +909,7 @@ void OperatorWithKernel::RunImpl(const Scope& scope, ...@@ -908,7 +909,7 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto* dev_ctx = pool.Get(place); auto* dev_ctx = pool.Get(place);
if (!enable_cache_expected_kernel || !kernel_type_) { if (kernel_type_.get() == nullptr || kernel_func_.get() == nullptr) {
ChooseKernel(*runtime_ctx, scope, place); ChooseKernel(*runtime_ctx, scope, place);
} }
...@@ -996,8 +997,11 @@ void OperatorWithKernel::ChooseKernel(const RuntimeContext& ctx, ...@@ -996,8 +997,11 @@ void OperatorWithKernel::ChooseKernel(const RuntimeContext& ctx,
KernelTypeToString(expected_kernel_key)); KernelTypeToString(expected_kernel_key));
} }
std::lock_guard<std::mutex> lock(cache_update_mutex_);
if (kernel_type_.get() == nullptr || kernel_func_.get() == nullptr) {
kernel_type_.reset(new OpKernelType(expected_kernel_key)); kernel_type_.reset(new OpKernelType(expected_kernel_key));
kernel_func_.reset(new OpKernelFunc(kernel_iter->second)); kernel_func_.reset(new OpKernelFunc(kernel_iter->second));
}
} }
void OperatorWithKernel::TransferInplaceVarsBack( void OperatorWithKernel::TransferInplaceVarsBack(
......
...@@ -17,6 +17,7 @@ limitations under the License. */ ...@@ -17,6 +17,7 @@ limitations under the License. */
#include <algorithm> #include <algorithm>
#include <atomic> #include <atomic>
#include <memory> #include <memory>
#include <mutex> // NOLINT
#include <string> #include <string>
#include <tuple> #include <tuple>
#include <unordered_map> #include <unordered_map>
...@@ -508,8 +509,8 @@ class OperatorWithKernel : public OperatorBase { ...@@ -508,8 +509,8 @@ class OperatorWithKernel : public OperatorBase {
mutable std::unique_ptr<RuntimeContext> runtime_ctx_; mutable std::unique_ptr<RuntimeContext> runtime_ctx_;
mutable const Scope* pre_scope_ = nullptr; mutable const Scope* pre_scope_ = nullptr;
mutable bool enable_cache_runtime_context = false; mutable bool enable_cache_runtime_context = false;
mutable bool enable_cache_expected_kernel = false;
mutable bool all_kernels_must_compute_runtime_shape = false; mutable bool all_kernels_must_compute_runtime_shape = false;
mutable std::mutex cache_update_mutex_;
}; };
extern bool OpSupportGPU(const std::string& op_type); extern bool OpSupportGPU(const std::string& op_type);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册