diff --git a/paddle/phi/capi/include/kernel_registry.h b/paddle/phi/capi/include/kernel_registry.h index 47ddc0bf5be7ec532c4cbe3082913ec49172b7c6..73318561dd99425c6c904d69e61c7dd2aeecf1e2 100644 --- a/paddle/phi/capi/include/kernel_registry.h +++ b/paddle/phi/capi/include/kernel_registry.h @@ -167,6 +167,7 @@ inline std::vector PD_MultiInputAt( for (size_t i = 0; i < list.size; ++i) { ret.emplace_back(data[i]); } + PD_DeletePointerList(list); return ret; } @@ -182,13 +183,14 @@ inline std::vector PD_MultiOutputAt( for (size_t i = 0; i < list.size; ++i) { ret.emplace_back(data[i]); } + PD_DeletePointerList(list); return ret; } template inline std::vector PD_GetPointerVector(std::vector *vec) { std::vector ret; - for (auto &item : vec) { + for (auto &item : *vec) { ret.push_back(&item); } return ret; diff --git a/paddle/phi/capi/include/kernel_utils.h b/paddle/phi/capi/include/kernel_utils.h index 6c1d3f3c0ee758a4f7b97f989127d0b4c2db3a08..d92a9e22052187e64237f321ffde26913074887f 100644 --- a/paddle/phi/capi/include/kernel_utils.h +++ b/paddle/phi/capi/include/kernel_utils.h @@ -564,18 +564,24 @@ namespace capi { static_assert(out_idx == 0, \ "Kernel's Input should appear before Outputs."); \ auto arg = PD_MultiInputAt(ctx, in_idx); \ - auto arg_wrapper = PD_GetPointerVector(&arg); \ + std::vector tensor_ptr_vec; \ + for (auto &tensor : arg) { \ + tensor_ptr_vec.push_back(tensor.raw_data() ? &tensor : nullptr); \ + } \ CustomKernelCallHelper:: \ template Compute( \ - ctx, pargs..., arg_wrapper); \ + ctx, pargs..., tensor_ptr_vec); \ } \ template \ static void VariadicCompute(const std::tuple &ctx, \ PreviousArgs &...pargs) { \ auto &arg = std::get(ctx); \ - auto tensor = PD_TensorVector(reinterpret_cast( \ + auto tensor_vec = PD_TensorVector(reinterpret_cast( \ const_cast *>(&arg))); \ - auto tensor_ptr_vec = PD_GetPointerVector(&arg); \ + std::vector tensor_ptr_vec; \ + for (auto &tensor : tensor_vec) { \ + tensor_ptr_vec.push_back(tensor.raw_data() ? &tensor : nullptr); \ + } \ return CustomKernelCallHelper::template VariadicCompute( \ ctx, pargs..., tensor_ptr_vec); \ @@ -681,7 +687,7 @@ namespace capi { tensor_ptr_vec.push_back(tensor.raw_data() ? &tensor : nullptr); \ } \ CustomKernelCallHelper:: \ - template Compute( \ + template Compute( \ ctx, pargs..., tensor_ptr_vec); \ } \ template \ diff --git a/paddle/phi/capi/lib/c_kernel_context.cc b/paddle/phi/capi/lib/c_kernel_context.cc index d38a19038e31446277f3b092ed4b3066f776bcca..e9fe2aada1f35f13481e0a080cc9bdb3c27b356b 100644 --- a/paddle/phi/capi/lib/c_kernel_context.cc +++ b/paddle/phi/capi/lib/c_kernel_context.cc @@ -60,7 +60,11 @@ PD_List PD_KernelContextMultiInputAt(PD_KernelContext* ctx, size_t index) { range.first, range.second); PD_List list; list.size = tensor_vec.size(); - list.data = tensor_vec.data(); + list.data = new void*[list.size]; + for (size_t i = 0; i < list.size; ++i) { + (reinterpret_cast(list.data))[i] = + reinterpret_cast(const_cast(tensor_vec[i])); + } return list; } @@ -78,7 +82,11 @@ PD_List PD_KernelContextMultiOutputAt(PD_KernelContext* ctx, size_t index) { range.first, range.second); PD_List list; list.size = tensor_vec.size(); - list.data = tensor_vec.data(); + list.data = new void*[list.size]; + for (size_t i = 0; i < list.size; ++i) { + (reinterpret_cast(list.data))[i] = + reinterpret_cast(tensor_vec[i]); + } return list; }