未验证 提交 192184e8 编写于 作者: C Chen Weihang 提交者: GitHub

[Pten] Add select kernel map method for infrt (#38972)

* add select kernel map method

* fix error
上级 5c358674
......@@ -50,6 +50,15 @@ Kernel KernelFactory::SelectKernel(const std::string& kernel_name,
return kernel_iter->second;
}
paddle::flat_hash_map<KernelKey, Kernel, KernelKey::Hash>
KernelFactory::SelectKernelMap(const std::string& kernel_name) const {
auto iter = kernels_.find(kernel_name);
if (iter == kernels_.end()) {
return paddle::flat_hash_map<KernelKey, Kernel, KernelKey::Hash>();
}
return iter->second;
}
const Kernel& KernelFactory::SelectKernelOrThrowError(
const std::string& kernel_name, const KernelKey& kernel_key) const {
auto iter = kernels_.find(kernel_name);
......
......@@ -232,6 +232,9 @@ class KernelFactory {
Kernel SelectKernel(const std::string& kernel_name,
const KernelKey& kernel_key) const;
paddle::flat_hash_map<KernelKey, Kernel, KernelKey::Hash> SelectKernelMap(
const std::string& kernel_name) const;
private:
KernelFactory() = default;
......
......@@ -2,4 +2,4 @@ cc_test(test_dense_tensor SRCS test_dense_tensor.cc DEPS dense_tensor)
cc_test(test_intrusive_ptr SRCS test_intrusive_ptr.cc)
cc_test(test_type_info SRCS test_type_info.cc)
cc_test(test_convert_utils SRCS test_convert_utils.cc DEPS convert_utils)
cc_test(test_kernel_factory SRCS test_kernel_factory.cc DEPS kernel_factory)
cc_test(test_kernel_factory SRCS test_kernel_factory.cc DEPS kernel_factory scale_kernel)
......@@ -16,9 +16,12 @@ limitations under the License. */
#include <sstream>
#include "paddle/pten/core/kernel_factory.h"
#include "paddle/pten/core/kernel_registry.h"
#include "gtest/gtest.h"
PT_DECLARE_KERNEL(scale, CPU, ALL_LAYOUT);
namespace pten {
namespace tests {
......@@ -33,9 +36,16 @@ TEST(KernelKey, ConstructAndOStream) {
std::ostringstream oss;
oss << key;
std::cout << oss.str();
// EXPECT_EQ(oss.str(), "scale.host");
oss.flush();
}
TEST(KernelFactory, SelectedKernelMap) {
auto kernel_map = pten::KernelFactory::Instance().SelectKernelMap("scale");
EXPECT_GT(kernel_map.size(), 1UL);
for (auto& iter : kernel_map) {
std::cout << iter.first << ": " << iter.second;
}
}
} // namespace tests
} // namespace pten
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册