未验证 提交 0f3b1ad6 编写于 作者: R ronnywang 提交者: GitHub

fix phi capi kernel registration macro error (#48616)

* fix capi kernel registration macro error

* update
上级 a7c43ffa
......@@ -167,6 +167,7 @@ inline std::vector<phi::capi::DenseTensor> PD_MultiInputAt(
for (size_t i = 0; i < list.size; ++i) {
ret.emplace_back(data[i]);
}
PD_DeletePointerList(list);
return ret;
}
......@@ -182,13 +183,14 @@ inline std::vector<phi::capi::DenseTensor> PD_MultiOutputAt(
for (size_t i = 0; i < list.size; ++i) {
ret.emplace_back(data[i]);
}
PD_DeletePointerList(list);
return ret;
}
template <typename T>
inline std::vector<T *> PD_GetPointerVector(std::vector<T> *vec) {
std::vector<T *> ret;
for (auto &item : vec) {
for (auto &item : *vec) {
ret.push_back(&item);
}
return ret;
......
......@@ -564,18 +564,24 @@ namespace capi {
static_assert(out_idx == 0, \
"Kernel's Input should appear before Outputs."); \
auto arg = PD_MultiInputAt(ctx, in_idx); \
auto arg_wrapper = PD_GetPointerVector(&arg); \
std::vector<const tensor_type *> tensor_ptr_vec; \
for (auto &tensor : arg) { \
tensor_ptr_vec.push_back(tensor.raw_data() ? &tensor : nullptr); \
} \
CustomKernelCallHelper<Tail...>:: \
template Compute<dev_ctx_idx, in_idx + 1, attr_idx, out_idx>( \
ctx, pargs..., arg_wrapper); \
ctx, pargs..., tensor_ptr_vec); \
} \
template <int idx, typename... PreviousArgs> \
static void VariadicCompute(const std::tuple<DevCtx, Args &...> &ctx, \
PreviousArgs &...pargs) { \
auto &arg = std::get<idx>(ctx); \
auto tensor = PD_TensorVector(reinterpret_cast<PD_Tensor *>( \
auto tensor_vec = PD_TensorVector(reinterpret_cast<PD_Tensor *>( \
const_cast<std::vector<const tensor_type *> *>(&arg))); \
auto tensor_ptr_vec = PD_GetPointerVector(&arg); \
std::vector<const tensor_type *> tensor_ptr_vec; \
for (auto &tensor : tensor_vec) { \
tensor_ptr_vec.push_back(tensor.raw_data() ? &tensor : nullptr); \
} \
return CustomKernelCallHelper<Tail...>::template VariadicCompute<idx + \
1>( \
ctx, pargs..., tensor_ptr_vec); \
......@@ -681,7 +687,7 @@ namespace capi {
tensor_ptr_vec.push_back(tensor.raw_data() ? &tensor : nullptr); \
} \
CustomKernelCallHelper<Tail...>:: \
template Compute<dev_ctx_idx, in_idx + 1, attr_idx, out_idx>( \
template Compute<dev_ctx_idx, in_idx, attr_idx, out_idx + 1>( \
ctx, pargs..., tensor_ptr_vec); \
} \
template <int idx, typename... PreviousArgs> \
......
......@@ -60,7 +60,11 @@ PD_List PD_KernelContextMultiInputAt(PD_KernelContext* ctx, size_t index) {
range.first, range.second);
PD_List list;
list.size = tensor_vec.size();
list.data = tensor_vec.data();
list.data = new void*[list.size];
for (size_t i = 0; i < list.size; ++i) {
(reinterpret_cast<void**>(list.data))[i] =
reinterpret_cast<void*>(const_cast<phi::DenseTensor*>(tensor_vec[i]));
}
return list;
}
......@@ -78,7 +82,11 @@ PD_List PD_KernelContextMultiOutputAt(PD_KernelContext* ctx, size_t index) {
range.first, range.second);
PD_List list;
list.size = tensor_vec.size();
list.data = tensor_vec.data();
list.data = new void*[list.size];
for (size_t i = 0; i < list.size; ++i) {
(reinterpret_cast<void**>(list.data))[i] =
reinterpret_cast<void*>(tensor_vec[i]);
}
return list;
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册