From 0f3b1ad6391c9d1c7e3ef563b6b78b9cf26eae93 Mon Sep 17 00:00:00 2001 From: ronnywang Date: Fri, 2 Dec 2022 14:11:00 +0800 Subject: [PATCH] fix phi capi kernel registration macro error (#48616) * fix capi kernel registration macro error * update --- paddle/phi/capi/include/kernel_registry.h | 4 +++- paddle/phi/capi/include/kernel_utils.h | 16 +++++++++++----- paddle/phi/capi/lib/c_kernel_context.cc | 12 ++++++++++-- 3 files changed, 24 insertions(+), 8 deletions(-) diff --git a/paddle/phi/capi/include/kernel_registry.h b/paddle/phi/capi/include/kernel_registry.h index 47ddc0bf5be..73318561dd9 100644 --- a/paddle/phi/capi/include/kernel_registry.h +++ b/paddle/phi/capi/include/kernel_registry.h @@ -167,6 +167,7 @@ inline std::vector PD_MultiInputAt( for (size_t i = 0; i < list.size; ++i) { ret.emplace_back(data[i]); } + PD_DeletePointerList(list); return ret; } @@ -182,13 +183,14 @@ inline std::vector PD_MultiOutputAt( for (size_t i = 0; i < list.size; ++i) { ret.emplace_back(data[i]); } + PD_DeletePointerList(list); return ret; } template inline std::vector PD_GetPointerVector(std::vector *vec) { std::vector ret; - for (auto &item : vec) { + for (auto &item : *vec) { ret.push_back(&item); } return ret; diff --git a/paddle/phi/capi/include/kernel_utils.h b/paddle/phi/capi/include/kernel_utils.h index 6c1d3f3c0ee..d92a9e22052 100644 --- a/paddle/phi/capi/include/kernel_utils.h +++ b/paddle/phi/capi/include/kernel_utils.h @@ -564,18 +564,24 @@ namespace capi { static_assert(out_idx == 0, \ "Kernel's Input should appear before Outputs."); \ auto arg = PD_MultiInputAt(ctx, in_idx); \ - auto arg_wrapper = PD_GetPointerVector(&arg); \ + std::vector tensor_ptr_vec; \ + for (auto &tensor : arg) { \ + tensor_ptr_vec.push_back(tensor.raw_data() ? &tensor : nullptr); \ + } \ CustomKernelCallHelper:: \ template Compute( \ - ctx, pargs..., arg_wrapper); \ + ctx, pargs..., tensor_ptr_vec); \ } \ template \ static void VariadicCompute(const std::tuple &ctx, \ PreviousArgs &...pargs) { \ auto &arg = std::get(ctx); \ - auto tensor = PD_TensorVector(reinterpret_cast( \ + auto tensor_vec = PD_TensorVector(reinterpret_cast( \ const_cast *>(&arg))); \ - auto tensor_ptr_vec = PD_GetPointerVector(&arg); \ + std::vector tensor_ptr_vec; \ + for (auto &tensor : tensor_vec) { \ + tensor_ptr_vec.push_back(tensor.raw_data() ? &tensor : nullptr); \ + } \ return CustomKernelCallHelper::template VariadicCompute( \ ctx, pargs..., tensor_ptr_vec); \ @@ -681,7 +687,7 @@ namespace capi { tensor_ptr_vec.push_back(tensor.raw_data() ? &tensor : nullptr); \ } \ CustomKernelCallHelper:: \ - template Compute( \ + template Compute( \ ctx, pargs..., tensor_ptr_vec); \ } \ template \ diff --git a/paddle/phi/capi/lib/c_kernel_context.cc b/paddle/phi/capi/lib/c_kernel_context.cc index d38a19038e3..e9fe2aada1f 100644 --- a/paddle/phi/capi/lib/c_kernel_context.cc +++ b/paddle/phi/capi/lib/c_kernel_context.cc @@ -60,7 +60,11 @@ PD_List PD_KernelContextMultiInputAt(PD_KernelContext* ctx, size_t index) { range.first, range.second); PD_List list; list.size = tensor_vec.size(); - list.data = tensor_vec.data(); + list.data = new void*[list.size]; + for (size_t i = 0; i < list.size; ++i) { + (reinterpret_cast(list.data))[i] = + reinterpret_cast(const_cast(tensor_vec[i])); + } return list; } @@ -78,7 +82,11 @@ PD_List PD_KernelContextMultiOutputAt(PD_KernelContext* ctx, size_t index) { range.first, range.second); PD_List list; list.size = tensor_vec.size(); - list.data = tensor_vec.data(); + list.data = new void*[list.size]; + for (size_t i = 0; i < list.size; ++i) { + (reinterpret_cast(list.data))[i] = + reinterpret_cast(tensor_vec[i]); + } return list; } -- GitLab