// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #pragma once #include "paddle/pten/backends/all_context.h" #include "paddle/pten/common/scalar.h" #include "paddle/pten/common/scalar_array.h" #include "paddle/pten/core/dense_tensor.h" #include "paddle/pten/core/kernel_context.h" #include "paddle/pten/core/kernel_def.h" // See Note [ Why still include the fluid headers? ] #include "paddle/fluid/platform/enforce.h" namespace pten { #define PT_KERNEL(...) \ ::pten::KernelImpl::Compute #define PT_VARIADIC_KERNEL(...) \ reinterpret_cast(&::pten::KernelImpl::VariadicCompute) #define PT_SPECIALIZE_KernelCallHelper_FOR_DEVICE_CONTEXT(dev_ctx) \ template \ struct KernelCallHelper { \ template \ static void Compute(KernelContext* ctx, PreviousArgs&... pargs) { \ static_assert(in_idx == 0, \ "Kernel's DeviceContext should appear before Inputs."); \ static_assert( \ attr_idx == 0, \ "Kernel's DeviceContext should appear before Attributes."); \ static_assert(out_idx == 0, \ "Kernel's DeviceContext should appear before Outputs."); \ const dev_ctx& arg = ctx->GetDeviceContext(); \ KernelCallHelper:: \ template Compute( \ ctx, pargs..., arg); \ } \ } #define PT_SPECIALIZE_KernelCallHelper_FOR_INPUT(tensor_type) \ template \ struct KernelCallHelper { \ template \ static void Compute(KernelContext* ctx, PreviousArgs&... pargs) { \ static_assert(attr_idx == 0, \ "Kernel's Input should appear before Attributes."); \ static_assert(out_idx == 0, \ "Kernel's Input should appear before Outputs."); \ const std::pair range = ctx->InputRangeAt(in_idx); \ const tensor_type& arg = ctx->InputAt(range.first); \ KernelCallHelper:: \ template Compute( \ ctx, pargs..., arg); \ } \ } #define PT_SPECIALIZE_KernelCallHelper_FOR_MULTI_INPUT(tensor_type) \ template \ struct KernelCallHelper&, Tail...> { \ template \ static void Compute(KernelContext* ctx, PreviousArgs&... pargs) { \ static_assert(attr_idx == 0, \ "Kernel's Input should appear before Attributes."); \ static_assert(out_idx == 0, \ "Kernel's Input should appear before Outputs."); \ const std::pair range = ctx->InputRangeAt(in_idx); \ std::vector arg = std::move( \ ctx->MoveInputsBetween(range.first, range.second)); \ KernelCallHelper:: \ template Compute( \ ctx, pargs..., arg); \ } \ } #define PT_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(attr_type) \ template \ struct KernelCallHelper { \ template \ static void Compute(KernelContext* ctx, PreviousArgs&... pargs) { \ static_assert(out_idx == 0, \ "Kernel's Attributes should appear before Outputs."); \ attr_type arg = ctx->AttrAt(attr_idx); \ KernelCallHelper:: \ template Compute( \ ctx, pargs..., arg); \ } \ } #define PT_SPECIALIZE_KernelCallHelper_FOR_OUTPUT(tensor_type) \ template \ struct KernelCallHelper { \ template \ static void Compute(KernelContext* ctx, PreviousArgs&... pargs) { \ const std::pair range = ctx->OutputRangeAt(out_idx); \ tensor_type* arg = ctx->MutableOutputAt(range.first); \ KernelCallHelper:: \ template Compute( \ ctx, pargs..., arg); \ } \ } #define PT_SPECIALIZE_KernelCallHelper_FOR_MULTI_OUTPUT(tensor_type) \ template \ struct KernelCallHelper, Tail...> { \ template \ static void Compute(KernelContext* ctx, PreviousArgs&... pargs) { \ const std::pair range = ctx->OutputRangeAt(out_idx); \ std::vector arg = std::move( \ ctx->MutableOutputBetween(range.first, range.second)); \ KernelCallHelper:: \ template Compute( \ ctx, pargs..., arg); \ } \ } template struct TypeTag {}; template struct KernelImpl; template struct KernelImpl { static void Compute(KernelContext* ctx) { KernelCallHelper>::template Compute<0, 0, 0, 0>(ctx); } static void VariadicCompute(const DeviceContext& dev_ctx, Args... args) { return kernel_fn(static_cast(dev_ctx), std::forward(args)...); } private: template struct KernelCallHelper; /* DeviceContext Helpers */ PT_SPECIALIZE_KernelCallHelper_FOR_DEVICE_CONTEXT(CPUContext); #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) PT_SPECIALIZE_KernelCallHelper_FOR_DEVICE_CONTEXT(GPUContext); #endif #ifdef PADDLE_WITH_XPU PT_SPECIALIZE_KernelCallHelper_FOR_DEVICE_CONTEXT(XPUContext); #endif /* Input Helpers */ PT_SPECIALIZE_KernelCallHelper_FOR_INPUT(DenseTensor); PT_SPECIALIZE_KernelCallHelper_FOR_MULTI_INPUT(DenseTensor); // TODO(chenweihang): adapt SelectedRows // PT_SPECIALIZE_KernelCallHelper_FOR_INPUT(SelectedRowsTensor); /* Attribute Helpers */ PT_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(bool); PT_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(float); PT_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(double); PT_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(int); PT_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(int64_t); PT_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(paddle::platform::float16); PT_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(const Scalar&); PT_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(DataType); PT_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(const std::vector&); PT_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(const ScalarArray&); PT_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(const std::vector&); /* Output Helpers */ PT_SPECIALIZE_KernelCallHelper_FOR_OUTPUT(DenseTensor); PT_SPECIALIZE_KernelCallHelper_FOR_MULTI_OUTPUT(DenseTensor); // TODO(chenweihang): adapt SelectedRows // PT_SPECIALIZE_KernelCallHelper_FOR_OUTPUT(SelectedRowsTensor); /* End case */ template struct KernelCallHelper> { template static void Compute(KernelContext* ctx, DevCtx dev_ctx, Args&... args) { static_assert(dev_ctx_idx > 0, "Kernel should pass DeviceContext as argument."); static_assert(out_idx > 0, "Kernel should have output argument."); // TODO(chenweihang): check dev_ctx, in, attr, out number return kernel_fn(dev_ctx, args...); } }; }; } // namespace pten