未验证 提交 60c4c9cd 编写于 作者: 王明冬 提交者: GitHub

[Infrt] add infer shape cache for kernel. (#41104)

上级 532eba99
...@@ -24,30 +24,30 @@ namespace host_context { ...@@ -24,30 +24,30 @@ namespace host_context {
struct KernelRegistry::Impl { struct KernelRegistry::Impl {
std::unordered_map<std::string, std::unordered_map<std::string,
std::pair<KernelImplementation, std::vector<const char *>>> std::pair<KernelLauncher, std::vector<const char *>>>
data; data;
}; };
KernelRegistry::KernelRegistry() : impl_(std::make_unique<Impl>()) {} KernelRegistry::KernelRegistry() : impl_(std::make_unique<Impl>()) {}
void KernelRegistry::AddKernel(const std::string &key,
KernelImplementation fn) {
CHECK(!impl_->data.count(key)) << "kernel [" << key
<< "] is registered twice";
impl_->data.emplace(
key, std::make_pair(std::move(fn), std::vector<const char *>{}));
}
const std::vector<const char *> &KernelRegistry::GetAttrNameList( const std::vector<const char *> &KernelRegistry::GetAttrNameList(
const std::string &key) const { const std::string &key) const {
CHECK(impl_->data.count(key)); CHECK(impl_->data.count(key));
return impl_->data[key].second; return impl_->data[key].second;
} }
void KernelRegistry::AddKernelWithAttrs( void KernelRegistry::AddKernel(const std::string &key,
const std::string &key,
KernelImplementation fn, KernelImplementation fn,
std::vector<const char *> &&attr_order) { const std::vector<const char *> &attr_order) {
CHECK(!impl_->data.count(key)) << "kernel [" << key
<< "] is registered twice";
impl_->data.emplace(
key, std::make_pair([fn]() { return fn; }, std::move(attr_order)));
}
void KernelRegistry::AddKernel(const std::string &key,
KernelLauncher fn,
const std::vector<const char *> &attr_order) {
CHECK(!impl_->data.count(key)) << "kernel [" << key CHECK(!impl_->data.count(key)) << "kernel [" << key
<< "] is registered twice"; << "] is registered twice";
impl_->data.emplace(key, impl_->data.emplace(key,
...@@ -56,7 +56,7 @@ void KernelRegistry::AddKernelWithAttrs( ...@@ -56,7 +56,7 @@ void KernelRegistry::AddKernelWithAttrs(
KernelImplementation KernelRegistry::GetKernel(const std::string &key) const { KernelImplementation KernelRegistry::GetKernel(const std::string &key) const {
auto it = impl_->data.find(key); auto it = impl_->data.find(key);
return it != impl_->data.end() ? it->second.first : KernelImplementation{}; return it != impl_->data.end() ? it->second.first() : KernelImplementation{};
} }
std::vector<std::string> KernelRegistry::GetKernelList() const { std::vector<std::string> KernelRegistry::GetKernelList() const {
......
...@@ -25,6 +25,7 @@ namespace host_context { ...@@ -25,6 +25,7 @@ namespace host_context {
class KernelFrame; class KernelFrame;
using KernelImplementation = std::function<void(KernelFrame *frame)>; using KernelImplementation = std::function<void(KernelFrame *frame)>;
using KernelLauncher = std::function<KernelImplementation()>;
/** /**
* Hold the kernels registered in the system. * Hold the kernels registered in the system.
...@@ -33,10 +34,12 @@ class KernelRegistry { ...@@ -33,10 +34,12 @@ class KernelRegistry {
public: public:
KernelRegistry(); KernelRegistry();
void AddKernel(const std::string &key, KernelImplementation fn); void AddKernel(const std::string &key,
void AddKernelWithAttrs(const std::string &key,
KernelImplementation fn, KernelImplementation fn,
std::vector<const char *> &&attrs_order); const std::vector<const char *> &attrs_order = {});
void AddKernel(const std::string &key,
KernelLauncher fn,
const std::vector<const char *> &attrs_order = {});
KernelImplementation GetKernel(const std::string &key) const; KernelImplementation GetKernel(const std::string &key) const;
const std::vector<const char *> &GetAttrNameList( const std::vector<const char *> &GetAttrNameList(
......
...@@ -360,8 +360,7 @@ bool MlirToRuntimeTranslator::EmitGeneralOp( ...@@ -360,8 +360,7 @@ bool MlirToRuntimeTranslator::EmitGeneralOp(
if (attrs.size()) { if (attrs.size()) {
if (attr_names.empty()) { if (attr_names.empty()) {
LOG(WARNING) << "The kernel `" << kernel_name LOG(WARNING) << "The kernel `" << kernel_name
<< "` has not been registered with " << "` has not been registered with attributes order ";
"`KernelRegistry::AddKernelWithAttrs()`.";
} else { } else {
CHECK_EQ(attr_names.size(), attrs.size()) CHECK_EQ(attr_names.size(), attrs.size())
<< "The number of kernel `" << kernel_name << "The number of kernel `" << kernel_name
...@@ -380,8 +379,7 @@ bool MlirToRuntimeTranslator::EmitGeneralOp( ...@@ -380,8 +379,7 @@ bool MlirToRuntimeTranslator::EmitGeneralOp(
} }
} }
LOG(WARNING) << "The attribute `" << attr << "` of kernel `" << kernel_name LOG(WARNING) << "The attribute `" << attr << "` of kernel `" << kernel_name
<< "` is not properly registered with " << "` is not properly register";
"`KernelRegistry::AddKernelWithAttrs()`.";
return -1; return -1;
}; };
......
...@@ -29,7 +29,6 @@ add_custom_target(infrt_register_phi_kernel ...@@ -29,7 +29,6 @@ add_custom_target(infrt_register_phi_kernel
cc_library(infrt_naive SRCS infershaped/infershaped_kernel_launcher.cc cc_library(infrt_naive SRCS infershaped/infershaped_kernel_launcher.cc
infershaped/infershaped_kernel_launchers.cc infershaped/infershaped_kernel_launchers.cc
DEPS phi wrapped_infermeta) DEPS phi wrapped_infermeta)
add_dependencies(infrt_naive infrt_register_phi_kernel)
cc_test_tiny(test_infrt_infershape_launchers SRCS cc_test_tiny(test_infrt_infershape_launchers SRCS
infershaped/infershape_launchers_test.cc DEPS infrt) infershaped/infershape_launchers_test.cc DEPS infrt)
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
#include <iostream> #include <iostream>
#include "paddle/infrt/backends/host/phi_context.h" #include "paddle/infrt/backends/host/phi_context.h"
#include "paddle/infrt/host_context/kernel_registry.h"
#include "paddle/infrt/host_context/kernel_utils.h" #include "paddle/infrt/host_context/kernel_utils.h"
#include "paddle/infrt/kernel/phi/infershaped/infershaped_kernel_launcher.h" #include "paddle/infrt/kernel/phi/infershaped/infershaped_kernel_launcher.h"
#include "paddle/infrt/kernel/phi/infershaped/infershaped_utils.h" #include "paddle/infrt/kernel/phi/infershaped/infershaped_utils.h"
...@@ -36,11 +37,12 @@ template <typename KernelFunc, ...@@ -36,11 +37,12 @@ template <typename KernelFunc,
KernelFunc kernel, KernelFunc kernel,
typename InferShapedFunc, typename InferShapedFunc,
InferShapedFunc infershape> InferShapedFunc infershape>
void KernelLauncherFunc(host_context::KernelFrame* frame) { ::infrt::host_context::KernelImplementation KernelLauncherFunc() {
InferShapedKernelLauncher launcher(FuncArgStatics<InferShapedFunc>::arg_size); InferShapedKernelLauncher launcher(FuncArgStatics<InferShapedFunc>::arg_size);
static const uint16_t num_input_tensors{InferShapeHelper<KernelFunc>::count}; static const uint16_t num_input_tensors{InferShapeHelper<KernelFunc>::count};
static const bool turn_on_infer_shape_cache{true}; static const bool turn_on_infer_shape_cache{true};
return [=](host_context::KernelFrame* frame) mutable {
#ifndef NDEBUG #ifndef NDEBUG
LOG(INFO) << "Kernel.frame: " << frame->DumpArgTypes(); LOG(INFO) << "Kernel.frame: " << frame->DumpArgTypes();
#endif #endif
...@@ -59,8 +61,12 @@ void KernelLauncherFunc(host_context::KernelFrame* frame) { ...@@ -59,8 +61,12 @@ void KernelLauncherFunc(host_context::KernelFrame* frame) {
&launcher.infershape_kernel_frame_builder); &launcher.infershape_kernel_frame_builder);
launcher.BuildInferShapeCache(num_input_tensors); launcher.BuildInferShapeCache(num_input_tensors);
} }
} else {
::infrt::host_context::KernelImpl<InferShapedFunc, infershape>::Invoke(
&launcher.infershape_kernel_frame_builder);
} }
::infrt::host_context::KernelImpl<KernelFunc, kernel>::Invoke(frame); ::infrt::host_context::KernelImpl<KernelFunc, kernel>::Invoke(frame);
};
} }
} // namespace kernel } // namespace kernel
......
...@@ -34,18 +34,16 @@ namespace kernel { ...@@ -34,18 +34,16 @@ namespace kernel {
void RegisterPhiKernels(host_context::KernelRegistry* registry) { void RegisterPhiKernels(host_context::KernelRegistry* registry) {
registry->AddKernel("phi_dt.create_context.cpu", registry->AddKernel("phi_dt.create_context.cpu",
INFRT_KERNEL(infrt::kernel::phi::CreateCPUContext)); INFRT_KERNEL(infrt::kernel::phi::CreateCPUContext));
registry->AddKernelWithAttrs( registry->AddKernel("phi_dt.create_dense_tensor.cpu",
"phi_dt.create_dense_tensor.cpu",
INFRT_KERNEL(infrt::kernel::phi::CreateDenseTensor), INFRT_KERNEL(infrt::kernel::phi::CreateDenseTensor),
{"dims", "lod", "layout", "precision"}); {"dims", "lod", "layout", "precision"});
registry->AddKernelWithAttrs( registry->AddKernel(
"phi_dt.create_inited_dense_tensor.cpu.f32", "phi_dt.create_inited_dense_tensor.cpu.f32",
INFRT_KERNEL(infrt::kernel::phi::CreateInitedDenseTensorF32), INFRT_KERNEL(infrt::kernel::phi::CreateInitedDenseTensorF32),
{"dims", "lod", "layout", "value"}); {"dims", "lod", "layout", "value"});
registry->AddKernelWithAttrs( registry->AddKernel("phi_dt.fill_dense_tensor.f32",
"phi_dt.fill_dense_tensor.f32",
INFRT_KERNEL(infrt::kernel::phi::FillDenseTensorF32), INFRT_KERNEL(infrt::kernel::phi::FillDenseTensorF32),
{"value"}); {"value"});
registry->AddKernel("phi_dt.print_tensor", registry->AddKernel("phi_dt.print_tensor",
...@@ -54,23 +52,20 @@ void RegisterPhiKernels(host_context::KernelRegistry* registry) { ...@@ -54,23 +52,20 @@ void RegisterPhiKernels(host_context::KernelRegistry* registry) {
#ifdef INFRT_WITH_GPU #ifdef INFRT_WITH_GPU
registry->AddKernel("phi_dt.create_context.gpu", registry->AddKernel("phi_dt.create_context.gpu",
INFRT_KERNEL(infrt::kernel::phi::CreateGPUContext)); INFRT_KERNEL(infrt::kernel::phi::CreateGPUContext));
registry->AddKernelWithAttrs( registry->AddKernel("phi_dt.create_dense_tensor.gpu",
"phi_dt.create_dense_tensor.gpu",
INFRT_KERNEL(infrt::kernel::phi::CreateGPUDenseTensor), INFRT_KERNEL(infrt::kernel::phi::CreateGPUDenseTensor),
{"dims", "lod", "layout", "precision"}); {"dims", "lod", "layout", "precision"});
registry->AddKernelWithAttrs("phi_dt.memcpy.gpu", registry->AddKernel("phi_dt.memcpy.gpu",
INFRT_KERNEL(infrt::kernel::phi::GpuMemCpy), INFRT_KERNEL(infrt::kernel::phi::GpuMemCpy),
{"d2h"}); {"d2h"});
#endif #endif
registry->AddKernelWithAttrs("phi_dt.load_params", registry->AddKernel("phi_dt.load_params",
INFRT_KERNEL(infrt::kernel::phi::LoadParams), INFRT_KERNEL(infrt::kernel::phi::LoadParams),
{"path"}); {"path"});
registry->AddKernelWithAttrs( registry->AddKernel("phi_dt.load_combined_params",
"phi_dt.load_combined_params",
INFRT_KERNEL(infrt::kernel::phi::LoadCombinedParams), INFRT_KERNEL(infrt::kernel::phi::LoadCombinedParams),
{"model_path", "params_path"}); {"model_path", "params_path"});
registry->AddKernelWithAttrs( registry->AddKernel("phi_dt.tensor_map_get_tensor",
"phi_dt.tensor_map_get_tensor",
INFRT_KERNEL(infrt::kernel::phi::TensorMapGetTensor), INFRT_KERNEL(infrt::kernel::phi::TensorMapGetTensor),
{"name"}); {"name"});
registry->AddKernel("phi_dt.tensor_map_get_size", registry->AddKernel("phi_dt.tensor_map_get_size",
......
...@@ -129,7 +129,7 @@ void NaiveMatmul(const DenseHostTensor &x, ...@@ -129,7 +129,7 @@ void NaiveMatmul(const DenseHostTensor &x,
/// ===== Kernel end ==== /// ===== Kernel end ====
void RegisterTensorKernels(host_context::KernelRegistry *registry) { void RegisterTensorKernels(host_context::KernelRegistry *registry) {
registry->AddKernelWithAttrs("dt.create_uninit_tensor.f32", registry->AddKernel("dt.create_uninit_tensor.f32",
INFRT_KERNEL(CreateUninitTensor<float>), INFRT_KERNEL(CreateUninitTensor<float>),
{"shape"}); {"shape"});
registry->AddKernel("dt.print_tensor", INFRT_KERNEL(PrintTensor)); registry->AddKernel("dt.print_tensor", INFRT_KERNEL(PrintTensor));
...@@ -146,7 +146,7 @@ void RegisterTensorKernels(host_context::KernelRegistry *registry) { ...@@ -146,7 +146,7 @@ void RegisterTensorKernels(host_context::KernelRegistry *registry) {
// TensorList related methods. // TensorList related methods.
#ifdef INFRT_WITH_PHI #ifdef INFRT_WITH_PHI
registry->AddKernelWithAttrs( registry->AddKernel(
"dt.tensor_list_get_tensor", INFRT_KERNEL(TensorListGetTensor), {"id"}); "dt.tensor_list_get_tensor", INFRT_KERNEL(TensorListGetTensor), {"id"});
registry->AddKernel("dt.tensor_list_get_size", registry->AddKernel("dt.tensor_list_get_size",
INFRT_KERNEL(TensorListGetSize)); INFRT_KERNEL(TensorListGetSize));
......
...@@ -287,7 +287,7 @@ def gen_register_code_info(item: List[str], attr_data: Dict[str, List[str]]): ...@@ -287,7 +287,7 @@ def gen_register_code_info(item: List[str], attr_data: Dict[str, List[str]]):
attr_names = ', '.join( attr_names = ', '.join(
["\"" + a + "\"" for a in attr_data[ir_name]]) ["\"" + a + "\"" for a in attr_data[ir_name]])
res += f""" res += f"""
registry->AddKernelWithAttrs("{ir_name}",""" registry->AddKernel("{ir_name}","""
res += f""" res += f"""
&KernelLauncherFunc<decltype({kernel_func}), &KernelLauncherFunc<decltype({kernel_func}),
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册