Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
60c4c9cd
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
60c4c9cd
编写于
3月 30, 2022
作者:
王
王明冬
提交者:
GitHub
3月 30, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[Infrt] add infer shape cache for kernel. (#41104)
上级
532eba99
变更
8
隐藏空白更改
内联
并排
Showing
8 changed file
with
70 addition
and
69 deletion
+70
-69
paddle/infrt/host_context/kernel_registry.cc
paddle/infrt/host_context/kernel_registry.cc
+14
-14
paddle/infrt/host_context/kernel_registry.h
paddle/infrt/host_context/kernel_registry.h
+7
-4
paddle/infrt/host_context/mlir_to_runtime_translate.cc
paddle/infrt/host_context/mlir_to_runtime_translate.cc
+2
-4
paddle/infrt/kernel/phi/CMakeLists.txt
paddle/infrt/kernel/phi/CMakeLists.txt
+0
-1
paddle/infrt/kernel/phi/infershaped/phi_kernel_launcher.h
paddle/infrt/kernel/phi/infershaped/phi_kernel_launcher.h
+20
-14
paddle/infrt/kernel/phi/registry.cc
paddle/infrt/kernel/phi/registry.cc
+22
-27
paddle/infrt/kernel/tensor_kernels.cc
paddle/infrt/kernel/tensor_kernels.cc
+4
-4
tools/infrt/get_phi_kernel_info.py
tools/infrt/get_phi_kernel_info.py
+1
-1
未找到文件。
paddle/infrt/host_context/kernel_registry.cc
浏览文件 @
60c4c9cd
...
...
@@ -24,30 +24,30 @@ namespace host_context {
struct
KernelRegistry
::
Impl
{
std
::
unordered_map
<
std
::
string
,
std
::
pair
<
Kernel
Implementation
,
std
::
vector
<
const
char
*>>>
std
::
pair
<
Kernel
Launcher
,
std
::
vector
<
const
char
*>>>
data
;
};
KernelRegistry
::
KernelRegistry
()
:
impl_
(
std
::
make_unique
<
Impl
>
())
{}
void
KernelRegistry
::
AddKernel
(
const
std
::
string
&
key
,
KernelImplementation
fn
)
{
CHECK
(
!
impl_
->
data
.
count
(
key
))
<<
"kernel ["
<<
key
<<
"] is registered twice"
;
impl_
->
data
.
emplace
(
key
,
std
::
make_pair
(
std
::
move
(
fn
),
std
::
vector
<
const
char
*>
{}));
}
const
std
::
vector
<
const
char
*>
&
KernelRegistry
::
GetAttrNameList
(
const
std
::
string
&
key
)
const
{
CHECK
(
impl_
->
data
.
count
(
key
));
return
impl_
->
data
[
key
].
second
;
}
void
KernelRegistry
::
AddKernelWithAttrs
(
const
std
::
string
&
key
,
KernelImplementation
fn
,
std
::
vector
<
const
char
*>
&&
attr_order
)
{
void
KernelRegistry
::
AddKernel
(
const
std
::
string
&
key
,
KernelImplementation
fn
,
const
std
::
vector
<
const
char
*>
&
attr_order
)
{
CHECK
(
!
impl_
->
data
.
count
(
key
))
<<
"kernel ["
<<
key
<<
"] is registered twice"
;
impl_
->
data
.
emplace
(
key
,
std
::
make_pair
([
fn
]()
{
return
fn
;
},
std
::
move
(
attr_order
)));
}
void
KernelRegistry
::
AddKernel
(
const
std
::
string
&
key
,
KernelLauncher
fn
,
const
std
::
vector
<
const
char
*>
&
attr_order
)
{
CHECK
(
!
impl_
->
data
.
count
(
key
))
<<
"kernel ["
<<
key
<<
"] is registered twice"
;
impl_
->
data
.
emplace
(
key
,
...
...
@@ -56,7 +56,7 @@ void KernelRegistry::AddKernelWithAttrs(
KernelImplementation
KernelRegistry
::
GetKernel
(
const
std
::
string
&
key
)
const
{
auto
it
=
impl_
->
data
.
find
(
key
);
return
it
!=
impl_
->
data
.
end
()
?
it
->
second
.
first
:
KernelImplementation
{};
return
it
!=
impl_
->
data
.
end
()
?
it
->
second
.
first
()
:
KernelImplementation
{};
}
std
::
vector
<
std
::
string
>
KernelRegistry
::
GetKernelList
()
const
{
...
...
paddle/infrt/host_context/kernel_registry.h
浏览文件 @
60c4c9cd
...
...
@@ -25,6 +25,7 @@ namespace host_context {
class
KernelFrame
;
using
KernelImplementation
=
std
::
function
<
void
(
KernelFrame
*
frame
)
>
;
using
KernelLauncher
=
std
::
function
<
KernelImplementation
()
>
;
/**
* Hold the kernels registered in the system.
...
...
@@ -33,10 +34,12 @@ class KernelRegistry {
public:
KernelRegistry
();
void
AddKernel
(
const
std
::
string
&
key
,
KernelImplementation
fn
);
void
AddKernelWithAttrs
(
const
std
::
string
&
key
,
KernelImplementation
fn
,
std
::
vector
<
const
char
*>
&&
attrs_order
);
void
AddKernel
(
const
std
::
string
&
key
,
KernelImplementation
fn
,
const
std
::
vector
<
const
char
*>
&
attrs_order
=
{});
void
AddKernel
(
const
std
::
string
&
key
,
KernelLauncher
fn
,
const
std
::
vector
<
const
char
*>
&
attrs_order
=
{});
KernelImplementation
GetKernel
(
const
std
::
string
&
key
)
const
;
const
std
::
vector
<
const
char
*>
&
GetAttrNameList
(
...
...
paddle/infrt/host_context/mlir_to_runtime_translate.cc
浏览文件 @
60c4c9cd
...
...
@@ -360,8 +360,7 @@ bool MlirToRuntimeTranslator::EmitGeneralOp(
if
(
attrs
.
size
())
{
if
(
attr_names
.
empty
())
{
LOG
(
WARNING
)
<<
"The kernel `"
<<
kernel_name
<<
"` has not been registered with "
"`KernelRegistry::AddKernelWithAttrs()`."
;
<<
"` has not been registered with attributes order "
;
}
else
{
CHECK_EQ
(
attr_names
.
size
(),
attrs
.
size
())
<<
"The number of kernel `"
<<
kernel_name
...
...
@@ -380,8 +379,7 @@ bool MlirToRuntimeTranslator::EmitGeneralOp(
}
}
LOG
(
WARNING
)
<<
"The attribute `"
<<
attr
<<
"` of kernel `"
<<
kernel_name
<<
"` is not properly registered with "
"`KernelRegistry::AddKernelWithAttrs()`."
;
<<
"` is not properly register"
;
return
-
1
;
};
...
...
paddle/infrt/kernel/phi/CMakeLists.txt
浏览文件 @
60c4c9cd
...
...
@@ -29,7 +29,6 @@ add_custom_target(infrt_register_phi_kernel
cc_library
(
infrt_naive SRCS infershaped/infershaped_kernel_launcher.cc
infershaped/infershaped_kernel_launchers.cc
DEPS phi wrapped_infermeta
)
add_dependencies
(
infrt_naive infrt_register_phi_kernel
)
cc_test_tiny
(
test_infrt_infershape_launchers SRCS
infershaped/infershape_launchers_test.cc DEPS infrt
)
paddle/infrt/kernel/phi/infershaped/phi_kernel_launcher.h
浏览文件 @
60c4c9cd
...
...
@@ -17,6 +17,7 @@
#include <iostream>
#include "paddle/infrt/backends/host/phi_context.h"
#include "paddle/infrt/host_context/kernel_registry.h"
#include "paddle/infrt/host_context/kernel_utils.h"
#include "paddle/infrt/kernel/phi/infershaped/infershaped_kernel_launcher.h"
#include "paddle/infrt/kernel/phi/infershaped/infershaped_utils.h"
...
...
@@ -36,31 +37,36 @@ template <typename KernelFunc,
KernelFunc
kernel
,
typename
InferShapedFunc
,
InferShapedFunc
infershape
>
void
KernelLauncherFunc
(
host_context
::
KernelFrame
*
frame
)
{
::
infrt
::
host_context
::
KernelImplementation
KernelLauncherFunc
(
)
{
InferShapedKernelLauncher
launcher
(
FuncArgStatics
<
InferShapedFunc
>::
arg_size
);
static
const
uint16_t
num_input_tensors
{
InferShapeHelper
<
KernelFunc
>::
count
};
static
const
bool
turn_on_infer_shape_cache
{
true
};
return
[
=
](
host_context
::
KernelFrame
*
frame
)
mutable
{
#ifndef NDEBUG
LOG
(
INFO
)
<<
"Kernel.frame: "
<<
frame
->
DumpArgTypes
();
LOG
(
INFO
)
<<
"Kernel.frame: "
<<
frame
->
DumpArgTypes
();
#endif
// Build the infershape KernelFrame if needed.
// TODO(Superjomn) add unlikely here.
if
(
launcher
.
infershape_kernel_frame_builder
.
IsEmpty
())
{
launcher
.
CreateKernelFrameForInferShape
(
frame
);
// Build the infershape KernelFrame if needed.
// TODO(Superjomn) add unlikely here.
if
(
launcher
.
infershape_kernel_frame_builder
.
IsEmpty
())
{
launcher
.
CreateKernelFrameForInferShape
(
frame
);
#ifndef NDEBUG
LOG
(
INFO
)
<<
"infershape.frame: "
<<
launcher
.
infershape_kernel_frame_builder
.
DumpArgTypes
();
LOG
(
INFO
)
<<
"infershape.frame: "
<<
launcher
.
infershape_kernel_frame_builder
.
DumpArgTypes
();
#endif
}
if
(
turn_on_infer_shape_cache
)
{
if
(
launcher
.
IsShapeChanged
(
num_input_tensors
))
{
}
if
(
turn_on_infer_shape_cache
)
{
if
(
launcher
.
IsShapeChanged
(
num_input_tensors
))
{
::
infrt
::
host_context
::
KernelImpl
<
InferShapedFunc
,
infershape
>::
Invoke
(
&
launcher
.
infershape_kernel_frame_builder
);
launcher
.
BuildInferShapeCache
(
num_input_tensors
);
}
}
else
{
::
infrt
::
host_context
::
KernelImpl
<
InferShapedFunc
,
infershape
>::
Invoke
(
&
launcher
.
infershape_kernel_frame_builder
);
launcher
.
BuildInferShapeCache
(
num_input_tensors
);
}
}
::
infrt
::
host_context
::
KernelImpl
<
KernelFunc
,
kernel
>::
Invoke
(
frame
)
;
::
infrt
::
host_context
::
KernelImpl
<
KernelFunc
,
kernel
>::
Invoke
(
frame
);
}
;
}
}
// namespace kernel
...
...
paddle/infrt/kernel/phi/registry.cc
浏览文件 @
60c4c9cd
...
...
@@ -34,45 +34,40 @@ namespace kernel {
void
RegisterPhiKernels
(
host_context
::
KernelRegistry
*
registry
)
{
registry
->
AddKernel
(
"phi_dt.create_context.cpu"
,
INFRT_KERNEL
(
infrt
::
kernel
::
phi
::
CreateCPUContext
));
registry
->
AddKernelWithAttrs
(
"phi_dt.create_dense_tensor.cpu"
,
INFRT_KERNEL
(
infrt
::
kernel
::
phi
::
CreateDenseTensor
),
{
"dims"
,
"lod"
,
"layout"
,
"precision"
});
registry
->
AddKernel
(
"phi_dt.create_dense_tensor.cpu"
,
INFRT_KERNEL
(
infrt
::
kernel
::
phi
::
CreateDenseTensor
),
{
"dims"
,
"lod"
,
"layout"
,
"precision"
});
registry
->
AddKernel
WithAttrs
(
registry
->
AddKernel
(
"phi_dt.create_inited_dense_tensor.cpu.f32"
,
INFRT_KERNEL
(
infrt
::
kernel
::
phi
::
CreateInitedDenseTensorF32
),
{
"dims"
,
"lod"
,
"layout"
,
"value"
});
registry
->
AddKernelWithAttrs
(
"phi_dt.fill_dense_tensor.f32"
,
INFRT_KERNEL
(
infrt
::
kernel
::
phi
::
FillDenseTensorF32
),
{
"value"
});
registry
->
AddKernel
(
"phi_dt.fill_dense_tensor.f32"
,
INFRT_KERNEL
(
infrt
::
kernel
::
phi
::
FillDenseTensorF32
),
{
"value"
});
registry
->
AddKernel
(
"phi_dt.print_tensor"
,
INFRT_KERNEL
(
infrt
::
kernel
::
phi
::
PrintDenseTensor
));
#ifdef INFRT_WITH_GPU
registry
->
AddKernel
(
"phi_dt.create_context.gpu"
,
INFRT_KERNEL
(
infrt
::
kernel
::
phi
::
CreateGPUContext
));
registry
->
AddKernelWithAttrs
(
"phi_dt.create_dense_tensor.gpu"
,
INFRT_KERNEL
(
infrt
::
kernel
::
phi
::
CreateGPUDenseTensor
),
{
"dims"
,
"lod"
,
"layout"
,
"precision"
});
registry
->
AddKernelWithAttrs
(
"phi_dt.memcpy.gpu"
,
INFRT_KERNEL
(
infrt
::
kernel
::
phi
::
GpuMemCpy
),
{
"d2h"
});
registry
->
AddKernel
(
"phi_dt.create_dense_tensor.gpu"
,
INFRT_KERNEL
(
infrt
::
kernel
::
phi
::
CreateGPUDenseTensor
),
{
"dims"
,
"lod"
,
"layout"
,
"precision"
});
registry
->
AddKernel
(
"phi_dt.memcpy.gpu"
,
INFRT_KERNEL
(
infrt
::
kernel
::
phi
::
GpuMemCpy
),
{
"d2h"
});
#endif
registry
->
AddKernelWithAttrs
(
"phi_dt.load_params"
,
INFRT_KERNEL
(
infrt
::
kernel
::
phi
::
LoadParams
),
{
"path"
});
registry
->
AddKernelWithAttrs
(
"phi_dt.load_combined_params"
,
INFRT_KERNEL
(
infrt
::
kernel
::
phi
::
LoadCombinedParams
),
{
"model_path"
,
"params_path"
});
registry
->
AddKernelWithAttrs
(
"phi_dt.tensor_map_get_tensor"
,
INFRT_KERNEL
(
infrt
::
kernel
::
phi
::
TensorMapGetTensor
),
{
"name"
});
registry
->
AddKernel
(
"phi_dt.load_params"
,
INFRT_KERNEL
(
infrt
::
kernel
::
phi
::
LoadParams
),
{
"path"
});
registry
->
AddKernel
(
"phi_dt.load_combined_params"
,
INFRT_KERNEL
(
infrt
::
kernel
::
phi
::
LoadCombinedParams
),
{
"model_path"
,
"params_path"
});
registry
->
AddKernel
(
"phi_dt.tensor_map_get_tensor"
,
INFRT_KERNEL
(
infrt
::
kernel
::
phi
::
TensorMapGetTensor
),
{
"name"
});
registry
->
AddKernel
(
"phi_dt.tensor_map_get_size"
,
INFRT_KERNEL
(
infrt
::
kernel
::
phi
::
TensorMapGetSize
));
}
...
...
paddle/infrt/kernel/tensor_kernels.cc
浏览文件 @
60c4c9cd
...
...
@@ -129,9 +129,9 @@ void NaiveMatmul(const DenseHostTensor &x,
/// ===== Kernel end ====
void
RegisterTensorKernels
(
host_context
::
KernelRegistry
*
registry
)
{
registry
->
AddKernel
WithAttrs
(
"dt.create_uninit_tensor.f32"
,
INFRT_KERNEL
(
CreateUninitTensor
<
float
>
),
{
"shape"
});
registry
->
AddKernel
(
"dt.create_uninit_tensor.f32"
,
INFRT_KERNEL
(
CreateUninitTensor
<
float
>
),
{
"shape"
});
registry
->
AddKernel
(
"dt.print_tensor"
,
INFRT_KERNEL
(
PrintTensor
));
registry
->
AddKernel
(
"dt.fill_tensor_with_constant.f32"
,
INFRT_KERNEL
(
FillTensorWithConstant
<
float
>
));
...
...
@@ -146,7 +146,7 @@ void RegisterTensorKernels(host_context::KernelRegistry *registry) {
// TensorList related methods.
#ifdef INFRT_WITH_PHI
registry
->
AddKernel
WithAttrs
(
registry
->
AddKernel
(
"dt.tensor_list_get_tensor"
,
INFRT_KERNEL
(
TensorListGetTensor
),
{
"id"
});
registry
->
AddKernel
(
"dt.tensor_list_get_size"
,
INFRT_KERNEL
(
TensorListGetSize
));
...
...
tools/infrt/get_phi_kernel_info.py
浏览文件 @
60c4c9cd
...
...
@@ -287,7 +287,7 @@ def gen_register_code_info(item: List[str], attr_data: Dict[str, List[str]]):
attr_names
=
', '
.
join
(
[
"
\"
"
+
a
+
"
\"
"
for
a
in
attr_data
[
ir_name
]])
res
+=
f
"""
registry->AddKernel
WithAttrs
("
{
ir_name
}
","""
registry->AddKernel("
{
ir_name
}
","""
res
+=
f
"""
&KernelLauncherFunc<decltype(
{
kernel_func
}
),
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录