diff --git a/paddle/infrt/host_context/kernel_registry.cc b/paddle/infrt/host_context/kernel_registry.cc index 4209b2a9648d8be0a9a3897c27c7a35113cba424..5693e973a3f9894cdabb9eb4bc837d840638213e 100644 --- a/paddle/infrt/host_context/kernel_registry.cc +++ b/paddle/infrt/host_context/kernel_registry.cc @@ -24,30 +24,30 @@ namespace host_context { struct KernelRegistry::Impl { std::unordered_map>> + std::pair>> data; }; KernelRegistry::KernelRegistry() : impl_(std::make_unique()) {} -void KernelRegistry::AddKernel(const std::string &key, - KernelImplementation fn) { - CHECK(!impl_->data.count(key)) << "kernel [" << key - << "] is registered twice"; - impl_->data.emplace( - key, std::make_pair(std::move(fn), std::vector{})); -} - const std::vector &KernelRegistry::GetAttrNameList( const std::string &key) const { CHECK(impl_->data.count(key)); return impl_->data[key].second; } -void KernelRegistry::AddKernelWithAttrs( - const std::string &key, - KernelImplementation fn, - std::vector &&attr_order) { +void KernelRegistry::AddKernel(const std::string &key, + KernelImplementation fn, + const std::vector &attr_order) { + CHECK(!impl_->data.count(key)) << "kernel [" << key + << "] is registered twice"; + impl_->data.emplace( + key, std::make_pair([fn]() { return fn; }, std::move(attr_order))); +} + +void KernelRegistry::AddKernel(const std::string &key, + KernelLauncher fn, + const std::vector &attr_order) { CHECK(!impl_->data.count(key)) << "kernel [" << key << "] is registered twice"; impl_->data.emplace(key, @@ -56,7 +56,7 @@ void KernelRegistry::AddKernelWithAttrs( KernelImplementation KernelRegistry::GetKernel(const std::string &key) const { auto it = impl_->data.find(key); - return it != impl_->data.end() ? it->second.first : KernelImplementation{}; + return it != impl_->data.end() ? it->second.first() : KernelImplementation{}; } std::vector KernelRegistry::GetKernelList() const { diff --git a/paddle/infrt/host_context/kernel_registry.h b/paddle/infrt/host_context/kernel_registry.h index a146b2b3c4c1e1090b5ac1843466b93a31b0bb0b..a9f2b407bd414f383e5a23990c648365c5a4c5fa 100644 --- a/paddle/infrt/host_context/kernel_registry.h +++ b/paddle/infrt/host_context/kernel_registry.h @@ -25,6 +25,7 @@ namespace host_context { class KernelFrame; using KernelImplementation = std::function; +using KernelLauncher = std::function; /** * Hold the kernels registered in the system. @@ -33,10 +34,12 @@ class KernelRegistry { public: KernelRegistry(); - void AddKernel(const std::string &key, KernelImplementation fn); - void AddKernelWithAttrs(const std::string &key, - KernelImplementation fn, - std::vector &&attrs_order); + void AddKernel(const std::string &key, + KernelImplementation fn, + const std::vector &attrs_order = {}); + void AddKernel(const std::string &key, + KernelLauncher fn, + const std::vector &attrs_order = {}); KernelImplementation GetKernel(const std::string &key) const; const std::vector &GetAttrNameList( diff --git a/paddle/infrt/host_context/mlir_to_runtime_translate.cc b/paddle/infrt/host_context/mlir_to_runtime_translate.cc index 007730151e370da4f53da74b302c4ff43f4b2238..05bb28b7c56137f19d0c6a4159a07bfc77e053e1 100644 --- a/paddle/infrt/host_context/mlir_to_runtime_translate.cc +++ b/paddle/infrt/host_context/mlir_to_runtime_translate.cc @@ -360,8 +360,7 @@ bool MlirToRuntimeTranslator::EmitGeneralOp( if (attrs.size()) { if (attr_names.empty()) { LOG(WARNING) << "The kernel `" << kernel_name - << "` has not been registered with " - "`KernelRegistry::AddKernelWithAttrs()`."; + << "` has not been registered with attributes order "; } else { CHECK_EQ(attr_names.size(), attrs.size()) << "The number of kernel `" << kernel_name @@ -380,8 +379,7 @@ bool MlirToRuntimeTranslator::EmitGeneralOp( } } LOG(WARNING) << "The attribute `" << attr << "` of kernel `" << kernel_name - << "` is not properly registered with " - "`KernelRegistry::AddKernelWithAttrs()`."; + << "` is not properly register"; return -1; }; diff --git a/paddle/infrt/kernel/phi/CMakeLists.txt b/paddle/infrt/kernel/phi/CMakeLists.txt index 50f61c7ba6ad070e5de536867b8a19860ffeeb1f..22a59ab2faf8c475cf916f29ee2420f50f6b20b4 100644 --- a/paddle/infrt/kernel/phi/CMakeLists.txt +++ b/paddle/infrt/kernel/phi/CMakeLists.txt @@ -29,7 +29,6 @@ add_custom_target(infrt_register_phi_kernel cc_library(infrt_naive SRCS infershaped/infershaped_kernel_launcher.cc infershaped/infershaped_kernel_launchers.cc DEPS phi wrapped_infermeta) -add_dependencies(infrt_naive infrt_register_phi_kernel) cc_test_tiny(test_infrt_infershape_launchers SRCS infershaped/infershape_launchers_test.cc DEPS infrt) diff --git a/paddle/infrt/kernel/phi/infershaped/phi_kernel_launcher.h b/paddle/infrt/kernel/phi/infershaped/phi_kernel_launcher.h index 34ef4460fc634b15aad08b05ecbb53981c3360fa..d87027847202bc12fd6d55712961946cf18a29a7 100644 --- a/paddle/infrt/kernel/phi/infershaped/phi_kernel_launcher.h +++ b/paddle/infrt/kernel/phi/infershaped/phi_kernel_launcher.h @@ -17,6 +17,7 @@ #include #include "paddle/infrt/backends/host/phi_context.h" +#include "paddle/infrt/host_context/kernel_registry.h" #include "paddle/infrt/host_context/kernel_utils.h" #include "paddle/infrt/kernel/phi/infershaped/infershaped_kernel_launcher.h" #include "paddle/infrt/kernel/phi/infershaped/infershaped_utils.h" @@ -36,31 +37,36 @@ template -void KernelLauncherFunc(host_context::KernelFrame* frame) { +::infrt::host_context::KernelImplementation KernelLauncherFunc() { InferShapedKernelLauncher launcher(FuncArgStatics::arg_size); static const uint16_t num_input_tensors{InferShapeHelper::count}; static const bool turn_on_infer_shape_cache{true}; + return [=](host_context::KernelFrame* frame) mutable { #ifndef NDEBUG - LOG(INFO) << "Kernel.frame: " << frame->DumpArgTypes(); + LOG(INFO) << "Kernel.frame: " << frame->DumpArgTypes(); #endif - // Build the infershape KernelFrame if needed. - // TODO(Superjomn) add unlikely here. - if (launcher.infershape_kernel_frame_builder.IsEmpty()) { - launcher.CreateKernelFrameForInferShape(frame); + // Build the infershape KernelFrame if needed. + // TODO(Superjomn) add unlikely here. + if (launcher.infershape_kernel_frame_builder.IsEmpty()) { + launcher.CreateKernelFrameForInferShape(frame); #ifndef NDEBUG - LOG(INFO) << "infershape.frame: " - << launcher.infershape_kernel_frame_builder.DumpArgTypes(); + LOG(INFO) << "infershape.frame: " + << launcher.infershape_kernel_frame_builder.DumpArgTypes(); #endif - } - if (turn_on_infer_shape_cache) { - if (launcher.IsShapeChanged(num_input_tensors)) { + } + if (turn_on_infer_shape_cache) { + if (launcher.IsShapeChanged(num_input_tensors)) { + ::infrt::host_context::KernelImpl::Invoke( + &launcher.infershape_kernel_frame_builder); + launcher.BuildInferShapeCache(num_input_tensors); + } + } else { ::infrt::host_context::KernelImpl::Invoke( &launcher.infershape_kernel_frame_builder); - launcher.BuildInferShapeCache(num_input_tensors); } - } - ::infrt::host_context::KernelImpl::Invoke(frame); + ::infrt::host_context::KernelImpl::Invoke(frame); + }; } } // namespace kernel diff --git a/paddle/infrt/kernel/phi/registry.cc b/paddle/infrt/kernel/phi/registry.cc index 0427a2c1e599854e3ad9cdcc2942b1ba27a5ab9e..047788112508078e1e07f82a251994c9218813a6 100644 --- a/paddle/infrt/kernel/phi/registry.cc +++ b/paddle/infrt/kernel/phi/registry.cc @@ -34,45 +34,40 @@ namespace kernel { void RegisterPhiKernels(host_context::KernelRegistry* registry) { registry->AddKernel("phi_dt.create_context.cpu", INFRT_KERNEL(infrt::kernel::phi::CreateCPUContext)); - registry->AddKernelWithAttrs( - "phi_dt.create_dense_tensor.cpu", - INFRT_KERNEL(infrt::kernel::phi::CreateDenseTensor), - {"dims", "lod", "layout", "precision"}); + registry->AddKernel("phi_dt.create_dense_tensor.cpu", + INFRT_KERNEL(infrt::kernel::phi::CreateDenseTensor), + {"dims", "lod", "layout", "precision"}); - registry->AddKernelWithAttrs( + registry->AddKernel( "phi_dt.create_inited_dense_tensor.cpu.f32", INFRT_KERNEL(infrt::kernel::phi::CreateInitedDenseTensorF32), {"dims", "lod", "layout", "value"}); - registry->AddKernelWithAttrs( - "phi_dt.fill_dense_tensor.f32", - INFRT_KERNEL(infrt::kernel::phi::FillDenseTensorF32), - {"value"}); + registry->AddKernel("phi_dt.fill_dense_tensor.f32", + INFRT_KERNEL(infrt::kernel::phi::FillDenseTensorF32), + {"value"}); registry->AddKernel("phi_dt.print_tensor", INFRT_KERNEL(infrt::kernel::phi::PrintDenseTensor)); #ifdef INFRT_WITH_GPU registry->AddKernel("phi_dt.create_context.gpu", INFRT_KERNEL(infrt::kernel::phi::CreateGPUContext)); - registry->AddKernelWithAttrs( - "phi_dt.create_dense_tensor.gpu", - INFRT_KERNEL(infrt::kernel::phi::CreateGPUDenseTensor), - {"dims", "lod", "layout", "precision"}); - registry->AddKernelWithAttrs("phi_dt.memcpy.gpu", - INFRT_KERNEL(infrt::kernel::phi::GpuMemCpy), - {"d2h"}); + registry->AddKernel("phi_dt.create_dense_tensor.gpu", + INFRT_KERNEL(infrt::kernel::phi::CreateGPUDenseTensor), + {"dims", "lod", "layout", "precision"}); + registry->AddKernel("phi_dt.memcpy.gpu", + INFRT_KERNEL(infrt::kernel::phi::GpuMemCpy), + {"d2h"}); #endif - registry->AddKernelWithAttrs("phi_dt.load_params", - INFRT_KERNEL(infrt::kernel::phi::LoadParams), - {"path"}); - registry->AddKernelWithAttrs( - "phi_dt.load_combined_params", - INFRT_KERNEL(infrt::kernel::phi::LoadCombinedParams), - {"model_path", "params_path"}); - registry->AddKernelWithAttrs( - "phi_dt.tensor_map_get_tensor", - INFRT_KERNEL(infrt::kernel::phi::TensorMapGetTensor), - {"name"}); + registry->AddKernel("phi_dt.load_params", + INFRT_KERNEL(infrt::kernel::phi::LoadParams), + {"path"}); + registry->AddKernel("phi_dt.load_combined_params", + INFRT_KERNEL(infrt::kernel::phi::LoadCombinedParams), + {"model_path", "params_path"}); + registry->AddKernel("phi_dt.tensor_map_get_tensor", + INFRT_KERNEL(infrt::kernel::phi::TensorMapGetTensor), + {"name"}); registry->AddKernel("phi_dt.tensor_map_get_size", INFRT_KERNEL(infrt::kernel::phi::TensorMapGetSize)); } diff --git a/paddle/infrt/kernel/tensor_kernels.cc b/paddle/infrt/kernel/tensor_kernels.cc index 407ae16c19c499a5feec269f39f5f907aedc84d4..65e137472b3d6225cff990afd2d97384d95adae7 100644 --- a/paddle/infrt/kernel/tensor_kernels.cc +++ b/paddle/infrt/kernel/tensor_kernels.cc @@ -129,9 +129,9 @@ void NaiveMatmul(const DenseHostTensor &x, /// ===== Kernel end ==== void RegisterTensorKernels(host_context::KernelRegistry *registry) { - registry->AddKernelWithAttrs("dt.create_uninit_tensor.f32", - INFRT_KERNEL(CreateUninitTensor), - {"shape"}); + registry->AddKernel("dt.create_uninit_tensor.f32", + INFRT_KERNEL(CreateUninitTensor), + {"shape"}); registry->AddKernel("dt.print_tensor", INFRT_KERNEL(PrintTensor)); registry->AddKernel("dt.fill_tensor_with_constant.f32", INFRT_KERNEL(FillTensorWithConstant)); @@ -146,7 +146,7 @@ void RegisterTensorKernels(host_context::KernelRegistry *registry) { // TensorList related methods. #ifdef INFRT_WITH_PHI - registry->AddKernelWithAttrs( + registry->AddKernel( "dt.tensor_list_get_tensor", INFRT_KERNEL(TensorListGetTensor), {"id"}); registry->AddKernel("dt.tensor_list_get_size", INFRT_KERNEL(TensorListGetSize)); diff --git a/tools/infrt/get_phi_kernel_info.py b/tools/infrt/get_phi_kernel_info.py index 3fb40706e230635c1e317a991ed2bcbf712c75ad..c4c02d67cf70b94e9ebb8cafd82b5800532e3726 100644 --- a/tools/infrt/get_phi_kernel_info.py +++ b/tools/infrt/get_phi_kernel_info.py @@ -287,7 +287,7 @@ def gen_register_code_info(item: List[str], attr_data: Dict[str, List[str]]): attr_names = ', '.join( ["\"" + a + "\"" for a in attr_data[ir_name]]) res += f""" -registry->AddKernelWithAttrs("{ir_name}",""" +registry->AddKernel("{ir_name}",""" res += f""" &KernelLauncherFunc