diff --git a/paddle/fluid/framework/operator.cc b/paddle/fluid/framework/operator.cc index 265ce01d814361818f508c4a46245f44e14652cf..3b41eb6c1938398b8a444fd1c4704e5fc858a547 100644 --- a/paddle/fluid/framework/operator.cc +++ b/paddle/fluid/framework/operator.cc @@ -1275,7 +1275,7 @@ void OperatorWithKernel::ChoosePtenKernel(const ExecutionContext& ctx) const { kernel_type_.reset( new OpKernelType(std::move(InnerGetExpectedKernelType(ctx)))); - auto pt_kernel_name = pten::KernelName(pt_kernel_signature_->name); + auto pt_kernel_name = pt_kernel_signature_->name; auto pt_kernel_key = TransOpKernelTypeToPtenKernelKey(*kernel_type_.get()); pt_kernel_.reset( new pten::Kernel(pten::KernelFactory::Instance().SelectKernel( diff --git a/paddle/fluid/imperative/prepared_operator.cc b/paddle/fluid/imperative/prepared_operator.cc index 055e02b0cb258def0e8820c73cf9058721272c77..c4092a33aa332da68122cd79fbeef4898a44a9ac 100644 --- a/paddle/fluid/imperative/prepared_operator.cc +++ b/paddle/fluid/imperative/prepared_operator.cc @@ -165,7 +165,7 @@ PreparedOp PrepareImpl(const NameVarMap& ins, auto pt_kernel_signature = op.GetExpectedPtenKernelArgs(dygraph_exe_ctx); VLOG(6) << framework::KernelSignatureToString(pt_kernel_signature); - auto pt_kernel_name = pten::KernelName(pt_kernel_signature.name); + auto pt_kernel_name = pt_kernel_signature.name; auto pt_kernel_key = TransOpKernelTypeToPtenKernelKey(expected_kernel_key); auto pt_kernel = pten::KernelFactory::Instance().SelectKernel( pt_kernel_name, pt_kernel_key); diff --git a/paddle/pten/api/lib/kernel_dispatch.h b/paddle/pten/api/lib/kernel_dispatch.h index e78e79f27c28bf1e859bb00655910b90a0fe1467..1bba16d107d48789b890d00e10ec942aa416b80d 100644 --- a/paddle/pten/api/lib/kernel_dispatch.h +++ b/paddle/pten/api/lib/kernel_dispatch.h @@ -24,7 +24,7 @@ limitations under the License. */ #include "paddle/pten/common/data_type.h" #include "paddle/pten/common/layout.h" -// TODO(chenweihang): split KernelName, Key, Kernel, Factory into diff files +// TODO(chenweihang): split Key, Kernel, Factory into diff files #include "paddle/pten/core/kernel_factory.h" // See Note [ Why still include the fluid headers? ] diff --git a/paddle/pten/core/kernel_factory.cc b/paddle/pten/core/kernel_factory.cc index aeefb7cfefb78da590f9ec75b19562e382b157c5..799b860859762b9736bd381d7f8c939dff2cd786 100644 --- a/paddle/pten/core/kernel_factory.cc +++ b/paddle/pten/core/kernel_factory.cc @@ -37,7 +37,7 @@ KernelFactory& KernelFactory::Instance() { return g_op_kernel_factory; } -Kernel KernelFactory::SelectKernel(const KernelName& kernel_name, +Kernel KernelFactory::SelectKernel(const std::string& kernel_name, const KernelKey& kernel_key) const { auto iter = kernels_.find(kernel_name); if (iter == kernels_.end()) { @@ -51,7 +51,7 @@ Kernel KernelFactory::SelectKernel(const KernelName& kernel_name, } const Kernel& KernelFactory::SelectKernelOrThrowError( - const KernelName& kernel_name, const KernelKey& kernel_key) const { + const std::string& kernel_name, const KernelKey& kernel_key) const { auto iter = kernels_.find(kernel_name); PADDLE_ENFORCE_NE(iter, kernels_.end(), @@ -78,7 +78,7 @@ const Kernel& KernelFactory::SelectKernelOrThrowError( } const Kernel& KernelFactory::SelectKernelOrThrowError( - const KernelName& kernel_name, + const std::string& kernel_name, Backend backend, DataLayout layout, DataType dtype) const { diff --git a/paddle/pten/core/kernel_factory.h b/paddle/pten/core/kernel_factory.h index e61143bf142b3f40733109bb2180f408f5ef15c5..e0585aea7f3db7aa6a310eadf6c62e3f51a897ff 100644 --- a/paddle/pten/core/kernel_factory.h +++ b/paddle/pten/core/kernel_factory.h @@ -51,61 +51,6 @@ class KernelContext; using KernelFn = void (*)(KernelContext* ctx); -class KernelName final { - public: - KernelName(std::string name, std::string overload_name) - : name_(std::move(name)), overload_name_(std::move(overload_name)) {} - - KernelName(const std::string& kernel_name) { - ParseNameAndOverloadNameFromString(kernel_name); - } - - KernelName(const char* kernel_name) { - std::string kernel_name_str(kernel_name); - ParseNameAndOverloadNameFromString(kernel_name_str); - } - - const std::string& name() const { return name_; } - const std::string& overload_name() const { return overload_name_; } - - struct Hash { - size_t operator()(const KernelName& kernel_name) const { - return std::hash()(kernel_name.name()) ^ - (std::hash()(kernel_name.overload_name()) << 1); - } - }; - - size_t hash_value() const { return Hash()(*this); } - - bool operator<(const KernelName& kernel_name) const { - return hash_value() < kernel_name.hash_value(); - } - - bool operator==(const KernelName& kernel_name) const { - return hash_value() == kernel_name.hash_value(); - } - - bool operator!=(const KernelName& kernel_name) const { - return hash_value() != kernel_name.hash_value(); - } - - private: - void ParseNameAndOverloadNameFromString(const std::string& kernel_name) { - size_t pos = kernel_name.find_first_of('.'); - if (pos == std::string::npos) { - name_ = kernel_name; - overload_name_ = ""; - } else { - name_ = kernel_name.substr(0, pos); - overload_name_ = kernel_name.substr(pos + 1, kernel_name.size()); - } - } - - // TODO(chenweihang): use string_view to improve performance later - std::string name_; - std::string overload_name_; -}; - class KernelKey { public: KernelKey() = default; @@ -265,9 +210,8 @@ class KernelFactory { public: // replaced by paddle::flat_hash_map later using KernelMap = paddle::flat_hash_map< - KernelName, - paddle::flat_hash_map, - KernelName::Hash>; + std::string, + paddle::flat_hash_map>; static KernelFactory& Instance(); @@ -277,15 +221,15 @@ class KernelFactory { return kernels_.find(TransToPtenKernelName(op_type)) != kernels_.end(); } - const Kernel& SelectKernelOrThrowError(const KernelName& kernel_name, + const Kernel& SelectKernelOrThrowError(const std::string& kernel_name, const KernelKey& kernel_key) const; - const Kernel& SelectKernelOrThrowError(const KernelName& kernel_name, + const Kernel& SelectKernelOrThrowError(const std::string& kernel_name, Backend backend, DataLayout layout, DataType dtype) const; - Kernel SelectKernel(const KernelName& kernel_name, + Kernel SelectKernel(const std::string& kernel_name, const KernelKey& kernel_key) const; private: @@ -294,18 +238,6 @@ class KernelFactory { KernelMap kernels_; }; -/** operator << overload **/ - -inline std::ostream& operator<<(std::ostream& os, - const KernelName& kernel_name) { - if (kernel_name.overload_name().empty()) { - os << kernel_name.name(); - } else { - os << kernel_name.name() << "." << kernel_name.overload_name(); - } - return os; -} - inline std::ostream& operator<<(std::ostream& os, const KernelKey& kernel_key) { os << "(" << kernel_key.backend() << ", " << kernel_key.layout() << ", " << kernel_key.dtype() << ")"; diff --git a/paddle/pten/core/kernel_registry.h b/paddle/pten/core/kernel_registry.h index 83ee8fd94b67d0891dcf5c6b59c294ce79612011..80ebd5b832a393d9a3b15bf682fbf915c37354df 100644 --- a/paddle/pten/core/kernel_registry.h +++ b/paddle/pten/core/kernel_registry.h @@ -143,7 +143,7 @@ struct KernelRegistrar { KernelArgsDefFn args_def_fn, KernelFn kernel_fn, void* variadic_kernel_fn) { - KernelName kernel_name(kernel_name_cstr); + std::string kernel_name(kernel_name_cstr); KernelKey kernel_key(backend, layout, dtype); Kernel kernel(kernel_fn, variadic_kernel_fn); args_parse_fn(kernel_key, kernel.mutable_args_def()); diff --git a/paddle/pten/tests/core/test_kernel_factory.cc b/paddle/pten/tests/core/test_kernel_factory.cc index 5ee4b17c9d676ef3bdafb761208e93b8e6e7a008..3f271b2a8f0d0d2e8360c62ee5a48be00b9575a4 100644 --- a/paddle/pten/tests/core/test_kernel_factory.cc +++ b/paddle/pten/tests/core/test_kernel_factory.cc @@ -24,18 +24,6 @@ namespace tests { // TODO(chenweihang): add more unittests later -TEST(KernelName, ConstructAndOStream) { - std::ostringstream oss; - oss << pten::KernelName("scale", "host"); - EXPECT_EQ(oss.str(), "scale.host"); - pten::KernelName kernel_name1("scale.host"); - EXPECT_EQ(kernel_name1.name(), "scale"); - EXPECT_EQ(kernel_name1.overload_name(), "host"); - pten::KernelName kernel_name2("scale.host"); - EXPECT_EQ(kernel_name2.name(), "scale"); - EXPECT_EQ(kernel_name2.overload_name(), "host"); -} - TEST(KernelKey, ConstructAndOStream) { pten::KernelKey key( pten::Backend::CPU, pten::DataLayout::NCHW, pten::DataType::FLOAT32);