diff --git a/paddle/pten/core/kernel_factory.cc b/paddle/pten/core/kernel_factory.cc index 799b860859762b9736bd381d7f8c939dff2cd786..f10b58506f728ed39b62ec6c6efad621ab8ce926 100644 --- a/paddle/pten/core/kernel_factory.cc +++ b/paddle/pten/core/kernel_factory.cc @@ -50,6 +50,15 @@ Kernel KernelFactory::SelectKernel(const std::string& kernel_name, return kernel_iter->second; } +paddle::flat_hash_map +KernelFactory::SelectKernelMap(const std::string& kernel_name) const { + auto iter = kernels_.find(kernel_name); + if (iter == kernels_.end()) { + return paddle::flat_hash_map(); + } + return iter->second; +} + const Kernel& KernelFactory::SelectKernelOrThrowError( const std::string& kernel_name, const KernelKey& kernel_key) const { auto iter = kernels_.find(kernel_name); diff --git a/paddle/pten/core/kernel_factory.h b/paddle/pten/core/kernel_factory.h index e0585aea7f3db7aa6a310eadf6c62e3f51a897ff..bd26d86a34a0942da61f08c040fdd6a0ec47a2cf 100644 --- a/paddle/pten/core/kernel_factory.h +++ b/paddle/pten/core/kernel_factory.h @@ -232,6 +232,9 @@ class KernelFactory { Kernel SelectKernel(const std::string& kernel_name, const KernelKey& kernel_key) const; + paddle::flat_hash_map SelectKernelMap( + const std::string& kernel_name) const; + private: KernelFactory() = default; diff --git a/paddle/pten/tests/core/CMakeLists.txt b/paddle/pten/tests/core/CMakeLists.txt index 07554f02d999222dd6cf41c6462c14e9a924b4db..2d4ee7f6d6a477bb7199b36901d4366d39e9abcb 100644 --- a/paddle/pten/tests/core/CMakeLists.txt +++ b/paddle/pten/tests/core/CMakeLists.txt @@ -2,4 +2,4 @@ cc_test(test_dense_tensor SRCS test_dense_tensor.cc DEPS dense_tensor) cc_test(test_intrusive_ptr SRCS test_intrusive_ptr.cc) cc_test(test_type_info SRCS test_type_info.cc) cc_test(test_convert_utils SRCS test_convert_utils.cc DEPS convert_utils) -cc_test(test_kernel_factory SRCS test_kernel_factory.cc DEPS kernel_factory) +cc_test(test_kernel_factory SRCS test_kernel_factory.cc DEPS kernel_factory scale_kernel) diff --git a/paddle/pten/tests/core/test_kernel_factory.cc b/paddle/pten/tests/core/test_kernel_factory.cc index 3f271b2a8f0d0d2e8360c62ee5a48be00b9575a4..5355921ddbe018b24cf874d6bb7b54da48c967f3 100644 --- a/paddle/pten/tests/core/test_kernel_factory.cc +++ b/paddle/pten/tests/core/test_kernel_factory.cc @@ -16,9 +16,12 @@ limitations under the License. */ #include #include "paddle/pten/core/kernel_factory.h" +#include "paddle/pten/core/kernel_registry.h" #include "gtest/gtest.h" +PT_DECLARE_KERNEL(scale, CPU, ALL_LAYOUT); + namespace pten { namespace tests { @@ -33,9 +36,16 @@ TEST(KernelKey, ConstructAndOStream) { std::ostringstream oss; oss << key; std::cout << oss.str(); - // EXPECT_EQ(oss.str(), "scale.host"); oss.flush(); } +TEST(KernelFactory, SelectedKernelMap) { + auto kernel_map = pten::KernelFactory::Instance().SelectKernelMap("scale"); + EXPECT_GT(kernel_map.size(), 1UL); + for (auto& iter : kernel_map) { + std::cout << iter.first << ": " << iter.second; + } +} + } // namespace tests } // namespace pten