// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "paddle/phi/capi/include/c_kernel_context.h" #include "paddle/phi/backends/all_context.h" #include "paddle/phi/capi/include/common.h" #include "paddle/phi/capi/include/type_utils.h" #include "paddle/phi/core/kernel_context.h" PD_DeviceContext* PD_KernelContextGetDeviceContext(PD_KernelContext* ctx) { auto kernel_context = reinterpret_cast(ctx); auto dev_ctx_type = kernel_context->GetDeviceContext() .GetPlace() .GetType(); if (dev_ctx_type == phi::AllocationType::CUSTOM) { return reinterpret_cast(const_cast( &kernel_context->GetDeviceContext())); } else if (dev_ctx_type == phi::AllocationType::CPU) { return reinterpret_cast(const_cast( &kernel_context->GetDeviceContext())); #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) } else if (dev_ctx_type == phi::AllocationType::GPU) { return reinterpret_cast(const_cast( &kernel_context->GetDeviceContext())); #endif #ifdef PADDLE_WITH_XPU } else if (dev_ctx_type == phi::AllocationType::XPU) { return reinterpret_cast(const_cast( &kernel_context->GetDeviceContext())); #endif } else { PADDLE_THROW(phi::errors::Unavailable( "Only support Custom/CPU/GPU/XPU DeviceContext")); } } PD_Tensor* PD_KernelContextInputAt(PD_KernelContext* ctx, size_t index) { auto kernel_context = reinterpret_cast(ctx); const std::pair& range = kernel_context->InputRangeAt(index); return reinterpret_cast(const_cast( &kernel_context->InputAt(range.first))); } PD_List PD_KernelContextMultiInputAt(PD_KernelContext* ctx, size_t index) { auto kernel_context = reinterpret_cast(ctx); const std::pair& range = kernel_context->InputRangeAt(index); auto tensor_vec = kernel_context->InputsBetween( range.first, range.second); PD_List list; list.size = tensor_vec.size(); list.data = tensor_vec.data(); return list; } PD_Tensor* PD_KernelContextOutputAt(PD_KernelContext* ctx, size_t index) { auto kernel_context = reinterpret_cast(ctx); const std::pair& range = kernel_context->OutputRangeAt(index); return reinterpret_cast( kernel_context->MutableOutputAt(range.first)); } PD_List PD_KernelContextMultiOutputAt(PD_KernelContext* ctx, size_t index) { auto kernel_context = reinterpret_cast(ctx); const std::pair& range = kernel_context->OutputRangeAt(index); auto tensor_vec = kernel_context->MutableOutputBetween( range.first, range.second); PD_List list; list.size = tensor_vec.size(); list.data = tensor_vec.data(); return list; } bool PD_KernelContextBoolAttrAt(PD_KernelContext* ctx, size_t index) { auto kernel_context = reinterpret_cast(ctx); return kernel_context->AttrAt(index); } int32_t PD_KernelContextInt32AttrAt(PD_KernelContext* ctx, size_t index) { auto kernel_context = reinterpret_cast(ctx); return kernel_context->AttrAt(index); } int64_t PD_KernelContextInt64AttrAt(PD_KernelContext* ctx, size_t index) { auto kernel_context = reinterpret_cast(ctx); return kernel_context->AttrAt(index); } float PD_KernelContextFloatAttrAt(PD_KernelContext* ctx, size_t index) { auto kernel_context = reinterpret_cast(ctx); return kernel_context->AttrAt(index); } double PD_KernelContextDoubleAttrAt(PD_KernelContext* ctx, size_t index) { auto kernel_context = reinterpret_cast(ctx); return kernel_context->AttrAt(index); } PD_Scalar* PD_KernelContextScalarAttrAt(PD_KernelContext* ctx, size_t index) { auto kernel_context = reinterpret_cast(ctx); return reinterpret_cast( const_cast(&kernel_context->AttrAt(index))); } PD_IntArray* PD_KernelContextIntArrayAttrAt(PD_KernelContext* ctx, size_t index) { auto kernel_context = reinterpret_cast(ctx); return reinterpret_cast(const_cast( &kernel_context->AttrAt(index))); } PD_List PD_KernelContextListBoolAttrAt(PD_KernelContext* ctx, size_t index) { PD_List list; auto kernel_context = reinterpret_cast(ctx); const auto& cc_list = kernel_context->AttrAt>(index); list.size = cc_list.size(); auto data = reinterpret_cast(new uint8_t[cc_list.size()]); for (size_t i = 0; i < cc_list.size(); ++i) { data[i] = static_cast(cc_list[i]); } list.data = data; return list; } PD_List PD_KernelContextListInt32AttrAt(PD_KernelContext* ctx, size_t index) { PD_List list; auto kernel_context = reinterpret_cast(ctx); const auto& cc_list = kernel_context->AttrAt>(index); list.size = cc_list.size(); list.data = const_cast(cc_list.data()); return list; } PD_List PD_KernelContextListInt64AttrAt(PD_KernelContext* ctx, size_t index) { PD_List list; auto kernel_context = reinterpret_cast(ctx); const auto& cc_list = kernel_context->AttrAt>(index); list.size = cc_list.size(); list.data = const_cast(cc_list.data()); return list; } PD_List PD_KernelContextListFloatAttrAt(PD_KernelContext* ctx, size_t index) { PD_List list; auto kernel_context = reinterpret_cast(ctx); const auto& cc_list = kernel_context->AttrAt>(index); list.size = cc_list.size(); list.data = const_cast(cc_list.data()); return list; } PD_List PD_KernelContextListDoubleAttrAt(PD_KernelContext* ctx, size_t index) { PD_List list; auto kernel_context = reinterpret_cast(ctx); const auto& cc_list = kernel_context->AttrAt>(index); list.size = cc_list.size(); list.data = const_cast(cc_list.data()); return list; } char* PD_KernelContextStringAttrAt(PD_KernelContext* ctx, size_t index) { auto kernel_context = reinterpret_cast(ctx); return const_cast(kernel_context->AttrAt(index).data()); } PD_List PD_KernelContextListStringAttrAt(PD_KernelContext* ctx, size_t index) { PD_List list; auto kernel_context = reinterpret_cast(ctx); const auto& cc_list = kernel_context->AttrAt>(index); list.size = cc_list.size(); auto data = new char*[list.size]; for (size_t i = 0; i < list.size; ++i) { data[i] = const_cast(cc_list[i].data()); } list.data = reinterpret_cast(data); return list; } PD_List PD_KernelContextListScalarAttrAt(PD_KernelContext* ctx, size_t index) { PD_List list; auto kernel_context = reinterpret_cast(ctx); const auto& cc_list = kernel_context->AttrAt>(index); list.size = cc_list.size(); auto data = new PD_Scalar*[list.size]; for (size_t i = 0; i < list.size; ++i) { data[i] = const_cast(reinterpret_cast(&cc_list[i])); } list.data = data; return list; } PD_Place* PD_KernelContextPlaceAttrAt(PD_KernelContext* ctx, size_t index) { auto kernel_context = reinterpret_cast(ctx); return reinterpret_cast( const_cast(&kernel_context->AttrAt(index))); } PD_DataType PD_KernelContextDataTypeAttrAt(PD_KernelContext* ctx, size_t index) { auto kernel_context = reinterpret_cast(ctx); return phi::capi::ToPDDataType(kernel_context->AttrAt(index)); } PD_DataLayout PD_KernelContextDataLayoutAttrAt(PD_KernelContext* ctx, size_t index) { auto kernel_context = reinterpret_cast(ctx); return phi::capi::ToPDDataLayout( kernel_context->AttrAt(index)); } // eager const char* PD_StringAttr(void* attr) { auto* str = reinterpret_cast(attr); return str->c_str(); } PD_DataType PD_DatatTypeAttr(void* attr) { auto* dtype = reinterpret_cast(attr); return phi::capi::ToPDDataType(*dtype); } PD_DataLayout PD_DatatLayoutAttr(void* attr) { auto* layout = reinterpret_cast(attr); return phi::capi::ToPDDataLayout(*layout); } PD_List PD_ListInt32Attr(void* attr) { PD_List list; const auto& cc_list = *reinterpret_cast*>(attr); list.size = cc_list.size(); list.data = const_cast(cc_list.data()); return list; } PD_List PD_ListInt64Attr(void* attr) { PD_List list; const auto& cc_list = *reinterpret_cast*>(attr); list.size = cc_list.size(); list.data = const_cast(cc_list.data()); return list; } PD_List PD_ListFloatAttr(void* attr) { PD_List list; const auto& cc_list = *reinterpret_cast*>(attr); list.size = cc_list.size(); list.data = const_cast(cc_list.data()); return list; } PD_List PD_ListDoubleAttr(void* attr) { PD_List list; const auto& cc_list = *reinterpret_cast*>(attr); list.size = cc_list.size(); list.data = const_cast(cc_list.data()); return list; } PD_List PD_ListScalarAttr(void* attr) { PD_List list; const auto& cc_list = *reinterpret_cast*>(attr); list.size = cc_list.size(); auto data = new PD_Scalar*[list.size]; for (size_t i = 0; i < list.size; ++i) { data[i] = const_cast(reinterpret_cast(&cc_list[i])); } list.data = data; return list; } PD_List PD_ListStringAttr(void* attr) { PD_List list; const auto& cc_list = *reinterpret_cast*>(attr); list.size = cc_list.size(); auto data = new char*[list.size]; for (size_t i = 0; i < list.size; ++i) { data[i] = const_cast(cc_list[i].data()); } list.data = reinterpret_cast(data); return list; } PD_List PD_ListBoolAttr(void* attr) { PD_List list; const auto& cc_list = *reinterpret_cast*>(attr); list.size = cc_list.size(); auto data = reinterpret_cast(new uint8_t[cc_list.size()]); for (size_t i = 0; i < cc_list.size(); ++i) { data[i] = static_cast(cc_list[i]); } list.data = data; return list; } PD_REGISTER_CAPI(kernel_context);