diff --git a/paddle/fluid/framework/pten_utils.cc b/paddle/fluid/framework/pten_utils.cc index 8bd9b87a47847516d8f092ee9f46ffa098a50799..ff24a8c73f705f676a2fe45cd13187ba929d5bf7 100644 --- a/paddle/fluid/framework/pten_utils.cc +++ b/paddle/fluid/framework/pten_utils.cc @@ -59,6 +59,43 @@ pten::KernelKey TransOpKernelTypeToPtenKernelKey( return pten::KernelKey(backend, layout, dtype); } +KernelSignatureMap* KernelSignatureMap::kernel_signature_map_ = nullptr; +std::mutex KernelSignatureMap::mutex_; + +KernelSignatureMap& KernelSignatureMap::Instance() { + if (kernel_signature_map_ == nullptr) { + std::unique_lock lock(mutex_); + if (kernel_signature_map_ == nullptr) { + kernel_signature_map_ = new KernelSignatureMap; + } + } + return *kernel_signature_map_; +} + +bool KernelSignatureMap::Has(const std::string& op_type) const { + return map_.find(op_type) != map_.end(); +} + +void KernelSignatureMap::Emplace(const std::string& op_type, + KernelSignature&& signature) { + if (!Has(op_type)) { + std::unique_lock lock(mutex_); + if (!Has(op_type)) { + map_.emplace(op_type, signature); + } + } +} + +const KernelSignature& KernelSignatureMap::Get( + const std::string& op_type) const { + auto it = map_.find(op_type); + PADDLE_ENFORCE_NE( + it, map_.end(), + platform::errors::NotFound( + "Operator `%s`'s kernel signature is not registered.", op_type)); + return it->second; +} + const paddle::SmallVector& KernelArgsNameMakerByOpProto::GetInputArgsNames() { for (int i = 0; i < op_proto_->inputs_size(); ++i) { diff --git a/paddle/fluid/framework/pten_utils.h b/paddle/fluid/framework/pten_utils.h index 30000ab62d9f73f3b6f4bbe75806eacea46819ae..aab756f50e395aa37942ceb7d42655969661b354 100644 --- a/paddle/fluid/framework/pten_utils.h +++ b/paddle/fluid/framework/pten_utils.h @@ -14,6 +14,7 @@ limitations under the License. */ #pragma once +#include #include #include #include @@ -62,35 +63,23 @@ struct KernelSignature { // TODO(chenweihang): we can generate this map by proto info in compile time class KernelSignatureMap { public: - static KernelSignatureMap& Instance() { - static KernelSignatureMap g_kernel_signature_map; - return g_kernel_signature_map; - } - - bool Has(const std::string& op_type) const { - return map_.find(op_type) != map_.end(); - } - - void Emplace(const std::string& op_type, KernelSignature&& signature) { - if (!Has(op_type)) { - map_.emplace(op_type, signature); - } - } - - const KernelSignature& Get(const std::string& op_type) const { - auto it = map_.find(op_type); - PADDLE_ENFORCE_NE( - it, map_.end(), - platform::errors::NotFound( - "Operator `%s`'s kernel signature is not registered.", op_type)); - return it->second; - } + static KernelSignatureMap& Instance(); + + bool Has(const std::string& op_type) const; + + void Emplace(const std::string& op_type, KernelSignature&& signature); + + const KernelSignature& Get(const std::string& op_type) const; private: KernelSignatureMap() = default; - paddle::flat_hash_map map_; - DISABLE_COPY_AND_ASSIGN(KernelSignatureMap); + + private: + static KernelSignatureMap* kernel_signature_map_; + static std::mutex mutex_; + + paddle::flat_hash_map map_; }; class KernelArgsNameMaker {