未验证 提交 0f3b1ad6 编写于 作者: R ronnywang 提交者: GitHub

fix phi capi kernel registration macro error (#48616)

* fix capi kernel registration macro error

* update
上级 a7c43ffa
...@@ -167,6 +167,7 @@ inline std::vector<phi::capi::DenseTensor> PD_MultiInputAt( ...@@ -167,6 +167,7 @@ inline std::vector<phi::capi::DenseTensor> PD_MultiInputAt(
for (size_t i = 0; i < list.size; ++i) { for (size_t i = 0; i < list.size; ++i) {
ret.emplace_back(data[i]); ret.emplace_back(data[i]);
} }
PD_DeletePointerList(list);
return ret; return ret;
} }
...@@ -182,13 +183,14 @@ inline std::vector<phi::capi::DenseTensor> PD_MultiOutputAt( ...@@ -182,13 +183,14 @@ inline std::vector<phi::capi::DenseTensor> PD_MultiOutputAt(
for (size_t i = 0; i < list.size; ++i) { for (size_t i = 0; i < list.size; ++i) {
ret.emplace_back(data[i]); ret.emplace_back(data[i]);
} }
PD_DeletePointerList(list);
return ret; return ret;
} }
template <typename T> template <typename T>
inline std::vector<T *> PD_GetPointerVector(std::vector<T> *vec) { inline std::vector<T *> PD_GetPointerVector(std::vector<T> *vec) {
std::vector<T *> ret; std::vector<T *> ret;
for (auto &item : vec) { for (auto &item : *vec) {
ret.push_back(&item); ret.push_back(&item);
} }
return ret; return ret;
......
...@@ -564,18 +564,24 @@ namespace capi { ...@@ -564,18 +564,24 @@ namespace capi {
static_assert(out_idx == 0, \ static_assert(out_idx == 0, \
"Kernel's Input should appear before Outputs."); \ "Kernel's Input should appear before Outputs."); \
auto arg = PD_MultiInputAt(ctx, in_idx); \ auto arg = PD_MultiInputAt(ctx, in_idx); \
auto arg_wrapper = PD_GetPointerVector(&arg); \ std::vector<const tensor_type *> tensor_ptr_vec; \
for (auto &tensor : arg) { \
tensor_ptr_vec.push_back(tensor.raw_data() ? &tensor : nullptr); \
} \
CustomKernelCallHelper<Tail...>:: \ CustomKernelCallHelper<Tail...>:: \
template Compute<dev_ctx_idx, in_idx + 1, attr_idx, out_idx>( \ template Compute<dev_ctx_idx, in_idx + 1, attr_idx, out_idx>( \
ctx, pargs..., arg_wrapper); \ ctx, pargs..., tensor_ptr_vec); \
} \ } \
template <int idx, typename... PreviousArgs> \ template <int idx, typename... PreviousArgs> \
static void VariadicCompute(const std::tuple<DevCtx, Args &...> &ctx, \ static void VariadicCompute(const std::tuple<DevCtx, Args &...> &ctx, \
PreviousArgs &...pargs) { \ PreviousArgs &...pargs) { \
auto &arg = std::get<idx>(ctx); \ auto &arg = std::get<idx>(ctx); \
auto tensor = PD_TensorVector(reinterpret_cast<PD_Tensor *>( \ auto tensor_vec = PD_TensorVector(reinterpret_cast<PD_Tensor *>( \
const_cast<std::vector<const tensor_type *> *>(&arg))); \ const_cast<std::vector<const tensor_type *> *>(&arg))); \
auto tensor_ptr_vec = PD_GetPointerVector(&arg); \ std::vector<const tensor_type *> tensor_ptr_vec; \
for (auto &tensor : tensor_vec) { \
tensor_ptr_vec.push_back(tensor.raw_data() ? &tensor : nullptr); \
} \
return CustomKernelCallHelper<Tail...>::template VariadicCompute<idx + \ return CustomKernelCallHelper<Tail...>::template VariadicCompute<idx + \
1>( \ 1>( \
ctx, pargs..., tensor_ptr_vec); \ ctx, pargs..., tensor_ptr_vec); \
...@@ -681,7 +687,7 @@ namespace capi { ...@@ -681,7 +687,7 @@ namespace capi {
tensor_ptr_vec.push_back(tensor.raw_data() ? &tensor : nullptr); \ tensor_ptr_vec.push_back(tensor.raw_data() ? &tensor : nullptr); \
} \ } \
CustomKernelCallHelper<Tail...>:: \ CustomKernelCallHelper<Tail...>:: \
template Compute<dev_ctx_idx, in_idx + 1, attr_idx, out_idx>( \ template Compute<dev_ctx_idx, in_idx, attr_idx, out_idx + 1>( \
ctx, pargs..., tensor_ptr_vec); \ ctx, pargs..., tensor_ptr_vec); \
} \ } \
template <int idx, typename... PreviousArgs> \ template <int idx, typename... PreviousArgs> \
......
...@@ -60,7 +60,11 @@ PD_List PD_KernelContextMultiInputAt(PD_KernelContext* ctx, size_t index) { ...@@ -60,7 +60,11 @@ PD_List PD_KernelContextMultiInputAt(PD_KernelContext* ctx, size_t index) {
range.first, range.second); range.first, range.second);
PD_List list; PD_List list;
list.size = tensor_vec.size(); list.size = tensor_vec.size();
list.data = tensor_vec.data(); list.data = new void*[list.size];
for (size_t i = 0; i < list.size; ++i) {
(reinterpret_cast<void**>(list.data))[i] =
reinterpret_cast<void*>(const_cast<phi::DenseTensor*>(tensor_vec[i]));
}
return list; return list;
} }
...@@ -78,7 +82,11 @@ PD_List PD_KernelContextMultiOutputAt(PD_KernelContext* ctx, size_t index) { ...@@ -78,7 +82,11 @@ PD_List PD_KernelContextMultiOutputAt(PD_KernelContext* ctx, size_t index) {
range.first, range.second); range.first, range.second);
PD_List list; PD_List list;
list.size = tensor_vec.size(); list.size = tensor_vec.size();
list.data = tensor_vec.data(); list.data = new void*[list.size];
for (size_t i = 0; i < list.size; ++i) {
(reinterpret_cast<void**>(list.data))[i] =
reinterpret_cast<void*>(tensor_vec[i]);
}
return list; return list;
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册