未验证 提交 60c4c9cd 编写于 作者: 王明冬 提交者: GitHub

[Infrt] add infer shape cache for kernel. (#41104)

上级 532eba99
......@@ -24,30 +24,30 @@ namespace host_context {
struct KernelRegistry::Impl {
std::unordered_map<std::string,
std::pair<KernelImplementation, std::vector<const char *>>>
std::pair<KernelLauncher, std::vector<const char *>>>
data;
};
KernelRegistry::KernelRegistry() : impl_(std::make_unique<Impl>()) {}
void KernelRegistry::AddKernel(const std::string &key,
KernelImplementation fn) {
CHECK(!impl_->data.count(key)) << "kernel [" << key
<< "] is registered twice";
impl_->data.emplace(
key, std::make_pair(std::move(fn), std::vector<const char *>{}));
}
const std::vector<const char *> &KernelRegistry::GetAttrNameList(
const std::string &key) const {
CHECK(impl_->data.count(key));
return impl_->data[key].second;
}
void KernelRegistry::AddKernelWithAttrs(
const std::string &key,
void KernelRegistry::AddKernel(const std::string &key,
KernelImplementation fn,
std::vector<const char *> &&attr_order) {
const std::vector<const char *> &attr_order) {
CHECK(!impl_->data.count(key)) << "kernel [" << key
<< "] is registered twice";
impl_->data.emplace(
key, std::make_pair([fn]() { return fn; }, std::move(attr_order)));
}
void KernelRegistry::AddKernel(const std::string &key,
KernelLauncher fn,
const std::vector<const char *> &attr_order) {
CHECK(!impl_->data.count(key)) << "kernel [" << key
<< "] is registered twice";
impl_->data.emplace(key,
......@@ -56,7 +56,7 @@ void KernelRegistry::AddKernelWithAttrs(
KernelImplementation KernelRegistry::GetKernel(const std::string &key) const {
auto it = impl_->data.find(key);
return it != impl_->data.end() ? it->second.first : KernelImplementation{};
return it != impl_->data.end() ? it->second.first() : KernelImplementation{};
}
std::vector<std::string> KernelRegistry::GetKernelList() const {
......
......@@ -25,6 +25,7 @@ namespace host_context {
class KernelFrame;
using KernelImplementation = std::function<void(KernelFrame *frame)>;
using KernelLauncher = std::function<KernelImplementation()>;
/**
* Hold the kernels registered in the system.
......@@ -33,10 +34,12 @@ class KernelRegistry {
public:
KernelRegistry();
void AddKernel(const std::string &key, KernelImplementation fn);
void AddKernelWithAttrs(const std::string &key,
void AddKernel(const std::string &key,
KernelImplementation fn,
std::vector<const char *> &&attrs_order);
const std::vector<const char *> &attrs_order = {});
void AddKernel(const std::string &key,
KernelLauncher fn,
const std::vector<const char *> &attrs_order = {});
KernelImplementation GetKernel(const std::string &key) const;
const std::vector<const char *> &GetAttrNameList(
......
......@@ -360,8 +360,7 @@ bool MlirToRuntimeTranslator::EmitGeneralOp(
if (attrs.size()) {
if (attr_names.empty()) {
LOG(WARNING) << "The kernel `" << kernel_name
<< "` has not been registered with "
"`KernelRegistry::AddKernelWithAttrs()`.";
<< "` has not been registered with attributes order ";
} else {
CHECK_EQ(attr_names.size(), attrs.size())
<< "The number of kernel `" << kernel_name
......@@ -380,8 +379,7 @@ bool MlirToRuntimeTranslator::EmitGeneralOp(
}
}
LOG(WARNING) << "The attribute `" << attr << "` of kernel `" << kernel_name
<< "` is not properly registered with "
"`KernelRegistry::AddKernelWithAttrs()`.";
<< "` is not properly register";
return -1;
};
......
......@@ -29,7 +29,6 @@ add_custom_target(infrt_register_phi_kernel
cc_library(infrt_naive SRCS infershaped/infershaped_kernel_launcher.cc
infershaped/infershaped_kernel_launchers.cc
DEPS phi wrapped_infermeta)
add_dependencies(infrt_naive infrt_register_phi_kernel)
cc_test_tiny(test_infrt_infershape_launchers SRCS
infershaped/infershape_launchers_test.cc DEPS infrt)
......@@ -17,6 +17,7 @@
#include <iostream>
#include "paddle/infrt/backends/host/phi_context.h"
#include "paddle/infrt/host_context/kernel_registry.h"
#include "paddle/infrt/host_context/kernel_utils.h"
#include "paddle/infrt/kernel/phi/infershaped/infershaped_kernel_launcher.h"
#include "paddle/infrt/kernel/phi/infershaped/infershaped_utils.h"
......@@ -36,11 +37,12 @@ template <typename KernelFunc,
KernelFunc kernel,
typename InferShapedFunc,
InferShapedFunc infershape>
void KernelLauncherFunc(host_context::KernelFrame* frame) {
::infrt::host_context::KernelImplementation KernelLauncherFunc() {
InferShapedKernelLauncher launcher(FuncArgStatics<InferShapedFunc>::arg_size);
static const uint16_t num_input_tensors{InferShapeHelper<KernelFunc>::count};
static const bool turn_on_infer_shape_cache{true};
return [=](host_context::KernelFrame* frame) mutable {
#ifndef NDEBUG
LOG(INFO) << "Kernel.frame: " << frame->DumpArgTypes();
#endif
......@@ -59,8 +61,12 @@ void KernelLauncherFunc(host_context::KernelFrame* frame) {
&launcher.infershape_kernel_frame_builder);
launcher.BuildInferShapeCache(num_input_tensors);
}
} else {
::infrt::host_context::KernelImpl<InferShapedFunc, infershape>::Invoke(
&launcher.infershape_kernel_frame_builder);
}
::infrt::host_context::KernelImpl<KernelFunc, kernel>::Invoke(frame);
};
}
} // namespace kernel
......
......@@ -34,18 +34,16 @@ namespace kernel {
void RegisterPhiKernels(host_context::KernelRegistry* registry) {
registry->AddKernel("phi_dt.create_context.cpu",
INFRT_KERNEL(infrt::kernel::phi::CreateCPUContext));
registry->AddKernelWithAttrs(
"phi_dt.create_dense_tensor.cpu",
registry->AddKernel("phi_dt.create_dense_tensor.cpu",
INFRT_KERNEL(infrt::kernel::phi::CreateDenseTensor),
{"dims", "lod", "layout", "precision"});
registry->AddKernelWithAttrs(
registry->AddKernel(
"phi_dt.create_inited_dense_tensor.cpu.f32",
INFRT_KERNEL(infrt::kernel::phi::CreateInitedDenseTensorF32),
{"dims", "lod", "layout", "value"});
registry->AddKernelWithAttrs(
"phi_dt.fill_dense_tensor.f32",
registry->AddKernel("phi_dt.fill_dense_tensor.f32",
INFRT_KERNEL(infrt::kernel::phi::FillDenseTensorF32),
{"value"});
registry->AddKernel("phi_dt.print_tensor",
......@@ -54,23 +52,20 @@ void RegisterPhiKernels(host_context::KernelRegistry* registry) {
#ifdef INFRT_WITH_GPU
registry->AddKernel("phi_dt.create_context.gpu",
INFRT_KERNEL(infrt::kernel::phi::CreateGPUContext));
registry->AddKernelWithAttrs(
"phi_dt.create_dense_tensor.gpu",
registry->AddKernel("phi_dt.create_dense_tensor.gpu",
INFRT_KERNEL(infrt::kernel::phi::CreateGPUDenseTensor),
{"dims", "lod", "layout", "precision"});
registry->AddKernelWithAttrs("phi_dt.memcpy.gpu",
registry->AddKernel("phi_dt.memcpy.gpu",
INFRT_KERNEL(infrt::kernel::phi::GpuMemCpy),
{"d2h"});
#endif
registry->AddKernelWithAttrs("phi_dt.load_params",
registry->AddKernel("phi_dt.load_params",
INFRT_KERNEL(infrt::kernel::phi::LoadParams),
{"path"});
registry->AddKernelWithAttrs(
"phi_dt.load_combined_params",
registry->AddKernel("phi_dt.load_combined_params",
INFRT_KERNEL(infrt::kernel::phi::LoadCombinedParams),
{"model_path", "params_path"});
registry->AddKernelWithAttrs(
"phi_dt.tensor_map_get_tensor",
registry->AddKernel("phi_dt.tensor_map_get_tensor",
INFRT_KERNEL(infrt::kernel::phi::TensorMapGetTensor),
{"name"});
registry->AddKernel("phi_dt.tensor_map_get_size",
......
......@@ -129,7 +129,7 @@ void NaiveMatmul(const DenseHostTensor &x,
/// ===== Kernel end ====
void RegisterTensorKernels(host_context::KernelRegistry *registry) {
registry->AddKernelWithAttrs("dt.create_uninit_tensor.f32",
registry->AddKernel("dt.create_uninit_tensor.f32",
INFRT_KERNEL(CreateUninitTensor<float>),
{"shape"});
registry->AddKernel("dt.print_tensor", INFRT_KERNEL(PrintTensor));
......@@ -146,7 +146,7 @@ void RegisterTensorKernels(host_context::KernelRegistry *registry) {
// TensorList related methods.
#ifdef INFRT_WITH_PHI
registry->AddKernelWithAttrs(
registry->AddKernel(
"dt.tensor_list_get_tensor", INFRT_KERNEL(TensorListGetTensor), {"id"});
registry->AddKernel("dt.tensor_list_get_size",
INFRT_KERNEL(TensorListGetSize));
......
......@@ -287,7 +287,7 @@ def gen_register_code_info(item: List[str], attr_data: Dict[str, List[str]]):
attr_names = ', '.join(
["\"" + a + "\"" for a in attr_data[ir_name]])
res += f"""
registry->AddKernelWithAttrs("{ir_name}","""
registry->AddKernel("{ir_name}","""
res += f"""
&KernelLauncherFunc<decltype({kernel_func}),
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册