未验证 提交 9c81a9bb 编写于 作者: Z Zeng Jinle 提交者: GitHub

Fix PTen thread safety error (#36960)

* fix pten thread safety error

* improve coverage
上级 2479664a
...@@ -399,7 +399,7 @@ cc_library(save_load_util SRCS save_load_util.cc DEPS tensor scope layer) ...@@ -399,7 +399,7 @@ cc_library(save_load_util SRCS save_load_util.cc DEPS tensor scope layer)
cc_test(save_load_util_test SRCS save_load_util_test.cc DEPS save_load_util tensor scope layer) cc_test(save_load_util_test SRCS save_load_util_test.cc DEPS save_load_util tensor scope layer)
cc_library(generator SRCS generator.cc DEPS enforce place) cc_library(generator SRCS generator.cc DEPS enforce place)
cc_library(pten_utils SRCS pten_utils.cc DEPS lod_tensor selected_rows place pten var_type_traits pten_hapi_utils) cc_library(pten_utils SRCS pten_utils.cc DEPS lod_tensor selected_rows place pten var_type_traits pten_hapi_utils op_info)
# Get the current working branch # Get the current working branch
execute_process( execute_process(
......
...@@ -1762,12 +1762,6 @@ OpKernelType OperatorWithKernel::GetKernelTypeForVar( ...@@ -1762,12 +1762,6 @@ OpKernelType OperatorWithKernel::GetKernelTypeForVar(
KernelSignature OperatorWithKernel::GetExpectedPtenKernelArgs( KernelSignature OperatorWithKernel::GetExpectedPtenKernelArgs(
const ExecutionContext& ctx) const { const ExecutionContext& ctx) const {
if (!KernelSignatureMap::Instance().Has(Type())) {
// TODO(chenweihang): we can generate this map by proto info in compile time
KernelArgsNameMakerByOpProto maker(Info().proto_);
KernelSignatureMap::Instance().Emplace(
Type(), std::move(maker.GetKernelSignature()));
}
return KernelSignatureMap::Instance().Get(Type()); return KernelSignatureMap::Instance().Get(Type());
} }
......
...@@ -15,8 +15,10 @@ limitations under the License. */ ...@@ -15,8 +15,10 @@ limitations under the License. */
#include <sstream> #include <sstream>
#include "paddle/fluid/framework/pten_utils.h" #include "paddle/fluid/framework/pten_utils.h"
#include "paddle/pten/core/kernel_factory.h"
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_info.h"
#include "paddle/fluid/framework/selected_rows.h" #include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/framework/variable.h" #include "paddle/fluid/framework/variable.h"
#include "paddle/fluid/string/string_helper.h" #include "paddle/fluid/string/string_helper.h"
...@@ -24,6 +26,34 @@ limitations under the License. */ ...@@ -24,6 +26,34 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace framework { namespace framework {
class KernelArgsNameMakerByOpProto : public KernelArgsNameMaker {
public:
explicit KernelArgsNameMakerByOpProto(
const framework::proto::OpProto* op_proto)
: op_proto_(op_proto) {
PADDLE_ENFORCE_NOT_NULL(op_proto_, platform::errors::InvalidArgument(
"Op proto cannot be nullptr."));
}
~KernelArgsNameMakerByOpProto() {}
const paddle::SmallVector<std::string>& GetInputArgsNames() override;
const paddle::SmallVector<std::string>& GetOutputArgsNames() override;
const paddle::SmallVector<std::string>& GetAttrsArgsNames() override;
KernelSignature GetKernelSignature();
private:
DISABLE_COPY_AND_ASSIGN(KernelArgsNameMakerByOpProto);
private:
const framework::proto::OpProto* op_proto_;
paddle::SmallVector<std::string> input_names_;
paddle::SmallVector<std::string> output_names_;
paddle::SmallVector<std::string> attr_names_;
};
OpKernelType TransPtenKernelKeyToOpKernelType( OpKernelType TransPtenKernelKeyToOpKernelType(
const pten::KernelKey& kernel_key) { const pten::KernelKey& kernel_key) {
proto::VarType::Type data_type = proto::VarType::Type data_type =
...@@ -60,15 +90,29 @@ pten::KernelKey TransOpKernelTypeToPtenKernelKey( ...@@ -60,15 +90,29 @@ pten::KernelKey TransOpKernelTypeToPtenKernelKey(
} }
KernelSignatureMap* KernelSignatureMap::kernel_signature_map_ = nullptr; KernelSignatureMap* KernelSignatureMap::kernel_signature_map_ = nullptr;
std::mutex KernelSignatureMap::mutex_; std::once_flag KernelSignatureMap::init_flag_;
KernelSignatureMap& KernelSignatureMap::Instance() { KernelSignatureMap& KernelSignatureMap::Instance() {
if (kernel_signature_map_ == nullptr) { std::call_once(init_flag_, [] {
std::unique_lock<std::mutex> lock(mutex_); kernel_signature_map_ = new KernelSignatureMap();
if (kernel_signature_map_ == nullptr) { for (const auto& pair : OpInfoMap::Instance().map()) {
kernel_signature_map_ = new KernelSignatureMap; const auto& op_type = pair.first;
const auto* op_proto = pair.second.proto_;
if (pten::KernelFactory::Instance().HasCompatiblePtenKernel(op_type)) {
KernelArgsNameMakerByOpProto maker(op_proto);
VLOG(10) << "Register kernel signature for " << op_type;
auto success =
kernel_signature_map_->map_
.emplace(op_type, std::move(maker.GetKernelSignature()))
.second;
PADDLE_ENFORCE_EQ(
success, true,
platform::errors::PermissionDenied(
"Kernel signature of the operator %s has been registered.",
op_type));
}
} }
} });
return *kernel_signature_map_; return *kernel_signature_map_;
} }
...@@ -76,16 +120,6 @@ bool KernelSignatureMap::Has(const std::string& op_type) const { ...@@ -76,16 +120,6 @@ bool KernelSignatureMap::Has(const std::string& op_type) const {
return map_.find(op_type) != map_.end(); return map_.find(op_type) != map_.end();
} }
void KernelSignatureMap::Emplace(const std::string& op_type,
KernelSignature&& signature) {
if (!Has(op_type)) {
std::unique_lock<std::mutex> lock(mutex_);
if (!Has(op_type)) {
map_.emplace(op_type, signature);
}
}
}
const KernelSignature& KernelSignatureMap::Get( const KernelSignature& KernelSignatureMap::Get(
const std::string& op_type) const { const std::string& op_type) const {
auto it = map_.find(op_type); auto it = map_.find(op_type);
......
...@@ -67,8 +67,6 @@ class KernelSignatureMap { ...@@ -67,8 +67,6 @@ class KernelSignatureMap {
bool Has(const std::string& op_type) const; bool Has(const std::string& op_type) const;
void Emplace(const std::string& op_type, KernelSignature&& signature);
const KernelSignature& Get(const std::string& op_type) const; const KernelSignature& Get(const std::string& op_type) const;
private: private:
...@@ -77,7 +75,7 @@ class KernelSignatureMap { ...@@ -77,7 +75,7 @@ class KernelSignatureMap {
private: private:
static KernelSignatureMap* kernel_signature_map_; static KernelSignatureMap* kernel_signature_map_;
static std::mutex mutex_; static std::once_flag init_flag_;
paddle::flat_hash_map<std::string, KernelSignature> map_; paddle::flat_hash_map<std::string, KernelSignature> map_;
}; };
...@@ -90,27 +88,6 @@ class KernelArgsNameMaker { ...@@ -90,27 +88,6 @@ class KernelArgsNameMaker {
virtual const paddle::SmallVector<std::string>& GetAttrsArgsNames() = 0; virtual const paddle::SmallVector<std::string>& GetAttrsArgsNames() = 0;
}; };
class KernelArgsNameMakerByOpProto : public KernelArgsNameMaker {
public:
explicit KernelArgsNameMakerByOpProto(framework::proto::OpProto* op_proto)
: op_proto_(op_proto) {}
~KernelArgsNameMakerByOpProto() {}
const paddle::SmallVector<std::string>& GetInputArgsNames() override;
const paddle::SmallVector<std::string>& GetOutputArgsNames() override;
const paddle::SmallVector<std::string>& GetAttrsArgsNames() override;
KernelSignature GetKernelSignature();
private:
framework::proto::OpProto* op_proto_;
paddle::SmallVector<std::string> input_names_;
paddle::SmallVector<std::string> output_names_;
paddle::SmallVector<std::string> attr_names_;
};
std::string KernelSignatureToString(const KernelSignature& signature); std::string KernelSignatureToString(const KernelSignature& signature);
} // namespace framework } // namespace framework
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册