未验证 提交 71d375bb 编写于 作者: C Chen Weihang 提交者: GitHub

[PTen] Add lock to kernel signature map init (#36923)

* add lock to kernel sig map

* add lock for map emplace
上级 e512aa9a
......@@ -59,6 +59,43 @@ pten::KernelKey TransOpKernelTypeToPtenKernelKey(
return pten::KernelKey(backend, layout, dtype);
}
KernelSignatureMap* KernelSignatureMap::kernel_signature_map_ = nullptr;
std::mutex KernelSignatureMap::mutex_;
KernelSignatureMap& KernelSignatureMap::Instance() {
if (kernel_signature_map_ == nullptr) {
std::unique_lock<std::mutex> lock(mutex_);
if (kernel_signature_map_ == nullptr) {
kernel_signature_map_ = new KernelSignatureMap;
}
}
return *kernel_signature_map_;
}
bool KernelSignatureMap::Has(const std::string& op_type) const {
return map_.find(op_type) != map_.end();
}
void KernelSignatureMap::Emplace(const std::string& op_type,
KernelSignature&& signature) {
if (!Has(op_type)) {
std::unique_lock<std::mutex> lock(mutex_);
if (!Has(op_type)) {
map_.emplace(op_type, signature);
}
}
}
const KernelSignature& KernelSignatureMap::Get(
const std::string& op_type) const {
auto it = map_.find(op_type);
PADDLE_ENFORCE_NE(
it, map_.end(),
platform::errors::NotFound(
"Operator `%s`'s kernel signature is not registered.", op_type));
return it->second;
}
const paddle::SmallVector<std::string>&
KernelArgsNameMakerByOpProto::GetInputArgsNames() {
for (int i = 0; i < op_proto_->inputs_size(); ++i) {
......
......@@ -14,6 +14,7 @@ limitations under the License. */
#pragma once
#include <mutex>
#include <string>
#include <unordered_map>
#include <vector>
......@@ -62,35 +63,23 @@ struct KernelSignature {
// TODO(chenweihang): we can generate this map by proto info in compile time
class KernelSignatureMap {
public:
static KernelSignatureMap& Instance() {
static KernelSignatureMap g_kernel_signature_map;
return g_kernel_signature_map;
}
bool Has(const std::string& op_type) const {
return map_.find(op_type) != map_.end();
}
void Emplace(const std::string& op_type, KernelSignature&& signature) {
if (!Has(op_type)) {
map_.emplace(op_type, signature);
}
}
const KernelSignature& Get(const std::string& op_type) const {
auto it = map_.find(op_type);
PADDLE_ENFORCE_NE(
it, map_.end(),
platform::errors::NotFound(
"Operator `%s`'s kernel signature is not registered.", op_type));
return it->second;
}
static KernelSignatureMap& Instance();
bool Has(const std::string& op_type) const;
void Emplace(const std::string& op_type, KernelSignature&& signature);
const KernelSignature& Get(const std::string& op_type) const;
private:
KernelSignatureMap() = default;
paddle::flat_hash_map<std::string, KernelSignature> map_;
DISABLE_COPY_AND_ASSIGN(KernelSignatureMap);
private:
static KernelSignatureMap* kernel_signature_map_;
static std::mutex mutex_;
paddle::flat_hash_map<std::string, KernelSignature> map_;
};
class KernelArgsNameMaker {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册