From 71d375bb9a5878486159526f8e41816276c3d09b Mon Sep 17 00:00:00 2001 From: Chen Weihang Date: Tue, 2 Nov 2021 20:12:28 +0800 Subject: [PATCH] [PTen] Add lock to kernel signature map init (#36923) * add lock to kernel sig map * add lock for map emplace --- paddle/fluid/framework/pten_utils.cc | 37 ++++++++++++++++++++++++++ paddle/fluid/framework/pten_utils.h | 39 ++++++++++------------------ 2 files changed, 51 insertions(+), 25 deletions(-) diff --git a/paddle/fluid/framework/pten_utils.cc b/paddle/fluid/framework/pten_utils.cc index 8bd9b87a47..ff24a8c73f 100644 --- a/paddle/fluid/framework/pten_utils.cc +++ b/paddle/fluid/framework/pten_utils.cc @@ -59,6 +59,43 @@ pten::KernelKey TransOpKernelTypeToPtenKernelKey( return pten::KernelKey(backend, layout, dtype); } +KernelSignatureMap* KernelSignatureMap::kernel_signature_map_ = nullptr; +std::mutex KernelSignatureMap::mutex_; + +KernelSignatureMap& KernelSignatureMap::Instance() { + if (kernel_signature_map_ == nullptr) { + std::unique_lock lock(mutex_); + if (kernel_signature_map_ == nullptr) { + kernel_signature_map_ = new KernelSignatureMap; + } + } + return *kernel_signature_map_; +} + +bool KernelSignatureMap::Has(const std::string& op_type) const { + return map_.find(op_type) != map_.end(); +} + +void KernelSignatureMap::Emplace(const std::string& op_type, + KernelSignature&& signature) { + if (!Has(op_type)) { + std::unique_lock lock(mutex_); + if (!Has(op_type)) { + map_.emplace(op_type, signature); + } + } +} + +const KernelSignature& KernelSignatureMap::Get( + const std::string& op_type) const { + auto it = map_.find(op_type); + PADDLE_ENFORCE_NE( + it, map_.end(), + platform::errors::NotFound( + "Operator `%s`'s kernel signature is not registered.", op_type)); + return it->second; +} + const paddle::SmallVector& KernelArgsNameMakerByOpProto::GetInputArgsNames() { for (int i = 0; i < op_proto_->inputs_size(); ++i) { diff --git a/paddle/fluid/framework/pten_utils.h b/paddle/fluid/framework/pten_utils.h index 30000ab62d..aab756f50e 100644 --- a/paddle/fluid/framework/pten_utils.h +++ b/paddle/fluid/framework/pten_utils.h @@ -14,6 +14,7 @@ limitations under the License. */ #pragma once +#include #include #include #include @@ -62,35 +63,23 @@ struct KernelSignature { // TODO(chenweihang): we can generate this map by proto info in compile time class KernelSignatureMap { public: - static KernelSignatureMap& Instance() { - static KernelSignatureMap g_kernel_signature_map; - return g_kernel_signature_map; - } - - bool Has(const std::string& op_type) const { - return map_.find(op_type) != map_.end(); - } - - void Emplace(const std::string& op_type, KernelSignature&& signature) { - if (!Has(op_type)) { - map_.emplace(op_type, signature); - } - } - - const KernelSignature& Get(const std::string& op_type) const { - auto it = map_.find(op_type); - PADDLE_ENFORCE_NE( - it, map_.end(), - platform::errors::NotFound( - "Operator `%s`'s kernel signature is not registered.", op_type)); - return it->second; - } + static KernelSignatureMap& Instance(); + + bool Has(const std::string& op_type) const; + + void Emplace(const std::string& op_type, KernelSignature&& signature); + + const KernelSignature& Get(const std::string& op_type) const; private: KernelSignatureMap() = default; - paddle::flat_hash_map map_; - DISABLE_COPY_AND_ASSIGN(KernelSignatureMap); + + private: + static KernelSignatureMap* kernel_signature_map_; + static std::mutex mutex_; + + paddle::flat_hash_map map_; }; class KernelArgsNameMaker { -- GitLab