From edc3ba13010497a04e6859a804ce535faf5e5945 Mon Sep 17 00:00:00 2001 From: Aganlengzi Date: Tue, 22 Feb 2022 21:45:33 +0800 Subject: [PATCH] [custom kernel]Delete useless and upgrade (#39791) * [custom kernel]Delete useless * change RegType enum names * mod notes * merge * update --- .../fluid/framework/op_kernel_info_helper.h | 71 - paddle/phi/api/ext/op_kernel_info.h | 1257 ----------------- paddle/phi/api/lib/op_kernel_info.cc | 108 -- paddle/phi/core/kernel_registry.h | 16 +- 4 files changed, 8 insertions(+), 1444 deletions(-) delete mode 100644 paddle/fluid/framework/op_kernel_info_helper.h delete mode 100644 paddle/phi/api/ext/op_kernel_info.h delete mode 100644 paddle/phi/api/lib/op_kernel_info.cc diff --git a/paddle/fluid/framework/op_kernel_info_helper.h b/paddle/fluid/framework/op_kernel_info_helper.h deleted file mode 100644 index d62711bb882..00000000000 --- a/paddle/fluid/framework/op_kernel_info_helper.h +++ /dev/null @@ -1,71 +0,0 @@ -/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once - -#include "paddle/phi/api/ext/op_kernel_info.h" -#include "paddle/phi/core/kernel_factory.h" - -namespace paddle { -namespace framework { - -class OpKernelInfoHelper { - public: - static const std::string& GetOpName(const paddle::OpKernelInfo& info) { - return info.op_name_; - } - - static const phi::Backend& GetBackend(const paddle::OpKernelInfo& info) { - return info.backend_; - } - - static const phi::DataLayout& GetDataLayout( - const paddle::OpKernelInfo& info) { - return info.layout_; - } - - static const phi::DataType& GetDataType(const paddle::OpKernelInfo& info) { - return info.dtype_; - } - - static phi::KernelKey GetKernelKey(const paddle::OpKernelInfo& info) { - return phi::KernelKey(info.backend_, info.layout_, info.dtype_); - } - - static const CustomKernelFunc& GetKernelFn(const paddle::OpKernelInfo& info) { - return info.kernel_fn_; - } - - static void* GetVariadicKernelFn(const paddle::OpKernelInfo& info) { - return info.variadic_kernel_fn_; - } - - static const paddle::SmallVector& GetInputDefs( - const paddle::OpKernelInfo& info) { - return info.input_defs_; - } - - static const paddle::SmallVector& GetOutputDefs( - const paddle::OpKernelInfo& info) { - return info.output_defs_; - } - - static const paddle::SmallVector& GetAttributeDefs( - const paddle::OpKernelInfo& info) { - return info.attribute_defs_; - } -}; - -} // namespace framework -} // namespace paddle diff --git a/paddle/phi/api/ext/op_kernel_info.h b/paddle/phi/api/ext/op_kernel_info.h deleted file mode 100644 index b3adbe9d18b..00000000000 --- a/paddle/phi/api/ext/op_kernel_info.h +++ /dev/null @@ -1,1257 +0,0 @@ -/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once - -#include -#include -#include -#include -#include -#include - -#include "paddle/phi/api/ext/dll_decl.h" -#include "paddle/phi/api/ext/exception.h" -#include "paddle/phi/api/ext/op_meta_info.h" -#include "paddle/phi/api/include/tensor.h" -#include "paddle/phi/common/scalar.h" -#include "paddle/phi/common/scalar_array.h" -#include "paddle/utils/any.h" -#include "paddle/utils/small_vector.h" - -#include "paddle/phi/common/data_type.h" - -/** - * Custom Kernel Info Define. - * - * Used to maintain custom kernel core information before registering. - * Pten is working on exposing headers, custom kernel depends on them, and - * we prefer outer users following pten-kernel-function-style and registering - * macro. So, we have to re-implement some structs or class and functions to - * make sure users' custom kernel functions can be registered to pten. - * - * TODO(Aganlengzi): We should upgrade following pten. - */ - -namespace paddle { -namespace framework { -class PADDLE_API OpKernelInfoHelper; -} // namespace framework - -// TODO(Aganlengzi): Simple DeviceContext temporarily for stream getting -// before phi::DeviceContext is exposed. -class DeviceContext { - public: - DeviceContext() { stream_ = nullptr; } - void set_stream(void* stream) { stream_ = stream; } - void* stream() const { return stream_; } - - private: - void* stream_; -}; -class CPUContext : public DeviceContext {}; - -// TODO(Aganlengzi): Use paddle::Tensor before DenseTensor is exposed -using Tensor = paddle::experimental::Tensor; -using Scalar = phi::Scalar; -using ScalarArray = phi::ScalarArray; - -// Record custom kernel core information -// We can not use phi::KernelFn directly, so users' custom kernel function -// is signatured to `CustomKernelFunc', notice that the first parameter is -// fixed to `const DeviceContext&'. -using CustomKernelFunc = - void (*)(const DeviceContext& dev_ctx, - const std::vector& inputs, - const std::vector>& vec_inputs, - const std::vector& attrs, - std::vector* outputs, - std::vector>* vec_outputs); - -////////////////////// Kernel Function (PD_PT_KERNEL) //////////////////////// -#define PD_SPECIALIZE_KernelCallHelper_FOR_DEV_CONTEXT(device_ctx) \ - template \ - struct CustomComputeCallHelper { \ - template \ - static void Compute(const DeviceContext& dev_ctx, \ - const std::vector& inputs, \ - const std::vector>& vec_inputs, \ - const std::vector& attrs, \ - std::vector* outputs, \ - std::vector>* vec_outputs, \ - PreviousArgs... pargs) { \ - static_assert(in_idx == 0, \ - "Kernel's DeviceContext should appear before Inputs."); \ - static_assert(vec_in_idx == 0, \ - "Kernel's DeviceContext should appear before Inputs."); \ - static_assert( \ - attr_idx == 0, \ - "Kernel's DeviceContext should appear before Attributes."); \ - static_assert(out_idx == 0, \ - "Kernel's DeviceContext should appear before Outputs."); \ - static_assert(vec_out_idx == 0, \ - "Kernel's DeviceContext should appear before Outputs."); \ - const device_ctx& arg = static_cast(dev_ctx); \ - CustomComputeCallHelper::template Compute( \ - dev_ctx, \ - inputs, \ - vec_inputs, \ - attrs, \ - outputs, \ - vec_outputs, \ - pargs..., \ - arg); \ - } \ - } - -#define PD_SPECIALIZE_KernelCallHelper_FOR_INPUT(tensor_type) \ - template \ - struct CustomComputeCallHelper { \ - template \ - static void Compute(const DeviceContext& dev_ctx, \ - const std::vector& inputs, \ - const std::vector>& vec_inputs, \ - const std::vector& attrs, \ - std::vector* outputs, \ - std::vector>* vec_outputs, \ - PreviousArgs... pargs) { \ - static_assert(attr_idx == 0, \ - "Kernel's Input should appear before Attributes."); \ - static_assert(out_idx == 0, \ - "Kernel's Input should appear before Outputs."); \ - static_assert(vec_out_idx == 0, \ - "Kernel's Input should appear before Outputs."); \ - const Tensor& arg = inputs[in_idx]; \ - CustomComputeCallHelper::template Compute( \ - dev_ctx, \ - inputs, \ - vec_inputs, \ - attrs, \ - outputs, \ - vec_outputs, \ - pargs..., \ - arg); \ - } \ - } - -#define PD_SPECIALIZE_KernelCallHelper_FOR_MULTI_INPUT(tensor_type) \ - template \ - struct CustomComputeCallHelper&, Tail...> { \ - template \ - static void Compute(const DeviceContext& dev_ctx, \ - const std::vector& inputs, \ - const std::vector>& vec_inputs, \ - const std::vector& attrs, \ - std::vector* outputs, \ - std::vector>* vec_outputs, \ - PreviousArgs... pargs) { \ - static_assert(attr_idx == 0, \ - "Kernel's Input should appear before Attributes."); \ - static_assert(out_idx == 0, \ - "Kernel's Input should appear before Outputs."); \ - static_assert(vec_out_idx == 0, \ - "Kernel's Input should appear before Outputs."); \ - const std::vector& arg = vec_inputs[vec_in_idx]; \ - CustomComputeCallHelper::template Compute( \ - dev_ctx, \ - inputs, \ - vec_inputs, \ - attrs, \ - outputs, \ - vec_outputs, \ - pargs..., \ - arg); \ - } \ - } - -#define PD_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(attr_type) \ - template \ - struct CustomComputeCallHelper { \ - template \ - static void Compute(const DeviceContext& dev_ctx, \ - const std::vector& inputs, \ - const std::vector>& vec_inputs, \ - const std::vector& attrs, \ - std::vector* outputs, \ - std::vector>* vec_outputs, \ - PreviousArgs... pargs) { \ - static_assert(out_idx == 0, \ - "Kernel's Attributes should appear before Outputs."); \ - static_assert(vec_out_idx == 0, \ - "Kernel's Attributes should appear before Outputs."); \ - try { \ - attr_type arg = paddle::any_cast(attrs[attr_idx]); \ - return CustomComputeCallHelper::template Compute< \ - dev_ctx_idx, \ - in_idx, \ - vec_in_idx, \ - attr_idx + 1, \ - out_idx, \ - vec_out_idx>(dev_ctx, \ - inputs, \ - vec_inputs, \ - attrs, \ - outputs, \ - vec_outputs, \ - pargs..., \ - arg); \ - } catch (paddle::bad_any_cast&) { \ - PD_THROW( \ - "Attribute cast error in custom operator. Expected " #attr_type \ - " value."); \ - } \ - } \ - } - -#define PD_SPECIALIZE_KernelCallHelper_FOR_OUTPUT(tensor_type) \ - template \ - struct CustomComputeCallHelper { \ - template \ - static void Compute(const DeviceContext& dev_ctx, \ - const std::vector& inputs, \ - const std::vector>& vec_inputs, \ - const std::vector& attrs, \ - std::vector* outputs, \ - std::vector>* vec_outputs, \ - PreviousArgs... pargs) { \ - tensor_type* arg = (*outputs)[out_idx]; \ - CustomComputeCallHelper::template Compute( \ - dev_ctx, \ - inputs, \ - vec_inputs, \ - attrs, \ - outputs, \ - vec_outputs, \ - pargs..., \ - arg); \ - } \ - } - -#define PD_SPECIALIZE_KernelCallHelper_FOR_MULTI_OUTPUT(tensor_type) \ - template \ - struct CustomComputeCallHelper, Tail...> { \ - template \ - static void Compute(const DeviceContext& dev_ctx, \ - const std::vector& inputs, \ - const std::vector>& vec_inputs, \ - const std::vector& attrs, \ - std::vector* outputs, \ - std::vector>* vec_outputs, \ - PreviousArgs... pargs) { \ - std::vector arg = (*vec_outputs)[vec_out_idx]; \ - CustomComputeCallHelper::template Compute( \ - dev_ctx, \ - inputs, \ - vec_inputs, \ - attrs, \ - outputs, \ - vec_outputs, \ - pargs..., \ - arg); \ - } \ - } - -template -struct PtenTypeTag {}; - -template -struct CustomKernelFuncImpl; - -template -struct CustomKernelFuncImpl { - static void Compute(const DeviceContext& dev_ctx, - const std::vector& inputs, - const std::vector>& vec_inputs, - const std::vector& attrs, - std::vector* outputs, - std::vector>* vec_outputs) { - CustomComputeCallHelper>:: - template Compute<0, 0, 0, 0, 0, 0>( - dev_ctx, inputs, vec_inputs, attrs, outputs, vec_outputs); - } - - // NOTE: Tensor in args is paddle::Tensor but not DenseTensor - static void VariadicCompute(const DeviceContext& dev_ctx, Args... args) { - return impl_fn(static_cast(dev_ctx), std::forward(args)...); - } - - private: - template - struct CustomComputeCallHelper; - - /* DeviceContext Helpers */ - PD_SPECIALIZE_KernelCallHelper_FOR_DEV_CONTEXT(CPUContext); - - /* Input Helpers */ - PD_SPECIALIZE_KernelCallHelper_FOR_INPUT(Tensor); - PD_SPECIALIZE_KernelCallHelper_FOR_MULTI_INPUT(Tensor); - - /* Attribute Helpers */ - PD_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(bool); - PD_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(float); - PD_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(double); - PD_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(int); - PD_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(int64_t); - PD_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(phi::dtype::float16); - PD_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(DataType); - PD_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(const Scalar&); - PD_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(const ScalarArray&); - PD_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(const std::vector&); - PD_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(const std::vector&); - - /* Output Helpers */ - PD_SPECIALIZE_KernelCallHelper_FOR_OUTPUT(Tensor); - PD_SPECIALIZE_KernelCallHelper_FOR_MULTI_OUTPUT(Tensor); - - // End: base template - template - struct CustomComputeCallHelper> { - template - static void Compute(const DeviceContext& dev_ctx, - const std::vector& inputs, - const std::vector>& vec_inputs, - const std::vector& attrs, - std::vector* outputs, - std::vector>* vec_outputs, - DevCtx device_ctx, - Args... args) { - return impl_fn(device_ctx, args...); - } - }; -}; - -#define PD_PT_KERNEL(...) \ - ::paddle::CustomKernelFuncImpl::Compute - -#define PD_PT_VARIADIC_KERNEL(...) \ - reinterpret_cast( \ - &::paddle::CustomKernelFuncImpl::VariadicCompute) - -////////////////////// Op Kernel Info depended structs ////////////////////// -// TODO(Aganlengzi): Re-define TensorArgDef and AttributeArgDef temporarily. -// TensorArgDef follows phi::TensorArgDef in kernel_factory.h, the -// difference is that custom_kernel needs extra `is_vector' to ensure we can -// deal with case like vector with only one element. -struct TensorArgDef { - phi::Backend backend; - phi::DataLayout layout; - phi::DataType dtype; - bool is_vector{false}; - - TensorArgDef(phi::Backend in_backend, - phi::DataLayout in_layout, - phi::DataType in_dtype, - bool is_vector = false) - : backend(in_backend), - layout(in_layout), - dtype(in_dtype), - is_vector(is_vector) {} - - TensorArgDef& SetBackend(phi::Backend in_backend) { - backend = in_backend; - return *this; - } - - TensorArgDef& SetDataLayout(phi::DataLayout in_layout) { - layout = in_layout; - return *this; - } - - TensorArgDef& SetDataType(phi::DataType in_dtype) { - dtype = in_dtype; - return *this; - } -}; - -// AttributeArgDef follows phi::AttributeArgDef in kernel_factory.h -struct AttributeArgDef { - std::type_index type_index; - - explicit AttributeArgDef(std::type_index type_index) - : type_index(type_index) {} -}; - -////////////////////// Op Kernel Info ////////////////////// -// OpKernelInfo stores all info parsed from user kernel function, includes: -// 0. op_name and kernel key(backend, data_layout and data_type) -// 1. unified custom kernel function -// 2. variadic kernel function(use paddle::Tensor) -// 3. args info and user defined change for specific arg -class PADDLE_API OpKernelInfo { - public: - explicit OpKernelInfo(const std::string& op_name, - phi::Backend backend, - phi::DataLayout data_layout, - phi::DataType data_type) - : op_name_(op_name), - backend_(backend), - layout_(data_layout), - dtype_(data_type) {} - - // format: PD_PT_KERNEL(...) - OpKernelInfo& SetKernelFn(CustomKernelFunc&& func); - // format: PD_PT_VARIADIC_KERNEL(...) - OpKernelInfo& SetVariadicKernelFn(void* func); - - // for Args parsing and storing - void AppendInput(phi::Backend backend, - phi::DataLayout layout, - phi::DataType dtype, - bool is_vector = false) { - input_defs_.emplace_back(TensorArgDef(backend, layout, dtype, is_vector)); - } - - void AppendOutput(phi::Backend backend, - phi::DataLayout layout, - phi::DataType dtype, - bool is_vector = false) { - output_defs_.emplace_back(TensorArgDef(backend, layout, dtype, is_vector)); - } - - void AppendAttribute(std::type_index type_index) { - attribute_defs_.emplace_back(AttributeArgDef(type_index)); - } - - // for Args user-def function - TensorArgDef& InputAt(size_t idx) { return input_defs_.at(idx); } - TensorArgDef& OutputAt(size_t idx) { return output_defs_.at(idx); } - - const phi::Backend& GetBackend() const { return backend_; } - const phi::DataLayout& GetDataLayout() const { return layout_; } - const phi::DataType& GetDataType() const { return dtype_; } - - private: - friend class framework::OpKernelInfoHelper; - - // 1. op info - std::string op_name_; - - // 2. kernel key info - phi::Backend backend_{phi::Backend::UNDEFINED}; - phi::DataLayout layout_{phi::DataLayout::UNDEFINED}; - phi::DataType dtype_{phi::DataType::UNDEFINED}; - - // 3. args info - paddle::SmallVector input_defs_{{}}; - paddle::SmallVector output_defs_{{}}; - paddle::SmallVector attribute_defs_{{}}; - - // 4. func info - CustomKernelFunc kernel_fn_{nullptr}; - void* variadic_kernel_fn_{nullptr}; -}; - -////////////////////// Op Kernel Args Parser ////////////////////// -// Define CustomKernelArgsParseFunctor for args parsing -// We have to store parsed info into OpKernelInfo before -// mapping to phi::KernelArgsDef in phi::Kernel -template -struct CustomKernelArgsParseFunctor; - -template -struct CustomKernelArgsParseFunctor { - using Args = std::tuple; - enum : std::size_t { Arity = sizeof...(Args_) }; - using Indices = std::make_index_sequence; - template - using Arg = typename std::tuple_element::type; - - static void Parse(OpKernelInfo* op_kernel_info) { - const phi::Backend& backend = op_kernel_info->GetBackend(); - const phi::DataLayout& layout = op_kernel_info->GetDataLayout(); - const phi::DataType& dtype = op_kernel_info->GetDataType(); - - auto default_tensor_layout = phi::DataLayout::NCHW; - if (layout != phi::DataLayout::ANY) { - default_tensor_layout = layout; - } - auto args_type = ParseArgType(Indices{}); - for (auto arg_type : args_type) { - if (arg_type == std::type_index(typeid(const CPUContext&))) { - // do nothing, skip context arg now - } else if (arg_type == std::type_index(typeid(const Tensor&))) { - op_kernel_info->AppendInput(backend, default_tensor_layout, dtype); - } else if (arg_type == - std::type_index(typeid(const std::vector&))) { - op_kernel_info->AppendInput( - backend, default_tensor_layout, dtype, true); - } else if (arg_type == std::type_index(typeid(Tensor*))) { - op_kernel_info->AppendOutput(backend, default_tensor_layout, dtype); - } else if (arg_type == std::type_index(typeid(std::vector))) { - op_kernel_info->AppendOutput( - backend, default_tensor_layout, dtype, true); - } else { - op_kernel_info->AppendAttribute(arg_type); - } - } - } - - private: - template - static std::vector ParseArgType( - std::index_sequence) { - return {std::type_index(typeid(Arg))...}; - } -}; - -#define PD_PT_ARGS_PARSE(...) \ - ::paddle::CustomKernelArgsParseFunctor::Parse - -//////////////// Op Kernel Info Map ///////////////// -// all user custom kernels information are stored in this map -class PADDLE_API OpKernelInfoMap { - public: - static OpKernelInfoMap& Instance() { - static OpKernelInfoMap g_custom_kernel_info_map; - return g_custom_kernel_info_map; - } - - std::vector& operator[](const std::string& name); - - const std::unordered_map>& GetMap() - const; - - private: - OpKernelInfoMap() = default; - std::unordered_map> map_; - - PD_DISABLE_COPY_AND_ASSIGN(OpKernelInfoMap); -}; - -//////////////// Op Kernel Info Builder ///////////////// -// format: PD_PT_ARGS_PARSE(...) -using CustomKernelArgsParseFn = void (*)(OpKernelInfo* op_kernel_info); -using CustomKernelArgsDefFn = void (*)(OpKernelInfo* kernel); - -class PADDLE_API OpKernelInfoBuilder { - public: - explicit OpKernelInfoBuilder(std::string&& op_name, - phi::Backend backend, - phi::DataLayout data_layout, - phi::DataType data_type); - - OpKernelInfoBuilder& SetKernelFn(CustomKernelFunc func); - OpKernelInfoBuilder& SetVariadicKernelFn(void* func); - OpKernelInfoBuilder& ArgsParse(CustomKernelArgsParseFn func); - OpKernelInfoBuilder& ArgsDef(CustomKernelArgsDefFn func); - - private: - // op name - std::string op_name_; - - // kernel key info - phi::Backend backend_{phi::Backend::UNDEFINED}; - phi::DataLayout layout_{phi::DataLayout::UNDEFINED}; - phi::DataType dtype_{phi::DataType::UNDEFINED}; - - // ref current info ptr - OpKernelInfo* info_ptr_; -}; -/////////////////////// Custom kernel register API ///////////////////////// -// For inference: compile directly with framework -// Call after PD_REGISTER_BUILTIN_KERNEL(...) -void RegisterAllCustomKernel(); - -//////////////// Custom kernel register macro ///////////////////// -// Refer to paddle/phi/core/kernel_registry.h, we can not use -// PD_REGISTER_KERNEL directly, common macros and functions are -// not ready for custom kernel now. -// Difference: custom_kernel stores all kernels' info into global -// g_custom_kernel_info_map before loading and registering into -// pten kernel management. Only providing PD_REGISTER_BUILTIN_KERNEL which -// supports 2 template arguments. - -#define PD_BACKEND(arg__) phi::Backend::arg__ -#define PD_DATALAYOUT(arg__) phi::DataLayout::arg__ -#define PD_DATATYPE(arg__) phi::DataType::arg__ - -#define PD_NARGS(...) _PD_NARGS((__VA_ARGS__, _PD_RESQ_N())) -#define _PD_NARGS(...) _PD_ARG_N(__VA_ARGS__) -#define _PD_ARG_N_EXPAND( \ - _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, N, ...) \ - N -#define _PD_ARG_N(args) _PD_ARG_N_EXPAND args -#define _PD_RESQ_N() 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0 - -#define PD_CONCATENATE(arg1, arg2) PD_CONCATENATE1(arg1, arg2) -#define PD_CONCATENATE1(arg1, arg2) PD_CONCATENATE2(arg1, arg2) -#define PD_CONCATENATE2(arg1, arg2) arg1##arg2 - -#define PD_EXPAND(x) x - -#ifdef __COUNTER__ -#define PD_ID __COUNTER__ -#else -#define PD_ID __LINE__ -#endif - -#define PD_REGISTER_BUILTIN_KERNEL( \ - kernel_name, backend, layout, func, cpp_dtype, ...) \ - STATIC_ASSERT_GLOBAL_NAMESPACE( \ - _reg_custom_kernel_ns_check_##kernel_name##_##backend##_##layout, \ - "PD_REGISTER_BUILTIN_KERNEL must be called in global namespace."); \ - _PD_REGISTER_2TA_KERNEL( \ - kernel_name, backend, layout, func, cpp_dtype, ##__VA_ARGS__) - -// WIN32 is not supported -#define _PD_REGISTER_2TA_KERNEL( \ - kernel_name, backend, layout, meta_kernel_fn, cpp_dtype, ...) \ - PD_KERNEL_INSTANTIATION(meta_kernel_fn, backend, cpp_dtype, ##__VA_ARGS__); \ - static void __PD_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout( \ - ::paddle::OpKernelInfo* kernel); \ - PD_KERNEL_REGISTRAR_INIT( \ - kernel_name, \ - backend, \ - layout, \ - &__PD_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout, \ - meta_kernel_fn, \ - cpp_dtype, \ - ##__VA_ARGS__); \ - void __PD_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout( \ - ::paddle::OpKernelInfo* kernel) - -#define PD_KERNEL_INSTANTIATION(meta_kernel_fn, backend, cpp_dtype, ...) \ - _PD_KERNEL_INSTANTIATION(PD_NARGS(cpp_dtype, ##__VA_ARGS__), \ - meta_kernel_fn, \ - backend, \ - cpp_dtype, \ - ##__VA_ARGS__) - -#define _PD_KERNEL_INSTANTIATION(N, meta_kernel_fn, backend, cpp_dtype, ...) \ - PD_CONCATENATE(_PD_KERNEL_INSTANTIATION_, N) \ - (meta_kernel_fn, backend, cpp_dtype, ##__VA_ARGS__) - -#define _PD_KERNEL_INSTANTIATION_1(meta_kernel_fn, backend, cpp_dtype, ...) \ - template decltype(meta_kernel_fn) \ - meta_kernel_fn -#define _PD_KERNEL_INSTANTIATION_2(meta_kernel_fn, backend, cpp_dtype, ...) \ - template decltype(meta_kernel_fn) \ - meta_kernel_fn; \ - PD_EXPAND(_PD_KERNEL_INSTANTIATION_1(meta_kernel_fn, backend, ##__VA_ARGS__)) -#define _PD_KERNEL_INSTANTIATION_3(meta_kernel_fn, backend, cpp_dtype, ...) \ - template decltype(meta_kernel_fn) \ - meta_kernel_fn; \ - PD_EXPAND(_PD_KERNEL_INSTANTIATION_2(meta_kernel_fn, backend, ##__VA_ARGS__)) -#define _PD_KERNEL_INSTANTIATION_4(meta_kernel_fn, backend, cpp_dtype, ...) \ - template decltype(meta_kernel_fn) \ - meta_kernel_fn; \ - PD_EXPAND(_PD_KERNEL_INSTANTIATION_3(meta_kernel_fn, backend, ##__VA_ARGS__)) -#define _PD_KERNEL_INSTANTIATION_5(meta_kernel_fn, backend, cpp_dtype, ...) \ - template decltype(meta_kernel_fn) \ - meta_kernel_fn; \ - PD_EXPAND(_PD_KERNEL_INSTANTIATION_4(meta_kernel_fn, backend, ##__VA_ARGS__)) -#define _PD_KERNEL_INSTANTIATION_6(meta_kernel_fn, backend, cpp_dtype, ...) \ - template decltype(meta_kernel_fn) \ - meta_kernel_fn; \ - PD_EXPAND(_PD_KERNEL_INSTANTIATION_5(meta_kernel_fn, backend, ##__VA_ARGS__)) -#define _PD_KERNEL_INSTANTIATION_7(meta_kernel_fn, backend, cpp_dtype, ...) \ - template decltype(meta_kernel_fn) \ - meta_kernel_fn; \ - PD_EXPAND(_PD_KERNEL_INSTANTIATION_6(meta_kernel_fn, backend, ##__VA_ARGS__)) -#define _PD_KERNEL_INSTANTIATION_8(meta_kernel_fn, backend, cpp_dtype, ...) \ - template decltype(meta_kernel_fn) \ - meta_kernel_fn; \ - PD_EXPAND(_PD_KERNEL_INSTANTIATION_7(meta_kernel_fn, backend, ##__VA_ARGS__)) -#define _PD_KERNEL_INSTANTIATION_9(meta_kernel_fn, backend, cpp_dtype, ...) \ - template decltype(meta_kernel_fn) \ - meta_kernel_fn; \ - PD_EXPAND(_PD_KERNEL_INSTANTIATION_8(meta_kernel_fn, backend, ##__VA_ARGS__)) -#define _PD_KERNEL_INSTANTIATION_10(meta_kernel_fn, backend, cpp_dtype, ...) \ - template decltype(meta_kernel_fn) \ - meta_kernel_fn; \ - PD_EXPAND(_PD_KERNEL_INSTANTIATION_9(meta_kernel_fn, backend, ##__VA_ARGS__)) -#define _PD_KERNEL_INSTANTIATION_11(meta_kernel_fn, backend, cpp_dtype, ...) \ - template decltype(meta_kernel_fn) \ - meta_kernel_fn; \ - PD_EXPAND(_PD_KERNEL_INSTANTIATION_10(meta_kernel_fn, backend, ##__VA_ARGS__)) -#define _PD_KERNEL_INSTANTIATION_12(meta_kernel_fn, backend, cpp_dtype, ...) \ - template decltype(meta_kernel_fn) \ - meta_kernel_fn; \ - PD_EXPAND(_PD_KERNEL_INSTANTIATION_11(meta_kernel_fn, backend, ##__VA_ARGS__)) -#define _PD_KERNEL_INSTANTIATION_13(meta_kernel_fn, backend, cpp_dtype, ...) \ - template decltype(meta_kernel_fn) \ - meta_kernel_fn; \ - PD_EXPAND(_PD_KERNEL_INSTANTIATION_12(meta_kernel_fn, backend, ##__VA_ARGS__)) -#define _PD_KERNEL_INSTANTIATION_14(meta_kernel_fn, backend, cpp_dtype, ...) \ - template decltype(meta_kernel_fn) \ - meta_kernel_fn; \ - PD_EXPAND(_PD_KERNEL_INSTANTIATION_13(meta_kernel_fn, backend, ##__VA_ARGS__)) -#define _PD_KERNEL_INSTANTIATION_15(meta_kernel_fn, backend, cpp_dtype, ...) \ - template decltype(meta_kernel_fn) \ - meta_kernel_fn; \ - PD_EXPAND(_PD_KERNEL_INSTANTIATION_14(meta_kernel_fn, backend, ##__VA_ARGS__)) - -#define PD_KERNEL_REGISTRAR_INIT( \ - kernel_name, backend, layout, args_def_fn, meta_kernel_fn, cpp_dtype, ...) \ - _PD_KERNEL_REGISTRAR_INIT(PD_NARGS(cpp_dtype, ##__VA_ARGS__), \ - kernel_name, \ - backend, \ - layout, \ - args_def_fn, \ - meta_kernel_fn, \ - cpp_dtype, \ - ##__VA_ARGS__) - -// clang-format off - -/* The =pre-commit always treats this macro into the wrong format, - and multi-line macros cannot be skipped with NOLINT.*/ -#define _PD_KERNEL_REGISTRAR_INIT(N, \ - kernel_name, \ - backend, \ - layout, \ - args_def_fn, \ - meta_kernel_fn, \ - cpp_dtype, \ - ...) \ - PD_CONCATENATE(_PD_KERNEL_REGISTRAR_INIT_, N) ( \ - kernel_name, \ - backend, \ - layout, \ - PD_ID, \ - args_def_fn, \ - meta_kernel_fn, \ - cpp_dtype, \ - ##__VA_ARGS__) - -// clang-format on - -#define _PD_KERNEL_REGISTRAR_INIT_1(kernel_name, \ - backend, \ - layout, \ - registrar_id, \ - args_def_fn, \ - meta_kernel_fn, \ - cpp_dtype, \ - ...) \ - static ::paddle::OpKernelInfoBuilder PD_CONCATENATE( \ - custom_kernel_info_##kernel_name##_##backend##_##layout##_, \ - registrar_id) = \ - ::paddle::OpKernelInfoBuilder( \ - #kernel_name, \ - PD_BACKEND(backend), \ - PD_DATALAYOUT(layout), \ - ::paddle::experimental::CppTypeToDataType::Type()) \ - .SetKernelFn(PD_PT_KERNEL( \ - meta_kernel_fn)) \ - .SetVariadicKernelFn(PD_PT_VARIADIC_KERNEL( \ - meta_kernel_fn)) \ - .ArgsParse(PD_PT_ARGS_PARSE( \ - meta_kernel_fn)) \ - .ArgsDef(args_def_fn); - -#define _PD_KERNEL_REGISTRAR_INIT_2(kernel_name, \ - backend, \ - layout, \ - registrar_id, \ - args_def_fn, \ - meta_kernel_fn, \ - cpp_dtype, \ - ...) \ - static ::paddle::OpKernelInfoBuilder PD_CONCATENATE( \ - custom_kernel_info_##kernel_name##_##backend##_##layout##_, \ - registrar_id) = \ - ::paddle::OpKernelInfoBuilder( \ - #kernel_name, \ - PD_BACKEND(backend), \ - PD_DATALAYOUT(layout), \ - ::paddle::experimental::CppTypeToDataType::Type()) \ - .SetKernelFn(PD_PT_KERNEL( \ - meta_kernel_fn)) \ - .SetVariadicKernelFn(PD_PT_VARIADIC_KERNEL( \ - meta_kernel_fn)) \ - .ArgsParse(PD_PT_ARGS_PARSE( \ - meta_kernel_fn)) \ - .ArgsDef(args_def_fn); \ - PD_EXPAND(_PD_KERNEL_REGISTRAR_INIT_1(kernel_name, \ - backend, \ - layout, \ - PD_ID, \ - args_def_fn, \ - meta_kernel_fn, \ - ##__VA_ARGS__)) - -#define _PD_KERNEL_REGISTRAR_INIT_3(kernel_name, \ - backend, \ - layout, \ - registrar_id, \ - args_def_fn, \ - meta_kernel_fn, \ - cpp_dtype, \ - ...) \ - static ::paddle::OpKernelInfoBuilder PD_CONCATENATE( \ - custom_kernel_info_##kernel_name##_##backend##_##layout##_, \ - registrar_id) = \ - ::paddle::OpKernelInfoBuilder( \ - #kernel_name, \ - PD_BACKEND(backend), \ - PD_DATALAYOUT(layout), \ - ::paddle::experimental::CppTypeToDataType::Type()) \ - .SetKernelFn(PD_PT_KERNEL( \ - meta_kernel_fn)) \ - .SetVariadicKernelFn(PD_PT_VARIADIC_KERNEL( \ - meta_kernel_fn)) \ - .ArgsParse(PD_PT_ARGS_PARSE( \ - meta_kernel_fn)) \ - .ArgsDef(args_def_fn); \ - PD_EXPAND(_PD_KERNEL_REGISTRAR_INIT_2(kernel_name, \ - backend, \ - layout, \ - PD_ID, \ - args_def_fn, \ - meta_kernel_fn, \ - ##__VA_ARGS__)) - -#define _PD_KERNEL_REGISTRAR_INIT_4(kernel_name, \ - backend, \ - layout, \ - registrar_id, \ - args_def_fn, \ - meta_kernel_fn, \ - cpp_dtype, \ - ...) \ - static ::paddle::OpKernelInfoBuilder PD_CONCATENATE( \ - custom_kernel_info_##kernel_name##_##backend##_##layout##_, \ - registrar_id) = \ - ::paddle::OpKernelInfoBuilder( \ - #kernel_name, \ - PD_BACKEND(backend), \ - PD_DATALAYOUT(layout), \ - ::paddle::experimental::CppTypeToDataType::Type()) \ - .SetKernelFn(PD_PT_KERNEL( \ - meta_kernel_fn)) \ - .SetVariadicKernelFn(PD_PT_VARIADIC_KERNEL( \ - meta_kernel_fn)) \ - .ArgsParse(PD_PT_ARGS_PARSE( \ - meta_kernel_fn)) \ - .ArgsDef(args_def_fn); \ - PD_EXPAND(_PD_KERNEL_REGISTRAR_INIT_3(kernel_name, \ - backend, \ - layout, \ - PD_ID, \ - args_def_fn, \ - meta_kernel_fn, \ - ##__VA_ARGS__)) - -#define _PD_KERNEL_REGISTRAR_INIT_5(kernel_name, \ - backend, \ - layout, \ - registrar_id, \ - args_def_fn, \ - meta_kernel_fn, \ - cpp_dtype, \ - ...) \ - static ::paddle::OpKernelInfoBuilder PD_CONCATENATE( \ - custom_kernel_info_##kernel_name##_##backend##_##layout##_, \ - registrar_id) = \ - ::paddle::OpKernelInfoBuilder( \ - #kernel_name, \ - PD_BACKEND(backend), \ - PD_DATALAYOUT(layout), \ - ::paddle::experimental::CppTypeToDataType::Type()) \ - .SetKernelFn(PD_PT_KERNEL( \ - meta_kernel_fn)) \ - .SetVariadicKernelFn(PD_PT_VARIADIC_KERNEL( \ - meta_kernel_fn)) \ - .ArgsParse(PD_PT_ARGS_PARSE( \ - meta_kernel_fn)) \ - .ArgsDef(args_def_fn); \ - PD_EXPAND(_PD_KERNEL_REGISTRAR_INIT_4(kernel_name, \ - backend, \ - layout, \ - PD_ID, \ - args_def_fn, \ - meta_kernel_fn, \ - ##__VA_ARGS__)) - -#define _PD_KERNEL_REGISTRAR_INIT_6(kernel_name, \ - backend, \ - layout, \ - registrar_id, \ - args_def_fn, \ - meta_kernel_fn, \ - cpp_dtype, \ - ...) \ - static ::paddle::OpKernelInfoBuilder PD_CONCATENATE( \ - custom_kernel_info_##kernel_name##_##backend##_##layout##_, \ - registrar_id) = \ - ::paddle::OpKernelInfoBuilder( \ - #kernel_name, \ - PD_BACKEND(backend), \ - PD_DATALAYOUT(layout), \ - ::paddle::experimental::CppTypeToDataType::Type()) \ - .SetKernelFn(PD_PT_KERNEL( \ - meta_kernel_fn)) \ - .SetVariadicKernelFn(PD_PT_VARIADIC_KERNEL( \ - meta_kernel_fn)) \ - .ArgsParse(PD_PT_ARGS_PARSE( \ - meta_kernel_fn)) \ - .ArgsDef(args_def_fn); \ - PD_EXPAND(_PD_KERNEL_REGISTRAR_INIT_5(kernel_name, \ - backend, \ - layout, \ - PD_ID, \ - args_def_fn, \ - meta_kernel_fn, \ - ##__VA_ARGS__)) - -#define _PD_KERNEL_REGISTRAR_INIT_7(kernel_name, \ - backend, \ - layout, \ - registrar_id, \ - args_def_fn, \ - meta_kernel_fn, \ - cpp_dtype, \ - ...) \ - static ::paddle::OpKernelInfoBuilder PD_CONCATENATE( \ - custom_kernel_info_##kernel_name##_##backend##_##layout##_, \ - registrar_id) = \ - ::paddle::OpKernelInfoBuilder( \ - #kernel_name, \ - PD_BACKEND(backend), \ - PD_DATALAYOUT(layout), \ - ::paddle::experimental::CppTypeToDataType::Type()) \ - .SetKernelFn(PD_PT_KERNEL( \ - meta_kernel_fn)) \ - .SetVariadicKernelFn(PD_PT_VARIADIC_KERNEL( \ - meta_kernel_fn)) \ - .ArgsParse(PD_PT_ARGS_PARSE( \ - meta_kernel_fn)) \ - .ArgsDef(args_def_fn); \ - PD_EXPAND(_PD_KERNEL_REGISTRAR_INIT_6(kernel_name, \ - backend, \ - layout, \ - PD_ID, \ - args_def_fn, \ - meta_kernel_fn, \ - ##__VA_ARGS__)) - -#define _PD_KERNEL_REGISTRAR_INIT_8(kernel_name, \ - backend, \ - layout, \ - registrar_id, \ - args_def_fn, \ - meta_kernel_fn, \ - cpp_dtype, \ - ...) \ - static ::paddle::OpKernelInfoBuilder PD_CONCATENATE( \ - custom_kernel_info_##kernel_name##_##backend##_##layout##_, \ - registrar_id) = \ - ::paddle::OpKernelInfoBuilder( \ - #kernel_name, \ - PD_BACKEND(backend), \ - PD_DATALAYOUT(layout), \ - ::paddle::experimental::CppTypeToDataType::Type()) \ - .SetKernelFn(PD_PT_KERNEL( \ - meta_kernel_fn)) \ - .SetVariadicKernelFn(PD_PT_VARIADIC_KERNEL( \ - meta_kernel_fn)) \ - .ArgsParse(PD_PT_ARGS_PARSE( \ - meta_kernel_fn)) \ - .ArgsDef(args_def_fn); \ - PD_EXPAND(_PD_KERNEL_REGISTRAR_INIT_7(kernel_name, \ - backend, \ - layout, \ - PD_ID, \ - args_def_fn, \ - meta_kernel_fn, \ - ##__VA_ARGS__)) - -#define _PD_KERNEL_REGISTRAR_INIT_9(kernel_name, \ - backend, \ - layout, \ - registrar_id, \ - args_def_fn, \ - meta_kernel_fn, \ - cpp_dtype, \ - ...) \ - static ::paddle::OpKernelInfoBuilder PD_CONCATENATE( \ - custom_kernel_info_##kernel_name##_##backend##_##layout##_, \ - registrar_id) = \ - ::paddle::OpKernelInfoBuilder( \ - #kernel_name, \ - PD_BACKEND(backend), \ - PD_DATALAYOUT(layout), \ - ::paddle::experimental::CppTypeToDataType::Type()) \ - .SetKernelFn(PD_PT_KERNEL( \ - meta_kernel_fn)) \ - .SetVariadicKernelFn(PD_PT_VARIADIC_KERNEL( \ - meta_kernel_fn)) \ - .ArgsParse(PD_PT_ARGS_PARSE( \ - meta_kernel_fn)) \ - .ArgsDef(args_def_fn); \ - PD_EXPAND(_PD_KERNEL_REGISTRAR_INIT_8(kernel_name, \ - backend, \ - layout, \ - PD_ID, \ - args_def_fn, \ - meta_kernel_fn, \ - ##__VA_ARGS__)) - -#define _PD_KERNEL_REGISTRAR_INIT_10(kernel_name, \ - backend, \ - layout, \ - registrar_id, \ - args_def_fn, \ - meta_kernel_fn, \ - cpp_dtype, \ - ...) \ - static ::paddle::OpKernelInfoBuilder PD_CONCATENATE( \ - custom_kernel_info_##kernel_name##_##backend##_##layout##_, \ - registrar_id) = \ - ::paddle::OpKernelInfoBuilder( \ - #kernel_name, \ - PD_BACKEND(backend), \ - PD_DATALAYOUT(layout), \ - ::paddle::experimental::CppTypeToDataType::Type()) \ - .SetKernelFn(PD_PT_KERNEL( \ - meta_kernel_fn)) \ - .SetVariadicKernelFn(PD_PT_VARIADIC_KERNEL( \ - meta_kernel_fn)) \ - .ArgsParse(PD_PT_ARGS_PARSE( \ - meta_kernel_fn)) \ - .ArgsDef(args_def_fn); \ - PD_EXPAND(_PD_KERNEL_REGISTRAR_INIT_9(kernel_name, \ - backend, \ - layout, \ - PD_ID, \ - args_def_fn, \ - meta_kernel_fn, \ - ##__VA_ARGS__)) - -#define _PD_KERNEL_REGISTRAR_INIT_11(kernel_name, \ - backend, \ - layout, \ - registrar_id, \ - args_def_fn, \ - meta_kernel_fn, \ - cpp_dtype, \ - ...) \ - static ::paddle::OpKernelInfoBuilder PD_CONCATENATE( \ - custom_kernel_info_##kernel_name##_##backend##_##layout##_, \ - registrar_id) = \ - ::paddle::OpKernelInfoBuilder( \ - #kernel_name, \ - PD_BACKEND(backend), \ - PD_DATALAYOUT(layout), \ - ::paddle::experimental::CppTypeToDataType::Type()) \ - .SetKernelFn(PD_PT_KERNEL( \ - meta_kernel_fn)) \ - .SetVariadicKernelFn(PD_PT_VARIADIC_KERNEL( \ - meta_kernel_fn)) \ - .ArgsParse(PD_PT_ARGS_PARSE( \ - meta_kernel_fn)) \ - .ArgsDef(args_def_fn); \ - PD_EXPAND(_PD_KERNEL_REGISTRAR_INIT_10(kernel_name, \ - backend, \ - layout, \ - PD_ID, \ - args_def_fn, \ - meta_kernel_fn, \ - ##__VA_ARGS__)) - -#define _PD_KERNEL_REGISTRAR_INIT_12(kernel_name, \ - backend, \ - layout, \ - registrar_id, \ - args_def_fn, \ - meta_kernel_fn, \ - cpp_dtype, \ - ...) \ - static ::paddle::OpKernelInfoBuilder PD_CONCATENATE( \ - custom_kernel_info_##kernel_name##_##backend##_##layout##_, \ - registrar_id) = \ - ::paddle::OpKernelInfoBuilder( \ - #kernel_name, \ - PD_BACKEND(backend), \ - PD_DATALAYOUT(layout), \ - ::paddle::experimental::CppTypeToDataType::Type()) \ - .SetKernelFn(PD_PT_KERNEL( \ - meta_kernel_fn)) \ - .SetVariadicKernelFn(PD_PT_VARIADIC_KERNEL( \ - meta_kernel_fn)) \ - .ArgsParse(PD_PT_ARGS_PARSE( \ - meta_kernel_fn)) \ - .ArgsDef(args_def_fn); \ - PD_EXPAND(_PD_KERNEL_REGISTRAR_INIT_11(kernel_name, \ - backend, \ - layout, \ - PD_ID, \ - args_def_fn, \ - meta_kernel_fn, \ - ##__VA_ARGS__)) - -#define _PD_KERNEL_REGISTRAR_INIT_13(kernel_name, \ - backend, \ - layout, \ - registrar_id, \ - args_def_fn, \ - meta_kernel_fn, \ - cpp_dtype, \ - ...) \ - static ::paddle::OpKernelInfoBuilder PD_CONCATENATE( \ - custom_kernel_info_##kernel_name##_##backend##_##layout##_, \ - registrar_id) = \ - ::paddle::OpKernelInfoBuilder( \ - #kernel_name, \ - PD_BACKEND(backend), \ - PD_DATALAYOUT(layout), \ - ::paddle::experimental::CppTypeToDataType::Type()) \ - .SetKernelFn(PD_PT_KERNEL( \ - meta_kernel_fn)) \ - .SetVariadicKernelFn(PD_PT_VARIADIC_KERNEL( \ - meta_kernel_fn)) \ - .ArgsParse(PD_PT_ARGS_PARSE( \ - meta_kernel_fn)) \ - .ArgsDef(args_def_fn); \ - PD_EXPAND(_PD_KERNEL_REGISTRAR_INIT_12(kernel_name, \ - backend, \ - layout, \ - PD_ID, \ - args_def_fn, \ - meta_kernel_fn, \ - ##__VA_ARGS__)) - -#define _PD_KERNEL_REGISTRAR_INIT_14(kernel_name, \ - backend, \ - layout, \ - registrar_id, \ - args_def_fn, \ - meta_kernel_fn, \ - cpp_dtype, \ - ...) \ - static ::paddle::OpKernelInfoBuilder PD_CONCATENATE( \ - custom_kernel_info_##kernel_name##_##backend##_##layout##_, \ - registrar_id) = \ - ::paddle::OpKernelInfoBuilder( \ - #kernel_name, \ - PD_BACKEND(backend), \ - PD_DATALAYOUT(layout), \ - ::paddle::experimental::CppTypeToDataType::Type()) \ - .SetKernelFn(PD_PT_KERNEL( \ - meta_kernel_fn)) \ - .SetVariadicKernelFn(PD_PT_VARIADIC_KERNEL( \ - meta_kernel_fn)) \ - .ArgsParse(PD_PT_ARGS_PARSE( \ - meta_kernel_fn)) \ - .ArgsDef(args_def_fn); \ - PD_EXPAND(_PD_KERNEL_REGISTRAR_INIT_13(kernel_name, \ - backend, \ - layout, \ - PD_ID, \ - args_def_fn, \ - meta_kernel_fn, \ - ##__VA_ARGS__)) - -#define _PD_KERNEL_REGISTRAR_INIT_15(kernel_name, \ - backend, \ - layout, \ - registrar_id, \ - args_def_fn, \ - meta_kernel_fn, \ - cpp_dtype, \ - ...) \ - static ::paddle::OpKernelInfoBuilder PD_CONCATENATE( \ - custom_kernel_info_##kernel_name##_##backend##_##layout##_, \ - registrar_id) = \ - ::paddle::OpKernelInfoBuilder( \ - #kernel_name, \ - PD_BACKEND(backend), \ - PD_DATALAYOUT(layout), \ - ::paddle::experimental::CppTypeToDataType::Type()) \ - .SetKernelFn(PD_PT_KERNEL( \ - meta_kernel_fn)) \ - .SetVariadicKernelFn(PD_PT_VARIADIC_KERNEL( \ - meta_kernel_fn)) \ - .ArgsParse(PD_PT_ARGS_PARSE( \ - meta_kernel_fn)) \ - .ArgsDef(args_def_fn); \ - PD_EXPAND(_PD_KERNEL_REGISTRAR_INIT_14(kernel_name, \ - backend, \ - layout, \ - PD_ID, \ - args_def_fn, \ - meta_kernel_fn, \ - ##__VA_ARGS__)) -} // namespace paddle diff --git a/paddle/phi/api/lib/op_kernel_info.cc b/paddle/phi/api/lib/op_kernel_info.cc deleted file mode 100644 index c2aef8288da..00000000000 --- a/paddle/phi/api/lib/op_kernel_info.cc +++ /dev/null @@ -1,108 +0,0 @@ -/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/phi/api/ext/op_kernel_info.h" -#include "paddle/fluid/framework/custom_kernel.h" - -namespace paddle { - -////////////////////// Op Kernel Info ////////////////////// - -OpKernelInfo& OpKernelInfo::SetKernelFn(CustomKernelFunc&& func) { - kernel_fn_ = std::forward(func); - return *this; -} - -OpKernelInfo& OpKernelInfo::SetVariadicKernelFn(void* func) { - variadic_kernel_fn_ = func; - return *this; -} - -//////////////// Op Kernel Info Map ///////////////// - -std::vector& OpKernelInfoMap::operator[]( - const std::string& name) { - return map_[name]; -} - -const std::unordered_map>& -OpKernelInfoMap::GetMap() const { - return map_; -} - -//////////////// Op Kernel Info Builder ///////////////// - -OpKernelInfoBuilder::OpKernelInfoBuilder(std::string&& op_name, - phi::Backend backend, - phi::DataLayout data_layout, - phi::DataType data_type) { - // 1. member assign - op_name_ = std::forward(op_name); - backend_ = backend; - layout_ = data_layout; - dtype_ = data_type; - - // 2. info parse - auto& info_vector = OpKernelInfoMap::Instance()[op_name_]; - auto op_kernel_info = OpKernelInfo(op_name_, backend_, layout_, dtype_); - info_vector.emplace_back(std::move(op_kernel_info)); - - // 3. get current info ptr - info_ptr_ = &(info_vector.back()); -} - -OpKernelInfoBuilder& OpKernelInfoBuilder::SetKernelFn(CustomKernelFunc func) { - info_ptr_->SetKernelFn(std::forward(func)); - return *this; -} - -OpKernelInfoBuilder& OpKernelInfoBuilder::SetVariadicKernelFn(void* func) { - info_ptr_->SetVariadicKernelFn(func); - return *this; -} - -OpKernelInfoBuilder& OpKernelInfoBuilder::ArgsParse( - CustomKernelArgsParseFn func) { - func(this->info_ptr_); - return *this; -} - -OpKernelInfoBuilder& OpKernelInfoBuilder::ArgsDef(CustomKernelArgsDefFn func) { - func(this->info_ptr_); - return *this; -} - -/////////////////////// Op register API ///////////////////////// - -// For inference: compile directly with framework -// Call after PD_REGISTER_BUILTIN_KERNEL(...) -void RegisterAllCustomKernel() { - auto& op_kernel_info_map = OpKernelInfoMap::Instance(); - framework::RegisterKernelWithMetaInfoMap(op_kernel_info_map); -} - -} // namespace paddle - -#ifdef __cplusplus -extern "C" { -#endif - -// C-API to get global OpKernelInfoMap. -paddle::OpKernelInfoMap& PD_GetOpKernelInfoMap() { - return paddle::OpKernelInfoMap::Instance(); -} - -#ifdef __cplusplus -} // end extern "C" -#endif diff --git a/paddle/phi/core/kernel_registry.h b/paddle/phi/core/kernel_registry.h index 4603f4123ac..6a1688947b9 100644 --- a/paddle/phi/core/kernel_registry.h +++ b/paddle/phi/core/kernel_registry.h @@ -129,10 +129,10 @@ struct KernelArgsParseFunctor { } }; -// NOTE: used for making a difference between kernels compiled with phi or not. +// NOTE: used for making a difference between inner or outer registration. enum class RegType : uint8_t { - BUILTIN = 0, // compiled with phi - PLUGIN, // separate compiled and registered + INNER = 0, + OUTER, }; // TODO(chenweihang): Polish the kernel selection logic, support the selection @@ -205,7 +205,7 @@ struct KernelRegistrar { Kernel kernel(kernel_fn, variadic_kernel_fn); args_parse_fn(kernel_key, kernel.mutable_args_def()); args_def_fn(kernel_key, &kernel); - if (reg_type == RegType::BUILTIN) { + if (reg_type == RegType::INNER) { KernelFactory::Instance().kernels()[kernel_name][kernel_key] = kernel; } else { CustomKernelMap::Instance().Kernels()[kernel_name][kernel_key] = kernel; @@ -244,7 +244,7 @@ struct KernelRegistrar { * Note: `2TA` means `2 template argument` */ #define PD_REGISTER_KERNEL(kernel_name, backend, layout, meta_kernel_fn, ...) \ - _PD_REGISTER_KERNEL(::phi::RegType::BUILTIN, \ + _PD_REGISTER_KERNEL(::phi::RegType::INNER, \ kernel_name, \ backend, \ ::phi::backend##Context, \ @@ -918,7 +918,7 @@ struct KernelRegistrar { #define PD_REGISTER_GENERAL_KERNEL( \ kernel_name, backend, layout, kernel_fn, dtype) \ _PD_REGISTER_GENERAL_KERNEL( \ - ::phi::RegType::BUILTIN, kernel_name, backend, layout, kernel_fn, dtype) + ::phi::RegType::INNER, kernel_name, backend, layout, kernel_fn, dtype) #define _PD_REGISTER_GENERAL_KERNEL( \ reg_type, kernel_name, backend, layout, kernel_fn, dtype) \ @@ -992,7 +992,7 @@ struct KernelRegistrar { */ #define PD_REGISTER_BUILTIN_KERNEL( \ kernel_name, backend, layout, meta_kernel_fn, ...) \ - _PD_REGISTER_KERNEL(::phi::RegType::PLUGIN, \ + _PD_REGISTER_KERNEL(::phi::RegType::OUTER, \ kernel_name, \ backend, \ ::phi::backend##Context, \ @@ -1007,7 +1007,7 @@ struct KernelRegistrar { */ #define PD_REGISTER_PLUGIN_KERNEL( \ kernel_name, backend, layout, meta_kernel_fn, ...) \ - _PD_REGISTER_KERNEL(::phi::RegType::PLUGIN, \ + _PD_REGISTER_KERNEL(::phi::RegType::OUTER, \ kernel_name, \ backend, \ ::phi::CustomContext, \ -- GitLab