未验证 提交 9c81a9bb 编写于 作者: Z Zeng Jinle 提交者: GitHub

Fix PTen thread safety error (#36960)

* fix pten thread safety error

* improve coverage
上级 2479664a
......@@ -399,7 +399,7 @@ cc_library(save_load_util SRCS save_load_util.cc DEPS tensor scope layer)
cc_test(save_load_util_test SRCS save_load_util_test.cc DEPS save_load_util tensor scope layer)
cc_library(generator SRCS generator.cc DEPS enforce place)
cc_library(pten_utils SRCS pten_utils.cc DEPS lod_tensor selected_rows place pten var_type_traits pten_hapi_utils)
cc_library(pten_utils SRCS pten_utils.cc DEPS lod_tensor selected_rows place pten var_type_traits pten_hapi_utils op_info)
# Get the current working branch
execute_process(
......
......@@ -1762,12 +1762,6 @@ OpKernelType OperatorWithKernel::GetKernelTypeForVar(
KernelSignature OperatorWithKernel::GetExpectedPtenKernelArgs(
const ExecutionContext& ctx) const {
if (!KernelSignatureMap::Instance().Has(Type())) {
// TODO(chenweihang): we can generate this map by proto info in compile time
KernelArgsNameMakerByOpProto maker(Info().proto_);
KernelSignatureMap::Instance().Emplace(
Type(), std::move(maker.GetKernelSignature()));
}
return KernelSignatureMap::Instance().Get(Type());
}
......
......@@ -15,8 +15,10 @@ limitations under the License. */
#include <sstream>
#include "paddle/fluid/framework/pten_utils.h"
#include "paddle/pten/core/kernel_factory.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_info.h"
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/framework/variable.h"
#include "paddle/fluid/string/string_helper.h"
......@@ -24,6 +26,34 @@ limitations under the License. */
namespace paddle {
namespace framework {
class KernelArgsNameMakerByOpProto : public KernelArgsNameMaker {
public:
explicit KernelArgsNameMakerByOpProto(
const framework::proto::OpProto* op_proto)
: op_proto_(op_proto) {
PADDLE_ENFORCE_NOT_NULL(op_proto_, platform::errors::InvalidArgument(
"Op proto cannot be nullptr."));
}
~KernelArgsNameMakerByOpProto() {}
const paddle::SmallVector<std::string>& GetInputArgsNames() override;
const paddle::SmallVector<std::string>& GetOutputArgsNames() override;
const paddle::SmallVector<std::string>& GetAttrsArgsNames() override;
KernelSignature GetKernelSignature();
private:
DISABLE_COPY_AND_ASSIGN(KernelArgsNameMakerByOpProto);
private:
const framework::proto::OpProto* op_proto_;
paddle::SmallVector<std::string> input_names_;
paddle::SmallVector<std::string> output_names_;
paddle::SmallVector<std::string> attr_names_;
};
OpKernelType TransPtenKernelKeyToOpKernelType(
const pten::KernelKey& kernel_key) {
proto::VarType::Type data_type =
......@@ -60,15 +90,29 @@ pten::KernelKey TransOpKernelTypeToPtenKernelKey(
}
KernelSignatureMap* KernelSignatureMap::kernel_signature_map_ = nullptr;
std::mutex KernelSignatureMap::mutex_;
std::once_flag KernelSignatureMap::init_flag_;
KernelSignatureMap& KernelSignatureMap::Instance() {
if (kernel_signature_map_ == nullptr) {
std::unique_lock<std::mutex> lock(mutex_);
if (kernel_signature_map_ == nullptr) {
kernel_signature_map_ = new KernelSignatureMap;
std::call_once(init_flag_, [] {
kernel_signature_map_ = new KernelSignatureMap();
for (const auto& pair : OpInfoMap::Instance().map()) {
const auto& op_type = pair.first;
const auto* op_proto = pair.second.proto_;
if (pten::KernelFactory::Instance().HasCompatiblePtenKernel(op_type)) {
KernelArgsNameMakerByOpProto maker(op_proto);
VLOG(10) << "Register kernel signature for " << op_type;
auto success =
kernel_signature_map_->map_
.emplace(op_type, std::move(maker.GetKernelSignature()))
.second;
PADDLE_ENFORCE_EQ(
success, true,
platform::errors::PermissionDenied(
"Kernel signature of the operator %s has been registered.",
op_type));
}
}
}
});
return *kernel_signature_map_;
}
......@@ -76,16 +120,6 @@ bool KernelSignatureMap::Has(const std::string& op_type) const {
return map_.find(op_type) != map_.end();
}
void KernelSignatureMap::Emplace(const std::string& op_type,
KernelSignature&& signature) {
if (!Has(op_type)) {
std::unique_lock<std::mutex> lock(mutex_);
if (!Has(op_type)) {
map_.emplace(op_type, signature);
}
}
}
const KernelSignature& KernelSignatureMap::Get(
const std::string& op_type) const {
auto it = map_.find(op_type);
......
......@@ -67,8 +67,6 @@ class KernelSignatureMap {
bool Has(const std::string& op_type) const;
void Emplace(const std::string& op_type, KernelSignature&& signature);
const KernelSignature& Get(const std::string& op_type) const;
private:
......@@ -77,7 +75,7 @@ class KernelSignatureMap {
private:
static KernelSignatureMap* kernel_signature_map_;
static std::mutex mutex_;
static std::once_flag init_flag_;
paddle::flat_hash_map<std::string, KernelSignature> map_;
};
......@@ -90,27 +88,6 @@ class KernelArgsNameMaker {
virtual const paddle::SmallVector<std::string>& GetAttrsArgsNames() = 0;
};
class KernelArgsNameMakerByOpProto : public KernelArgsNameMaker {
public:
explicit KernelArgsNameMakerByOpProto(framework::proto::OpProto* op_proto)
: op_proto_(op_proto) {}
~KernelArgsNameMakerByOpProto() {}
const paddle::SmallVector<std::string>& GetInputArgsNames() override;
const paddle::SmallVector<std::string>& GetOutputArgsNames() override;
const paddle::SmallVector<std::string>& GetAttrsArgsNames() override;
KernelSignature GetKernelSignature();
private:
framework::proto::OpProto* op_proto_;
paddle::SmallVector<std::string> input_names_;
paddle::SmallVector<std::string> output_names_;
paddle::SmallVector<std::string> attr_names_;
};
std::string KernelSignatureToString(const KernelSignature& signature);
} // namespace framework
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册