From 192184e8f3b36ca0b7843f765b2e004becf05e43 Mon Sep 17 00:00:00 2001 From: Chen Weihang Date: Sun, 16 Jan 2022 22:54:18 +0800 Subject: [PATCH] [Pten] Add select kernel map method for infrt (#38972) * add select kernel map method * fix error --- paddle/pten/core/kernel_factory.cc | 9 +++++++++ paddle/pten/core/kernel_factory.h | 3 +++ paddle/pten/tests/core/CMakeLists.txt | 2 +- paddle/pten/tests/core/test_kernel_factory.cc | 12 +++++++++++- 4 files changed, 24 insertions(+), 2 deletions(-) diff --git a/paddle/pten/core/kernel_factory.cc b/paddle/pten/core/kernel_factory.cc index 799b8608597..f10b58506f7 100644 --- a/paddle/pten/core/kernel_factory.cc +++ b/paddle/pten/core/kernel_factory.cc @@ -50,6 +50,15 @@ Kernel KernelFactory::SelectKernel(const std::string& kernel_name, return kernel_iter->second; } +paddle::flat_hash_map +KernelFactory::SelectKernelMap(const std::string& kernel_name) const { + auto iter = kernels_.find(kernel_name); + if (iter == kernels_.end()) { + return paddle::flat_hash_map(); + } + return iter->second; +} + const Kernel& KernelFactory::SelectKernelOrThrowError( const std::string& kernel_name, const KernelKey& kernel_key) const { auto iter = kernels_.find(kernel_name); diff --git a/paddle/pten/core/kernel_factory.h b/paddle/pten/core/kernel_factory.h index e0585aea7f3..bd26d86a34a 100644 --- a/paddle/pten/core/kernel_factory.h +++ b/paddle/pten/core/kernel_factory.h @@ -232,6 +232,9 @@ class KernelFactory { Kernel SelectKernel(const std::string& kernel_name, const KernelKey& kernel_key) const; + paddle::flat_hash_map SelectKernelMap( + const std::string& kernel_name) const; + private: KernelFactory() = default; diff --git a/paddle/pten/tests/core/CMakeLists.txt b/paddle/pten/tests/core/CMakeLists.txt index 07554f02d99..2d4ee7f6d6a 100644 --- a/paddle/pten/tests/core/CMakeLists.txt +++ b/paddle/pten/tests/core/CMakeLists.txt @@ -2,4 +2,4 @@ cc_test(test_dense_tensor SRCS test_dense_tensor.cc DEPS dense_tensor) cc_test(test_intrusive_ptr SRCS test_intrusive_ptr.cc) cc_test(test_type_info SRCS test_type_info.cc) cc_test(test_convert_utils SRCS test_convert_utils.cc DEPS convert_utils) -cc_test(test_kernel_factory SRCS test_kernel_factory.cc DEPS kernel_factory) +cc_test(test_kernel_factory SRCS test_kernel_factory.cc DEPS kernel_factory scale_kernel) diff --git a/paddle/pten/tests/core/test_kernel_factory.cc b/paddle/pten/tests/core/test_kernel_factory.cc index 3f271b2a8f0..5355921ddbe 100644 --- a/paddle/pten/tests/core/test_kernel_factory.cc +++ b/paddle/pten/tests/core/test_kernel_factory.cc @@ -16,9 +16,12 @@ limitations under the License. */ #include #include "paddle/pten/core/kernel_factory.h" +#include "paddle/pten/core/kernel_registry.h" #include "gtest/gtest.h" +PT_DECLARE_KERNEL(scale, CPU, ALL_LAYOUT); + namespace pten { namespace tests { @@ -33,9 +36,16 @@ TEST(KernelKey, ConstructAndOStream) { std::ostringstream oss; oss << key; std::cout << oss.str(); - // EXPECT_EQ(oss.str(), "scale.host"); oss.flush(); } +TEST(KernelFactory, SelectedKernelMap) { + auto kernel_map = pten::KernelFactory::Instance().SelectKernelMap("scale"); + EXPECT_GT(kernel_map.size(), 1UL); + for (auto& iter : kernel_map) { + std::cout << iter.first << ": " << iter.second; + } +} + } // namespace tests } // namespace pten -- GitLab