Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
0f3b1ad6
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
0f3b1ad6
编写于
12月 02, 2022
作者:
R
ronnywang
提交者:
GitHub
12月 02, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix phi capi kernel registration macro error (#48616)
* fix capi kernel registration macro error * update
上级
a7c43ffa
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
24 addition
and
8 deletion
+24
-8
paddle/phi/capi/include/kernel_registry.h
paddle/phi/capi/include/kernel_registry.h
+3
-1
paddle/phi/capi/include/kernel_utils.h
paddle/phi/capi/include/kernel_utils.h
+11
-5
paddle/phi/capi/lib/c_kernel_context.cc
paddle/phi/capi/lib/c_kernel_context.cc
+10
-2
未找到文件。
paddle/phi/capi/include/kernel_registry.h
浏览文件 @
0f3b1ad6
...
...
@@ -167,6 +167,7 @@ inline std::vector<phi::capi::DenseTensor> PD_MultiInputAt(
for
(
size_t
i
=
0
;
i
<
list
.
size
;
++
i
)
{
ret
.
emplace_back
(
data
[
i
]);
}
PD_DeletePointerList
(
list
);
return
ret
;
}
...
...
@@ -182,13 +183,14 @@ inline std::vector<phi::capi::DenseTensor> PD_MultiOutputAt(
for
(
size_t
i
=
0
;
i
<
list
.
size
;
++
i
)
{
ret
.
emplace_back
(
data
[
i
]);
}
PD_DeletePointerList
(
list
);
return
ret
;
}
template
<
typename
T
>
inline
std
::
vector
<
T
*>
PD_GetPointerVector
(
std
::
vector
<
T
>
*
vec
)
{
std
::
vector
<
T
*>
ret
;
for
(
auto
&
item
:
vec
)
{
for
(
auto
&
item
:
*
vec
)
{
ret
.
push_back
(
&
item
);
}
return
ret
;
...
...
paddle/phi/capi/include/kernel_utils.h
浏览文件 @
0f3b1ad6
...
...
@@ -564,18 +564,24 @@ namespace capi {
static_assert(out_idx == 0, \
"Kernel's Input should appear before Outputs."); \
auto arg = PD_MultiInputAt(ctx, in_idx); \
auto arg_wrapper = PD_GetPointerVector(&arg); \
std::vector<const tensor_type *> tensor_ptr_vec; \
for (auto &tensor : arg) { \
tensor_ptr_vec.push_back(tensor.raw_data() ? &tensor : nullptr); \
} \
CustomKernelCallHelper<Tail...>:: \
template Compute<dev_ctx_idx, in_idx + 1, attr_idx, out_idx>( \
ctx, pargs...,
arg_wrapper);
\
ctx, pargs...,
tensor_ptr_vec);
\
} \
template <int idx, typename... PreviousArgs> \
static void VariadicCompute(const std::tuple<DevCtx, Args &...> &ctx, \
PreviousArgs &...pargs) { \
auto &arg = std::get<idx>(ctx); \
auto tensor
= PD_TensorVector(reinterpret_cast<PD_Tensor *>(
\
auto tensor
_vec = PD_TensorVector(reinterpret_cast<PD_Tensor *>(
\
const_cast<std::vector<const tensor_type *> *>(&arg))); \
auto tensor_ptr_vec = PD_GetPointerVector(&arg); \
std::vector<const tensor_type *> tensor_ptr_vec; \
for (auto &tensor : tensor_vec) { \
tensor_ptr_vec.push_back(tensor.raw_data() ? &tensor : nullptr); \
} \
return CustomKernelCallHelper<Tail...>::template VariadicCompute<idx + \
1>( \
ctx, pargs..., tensor_ptr_vec); \
...
...
@@ -681,7 +687,7 @@ namespace capi {
tensor_ptr_vec.push_back(tensor.raw_data() ? &tensor : nullptr); \
} \
CustomKernelCallHelper<Tail...>:: \
template Compute<dev_ctx_idx, in_idx
+ 1, attr_idx, out_idx
>( \
template Compute<dev_ctx_idx, in_idx
, attr_idx, out_idx + 1
>( \
ctx, pargs..., tensor_ptr_vec); \
} \
template <int idx, typename... PreviousArgs> \
...
...
paddle/phi/capi/lib/c_kernel_context.cc
浏览文件 @
0f3b1ad6
...
...
@@ -60,7 +60,11 @@ PD_List PD_KernelContextMultiInputAt(PD_KernelContext* ctx, size_t index) {
range
.
first
,
range
.
second
);
PD_List
list
;
list
.
size
=
tensor_vec
.
size
();
list
.
data
=
tensor_vec
.
data
();
list
.
data
=
new
void
*
[
list
.
size
];
for
(
size_t
i
=
0
;
i
<
list
.
size
;
++
i
)
{
(
reinterpret_cast
<
void
**>
(
list
.
data
))[
i
]
=
reinterpret_cast
<
void
*>
(
const_cast
<
phi
::
DenseTensor
*>
(
tensor_vec
[
i
]));
}
return
list
;
}
...
...
@@ -78,7 +82,11 @@ PD_List PD_KernelContextMultiOutputAt(PD_KernelContext* ctx, size_t index) {
range
.
first
,
range
.
second
);
PD_List
list
;
list
.
size
=
tensor_vec
.
size
();
list
.
data
=
tensor_vec
.
data
();
list
.
data
=
new
void
*
[
list
.
size
];
for
(
size_t
i
=
0
;
i
<
list
.
size
;
++
i
)
{
(
reinterpret_cast
<
void
**>
(
list
.
data
))[
i
]
=
reinterpret_cast
<
void
*>
(
tensor_vec
[
i
]);
}
return
list
;
}
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录