未验证 提交 71d375bb 编写于 作者: C Chen Weihang 提交者: GitHub

[PTen] Add lock to kernel signature map init (#36923)

* add lock to kernel sig map

* add lock for map emplace
上级 e512aa9a
...@@ -59,6 +59,43 @@ pten::KernelKey TransOpKernelTypeToPtenKernelKey( ...@@ -59,6 +59,43 @@ pten::KernelKey TransOpKernelTypeToPtenKernelKey(
return pten::KernelKey(backend, layout, dtype); return pten::KernelKey(backend, layout, dtype);
} }
KernelSignatureMap* KernelSignatureMap::kernel_signature_map_ = nullptr;
std::mutex KernelSignatureMap::mutex_;
KernelSignatureMap& KernelSignatureMap::Instance() {
if (kernel_signature_map_ == nullptr) {
std::unique_lock<std::mutex> lock(mutex_);
if (kernel_signature_map_ == nullptr) {
kernel_signature_map_ = new KernelSignatureMap;
}
}
return *kernel_signature_map_;
}
bool KernelSignatureMap::Has(const std::string& op_type) const {
return map_.find(op_type) != map_.end();
}
void KernelSignatureMap::Emplace(const std::string& op_type,
KernelSignature&& signature) {
if (!Has(op_type)) {
std::unique_lock<std::mutex> lock(mutex_);
if (!Has(op_type)) {
map_.emplace(op_type, signature);
}
}
}
const KernelSignature& KernelSignatureMap::Get(
const std::string& op_type) const {
auto it = map_.find(op_type);
PADDLE_ENFORCE_NE(
it, map_.end(),
platform::errors::NotFound(
"Operator `%s`'s kernel signature is not registered.", op_type));
return it->second;
}
const paddle::SmallVector<std::string>& const paddle::SmallVector<std::string>&
KernelArgsNameMakerByOpProto::GetInputArgsNames() { KernelArgsNameMakerByOpProto::GetInputArgsNames() {
for (int i = 0; i < op_proto_->inputs_size(); ++i) { for (int i = 0; i < op_proto_->inputs_size(); ++i) {
......
...@@ -14,6 +14,7 @@ limitations under the License. */ ...@@ -14,6 +14,7 @@ limitations under the License. */
#pragma once #pragma once
#include <mutex>
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
#include <vector> #include <vector>
...@@ -62,35 +63,23 @@ struct KernelSignature { ...@@ -62,35 +63,23 @@ struct KernelSignature {
// TODO(chenweihang): we can generate this map by proto info in compile time // TODO(chenweihang): we can generate this map by proto info in compile time
class KernelSignatureMap { class KernelSignatureMap {
public: public:
static KernelSignatureMap& Instance() { static KernelSignatureMap& Instance();
static KernelSignatureMap g_kernel_signature_map;
return g_kernel_signature_map; bool Has(const std::string& op_type) const;
}
void Emplace(const std::string& op_type, KernelSignature&& signature);
bool Has(const std::string& op_type) const {
return map_.find(op_type) != map_.end(); const KernelSignature& Get(const std::string& op_type) const;
}
void Emplace(const std::string& op_type, KernelSignature&& signature) {
if (!Has(op_type)) {
map_.emplace(op_type, signature);
}
}
const KernelSignature& Get(const std::string& op_type) const {
auto it = map_.find(op_type);
PADDLE_ENFORCE_NE(
it, map_.end(),
platform::errors::NotFound(
"Operator `%s`'s kernel signature is not registered.", op_type));
return it->second;
}
private: private:
KernelSignatureMap() = default; KernelSignatureMap() = default;
paddle::flat_hash_map<std::string, KernelSignature> map_;
DISABLE_COPY_AND_ASSIGN(KernelSignatureMap); DISABLE_COPY_AND_ASSIGN(KernelSignatureMap);
private:
static KernelSignatureMap* kernel_signature_map_;
static std::mutex mutex_;
paddle::flat_hash_map<std::string, KernelSignature> map_;
}; };
class KernelArgsNameMaker { class KernelArgsNameMaker {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册