diff --git a/paddle/fluid/framework/CMakeLists.txt b/paddle/fluid/framework/CMakeLists.txt index f037a83bcab48481178deaabe68b6860bb69ee90..d0cbd7801ddaffd4b77e521c8e208d4a8b14ebf2 100644 --- a/paddle/fluid/framework/CMakeLists.txt +++ b/paddle/fluid/framework/CMakeLists.txt @@ -399,7 +399,7 @@ cc_library(save_load_util SRCS save_load_util.cc DEPS tensor scope layer) cc_test(save_load_util_test SRCS save_load_util_test.cc DEPS save_load_util tensor scope layer) cc_library(generator SRCS generator.cc DEPS enforce place) -cc_library(pten_utils SRCS pten_utils.cc DEPS lod_tensor selected_rows place pten var_type_traits pten_hapi_utils) +cc_library(pten_utils SRCS pten_utils.cc DEPS lod_tensor selected_rows place pten var_type_traits pten_hapi_utils op_info) # Get the current working branch execute_process( diff --git a/paddle/fluid/framework/operator.cc b/paddle/fluid/framework/operator.cc index d317aac8594b4c9045f06a596fb3f9292604fcd6..e75fb4e36336abdcffd5fe1ba81eeeecd2d8c3db 100644 --- a/paddle/fluid/framework/operator.cc +++ b/paddle/fluid/framework/operator.cc @@ -1762,12 +1762,6 @@ OpKernelType OperatorWithKernel::GetKernelTypeForVar( KernelSignature OperatorWithKernel::GetExpectedPtenKernelArgs( const ExecutionContext& ctx) const { - if (!KernelSignatureMap::Instance().Has(Type())) { - // TODO(chenweihang): we can generate this map by proto info in compile time - KernelArgsNameMakerByOpProto maker(Info().proto_); - KernelSignatureMap::Instance().Emplace( - Type(), std::move(maker.GetKernelSignature())); - } return KernelSignatureMap::Instance().Get(Type()); } diff --git a/paddle/fluid/framework/pten_utils.cc b/paddle/fluid/framework/pten_utils.cc index ff24a8c73f705f676a2fe45cd13187ba929d5bf7..b423d0e05e174402d5485d1aefaa3643a9a28c9f 100644 --- a/paddle/fluid/framework/pten_utils.cc +++ b/paddle/fluid/framework/pten_utils.cc @@ -15,8 +15,10 @@ limitations under the License. */ #include #include "paddle/fluid/framework/pten_utils.h" +#include "paddle/pten/core/kernel_factory.h" #include "paddle/fluid/framework/lod_tensor.h" +#include "paddle/fluid/framework/op_info.h" #include "paddle/fluid/framework/selected_rows.h" #include "paddle/fluid/framework/variable.h" #include "paddle/fluid/string/string_helper.h" @@ -24,6 +26,34 @@ limitations under the License. */ namespace paddle { namespace framework { +class KernelArgsNameMakerByOpProto : public KernelArgsNameMaker { + public: + explicit KernelArgsNameMakerByOpProto( + const framework::proto::OpProto* op_proto) + : op_proto_(op_proto) { + PADDLE_ENFORCE_NOT_NULL(op_proto_, platform::errors::InvalidArgument( + "Op proto cannot be nullptr.")); + } + + ~KernelArgsNameMakerByOpProto() {} + + const paddle::SmallVector& GetInputArgsNames() override; + const paddle::SmallVector& GetOutputArgsNames() override; + const paddle::SmallVector& GetAttrsArgsNames() override; + + KernelSignature GetKernelSignature(); + + private: + DISABLE_COPY_AND_ASSIGN(KernelArgsNameMakerByOpProto); + + private: + const framework::proto::OpProto* op_proto_; + + paddle::SmallVector input_names_; + paddle::SmallVector output_names_; + paddle::SmallVector attr_names_; +}; + OpKernelType TransPtenKernelKeyToOpKernelType( const pten::KernelKey& kernel_key) { proto::VarType::Type data_type = @@ -60,15 +90,29 @@ pten::KernelKey TransOpKernelTypeToPtenKernelKey( } KernelSignatureMap* KernelSignatureMap::kernel_signature_map_ = nullptr; -std::mutex KernelSignatureMap::mutex_; +std::once_flag KernelSignatureMap::init_flag_; KernelSignatureMap& KernelSignatureMap::Instance() { - if (kernel_signature_map_ == nullptr) { - std::unique_lock lock(mutex_); - if (kernel_signature_map_ == nullptr) { - kernel_signature_map_ = new KernelSignatureMap; + std::call_once(init_flag_, [] { + kernel_signature_map_ = new KernelSignatureMap(); + for (const auto& pair : OpInfoMap::Instance().map()) { + const auto& op_type = pair.first; + const auto* op_proto = pair.second.proto_; + if (pten::KernelFactory::Instance().HasCompatiblePtenKernel(op_type)) { + KernelArgsNameMakerByOpProto maker(op_proto); + VLOG(10) << "Register kernel signature for " << op_type; + auto success = + kernel_signature_map_->map_ + .emplace(op_type, std::move(maker.GetKernelSignature())) + .second; + PADDLE_ENFORCE_EQ( + success, true, + platform::errors::PermissionDenied( + "Kernel signature of the operator %s has been registered.", + op_type)); + } } - } + }); return *kernel_signature_map_; } @@ -76,16 +120,6 @@ bool KernelSignatureMap::Has(const std::string& op_type) const { return map_.find(op_type) != map_.end(); } -void KernelSignatureMap::Emplace(const std::string& op_type, - KernelSignature&& signature) { - if (!Has(op_type)) { - std::unique_lock lock(mutex_); - if (!Has(op_type)) { - map_.emplace(op_type, signature); - } - } -} - const KernelSignature& KernelSignatureMap::Get( const std::string& op_type) const { auto it = map_.find(op_type); diff --git a/paddle/fluid/framework/pten_utils.h b/paddle/fluid/framework/pten_utils.h index 6fe02ad4a4ae4a87c4d053e353da2474186b5d53..fd893e04d3ca45164b13a6f1d01a203c200f4cc5 100644 --- a/paddle/fluid/framework/pten_utils.h +++ b/paddle/fluid/framework/pten_utils.h @@ -67,8 +67,6 @@ class KernelSignatureMap { bool Has(const std::string& op_type) const; - void Emplace(const std::string& op_type, KernelSignature&& signature); - const KernelSignature& Get(const std::string& op_type) const; private: @@ -77,7 +75,7 @@ class KernelSignatureMap { private: static KernelSignatureMap* kernel_signature_map_; - static std::mutex mutex_; + static std::once_flag init_flag_; paddle::flat_hash_map map_; }; @@ -90,27 +88,6 @@ class KernelArgsNameMaker { virtual const paddle::SmallVector& GetAttrsArgsNames() = 0; }; -class KernelArgsNameMakerByOpProto : public KernelArgsNameMaker { - public: - explicit KernelArgsNameMakerByOpProto(framework::proto::OpProto* op_proto) - : op_proto_(op_proto) {} - - ~KernelArgsNameMakerByOpProto() {} - - const paddle::SmallVector& GetInputArgsNames() override; - const paddle::SmallVector& GetOutputArgsNames() override; - const paddle::SmallVector& GetAttrsArgsNames() override; - - KernelSignature GetKernelSignature(); - - private: - framework::proto::OpProto* op_proto_; - - paddle::SmallVector input_names_; - paddle::SmallVector output_names_; - paddle::SmallVector attr_names_; -}; - std::string KernelSignatureToString(const KernelSignature& signature); } // namespace framework