// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #pragma once #include "paddle/phi/backends/cpu/cpu_context.h" #include "paddle/phi/backends/custom/custom_context.h" #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/backends/onednn/onednn_context.h" #ifdef PADDLE_WITH_XPU #include "paddle/phi/backends/xpu/xpu_context.h" #endif #include "paddle/phi/common/int_array.h" #include "paddle/phi/common/scalar.h" #include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/enforce.h" #include "paddle/phi/core/kernel_context.h" #include "paddle/phi/core/selected_rows.h" #include "paddle/phi/core/sparse_coo_tensor.h" #include "paddle/phi/core/sparse_csr_tensor.h" #include "paddle/phi/core/string_tensor.h" #include "paddle/phi/core/tensor_array.h" #include "paddle/phi/core/type_defs.h" namespace phi { // PD_KERNEL has been used by custom op api #define PHI_KERNEL(...) \ ::phi::KernelImpl::Compute #define PHI_VARIADIC_KERNEL(...) \ reinterpret_cast(&::phi::KernelImpl::VariadicCompute) #define PD_SPECIALIZE_KernelCallHelper_FOR_DEVICE_CONTEXT(dev_ctx) \ template \ struct KernelCallHelper { \ template \ static void Compute(KernelContext* ctx, PreviousArgs&... pargs) { \ static_assert(in_idx == 0, \ "Kernel's DeviceContext should appear before Inputs."); \ static_assert( \ attr_idx == 0, \ "Kernel's DeviceContext should appear before Attributes."); \ static_assert(out_idx == 0, \ "Kernel's DeviceContext should appear before Outputs."); \ const dev_ctx& arg = ctx->GetDeviceContext(); \ KernelCallHelper:: \ template Compute( \ ctx, pargs..., arg); \ } \ } #define PD_SPECIALIZE_KernelCallHelper_FOR_INPUT(tensor_type) \ template \ struct KernelCallHelper { \ template \ static void Compute(KernelContext* ctx, PreviousArgs&... pargs) { \ static_assert(attr_idx == 0, \ "Kernel's Input should appear before Attributes."); \ static_assert(out_idx == 0, \ "Kernel's Input should appear before Outputs."); \ const std::pair& range = ctx->InputRangeAt(in_idx); \ const tensor_type& arg = ctx->InputAt(range.first); \ KernelCallHelper:: \ template Compute( \ ctx, pargs..., arg); \ } \ } #define PD_SPECIALIZE_KernelCallHelper_FOR_OPTIONAL_INPUT(tensor_type) \ template \ struct KernelCallHelper&, Tail...> { \ template \ static void Compute(KernelContext* ctx, PreviousArgs&... pargs) { \ static_assert(attr_idx == 0, \ "Kernel's Input should appear before Attributes."); \ static_assert(out_idx == 0, \ "Kernel's Input should appear before Outputs."); \ const std::pair& range = ctx->InputRangeAt(in_idx); \ auto arg = ctx->OptionalInputAt(range.first); \ KernelCallHelper:: \ template Compute( \ ctx, pargs..., arg); \ } \ } #define PD_SPECIALIZE_KernelCallHelper_FOR_MULTI_INPUT(tensor_type) \ template \ struct KernelCallHelper&, Tail...> { \ template \ static void Compute(KernelContext* ctx, PreviousArgs&... pargs) { \ static_assert(attr_idx == 0, \ "Kernel's Input should appear before Attributes."); \ static_assert(out_idx == 0, \ "Kernel's Input should appear before Outputs."); \ const std::pair& range = ctx->InputRangeAt(in_idx); \ std::vector arg = std::move( \ ctx->InputsBetween(range.first, range.second)); \ KernelCallHelper:: \ template Compute( \ ctx, pargs..., arg); \ } \ } #define PD_SPECIALIZE_KernelCallHelper_FOR_OPTIONAL_MULTI_INPUT(tensor_type) \ template \ struct KernelCallHelper< \ const paddle::optional>&, \ Tail...> { \ template \ static void Compute(KernelContext* ctx, PreviousArgs&... pargs) { \ static_assert(attr_idx == 0, \ "Kernel's Input should appear before Attributes."); \ static_assert(out_idx == 0, \ "Kernel's Input should appear before Outputs."); \ const std::pair& range = ctx->InputRangeAt(in_idx); \ paddle::optional> arg = \ ctx->OptionalInputsBetween(range.first, range.second); \ KernelCallHelper:: \ template Compute( \ ctx, pargs..., arg); \ } \ } #define PD_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(attr_type) \ template \ struct KernelCallHelper { \ template \ static void Compute(KernelContext* ctx, PreviousArgs&... pargs) { \ static_assert(out_idx == 0, \ "Kernel's Attributes should appear before Outputs."); \ attr_type arg = ctx->AttrAt(attr_idx); \ KernelCallHelper:: \ template Compute( \ ctx, pargs..., arg); \ } \ } #define PD_SPECIALIZE_KernelCallHelper_FOR_CONST_ATTRIBUTE_REF(attr_type) \ template \ struct KernelCallHelper { \ template \ static void Compute(KernelContext* ctx, PreviousArgs&... pargs) { \ static_assert(out_idx == 0, \ "Kernel's Attributes should appear before Outputs."); \ const attr_type& arg = ctx->AttrAt(attr_idx); \ KernelCallHelper:: \ template Compute( \ ctx, pargs..., arg); \ } \ } #define PD_SPECIALIZE_KernelCallHelper_FOR_OUTPUT(tensor_type) \ template \ struct KernelCallHelper { \ template \ static void Compute(KernelContext* ctx, PreviousArgs&... pargs) { \ const std::pair& range = ctx->OutputRangeAt(out_idx); \ tensor_type* arg = ctx->MutableOutputAt(range.first); \ KernelCallHelper:: \ template Compute( \ ctx, pargs..., arg); \ } \ } #define PD_SPECIALIZE_KernelCallHelper_FOR_MULTI_OUTPUT(tensor_type) \ template \ struct KernelCallHelper, Tail...> { \ template \ static void Compute(KernelContext* ctx, PreviousArgs&... pargs) { \ const std::pair& range = ctx->OutputRangeAt(out_idx); \ std::vector arg = std::move( \ ctx->MutableOutputBetween(range.first, range.second)); \ KernelCallHelper:: \ template Compute( \ ctx, pargs..., arg); \ } \ } template struct TypeTag {}; template struct KernelImpl; template struct KernelImpl { static void Compute(KernelContext* ctx) { KernelCallHelper>:: template Compute<0, 0, 0, 0>(ctx); } static void VariadicCompute(const DeviceContext& dev_ctx, Args... args) { return kernel_fn(static_cast(dev_ctx), std::forward(args)...); } private: template struct KernelCallHelper; /* DeviceContext Helpers */ PD_SPECIALIZE_KernelCallHelper_FOR_DEVICE_CONTEXT(CPUContext); #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) PD_SPECIALIZE_KernelCallHelper_FOR_DEVICE_CONTEXT(GPUContext); #endif #ifdef PADDLE_WITH_XPU PD_SPECIALIZE_KernelCallHelper_FOR_DEVICE_CONTEXT(XPUContext); #endif #ifdef PADDLE_WITH_CUSTOM_DEVICE PD_SPECIALIZE_KernelCallHelper_FOR_DEVICE_CONTEXT(CustomContext); #endif #ifdef PADDLE_WITH_MKLDNN PD_SPECIALIZE_KernelCallHelper_FOR_DEVICE_CONTEXT(OneDNNContext); #endif /* Input Helpers */ PD_SPECIALIZE_KernelCallHelper_FOR_INPUT(DenseTensor); PD_SPECIALIZE_KernelCallHelper_FOR_OPTIONAL_INPUT(DenseTensor); PD_SPECIALIZE_KernelCallHelper_FOR_OPTIONAL_INPUT(SelectedRows); PD_SPECIALIZE_KernelCallHelper_FOR_MULTI_INPUT(DenseTensor); PD_SPECIALIZE_KernelCallHelper_FOR_INPUT(SelectedRows); PD_SPECIALIZE_KernelCallHelper_FOR_OPTIONAL_MULTI_INPUT(DenseTensor); PD_SPECIALIZE_KernelCallHelper_FOR_INPUT(SparseCooTensor); PD_SPECIALIZE_KernelCallHelper_FOR_OPTIONAL_INPUT(SparseCooTensor); PD_SPECIALIZE_KernelCallHelper_FOR_MULTI_INPUT(SparseCooTensor); PD_SPECIALIZE_KernelCallHelper_FOR_INPUT(SparseCsrTensor); PD_SPECIALIZE_KernelCallHelper_FOR_OPTIONAL_INPUT(SparseCsrTensor); PD_SPECIALIZE_KernelCallHelper_FOR_MULTI_INPUT(SparseCsrTensor); PD_SPECIALIZE_KernelCallHelper_FOR_INPUT(StringTensor); PD_SPECIALIZE_KernelCallHelper_FOR_OPTIONAL_INPUT(StringTensor); PD_SPECIALIZE_KernelCallHelper_FOR_MULTI_INPUT(StringTensor); PD_SPECIALIZE_KernelCallHelper_FOR_INPUT(TensorArray); PD_SPECIALIZE_KernelCallHelper_FOR_MULTI_INPUT(TensorArray); /* Attribute Helpers */ PD_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(bool); PD_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(float); PD_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(double); PD_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(int); PD_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(int64_t); PD_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(phi::dtype::float16); PD_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(DataType); PD_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(DataLayout); PD_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(Place); PD_SPECIALIZE_KernelCallHelper_FOR_CONST_ATTRIBUTE_REF(std::string); PD_SPECIALIZE_KernelCallHelper_FOR_CONST_ATTRIBUTE_REF(Scalar); PD_SPECIALIZE_KernelCallHelper_FOR_CONST_ATTRIBUTE_REF(IntArray); PD_SPECIALIZE_KernelCallHelper_FOR_CONST_ATTRIBUTE_REF(std::vector); PD_SPECIALIZE_KernelCallHelper_FOR_CONST_ATTRIBUTE_REF(std::vector); PD_SPECIALIZE_KernelCallHelper_FOR_CONST_ATTRIBUTE_REF(std::vector); PD_SPECIALIZE_KernelCallHelper_FOR_CONST_ATTRIBUTE_REF(std::vector); PD_SPECIALIZE_KernelCallHelper_FOR_CONST_ATTRIBUTE_REF(std::vector); PD_SPECIALIZE_KernelCallHelper_FOR_CONST_ATTRIBUTE_REF( std::vector); PD_SPECIALIZE_KernelCallHelper_FOR_CONST_ATTRIBUTE_REF(std::vector); /* Output Helpers */ PD_SPECIALIZE_KernelCallHelper_FOR_OUTPUT(DenseTensor); PD_SPECIALIZE_KernelCallHelper_FOR_MULTI_OUTPUT(DenseTensor); PD_SPECIALIZE_KernelCallHelper_FOR_OUTPUT(SelectedRows); PD_SPECIALIZE_KernelCallHelper_FOR_OUTPUT(SparseCooTensor); PD_SPECIALIZE_KernelCallHelper_FOR_MULTI_OUTPUT(SparseCooTensor); PD_SPECIALIZE_KernelCallHelper_FOR_OUTPUT(SparseCsrTensor); PD_SPECIALIZE_KernelCallHelper_FOR_MULTI_OUTPUT(SparseCsrTensor); PD_SPECIALIZE_KernelCallHelper_FOR_OUTPUT(StringTensor); PD_SPECIALIZE_KernelCallHelper_FOR_MULTI_OUTPUT(StringTensor); PD_SPECIALIZE_KernelCallHelper_FOR_OUTPUT(TensorArray); template struct KernelCallHelper { template static void Compute(KernelContext* ctx, PreviousArgs&... pargs) { const auto& runtime_attrs = ctx->GetRuntimeAttrs(); KernelCallHelper:: template Compute( ctx, pargs..., runtime_attrs); } }; /* End case */ template struct KernelCallHelper> { template static void Compute(KernelContext* ctx, DevCtx dev_ctx, Args&... args) { static_assert(dev_ctx_idx > 0, "Kernel should pass DeviceContext as argument."); return kernel_fn(dev_ctx, args...); } }; }; } // namespace phi