From 68631ed45874c79438a99a18b4415edd9f908dc4 Mon Sep 17 00:00:00 2001 From: Aganlengzi Date: Mon, 21 Feb 2022 21:53:49 +0800 Subject: [PATCH] [PluggableDevice]custom kernel to phi core structs (#39690) * [PluggableDevice]custom kernel to pten core structs * mod extension.h for custom op * compatible python for CI * support custom context * refactor to pten * fix windows and ut --- paddle/fluid/framework/CMakeLists.txt | 4 +- paddle/fluid/framework/custom_kernel.cc | 345 +---- paddle/fluid/framework/custom_kernel.h | 13 +- paddle/fluid/platform/CMakeLists.txt | 4 + paddle/phi/api/all.h | 1 - paddle/phi/api/lib/CMakeLists.txt | 1 - paddle/phi/backends/CMakeLists.txt | 4 + paddle/phi/backends/all_context.h | 3 + paddle/phi/backends/custom/custom_context.cc | 6 +- paddle/phi/backends/custom/custom_context.h | 3 +- paddle/phi/common/backend.h | 26 + paddle/phi/core/CMakeLists.txt | 2 + paddle/phi/core/custom_kernel.cc | 66 + paddle/phi/core/custom_kernel.h | 49 + paddle/phi/core/dense_tensor.h | 3 + paddle/phi/core/kernel_context.h | 1 + paddle/phi/core/kernel_registry.h | 1141 ++++++++++------- paddle/phi/core/kernel_utils.h | 10 + paddle/phi/core/lod_utils.h | 6 + paddle/phi/core/tensor_meta.h | 6 + paddle/phi/core/tensor_utils.h | 24 +- paddle/phi/tests/common/test_backend.cc | 14 + paddle/phi/tests/core/CMakeLists.txt | 1 + .../tests/core/test_custom_kernel.cc} | 151 +-- .../tests/custom_kernel/custom_kernel_dot.cc | 19 +- .../custom_kernel/custom_kernel_dot_setup.py | 36 +- python/setup.py.in | 4 +- 27 files changed, 1001 insertions(+), 942 deletions(-) create mode 100644 paddle/phi/core/custom_kernel.cc create mode 100644 paddle/phi/core/custom_kernel.h rename paddle/{fluid/framework/custom_kernel_test.cc => phi/tests/core/test_custom_kernel.cc} (70%) diff --git a/paddle/fluid/framework/CMakeLists.txt b/paddle/fluid/framework/CMakeLists.txt index 78f5bb077aa..7d527e24a00 100644 --- a/paddle/fluid/framework/CMakeLists.txt +++ b/paddle/fluid/framework/CMakeLists.txt @@ -437,8 +437,7 @@ message(STATUS "branch: ${PADDLE_BRANCH}") configure_file(commit.h.in commit.h) cc_library(custom_operator SRCS custom_operator.cc DEPS tensor attribute framework_proto op_registry operator dynamic_loader string_helper pten_tensor op_meta_info pten_api) -cc_library(custom_kernel SRCS custom_kernel.cc DEPS - tensor attribute framework_proto op_registry operator dynamic_loader string_helper pten_tensor op_kernel_info pten_api) +cc_library(custom_kernel SRCS custom_kernel.cc DEPS op_registry pten_custom_kernel pten_tensor_raw) #cc_binary(test_executor SRCS test_executor.cc DEPS executor op_registry ${GLOB_OP_LIB} ${GLOB_OPERATOR_DEPS} ) #cc_binary(new_executor SRCS new_exec_test.cc DEPS operator op_registry executor ${GLOB_OP_LIB} ${GLOB_OPERATOR_DEPS} profiler) @@ -459,4 +458,3 @@ else() cc_library(fluid_convert_utils SRCS convert_utils.cc DEPS data_type place) endif() cc_test(convert_utils_test SRCS convert_utils_test.cc DEPS fluid_convert_utils) -cc_test(custom_kernel_test SRCS custom_kernel_test.cc DEPS custom_kernel pten_tensor) diff --git a/paddle/fluid/framework/custom_kernel.cc b/paddle/fluid/framework/custom_kernel.cc index 3a00d942464..49a1e0774a6 100644 --- a/paddle/fluid/framework/custom_kernel.cc +++ b/paddle/fluid/framework/custom_kernel.cc @@ -18,355 +18,24 @@ limitations under the License. */ #endif #include "paddle/fluid/framework/custom_kernel.h" -#include -#include -#include -#include "paddle/fluid/framework/op_kernel_info_helper.h" -#include "paddle/fluid/framework/operator.h" -#include "paddle/fluid/platform/enforce.h" -#include "paddle/phi/api/ext/op_kernel_info.h" -#include "paddle/phi/core/compat/convert_utils.h" -#include "paddle/phi/core/kernel_context.h" -#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/core/custom_kernel.h" namespace paddle { - namespace framework { -// set phi::Kernel args_def_ from op_kernel_info -// because we can not set directly to phi::Kernel without exposing -// phi::KernelArgsDef when parsing custom user function -static void ParseArgs(const OpKernelInfo& op_kernel_info, - phi::KernelArgsDef* args_def) { - auto& input_defs = OpKernelInfoHelper::GetInputDefs(op_kernel_info); - auto& output_defs = OpKernelInfoHelper::GetOutputDefs(op_kernel_info); - auto& attribute_defs = OpKernelInfoHelper::GetAttributeDefs(op_kernel_info); - - for (auto& input : input_defs) { - auto type_index = - input.is_vector - ? std::type_index(typeid(const std::vector&)) - : std::type_index(typeid(const phi::DenseTensor&)); - args_def->AppendInput(input.backend, input.layout, input.dtype, type_index); - } - for (auto& output : output_defs) { - auto type_index = - output.is_vector - ? std::type_index(typeid(const std::vector&)) - : std::type_index(typeid(const phi::DenseTensor&)); - args_def->AppendOutput(output.backend, output.layout, output.dtype, - type_index); - } - for (auto& attr : attribute_defs) { - args_def->AppendAttribute(attr.type_index); - } -} - -// custom pten kernel call function define -static void RunKernelFunc(phi::KernelContext* ctx, - const OpKernelInfo& op_kernel_info) { - VLOG(3) << "[CUSTOM KERNEL] RunKernelFunc begin..."; - - // input and output size is not params' num - // but actual Tensors' size - size_t input_size = ctx->InputsSize(); - size_t output_size = ctx->OutputsSize(); - size_t attr_size = ctx->AttrsSize(); - - // parameters' num of unified user kernel function - auto& input_defs = OpKernelInfoHelper::GetInputDefs(op_kernel_info); - auto& output_defs = OpKernelInfoHelper::GetOutputDefs(op_kernel_info); - auto& attribute_defs = OpKernelInfoHelper::GetAttributeDefs(op_kernel_info); - - PADDLE_ENFORCE_GE(input_size, input_defs.size(), - platform::errors::InvalidArgument( - "the size of ctx inputs size (%d) must be larger than " - "the size of kernel input_defs (%d).", - input_size, input_defs.size())); - - PADDLE_ENFORCE_GE(output_size, output_defs.size(), - platform::errors::InvalidArgument( - "the size of ctx outputs size (%d) must be larger than " - "the size of kernel output_defs (%d).", - output_size, output_defs.size())); - - PADDLE_ENFORCE_EQ(attr_size, attribute_defs.size(), - platform::errors::InvalidArgument( - "the size of ctx attribute size (%d) must be equal to " - "to the size of kernel attribute_defs (%d).", - attr_size, attribute_defs.size())); - - VLOG(3) << "[CUSTOM KERNEL] Input num: " << input_defs.size() - << "[tensor size:" << input_size << "]" - << " Attribute num: " << attribute_defs.size() - << " Output num: " << output_defs.size() - << "[tensor size:" << output_size << "]."; - - // Inputs mapping - std::vector custom_ins; - std::vector> custom_vec_ins; - for (size_t in_idx = 0; in_idx < input_defs.size(); ++in_idx) { - VLOG(3) << "Mapping Input[" << in_idx << "]"; - const std::pair range = ctx->InputRangeAt(in_idx); - - // is_vector tells if this Input is Tensor or std::vector - if (!input_defs.at(in_idx).is_vector) { - paddle::experimental::Tensor custom_t; - auto& ctx_tensor = ctx->InputAt(range.first); - custom_t.set_impl(std::make_shared(ctx_tensor)); - custom_ins.emplace_back(custom_t); - } else { - std::vector custom_vec_in; - auto ctx_tensor_vec = - ctx->MoveInputsBetween(range.first, range.second); - for (auto& ctx_tensor : ctx_tensor_vec) { - paddle::experimental::Tensor custom_t; - custom_t.set_impl(std::make_shared(ctx_tensor)); - custom_vec_in.emplace_back(custom_t); - } - custom_vec_ins.emplace_back(custom_vec_in); - } - VLOG(3) << "Mapped Input[" << in_idx << "] with range[" << range.first - << "," << range.second << ")."; - } - - // Attributes mapping - std::vector custom_attrs; - for (size_t attr_idx = 0; attr_idx < attribute_defs.size(); ++attr_idx) { - VLOG(3) << "Mapping Attribute[" << attr_idx << "]"; - if (attribute_defs[attr_idx].type_index == std::type_index(typeid(bool))) { - bool arg = ctx->AttrAt(attr_idx); - custom_attrs.emplace_back(arg); - } else if (attribute_defs[attr_idx].type_index == - std::type_index(typeid(int))) { - int arg = ctx->AttrAt(attr_idx); - custom_attrs.emplace_back(arg); - } else if (attribute_defs[attr_idx].type_index == - std::type_index(typeid(float))) { - float arg = ctx->AttrAt(attr_idx); - custom_attrs.emplace_back(arg); - } else if (attribute_defs[attr_idx].type_index == - std::type_index(typeid(double))) { - double arg = ctx->AttrAt(attr_idx); - custom_attrs.emplace_back(arg); - } else if (attribute_defs[attr_idx].type_index == - std::type_index(typeid(int64_t))) { - int64_t arg = ctx->AttrAt(attr_idx); - custom_attrs.emplace_back(arg); - } else if (attribute_defs[attr_idx].type_index == - std::type_index(typeid(phi::dtype::float16))) { - phi::dtype::float16 arg = ctx->AttrAt(attr_idx); - custom_attrs.emplace_back(arg); - } else if (attribute_defs[attr_idx].type_index == - std::type_index(typeid(DataType))) { - DataType arg = ctx->AttrAt(attr_idx); - custom_attrs.emplace_back(arg); - } else if (attribute_defs[attr_idx].type_index == - std::type_index(typeid(const Scalar&))) { - const Scalar& arg = ctx->AttrAt(attr_idx); - custom_attrs.emplace_back(arg); - } else if (attribute_defs[attr_idx].type_index == - std::type_index(typeid(const std::vector&))) { - const std::vector& arg = - ctx->AttrAt&>(attr_idx); - custom_attrs.emplace_back(arg); - } else if (attribute_defs[attr_idx].type_index == - std::type_index(typeid(const ScalarArray&))) { - const ScalarArray& arg = ctx->AttrAt(attr_idx); - custom_attrs.emplace_back(arg); - } else if (attribute_defs[attr_idx].type_index == - std::type_index(typeid(const std::vector&))) { - const std::vector& arg = - ctx->AttrAt&>(attr_idx); - custom_attrs.emplace_back(arg); - } else { - PADDLE_THROW(platform::errors::Unimplemented( - "Unsupported attribute attribute_defs[%d].type_index", attr_idx)); - } - VLOG(3) << "Mapped Attribute[" << attr_idx << "]"; - } - - // Outputs mapping - std::vector custom_outs; - std::vector> custom_vec_outs; - std::vector> custom_outs_ptr; - std::vector>> - custom_vec_outs_ptr; - - for (size_t out_idx = 0; out_idx < output_defs.size(); ++out_idx) { - VLOG(3) << "Mapping Output[" << out_idx << "]"; - const std::pair range = ctx->OutputRangeAt(out_idx); - - // is_vector tells if this Output is Tensor or std::vector - if (!output_defs.at(out_idx).is_vector) { - auto* ctx_tensor = ctx->MutableOutputAt(range.first); - auto* custom_t = new paddle::experimental::Tensor(); - auto custom_t_ptr = std::make_shared(*ctx_tensor); - custom_t->set_impl(custom_t_ptr); - custom_outs.emplace_back(custom_t); - custom_outs_ptr.emplace_back(custom_t_ptr); - } else { - std::vector custom_vec_out; - std::vector> custom_vec_out_ptr; - auto ctx_tensor_vec = ctx->MutableOutputBetween( - range.first, range.second); - for (auto ctx_tensor : ctx_tensor_vec) { - auto* custom_t = new paddle::experimental::Tensor(); - auto custom_t_ptr = std::make_shared(*ctx_tensor); - custom_t->set_impl(custom_t_ptr); - custom_vec_out.emplace_back(custom_t); - custom_vec_out_ptr.emplace_back(custom_t_ptr); - } - custom_vec_outs.emplace_back(custom_vec_out); - custom_vec_outs_ptr.emplace_back(custom_vec_out_ptr); - } - VLOG(3) << "Mapped Output[" << out_idx << "] with range[" << range.first - << "," << range.second << ")."; - } - - // DeviceContext - // In pten, the first paramter XXContext is decided when registering - // through template param, but custom kernel function use unified - // DeviceContext as first parameter of user_kernel_fn, we use backend - // from OpKernelInfo to decide XXContext. In temporary simple - // DeviceContext, we just set necessary info to dev_ctx(such as stream - // in NPUContext), more related work should be done when - // phi::DeviceContext is exposed to outer. - DeviceContext dev_ctx; - auto& backend = OpKernelInfoHelper::GetBackend(op_kernel_info); - if (backend == phi::Backend::CPU) { - // do nothing - } else { -#ifdef PADDLE_WITH_CUSTOM_DEVICE - size_t device_type_id_ = static_cast(backend) - - static_cast(phi::Backend::ALL_BACKEND); - std::string device_type = phi::GetGlobalDeviceType(device_type_id_); - if (!device_type.empty()) { - auto custom_ctx = - ctx->GetDeviceContext(); - dev_ctx.set_stream(custom_ctx.stream()); - return; - } -#endif - LOG(ERROR) << "[CUSTOM KERNEL] Unsupported kernel backend: " << backend - << " with compiled Paddle."; - return; - } - - auto& user_kernel_fn = OpKernelInfoHelper::GetKernelFn(op_kernel_info); - // call user function - user_kernel_fn(dev_ctx, custom_ins, custom_vec_ins, custom_attrs, - &custom_outs, &custom_vec_outs); - - VLOG(3) << "[CUSTOM KERNEL] finished call user kernel function."; - - // NOTE: Map back the output tensors with stored shared_ptrs. - for (int out_idx = output_defs.size() - 1; out_idx >= 0; --out_idx) { - VLOG(3) << "Mapping Back Output[" << out_idx << "]"; - const std::pair range = ctx->OutputRangeAt(out_idx); - - // is_vector tells if this Output is Tensor or std::vector - if (!output_defs.at(out_idx).is_vector) { - auto* ctx_tensor = ctx->MutableOutputAt(range.first); - *ctx_tensor = *(custom_outs_ptr.back().get()); - custom_outs_ptr.pop_back(); - } else { - auto ctx_tensor_vec = ctx->MutableOutputBetween( - range.first, range.second); - auto custom_vec_ptr_out = custom_vec_outs_ptr.back(); - for (int idx = ctx_tensor_vec.size() - 1; idx >= 0; --idx) { - *(ctx_tensor_vec[idx]) = *(custom_vec_ptr_out.back().get()); - custom_vec_ptr_out.pop_back(); - } - custom_vec_outs_ptr.pop_back(); - } - VLOG(3) << "Mapped Output[" << out_idx << "] with range[" << range.first - << "," << range.second << "]."; - } - - // delete newed paddle::Tensor for outputs while calling user kernel function - for (size_t i = 0; i < custom_outs.size(); ++i) { - delete custom_outs[i]; - } - for (size_t i = 0; i < custom_vec_outs.size(); ++i) { - for (size_t j = 0; j < custom_vec_outs[i].size(); ++j) { - delete custom_vec_outs[i][j]; - } - } -} - -void RegisterKernelWithMetaInfo( - const std::vector& op_kernel_infos) { - for (size_t i = 0; i < op_kernel_infos.size(); ++i) { - auto& kernel_info = op_kernel_infos[i]; - auto op_type = OpKernelInfoHelper::GetOpName(kernel_info); - auto kernel_key = OpKernelInfoHelper::GetKernelKey(kernel_info); - - VLOG(3) << "[CUSTOM KERNEL] registering [" << op_type << "]" << kernel_key; - - // 1.Check whether this kernel is valid for a specific operator - PADDLE_ENFORCE_EQ( - phi::KernelFactory::Instance().HasCompatiblePtenKernel(op_type), true, - platform::errors::InvalidArgument( - "[CUSTOM KERNEL] %s is not ready for custom kernel registering.", - op_type)); - - // 2.Check whether kernel_key has been already registed - PADDLE_ENFORCE_EQ( - phi::KernelFactory::Instance().kernels()[op_type].find(kernel_key), - phi::KernelFactory::Instance().kernels()[op_type].end(), - platform::errors::InvalidArgument( - "[CUSTOM KERNEL] The operator <%s>'s kernel: %s has been " - "already existed in Paddle, please contribute PR if need " - "to optimize the kernel code. Custom kernel do NOT support " - "to replace existing kernel in Paddle.", - op_type, kernel_key)); - - // phi::KernelFn - phi::KernelFn kernel_fn = [kernel_info](phi::KernelContext* ctx) { - VLOG(3) << "[CUSTOM KERNEL] run custom PTEN kernel func in lambda."; - RunKernelFunc(ctx, kernel_info); - }; - // variadic_kernel_fn - void* variadic_kernel_fn = - OpKernelInfoHelper::GetVariadicKernelFn(kernel_info); - phi::Kernel kernel(kernel_fn, variadic_kernel_fn); - // args info - ParseArgs(kernel_info, kernel.mutable_args_def()); - // register custom kernel to phi::KernelFactory - phi::KernelFactory::Instance().kernels()[op_type][kernel_key] = kernel; - VLOG(3) << "[CUSTOM KERNEL] Successed in registering operator <" << op_type - << ">'s kernel " << kernel_key << " to Paddle. " - << "It will be used like native ones."; - } -} - -void RegisterKernelWithMetaInfoMap( - const paddle::OpKernelInfoMap& op_kernel_info_map) { - auto& kernel_info_map = op_kernel_info_map.GetMap(); - VLOG(3) << "[CUSTOM KERNEL] size of op_kernel_info_map: " - << kernel_info_map.size(); - - // pair: {op_type, OpKernelInfo} - for (auto& pair : kernel_info_map) { - VLOG(3) << "[CUSTOM KERNEL] pair first -> op name: " << pair.first; - RegisterKernelWithMetaInfo(pair.second); - } -} - void LoadCustomKernelLib(const std::string& dso_lib_path, void* dso_handle) { #ifdef _LINUX - typedef OpKernelInfoMap& get_op_kernel_info_map_t(); - auto* func = reinterpret_cast( - dlsym(dso_handle, "PD_GetOpKernelInfoMap")); + typedef phi::CustomKernelMap& get_custom_kernel_map_t(); + auto* func = reinterpret_cast( + dlsym(dso_handle, "PD_GetCustomKernelMap")); if (func == nullptr) { LOG(WARNING) << "Skipped lib [" << dso_lib_path << "]: fail to find " - << "PD_GetOpKernelInfoMap symbol in this lib."; + << "PD_GetCustomKernelMap symbol in this lib."; return; } - auto& op_kernel_info_map = func(); - RegisterKernelWithMetaInfoMap(op_kernel_info_map); + auto& custom_kernel_map = func(); + phi::RegisterCustomKernels(custom_kernel_map); LOG(INFO) << "Successed in loading custom kernels in lib: " << dso_lib_path; #else VLOG(3) << "Unsupported: Custom kernel is only implemented on Linux."; diff --git a/paddle/fluid/framework/custom_kernel.h b/paddle/fluid/framework/custom_kernel.h index 30bccc97000..31084a34413 100644 --- a/paddle/fluid/framework/custom_kernel.h +++ b/paddle/fluid/framework/custom_kernel.h @@ -14,22 +14,13 @@ limitations under the License. */ #pragma once -#include "paddle/phi/api/ext/op_kernel_info.h" +#include namespace paddle { namespace framework { +// Load custom kernel lib and register void LoadCustomKernelLib(const std::string& dso_lib_path, void* dso_handle); -// Load custom kernel api: register kernel after user compiled -void LoadOpKernelInfoAndRegister(const std::string& dso_name); - -// Register custom kernel api: register kernel directly -void RegisterKernelWithMetaInfoMap( - const paddle::OpKernelInfoMap& op_kernel_info_map); - -// Interface for selective register custom kernel. -void RegisterKernelWithMetaInfo( - const std::vector& op_kernel_infos); } // namespace framework } // namespace paddle diff --git a/paddle/fluid/platform/CMakeLists.txt b/paddle/fluid/platform/CMakeLists.txt index be02bac1aa0..b808e1561b2 100644 --- a/paddle/fluid/platform/CMakeLists.txt +++ b/paddle/fluid/platform/CMakeLists.txt @@ -146,6 +146,10 @@ if(WITH_ASCEND_CL) target_link_libraries(device_context npu_resource_pool) endif() +if(WITH_CUSTOM_DEVICE) + target_link_libraries(device_context custom_context) +endif() + cc_test(init_test SRCS init_test.cc DEPS device_context) # Manage all device event library diff --git a/paddle/phi/api/all.h b/paddle/phi/api/all.h index 8d840214092..06f3cd84476 100644 --- a/paddle/phi/api/all.h +++ b/paddle/phi/api/all.h @@ -41,7 +41,6 @@ limitations under the License. */ #include "paddle/phi/api/ext/dispatch.h" #include "paddle/phi/api/ext/dll_decl.h" #include "paddle/phi/api/ext/exception.h" -#include "paddle/phi/api/ext/op_kernel_info.h" #include "paddle/phi/api/ext/op_meta_info.h" #include "paddle/phi/api/ext/place.h" #include "paddle/phi/api/ext/tensor_compat.h" diff --git a/paddle/phi/api/lib/CMakeLists.txt b/paddle/phi/api/lib/CMakeLists.txt index 175bf34c0da..720c6f54bb0 100644 --- a/paddle/phi/api/lib/CMakeLists.txt +++ b/paddle/phi/api/lib/CMakeLists.txt @@ -90,7 +90,6 @@ cc_library(manual_api SRCS manual_api.cc DEPS pten_tensor_raw pten kernel_dispat cc_library(pten_tensor SRCS tensor_method.cc DEPS pten_tensor_raw pten_function_api) cc_library(op_meta_info SRCS op_meta_info.cc DEPS pten_tensor) -cc_library(op_kernel_info SRCS op_kernel_info.cc DEPS pten_tensor_raw) cc_library(sparse_api SRCS sparse_api.cc DEPS pten_tensor pten kernel_dispatch pten_data_transform) cc_library(pten_function_api SRCS ${api_source_file} DEPS pten_tensor pten kernel_dispatch pten_data_transform) diff --git a/paddle/phi/backends/CMakeLists.txt b/paddle/phi/backends/CMakeLists.txt index 441bd0a8c30..38366d57841 100644 --- a/paddle/phi/backends/CMakeLists.txt +++ b/paddle/phi/backends/CMakeLists.txt @@ -21,3 +21,7 @@ endif() if(WITH_GPU) add_dependencies(pten_context gpu_context) endif() + +if(WITH_CUSTOM_DEVICE) + add_dependencies(pten_context custom_context) +endif() diff --git a/paddle/phi/backends/all_context.h b/paddle/phi/backends/all_context.h index b53c5ce5c78..3fe03905e42 100644 --- a/paddle/phi/backends/all_context.h +++ b/paddle/phi/backends/all_context.h @@ -21,12 +21,15 @@ limitations under the License. */ // path replacement after implementing pten DeviceContext #include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/backends/custom/custom_context.h" #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/backends/xpu/xpu_context.h" +#ifndef PADDLE_WITH_CUSTOM_KERNEL // TODO(wilber): DeviceContextPool nees include fluid file. #include "paddle/fluid/platform/device_context.h" namespace phi { using DeviceContextPool = paddle::platform::DeviceContextPool; } // namespace phi +#endif diff --git a/paddle/phi/backends/custom/custom_context.cc b/paddle/phi/backends/custom/custom_context.cc index 445f5508391..bde3b6a0853 100644 --- a/paddle/phi/backends/custom/custom_context.cc +++ b/paddle/phi/backends/custom/custom_context.cc @@ -32,8 +32,8 @@ struct CustomContext::Impl { const Place& GetPlace() const { return place_; } - C_Stream stream() const { - return reinterpret_cast(stream_->raw_stream()); + void* stream() const { + return reinterpret_cast(stream_->raw_stream()); } void Wait() const { stream_->Wait(); } @@ -47,7 +47,7 @@ void CustomContext::Init() { impl_->Init(); } const Place& CustomContext::GetPlace() const { return impl_->GetPlace(); } -C_Stream CustomContext::stream() const { return impl_->stream(); } +void* CustomContext::stream() const { return impl_->stream(); } void CustomContext::Wait() const { return impl_->Wait(); } diff --git a/paddle/phi/backends/custom/custom_context.h b/paddle/phi/backends/custom/custom_context.h index 109f5e53707..37b0ee21219 100644 --- a/paddle/phi/backends/custom/custom_context.h +++ b/paddle/phi/backends/custom/custom_context.h @@ -15,7 +15,6 @@ limitations under the License. */ #pragma once #include -#include "paddle/fluid/platform/device/device_ext.h" #include "paddle/phi/common/place.h" #include "paddle/phi/core/device_context.h" @@ -30,7 +29,7 @@ class CustomContext : public DeviceContext { const Place& GetPlace() const override; /*! \brief Return stream in the device context. */ - C_Stream stream() const; + void* stream() const; // Wait for all operations completion in the stream. void Wait() const override; diff --git a/paddle/phi/common/backend.h b/paddle/phi/common/backend.h index f7c39eacae9..62692fb9475 100644 --- a/paddle/phi/common/backend.h +++ b/paddle/phi/common/backend.h @@ -130,6 +130,32 @@ inline std::ostream& operator<<(std::ostream& os, Backend backend) { return os; } +inline Backend StringToBackend(const char* backend_cstr) { + std::string s(backend_cstr); + if (s == std::string("Undefined")) { + return Backend::UNDEFINED; + } + for (size_t i = 0; i < s.size(); ++i) { + s[i] = toupper(s[i]); + } + if (s == std::string("CPU")) { + return Backend::CPU; + } else if (s == std::string("GPU")) { + return Backend::GPU; + } else if (s == std::string("XPU")) { + return Backend::XPU; + } else if (s == std::string("NPU")) { + return Backend::NPU; + } else if (s == std::string("MKLDNN")) { + return Backend::MKLDNN; + } else if (s == std::string("CUDNN")) { + return Backend::CUDNN; + } else { + return static_cast(static_cast(Backend::NUM_BACKENDS) + + phi::GetOrRegisterGlobalDeviceTypeId(s)); + } +} + } // namespace experimental } // namespace paddle diff --git a/paddle/phi/core/CMakeLists.txt b/paddle/phi/core/CMakeLists.txt index 18f209377ba..32b9b42f74f 100644 --- a/paddle/phi/core/CMakeLists.txt +++ b/paddle/phi/core/CMakeLists.txt @@ -25,6 +25,8 @@ cc_library(meta_tensor SRCS meta_tensor.cc DEPS tensor_base tensor_meta dense_te cc_library(infermeta_utils SRCS infermeta_utils.cc DEPS meta_tensor) cc_library(selected_rows SRCS selected_rows_impl.cc DEPS dense_tensor mixed_vector pten_enforce ddim) +cc_library(pten_custom_kernel SRCS custom_kernel.cc DEPS kernel_factory convert_utils) + # Will remove once we implemented MKLDNN_Tensor if(WITH_MKLDNN) add_dependencies(dense_tensor mkldnn) diff --git a/paddle/phi/core/custom_kernel.cc b/paddle/phi/core/custom_kernel.cc new file mode 100644 index 00000000000..75ff9cc2860 --- /dev/null +++ b/paddle/phi/core/custom_kernel.cc @@ -0,0 +1,66 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/core/custom_kernel.h" + +namespace phi { + +void RegisterCustomKernels(const CustomKernelMap& custom_kernel_map) { + auto& kernel_info_map = custom_kernel_map.GetMap(); + VLOG(3) << "Size of custom_kernel_map: " << kernel_info_map.size(); + + for (auto& pair : kernel_info_map) { + PADDLE_ENFORCE_EQ( + KernelFactory::Instance().HasCompatiblePtenKernel(pair.first), + true, + phi::errors::InvalidArgument( + "The kernel %s is not ready for custom kernel registering.", + pair.first)); + + for (auto& info_pair : pair.second) { + auto& kernels = KernelFactory::Instance().kernels(); + PADDLE_ENFORCE_EQ( + kernels[pair.first].find(info_pair.first), + kernels[pair.first].end(), + phi::errors::InvalidArgument( + "The operator <%s>'s kernel: %s has been already existed " + "in Paddle, please contribute PR if it is necessary " + "to optimize the kernel code. Custom kernel does NOT support " + "to replace existing kernel in Paddle.", + pair.first, + info_pair.first)); + + kernels[pair.first][info_pair.first] = info_pair.second; + + VLOG(3) << "Successed in registering operator <" << pair.first + << ">'s kernel: " << info_pair.first + << " to Paddle. It will be used like native ones."; + } + } +} + +} // namespace phi + +#ifdef __cplusplus +extern "C" { +#endif + +// C-API to get global CustomKernelMap. +phi::CustomKernelMap& PD_GetCustomKernelMap() { + return phi::CustomKernelMap::Instance(); +} + +#ifdef __cplusplus +} // end extern "C" +#endif diff --git a/paddle/phi/core/custom_kernel.h b/paddle/phi/core/custom_kernel.h new file mode 100644 index 00000000000..20ae2b7bb73 --- /dev/null +++ b/paddle/phi/core/custom_kernel.h @@ -0,0 +1,49 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/kernel_factory.h" +#include "paddle/phi/core/macros.h" + +namespace phi { +/** + * Note: + * Used to store kernels' info before registered to KernelFactory. + */ +class CustomKernelMap { + public: + static CustomKernelMap& Instance() { + static CustomKernelMap g_custom_kernel_info_map; + return g_custom_kernel_info_map; + } + + KernelNameMap& Kernels() { return kernels_; } + + const KernelNameMap& GetMap() const { return kernels_; } + + private: + CustomKernelMap() = default; + DISABLE_COPY_AND_ASSIGN(CustomKernelMap); + + KernelNameMap kernels_; +}; + +/** + * Note: + * Used to register custom kernels to KernelFactory. + */ +void RegisterCustomKernels(const CustomKernelMap& custom_kernel_map); + +} // namespace phi diff --git a/paddle/phi/core/dense_tensor.h b/paddle/phi/core/dense_tensor.h index 622cedf1d7f..0dddd63099b 100644 --- a/paddle/phi/core/dense_tensor.h +++ b/paddle/phi/core/dense_tensor.h @@ -171,6 +171,9 @@ class DenseTensor : public TensorBase, DenseTensorMeta meta_; std::shared_ptr holder_; +#ifndef PADDLE_WITH_CUSTOM_KERNEL #include "paddle/phi/core/dense_tensor.inl" +#endif }; + } // namespace phi diff --git a/paddle/phi/core/kernel_context.h b/paddle/phi/core/kernel_context.h index 0b960004fcb..57e2db60c24 100644 --- a/paddle/phi/core/kernel_context.h +++ b/paddle/phi/core/kernel_context.h @@ -22,6 +22,7 @@ #include "paddle/phi/core/tensor_base.h" #include "paddle/phi/core/tensor_utils.h" #include "paddle/utils/any.h" +#include "paddle/utils/optional.h" #include "paddle/utils/small_vector.h" namespace phi { diff --git a/paddle/phi/core/kernel_registry.h b/paddle/phi/core/kernel_registry.h index 577e9e28cf3..a93c9a28260 100644 --- a/paddle/phi/core/kernel_registry.h +++ b/paddle/phi/core/kernel_registry.h @@ -21,6 +21,7 @@ #include #include +#include "paddle/phi/core/custom_kernel.h" #include "paddle/phi/core/kernel_factory.h" #include "paddle/phi/core/kernel_utils.h" #include "paddle/phi/core/macros.h" @@ -62,6 +63,9 @@ struct KernelArgsParseFunctor { #elif defined(PADDLE_WITH_XPU) || arg_type == std::type_index(typeid(const XPUContext&))) { +#elif defined(PADDLE_WITH_CUSTOM_DEVICE) + || + arg_type == std::type_index(typeid(const CustomContext&))) { #else ) { #endif @@ -83,11 +87,13 @@ struct KernelArgsParseFunctor { default_tensor_layout, default_key.dtype(), arg_type); +#ifndef PADDLE_WITH_CUSTOM_KERNEL } else if (arg_type == std::type_index(typeid(const SelectedRows&))) { args_def->AppendInput(default_key.backend(), default_tensor_layout, default_key.dtype(), arg_type); +#endif } else if (arg_type == std::type_index(typeid(DenseTensor*))) { args_def->AppendOutput(default_key.backend(), default_tensor_layout, @@ -99,11 +105,13 @@ struct KernelArgsParseFunctor { default_tensor_layout, default_key.dtype(), arg_type); +#ifndef PADDLE_WITH_CUSTOM_KERNEL } else if (arg_type == std::type_index(typeid(SelectedRows*))) { args_def->AppendOutput(default_key.backend(), default_tensor_layout, default_key.dtype(), arg_type); +#endif } else { // Attribute deal with // TODO(chenweihang): now here allow any types of attribute, maybe @@ -121,20 +129,28 @@ struct KernelArgsParseFunctor { } }; +// NOTE: used for making a difference between kernels compiled with phi or not. +enum class RegType : uint8_t { + BUILTIN = 0, // compiled with phi + PLUGIN, // separate compiled and registered +}; + // TODO(chenweihang): Polish the kernel selection logic, support the selection // of ALL_DTYPE kernel, and simplify the constructor struct KernelRegistrar { public: - KernelRegistrar(const char* kernel_name_cstr, - Backend backend, + KernelRegistrar(RegType reg_type, + const char* kernel_name_cstr, + const char* backend_cstr, DataLayout layout, DataType dtype, KernelArgsParseFn args_parse_fn, KernelArgsDefFn args_def_fn, KernelFn kernel_fn, void* variadic_kernel_fn) { - ConstructKernel(kernel_name_cstr, - backend, + ConstructKernel(reg_type, + kernel_name_cstr, + backend_cstr, layout, dtype, args_parse_fn, @@ -143,8 +159,9 @@ struct KernelRegistrar { variadic_kernel_fn); } - KernelRegistrar(const char* kernel_name_cstr, - Backend backend, + KernelRegistrar(RegType reg_type, + const char* kernel_name_cstr, + const char* backend_cstr, DataLayout layout, KernelArgsParseFn args_parse_fn, KernelArgsDefFn args_def_fn, @@ -160,8 +177,9 @@ struct KernelRegistrar { dtype == static_cast(DataType::UINT16)) { continue; } - ConstructKernel(kernel_name_cstr, - backend, + ConstructKernel(reg_type, + kernel_name_cstr, + backend_cstr, layout, static_cast(dtype), args_parse_fn, @@ -172,8 +190,9 @@ struct KernelRegistrar { } private: - void ConstructKernel(const char* kernel_name_cstr, - Backend backend, + void ConstructKernel(RegType reg_type, + const char* kernel_name_cstr, + const char* backend_cstr, DataLayout layout, DataType dtype, KernelArgsParseFn args_parse_fn, @@ -181,11 +200,16 @@ struct KernelRegistrar { KernelFn kernel_fn, void* variadic_kernel_fn) { std::string kernel_name(kernel_name_cstr); - KernelKey kernel_key(backend, layout, dtype); + KernelKey kernel_key( + paddle::experimental::StringToBackend(backend_cstr), layout, dtype); Kernel kernel(kernel_fn, variadic_kernel_fn); args_parse_fn(kernel_key, kernel.mutable_args_def()); args_def_fn(kernel_key, &kernel); - KernelFactory::Instance().kernels()[kernel_name][kernel_key] = kernel; + if (reg_type == RegType::BUILTIN) { + KernelFactory::Instance().kernels()[kernel_name][kernel_key] = kernel; + } else { + CustomKernelMap::Instance().Kernels()[kernel_name][kernel_key] = kernel; + } } }; @@ -220,21 +244,38 @@ struct KernelRegistrar { * Note: `2TA` means `2 template argument` */ #define PT_REGISTER_KERNEL(kernel_name, backend, layout, meta_kernel_fn, ...) \ - PT_STATIC_ASSERT_GLOBAL_NAMESPACE( \ - pt_register_tp_kernel_ns_check_##kernel_name##_##backend##_##layout, \ - "PT_REGISTER_KERNEL must be called in global namespace."); \ - PT_EXPAND(_PT_REGISTER_2TA_KERNEL( \ - kernel_name, backend, layout, meta_kernel_fn, __VA_ARGS__)) + _PT_REGISTER_KERNEL(::phi::RegType::BUILTIN, \ + kernel_name, \ + backend, \ + ::phi::backend##Context, \ + layout, \ + meta_kernel_fn, \ + __VA_ARGS__) + +#define _PT_REGISTER_KERNEL( \ + reg_type, kernel_name, backend, context, layout, meta_kernel_fn, ...) \ + PT_STATIC_ASSERT_GLOBAL_NAMESPACE( \ + pt_register_tp_kernel_ns_check_##kernel_name##_##backend##_##layout, \ + "PT_REGISTER_KERNEL must be called in global namespace."); \ + PT_EXPAND(_PT_REGISTER_2TA_KERNEL(reg_type, \ + kernel_name, \ + backend, \ + context, \ + layout, \ + meta_kernel_fn, \ + __VA_ARGS__)) #ifndef _WIN32 #define _PT_REGISTER_2TA_KERNEL( \ - kernel_name, backend, layout, meta_kernel_fn, ...) \ - PT_KERNEL_INSTANTIATION(meta_kernel_fn, backend, __VA_ARGS__); \ + reg_type, kernel_name, backend, context, layout, meta_kernel_fn, ...) \ + PT_KERNEL_INSTANTIATION(meta_kernel_fn, backend, context, __VA_ARGS__); \ static void __PT_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout( \ const ::phi::KernelKey& kernel_key, ::phi::Kernel* kernel); \ PT_KERNEL_REGISTRAR_INIT( \ + reg_type, \ kernel_name, \ backend, \ + context, \ layout, \ &__PT_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout, \ meta_kernel_fn, \ @@ -255,12 +296,14 @@ struct KernelRegistrar { * And msvc can work without template instantiation */ #define _PT_REGISTER_2TA_KERNEL( \ - kernel_name, backend, layout, meta_kernel_fn, ...) \ + reg_type, kernel_name, backend, context, layout, meta_kernel_fn, ...) \ static void __PT_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout( \ const ::phi::KernelKey& kernel_key, ::phi::Kernel* kernel); \ PT_EXPAND(PT_KERNEL_REGISTRAR_INIT( \ + reg_type, \ kernel_name, \ backend, \ + context, \ layout, \ &__PT_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout, \ meta_kernel_fn, \ @@ -269,82 +312,119 @@ struct KernelRegistrar { const ::phi::KernelKey& kernel_key, ::phi::Kernel* kernel) #endif -#define PT_KERNEL_INSTANTIATION(meta_kernel_fn, backend, ...) \ - _PT_KERNEL_INSTANTIATION( \ - PT_NARGS(__VA_ARGS__), meta_kernel_fn, backend, __VA_ARGS__) - -#define _PT_KERNEL_INSTANTIATION(N, meta_kernel_fn, backend, ...) \ - PT_CONCATENATE(_PT_KERNEL_INSTANTIATION_, N) \ - (meta_kernel_fn, backend, __VA_ARGS__) - -#define _PT_KERNEL_INSTANTIATION_1(meta_kernel_fn, backend, cpp_dtype) \ - template decltype(meta_kernel_fn) \ - meta_kernel_fn -#define _PT_KERNEL_INSTANTIATION_2(meta_kernel_fn, backend, cpp_dtype, ...) \ - template decltype(meta_kernel_fn) \ - meta_kernel_fn; \ - PT_EXPAND(_PT_KERNEL_INSTANTIATION_1(meta_kernel_fn, backend, __VA_ARGS__)) -#define _PT_KERNEL_INSTANTIATION_3(meta_kernel_fn, backend, cpp_dtype, ...) \ - template decltype(meta_kernel_fn) \ - meta_kernel_fn; \ - PT_EXPAND(_PT_KERNEL_INSTANTIATION_2(meta_kernel_fn, backend, __VA_ARGS__)) -#define _PT_KERNEL_INSTANTIATION_4(meta_kernel_fn, backend, cpp_dtype, ...) \ - template decltype(meta_kernel_fn) \ - meta_kernel_fn; \ - PT_EXPAND(_PT_KERNEL_INSTANTIATION_3(meta_kernel_fn, backend, __VA_ARGS__)) -#define _PT_KERNEL_INSTANTIATION_5(meta_kernel_fn, backend, cpp_dtype, ...) \ - template decltype(meta_kernel_fn) \ - meta_kernel_fn; \ - PT_EXPAND(_PT_KERNEL_INSTANTIATION_4(meta_kernel_fn, backend, __VA_ARGS__)) -#define _PT_KERNEL_INSTANTIATION_6(meta_kernel_fn, backend, cpp_dtype, ...) \ - template decltype(meta_kernel_fn) \ - meta_kernel_fn; \ - PT_EXPAND(_PT_KERNEL_INSTANTIATION_5(meta_kernel_fn, backend, __VA_ARGS__)) -#define _PT_KERNEL_INSTANTIATION_7(meta_kernel_fn, backend, cpp_dtype, ...) \ - template decltype(meta_kernel_fn) \ - meta_kernel_fn; \ - PT_EXPAND(_PT_KERNEL_INSTANTIATION_6(meta_kernel_fn, backend, __VA_ARGS__)) -#define _PT_KERNEL_INSTANTIATION_8(meta_kernel_fn, backend, cpp_dtype, ...) \ - template decltype(meta_kernel_fn) \ - meta_kernel_fn; \ - PT_EXPAND(_PT_KERNEL_INSTANTIATION_7(meta_kernel_fn, backend, __VA_ARGS__)) -#define _PT_KERNEL_INSTANTIATION_9(meta_kernel_fn, backend, cpp_dtype, ...) \ - template decltype(meta_kernel_fn) \ - meta_kernel_fn; \ - PT_EXPAND(_PT_KERNEL_INSTANTIATION_8(meta_kernel_fn, backend, __VA_ARGS__)) -#define _PT_KERNEL_INSTANTIATION_10(meta_kernel_fn, backend, cpp_dtype, ...) \ - template decltype(meta_kernel_fn) \ - meta_kernel_fn; \ - PT_EXPAND(_PT_KERNEL_INSTANTIATION_9(meta_kernel_fn, backend, __VA_ARGS__)) -#define _PT_KERNEL_INSTANTIATION_11(meta_kernel_fn, backend, cpp_dtype, ...) \ - template decltype(meta_kernel_fn) \ - meta_kernel_fn; \ - PT_EXPAND(_PT_KERNEL_INSTANTIATION_10(meta_kernel_fn, backend, __VA_ARGS__)) -#define _PT_KERNEL_INSTANTIATION_12(meta_kernel_fn, backend, cpp_dtype, ...) \ - template decltype(meta_kernel_fn) \ - meta_kernel_fn; \ - PT_EXPAND(_PT_KERNEL_INSTANTIATION_11(meta_kernel_fn, backend, __VA_ARGS__)) -#define _PT_KERNEL_INSTANTIATION_13(meta_kernel_fn, backend, cpp_dtype, ...) \ - template decltype(meta_kernel_fn) \ - meta_kernel_fn; \ - PT_EXPAND(_PT_KERNEL_INSTANTIATION_12(meta_kernel_fn, backend, __VA_ARGS__)) -#define _PT_KERNEL_INSTANTIATION_14(meta_kernel_fn, backend, cpp_dtype, ...) \ - template decltype(meta_kernel_fn) \ - meta_kernel_fn; \ - PT_EXPAND(_PT_KERNEL_INSTANTIATION_13(meta_kernel_fn, backend, __VA_ARGS__)) -#define _PT_KERNEL_INSTANTIATION_15(meta_kernel_fn, backend, cpp_dtype, ...) \ - template decltype(meta_kernel_fn) \ - meta_kernel_fn; \ - PT_EXPAND(_PT_KERNEL_INSTANTIATION_14(meta_kernel_fn, backend, __VA_ARGS__)) - -#define PT_KERNEL_REGISTRAR_INIT( \ - kernel_name, backend, layout, args_def_fn, meta_kernel_fn, ...) \ - PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT(PT_NARGS(__VA_ARGS__), \ - kernel_name, \ - backend, \ - layout, \ - args_def_fn, \ - meta_kernel_fn, \ +#define PT_KERNEL_INSTANTIATION(meta_kernel_fn, backend, context, ...) \ + _PT_KERNEL_INSTANTIATION( \ + PT_NARGS(__VA_ARGS__), meta_kernel_fn, backend, context, __VA_ARGS__) + +#define _PT_KERNEL_INSTANTIATION(N, meta_kernel_fn, backend, context, ...) \ + PT_CONCATENATE(_PT_KERNEL_INSTANTIATION_, N) \ + (meta_kernel_fn, backend, context, __VA_ARGS__) + +#define _PT_KERNEL_INSTANTIATION_1( \ + meta_kernel_fn, backend, context, cpp_dtype) \ + template decltype( \ + meta_kernel_fn) meta_kernel_fn +#define _PT_KERNEL_INSTANTIATION_2( \ + meta_kernel_fn, backend, context, cpp_dtype, ...) \ + template decltype( \ + meta_kernel_fn) meta_kernel_fn; \ + PT_EXPAND(_PT_KERNEL_INSTANTIATION_1( \ + meta_kernel_fn, backend, context, __VA_ARGS__)) +#define _PT_KERNEL_INSTANTIATION_3( \ + meta_kernel_fn, backend, context, cpp_dtype, ...) \ + template decltype( \ + meta_kernel_fn) meta_kernel_fn; \ + PT_EXPAND(_PT_KERNEL_INSTANTIATION_2( \ + meta_kernel_fn, backend, context, __VA_ARGS__)) +#define _PT_KERNEL_INSTANTIATION_4( \ + meta_kernel_fn, backend, context, cpp_dtype, ...) \ + template decltype( \ + meta_kernel_fn) meta_kernel_fn; \ + PT_EXPAND(_PT_KERNEL_INSTANTIATION_3( \ + meta_kernel_fn, backend, context, __VA_ARGS__)) +#define _PT_KERNEL_INSTANTIATION_5( \ + meta_kernel_fn, backend, context, cpp_dtype, ...) \ + template decltype( \ + meta_kernel_fn) meta_kernel_fn; \ + PT_EXPAND(_PT_KERNEL_INSTANTIATION_4( \ + meta_kernel_fn, backend, context, __VA_ARGS__)) +#define _PT_KERNEL_INSTANTIATION_6( \ + meta_kernel_fn, backend, context, cpp_dtype, ...) \ + template decltype( \ + meta_kernel_fn) meta_kernel_fn; \ + PT_EXPAND(_PT_KERNEL_INSTANTIATION_5( \ + meta_kernel_fn, backend, context, __VA_ARGS__)) +#define _PT_KERNEL_INSTANTIATION_7( \ + meta_kernel_fn, backend, context, cpp_dtype, ...) \ + template decltype( \ + meta_kernel_fn) meta_kernel_fn; \ + PT_EXPAND(_PT_KERNEL_INSTANTIATION_6( \ + meta_kernel_fn, backend, context, __VA_ARGS__)) +#define _PT_KERNEL_INSTANTIATION_8( \ + meta_kernel_fn, backend, context, cpp_dtype, ...) \ + template decltype( \ + meta_kernel_fn) meta_kernel_fn; \ + PT_EXPAND(_PT_KERNEL_INSTANTIATION_7( \ + meta_kernel_fn, backend, context, __VA_ARGS__)) +#define _PT_KERNEL_INSTANTIATION_9( \ + meta_kernel_fn, backend, context, cpp_dtype, ...) \ + template decltype( \ + meta_kernel_fn) meta_kernel_fn; \ + PT_EXPAND(_PT_KERNEL_INSTANTIATION_8( \ + meta_kernel_fn, backend, context, __VA_ARGS__)) +#define _PT_KERNEL_INSTANTIATION_10( \ + meta_kernel_fn, backend, context, cpp_dtype, ...) \ + template decltype( \ + meta_kernel_fn) meta_kernel_fn; \ + PT_EXPAND(_PT_KERNEL_INSTANTIATION_9( \ + meta_kernel_fn, backend, context, __VA_ARGS__)) +#define _PT_KERNEL_INSTANTIATION_11( \ + meta_kernel_fn, backend, context, cpp_dtype, ...) \ + template decltype( \ + meta_kernel_fn) meta_kernel_fn; \ + PT_EXPAND(_PT_KERNEL_INSTANTIATION_10( \ + meta_kernel_fn, backend, context, __VA_ARGS__)) +#define _PT_KERNEL_INSTANTIATION_12( \ + meta_kernel_fn, backend, context, cpp_dtype, ...) \ + template decltype( \ + meta_kernel_fn) meta_kernel_fn; \ + PT_EXPAND(_PT_KERNEL_INSTANTIATION_11( \ + meta_kernel_fn, backend, context, __VA_ARGS__)) +#define _PT_KERNEL_INSTANTIATION_13( \ + meta_kernel_fn, backend, context, cpp_dtype, ...) \ + template decltype( \ + meta_kernel_fn) meta_kernel_fn; \ + PT_EXPAND(_PT_KERNEL_INSTANTIATION_12( \ + meta_kernel_fn, backend, context, __VA_ARGS__)) +#define _PT_KERNEL_INSTANTIATION_14( \ + meta_kernel_fn, backend, context, cpp_dtype, ...) \ + template decltype( \ + meta_kernel_fn) meta_kernel_fn; \ + PT_EXPAND(_PT_KERNEL_INSTANTIATION_13( \ + meta_kernel_fn, backend, context, __VA_ARGS__)) +#define _PT_KERNEL_INSTANTIATION_15( \ + meta_kernel_fn, backend, context, cpp_dtype, ...) \ + template decltype( \ + meta_kernel_fn) meta_kernel_fn; \ + PT_EXPAND(_PT_KERNEL_INSTANTIATION_14( \ + meta_kernel_fn, backend, context, __VA_ARGS__)) + +#define PT_KERNEL_REGISTRAR_INIT(reg_type, \ + kernel_name, \ + backend, \ + context, \ + layout, \ + args_def_fn, \ + meta_kernel_fn, \ + ...) \ + PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT(PT_NARGS(__VA_ARGS__), \ + reg_type, \ + kernel_name, \ + backend, \ + context, \ + layout, \ + args_def_fn, \ + meta_kernel_fn, \ __VA_ARGS__)) // clang-format off @@ -352,15 +432,19 @@ struct KernelRegistrar { /* The =pre-commit always treats this macro into the wrong format, and multi-line macros cannot be skipped with NOLINT.*/ #define _PT_KERNEL_REGISTRAR_INIT(N, \ + reg_type, \ kernel_name, \ backend, \ + context, \ layout, \ args_def_fn, \ meta_kernel_fn, \ ...) \ PT_EXPAND(PT_CONCATENATE(_PT_KERNEL_REGISTRAR_INIT_, N) ( \ + reg_type, \ kernel_name, \ backend, \ + context, \ layout, \ PT_ID, \ args_def_fn, \ @@ -369,413 +453,492 @@ struct KernelRegistrar { // clang-format on -#define _PT_KERNEL_REGISTRAR_INIT_1(kernel_name, \ - backend, \ - layout, \ - registrar_id, \ - args_def_fn, \ - meta_kernel_fn, \ - cpp_dtype) \ - static const ::phi::KernelRegistrar PT_CONCATENATE( \ - __reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \ - #kernel_name, \ - BACKEND(backend), \ - DATALAYOUT(layout), \ - ::paddle::experimental::CppTypeToDataType::Type(), \ - ::phi::KernelArgsParseFunctor)>::Parse, \ - args_def_fn, \ - PT_KERNEL(meta_kernel_fn), \ - PT_VARIADIC_KERNEL(meta_kernel_fn)); \ +#define _PT_KERNEL_REGISTRAR_INIT_1(reg_type, \ + kernel_name, \ + backend, \ + context, \ + layout, \ + registrar_id, \ + args_def_fn, \ + meta_kernel_fn, \ + cpp_dtype) \ + static const ::phi::KernelRegistrar PT_CONCATENATE( \ + __reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \ + reg_type, \ + #kernel_name, \ + #backend, \ + DATALAYOUT(layout), \ + ::paddle::experimental::CppTypeToDataType::Type(), \ + ::phi::KernelArgsParseFunctor)>::Parse, \ + args_def_fn, \ + PT_KERNEL(meta_kernel_fn), \ + PT_VARIADIC_KERNEL(meta_kernel_fn)); \ int TouchKernelSymbolFor_##kernel_name##_##backend##_##layout() { return 0; } -#define _PT_KERNEL_REGISTRAR_INIT_2(kernel_name, \ - backend, \ - layout, \ - registrar_id, \ - args_def_fn, \ - meta_kernel_fn, \ - cpp_dtype, \ - ...) \ - static const ::phi::KernelRegistrar PT_CONCATENATE( \ - __reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \ - #kernel_name, \ - BACKEND(backend), \ - DATALAYOUT(layout), \ - ::paddle::experimental::CppTypeToDataType::Type(), \ - ::phi::KernelArgsParseFunctor)>::Parse, \ - args_def_fn, \ - PT_KERNEL(meta_kernel_fn), \ - PT_VARIADIC_KERNEL(meta_kernel_fn)); \ - PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_1(kernel_name, \ - backend, \ - layout, \ - PT_ID, \ - args_def_fn, \ - meta_kernel_fn, \ +#define _PT_KERNEL_REGISTRAR_INIT_2(reg_type, \ + kernel_name, \ + backend, \ + context, \ + layout, \ + registrar_id, \ + args_def_fn, \ + meta_kernel_fn, \ + cpp_dtype, \ + ...) \ + static const ::phi::KernelRegistrar PT_CONCATENATE( \ + __reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \ + reg_type, \ + #kernel_name, \ + #backend, \ + DATALAYOUT(layout), \ + ::paddle::experimental::CppTypeToDataType::Type(), \ + ::phi::KernelArgsParseFunctor)>::Parse, \ + args_def_fn, \ + PT_KERNEL(meta_kernel_fn), \ + PT_VARIADIC_KERNEL(meta_kernel_fn)); \ + PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_1(reg_type, \ + kernel_name, \ + backend, \ + context, \ + layout, \ + PT_ID, \ + args_def_fn, \ + meta_kernel_fn, \ __VA_ARGS__)) -#define _PT_KERNEL_REGISTRAR_INIT_3(kernel_name, \ - backend, \ - layout, \ - registrar_id, \ - args_def_fn, \ - meta_kernel_fn, \ - cpp_dtype, \ - ...) \ - static const ::phi::KernelRegistrar PT_CONCATENATE( \ - __reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \ - #kernel_name, \ - BACKEND(backend), \ - DATALAYOUT(layout), \ - ::paddle::experimental::CppTypeToDataType::Type(), \ - ::phi::KernelArgsParseFunctor)>::Parse, \ - args_def_fn, \ - PT_KERNEL(meta_kernel_fn), \ - PT_VARIADIC_KERNEL(meta_kernel_fn)); \ - PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_2(kernel_name, \ - backend, \ - layout, \ - PT_ID, \ - args_def_fn, \ - meta_kernel_fn, \ +#define _PT_KERNEL_REGISTRAR_INIT_3(reg_type, \ + kernel_name, \ + backend, \ + context, \ + layout, \ + registrar_id, \ + args_def_fn, \ + meta_kernel_fn, \ + cpp_dtype, \ + ...) \ + static const ::phi::KernelRegistrar PT_CONCATENATE( \ + __reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \ + reg_type, \ + #kernel_name, \ + #backend, \ + DATALAYOUT(layout), \ + ::paddle::experimental::CppTypeToDataType::Type(), \ + ::phi::KernelArgsParseFunctor)>::Parse, \ + args_def_fn, \ + PT_KERNEL(meta_kernel_fn), \ + PT_VARIADIC_KERNEL(meta_kernel_fn)); \ + PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_2(reg_type, \ + kernel_name, \ + backend, \ + context, \ + layout, \ + PT_ID, \ + args_def_fn, \ + meta_kernel_fn, \ __VA_ARGS__)) -#define _PT_KERNEL_REGISTRAR_INIT_4(kernel_name, \ - backend, \ - layout, \ - registrar_id, \ - args_def_fn, \ - meta_kernel_fn, \ - cpp_dtype, \ - ...) \ - static const ::phi::KernelRegistrar PT_CONCATENATE( \ - __reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \ - #kernel_name, \ - BACKEND(backend), \ - DATALAYOUT(layout), \ - ::paddle::experimental::CppTypeToDataType::Type(), \ - ::phi::KernelArgsParseFunctor)>::Parse, \ - args_def_fn, \ - PT_KERNEL(meta_kernel_fn), \ - PT_VARIADIC_KERNEL(meta_kernel_fn)); \ - PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_3(kernel_name, \ - backend, \ - layout, \ - PT_ID, \ - args_def_fn, \ - meta_kernel_fn, \ +#define _PT_KERNEL_REGISTRAR_INIT_4(reg_type, \ + kernel_name, \ + backend, \ + context, \ + layout, \ + registrar_id, \ + args_def_fn, \ + meta_kernel_fn, \ + cpp_dtype, \ + ...) \ + static const ::phi::KernelRegistrar PT_CONCATENATE( \ + __reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \ + reg_type, \ + #kernel_name, \ + #backend, \ + DATALAYOUT(layout), \ + ::paddle::experimental::CppTypeToDataType::Type(), \ + ::phi::KernelArgsParseFunctor)>::Parse, \ + args_def_fn, \ + PT_KERNEL(meta_kernel_fn), \ + PT_VARIADIC_KERNEL(meta_kernel_fn)); \ + PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_3(reg_type, \ + kernel_name, \ + backend, \ + context, \ + layout, \ + PT_ID, \ + args_def_fn, \ + meta_kernel_fn, \ __VA_ARGS__)) -#define _PT_KERNEL_REGISTRAR_INIT_5(kernel_name, \ - backend, \ - layout, \ - registrar_id, \ - args_def_fn, \ - meta_kernel_fn, \ - cpp_dtype, \ - ...) \ - static const ::phi::KernelRegistrar PT_CONCATENATE( \ - __reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \ - #kernel_name, \ - BACKEND(backend), \ - DATALAYOUT(layout), \ - ::paddle::experimental::CppTypeToDataType::Type(), \ - ::phi::KernelArgsParseFunctor)>::Parse, \ - args_def_fn, \ - PT_KERNEL(meta_kernel_fn), \ - PT_VARIADIC_KERNEL(meta_kernel_fn)); \ - PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_4(kernel_name, \ - backend, \ - layout, \ - PT_ID, \ - args_def_fn, \ - meta_kernel_fn, \ +#define _PT_KERNEL_REGISTRAR_INIT_5(reg_type, \ + kernel_name, \ + backend, \ + context, \ + layout, \ + registrar_id, \ + args_def_fn, \ + meta_kernel_fn, \ + cpp_dtype, \ + ...) \ + static const ::phi::KernelRegistrar PT_CONCATENATE( \ + __reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \ + reg_type, \ + #kernel_name, \ + #backend, \ + DATALAYOUT(layout), \ + ::paddle::experimental::CppTypeToDataType::Type(), \ + ::phi::KernelArgsParseFunctor)>::Parse, \ + args_def_fn, \ + PT_KERNEL(meta_kernel_fn), \ + PT_VARIADIC_KERNEL(meta_kernel_fn)); \ + PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_4(reg_type, \ + kernel_name, \ + backend, \ + context, \ + layout, \ + PT_ID, \ + args_def_fn, \ + meta_kernel_fn, \ __VA_ARGS__)) -#define _PT_KERNEL_REGISTRAR_INIT_6(kernel_name, \ - backend, \ - layout, \ - registrar_id, \ - args_def_fn, \ - meta_kernel_fn, \ - cpp_dtype, \ - ...) \ - static const ::phi::KernelRegistrar PT_CONCATENATE( \ - __reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \ - #kernel_name, \ - BACKEND(backend), \ - DATALAYOUT(layout), \ - ::paddle::experimental::CppTypeToDataType::Type(), \ - ::phi::KernelArgsParseFunctor)>::Parse, \ - args_def_fn, \ - PT_KERNEL(meta_kernel_fn), \ - PT_VARIADIC_KERNEL(meta_kernel_fn)); \ - PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_5(kernel_name, \ - backend, \ - layout, \ - PT_ID, \ - args_def_fn, \ - meta_kernel_fn, \ +#define _PT_KERNEL_REGISTRAR_INIT_6(reg_type, \ + kernel_name, \ + backend, \ + context, \ + layout, \ + registrar_id, \ + args_def_fn, \ + meta_kernel_fn, \ + cpp_dtype, \ + ...) \ + static const ::phi::KernelRegistrar PT_CONCATENATE( \ + __reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \ + reg_type, \ + #kernel_name, \ + #backend, \ + DATALAYOUT(layout), \ + ::paddle::experimental::CppTypeToDataType::Type(), \ + ::phi::KernelArgsParseFunctor)>::Parse, \ + args_def_fn, \ + PT_KERNEL(meta_kernel_fn), \ + PT_VARIADIC_KERNEL(meta_kernel_fn)); \ + PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_5(reg_type, \ + kernel_name, \ + backend, \ + context, \ + layout, \ + PT_ID, \ + args_def_fn, \ + meta_kernel_fn, \ __VA_ARGS__)) -#define _PT_KERNEL_REGISTRAR_INIT_7(kernel_name, \ - backend, \ - layout, \ - registrar_id, \ - args_def_fn, \ - meta_kernel_fn, \ - cpp_dtype, \ - ...) \ - static const ::phi::KernelRegistrar PT_CONCATENATE( \ - __reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \ - #kernel_name, \ - BACKEND(backend), \ - DATALAYOUT(layout), \ - ::paddle::experimental::CppTypeToDataType::Type(), \ - ::phi::KernelArgsParseFunctor)>::Parse, \ - args_def_fn, \ - PT_KERNEL(meta_kernel_fn), \ - PT_VARIADIC_KERNEL(meta_kernel_fn)); \ - PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_6(kernel_name, \ - backend, \ - layout, \ - PT_ID, \ - args_def_fn, \ - meta_kernel_fn, \ +#define _PT_KERNEL_REGISTRAR_INIT_7(reg_type, \ + kernel_name, \ + backend, \ + context, \ + layout, \ + registrar_id, \ + args_def_fn, \ + meta_kernel_fn, \ + cpp_dtype, \ + ...) \ + static const ::phi::KernelRegistrar PT_CONCATENATE( \ + __reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \ + reg_type, \ + #kernel_name, \ + #backend, \ + DATALAYOUT(layout), \ + ::paddle::experimental::CppTypeToDataType::Type(), \ + ::phi::KernelArgsParseFunctor)>::Parse, \ + args_def_fn, \ + PT_KERNEL(meta_kernel_fn), \ + PT_VARIADIC_KERNEL(meta_kernel_fn)); \ + PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_6(reg_type, \ + kernel_name, \ + backend, \ + context, \ + layout, \ + PT_ID, \ + args_def_fn, \ + meta_kernel_fn, \ __VA_ARGS__)) -#define _PT_KERNEL_REGISTRAR_INIT_8(kernel_name, \ - backend, \ - layout, \ - registrar_id, \ - args_def_fn, \ - meta_kernel_fn, \ - cpp_dtype, \ - ...) \ - static const ::phi::KernelRegistrar PT_CONCATENATE( \ - __reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \ - #kernel_name, \ - BACKEND(backend), \ - DATALAYOUT(layout), \ - ::paddle::experimental::CppTypeToDataType::Type(), \ - ::phi::KernelArgsParseFunctor)>::Parse, \ - args_def_fn, \ - PT_KERNEL(meta_kernel_fn), \ - PT_VARIADIC_KERNEL(meta_kernel_fn)); \ - PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_7(kernel_name, \ - backend, \ - layout, \ - PT_ID, \ - args_def_fn, \ - meta_kernel_fn, \ +#define _PT_KERNEL_REGISTRAR_INIT_8(reg_type, \ + kernel_name, \ + backend, \ + context, \ + layout, \ + registrar_id, \ + args_def_fn, \ + meta_kernel_fn, \ + cpp_dtype, \ + ...) \ + static const ::phi::KernelRegistrar PT_CONCATENATE( \ + __reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \ + reg_type, \ + #kernel_name, \ + #backend, \ + DATALAYOUT(layout), \ + ::paddle::experimental::CppTypeToDataType::Type(), \ + ::phi::KernelArgsParseFunctor)>::Parse, \ + args_def_fn, \ + PT_KERNEL(meta_kernel_fn), \ + PT_VARIADIC_KERNEL(meta_kernel_fn)); \ + PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_7(reg_type, \ + kernel_name, \ + backend, \ + context, \ + layout, \ + PT_ID, \ + args_def_fn, \ + meta_kernel_fn, \ __VA_ARGS__)) -#define _PT_KERNEL_REGISTRAR_INIT_9(kernel_name, \ - backend, \ - layout, \ - registrar_id, \ - args_def_fn, \ - meta_kernel_fn, \ - cpp_dtype, \ - ...) \ - static const ::phi::KernelRegistrar PT_CONCATENATE( \ - __reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \ - #kernel_name, \ - BACKEND(backend), \ - DATALAYOUT(layout), \ - ::paddle::experimental::CppTypeToDataType::Type(), \ - ::phi::KernelArgsParseFunctor)>::Parse, \ - args_def_fn, \ - PT_KERNEL(meta_kernel_fn), \ - PT_VARIADIC_KERNEL(meta_kernel_fn)); \ - PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_8(kernel_name, \ - backend, \ - layout, \ - PT_ID, \ - args_def_fn, \ - meta_kernel_fn, \ +#define _PT_KERNEL_REGISTRAR_INIT_9(reg_type, \ + kernel_name, \ + backend, \ + context, \ + layout, \ + registrar_id, \ + args_def_fn, \ + meta_kernel_fn, \ + cpp_dtype, \ + ...) \ + static const ::phi::KernelRegistrar PT_CONCATENATE( \ + __reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \ + reg_type, \ + #kernel_name, \ + #backend, \ + DATALAYOUT(layout), \ + ::paddle::experimental::CppTypeToDataType::Type(), \ + ::phi::KernelArgsParseFunctor)>::Parse, \ + args_def_fn, \ + PT_KERNEL(meta_kernel_fn), \ + PT_VARIADIC_KERNEL(meta_kernel_fn)); \ + PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_8(reg_type, \ + kernel_name, \ + backend, \ + context, \ + layout, \ + PT_ID, \ + args_def_fn, \ + meta_kernel_fn, \ __VA_ARGS__)) -#define _PT_KERNEL_REGISTRAR_INIT_10(kernel_name, \ - backend, \ - layout, \ - registrar_id, \ - args_def_fn, \ - meta_kernel_fn, \ - cpp_dtype, \ - ...) \ - static const ::phi::KernelRegistrar PT_CONCATENATE( \ - __reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \ - #kernel_name, \ - BACKEND(backend), \ - DATALAYOUT(layout), \ - ::paddle::experimental::CppTypeToDataType::Type(), \ - ::phi::KernelArgsParseFunctor)>::Parse, \ - args_def_fn, \ - PT_KERNEL(meta_kernel_fn), \ - PT_VARIADIC_KERNEL(meta_kernel_fn)); \ - PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_9(kernel_name, \ - backend, \ - layout, \ - PT_ID, \ - args_def_fn, \ - meta_kernel_fn, \ +#define _PT_KERNEL_REGISTRAR_INIT_10(reg_type, \ + kernel_name, \ + backend, \ + context, \ + layout, \ + registrar_id, \ + args_def_fn, \ + meta_kernel_fn, \ + cpp_dtype, \ + ...) \ + static const ::phi::KernelRegistrar PT_CONCATENATE( \ + __reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \ + reg_type, \ + #kernel_name, \ + #backend, \ + DATALAYOUT(layout), \ + ::paddle::experimental::CppTypeToDataType::Type(), \ + ::phi::KernelArgsParseFunctor)>::Parse, \ + args_def_fn, \ + PT_KERNEL(meta_kernel_fn), \ + PT_VARIADIC_KERNEL(meta_kernel_fn)); \ + PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_9(reg_type, \ + kernel_name, \ + backend, \ + context, \ + layout, \ + PT_ID, \ + args_def_fn, \ + meta_kernel_fn, \ __VA_ARGS__)) -#define _PT_KERNEL_REGISTRAR_INIT_11(kernel_name, \ - backend, \ - layout, \ - registrar_id, \ - args_def_fn, \ - meta_kernel_fn, \ - cpp_dtype, \ - ...) \ - static const ::phi::KernelRegistrar PT_CONCATENATE( \ - __reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \ - #kernel_name, \ - BACKEND(backend), \ - DATALAYOUT(layout), \ - ::paddle::experimental::CppTypeToDataType::Type(), \ - ::phi::KernelArgsParseFunctor)>::Parse, \ - args_def_fn, \ - PT_KERNEL(meta_kernel_fn), \ - PT_VARIADIC_KERNEL(meta_kernel_fn)); \ - PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_10(kernel_name, \ - backend, \ - layout, \ - PT_ID, \ - args_def_fn, \ - meta_kernel_fn, \ +#define _PT_KERNEL_REGISTRAR_INIT_11(reg_type, \ + kernel_name, \ + backend, \ + context, \ + layout, \ + registrar_id, \ + args_def_fn, \ + meta_kernel_fn, \ + cpp_dtype, \ + ...) \ + static const ::phi::KernelRegistrar PT_CONCATENATE( \ + __reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \ + reg_type, \ + #kernel_name, \ + #backend, \ + DATALAYOUT(layout), \ + ::paddle::experimental::CppTypeToDataType::Type(), \ + ::phi::KernelArgsParseFunctor)>::Parse, \ + args_def_fn, \ + PT_KERNEL(meta_kernel_fn), \ + PT_VARIADIC_KERNEL(meta_kernel_fn)); \ + PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_10(reg_type, \ + kernel_name, \ + backend, \ + context, \ + layout, \ + PT_ID, \ + args_def_fn, \ + meta_kernel_fn, \ __VA_ARGS__)) -#define _PT_KERNEL_REGISTRAR_INIT_12(kernel_name, \ - backend, \ - layout, \ - registrar_id, \ - args_def_fn, \ - meta_kernel_fn, \ - cpp_dtype, \ - ...) \ - static const ::phi::KernelRegistrar PT_CONCATENATE( \ - __reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \ - #kernel_name, \ - BACKEND(backend), \ - DATALAYOUT(layout), \ - ::paddle::experimental::CppTypeToDataType::Type(), \ - ::phi::KernelArgsParseFunctor)>::Parse, \ - args_def_fn, \ - PT_KERNEL(meta_kernel_fn), \ - PT_VARIADIC_KERNEL(meta_kernel_fn)); \ - PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_11(kernel_name, \ - backend, \ - layout, \ - PT_ID, \ - args_def_fn, \ - meta_kernel_fn, \ +#define _PT_KERNEL_REGISTRAR_INIT_12(reg_type, \ + kernel_name, \ + backend, \ + context, \ + layout, \ + registrar_id, \ + args_def_fn, \ + meta_kernel_fn, \ + cpp_dtype, \ + ...) \ + static const ::phi::KernelRegistrar PT_CONCATENATE( \ + __reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \ + reg_type, \ + #kernel_name, \ + #backend, \ + DATALAYOUT(layout), \ + ::paddle::experimental::CppTypeToDataType::Type(), \ + ::phi::KernelArgsParseFunctor)>::Parse, \ + args_def_fn, \ + PT_KERNEL(meta_kernel_fn), \ + PT_VARIADIC_KERNEL(meta_kernel_fn)); \ + PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_11(reg_type, \ + kernel_name, \ + backend, \ + context, \ + layout, \ + PT_ID, \ + args_def_fn, \ + meta_kernel_fn, \ __VA_ARGS__)) -#define _PT_KERNEL_REGISTRAR_INIT_13(kernel_name, \ - backend, \ - layout, \ - registrar_id, \ - args_def_fn, \ - meta_kernel_fn, \ - cpp_dtype, \ - ...) \ - static const ::phi::KernelRegistrar PT_CONCATENATE( \ - __reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \ - #kernel_name, \ - BACKEND(backend), \ - DATALAYOUT(layout), \ - ::paddle::experimental::CppTypeToDataType::Type(), \ - ::phi::KernelArgsParseFunctor)>::Parse, \ - args_def_fn, \ - PT_KERNEL(meta_kernel_fn), \ - PT_VARIADIC_KERNEL(meta_kernel_fn)); \ - PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_12(kernel_name, \ - backend, \ - layout, \ - PT_ID, \ - args_def_fn, \ - meta_kernel_fn, \ +#define _PT_KERNEL_REGISTRAR_INIT_13(reg_type, \ + kernel_name, \ + backend, \ + context, \ + layout, \ + registrar_id, \ + args_def_fn, \ + meta_kernel_fn, \ + cpp_dtype, \ + ...) \ + static const ::phi::KernelRegistrar PT_CONCATENATE( \ + __reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \ + reg_type, \ + #kernel_name, \ + #backend, \ + DATALAYOUT(layout), \ + ::paddle::experimental::CppTypeToDataType::Type(), \ + ::phi::KernelArgsParseFunctor)>::Parse, \ + args_def_fn, \ + PT_KERNEL(meta_kernel_fn), \ + PT_VARIADIC_KERNEL(meta_kernel_fn)); \ + PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_12(reg_type, \ + kernel_name, \ + backend, \ + context, \ + layout, \ + PT_ID, \ + args_def_fn, \ + meta_kernel_fn, \ __VA_ARGS__)) -#define _PT_KERNEL_REGISTRAR_INIT_14(kernel_name, \ - backend, \ - layout, \ - registrar_id, \ - args_def_fn, \ - meta_kernel_fn, \ - cpp_dtype, \ - ...) \ - static const ::phi::KernelRegistrar PT_CONCATENATE( \ - __reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \ - #kernel_name, \ - BACKEND(backend), \ - DATALAYOUT(layout), \ - ::paddle::experimental::CppTypeToDataType::Type(), \ - ::phi::KernelArgsParseFunctor)>::Parse, \ - args_def_fn, \ - PT_KERNEL(meta_kernel_fn), \ - PT_VARIADIC_KERNEL(meta_kernel_fn)); \ - PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_13(kernel_name, \ - backend, \ - layout, \ - PT_ID, \ - args_def_fn, \ - meta_kernel_fn, \ +#define _PT_KERNEL_REGISTRAR_INIT_14(reg_type, \ + kernel_name, \ + backend, \ + context, \ + layout, \ + registrar_id, \ + args_def_fn, \ + meta_kernel_fn, \ + cpp_dtype, \ + ...) \ + static const ::phi::KernelRegistrar PT_CONCATENATE( \ + __reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \ + reg_type, \ + #kernel_name, \ + #backend, \ + DATALAYOUT(layout), \ + ::paddle::experimental::CppTypeToDataType::Type(), \ + ::phi::KernelArgsParseFunctor)>::Parse, \ + args_def_fn, \ + PT_KERNEL(meta_kernel_fn), \ + PT_VARIADIC_KERNEL(meta_kernel_fn)); \ + PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_13(reg_type, \ + kernel_name, \ + backend, \ + context, \ + layout, \ + PT_ID, \ + args_def_fn, \ + meta_kernel_fn, \ __VA_ARGS__)) -#define _PT_KERNEL_REGISTRAR_INIT_15(kernel_name, \ - backend, \ - layout, \ - registrar_id, \ - args_def_fn, \ - meta_kernel_fn, \ - cpp_dtype, \ - ...) \ - static const ::phi::KernelRegistrar PT_CONCATENATE( \ - __reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \ - #kernel_name, \ - BACKEND(backend), \ - DATALAYOUT(layout), \ - ::paddle::experimental::CppTypeToDataType::Type(), \ - ::phi::KernelArgsParseFunctor)>::Parse, \ - args_def_fn, \ - PT_KERNEL(meta_kernel_fn), \ - PT_VARIADIC_KERNEL(meta_kernel_fn)); \ - PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_14(kernel_name, \ - backend, \ - layout, \ - PT_ID, \ - args_def_fn, \ - meta_kernel_fn, \ +#define _PT_KERNEL_REGISTRAR_INIT_15(reg_type, \ + kernel_name, \ + backend, \ + context, \ + layout, \ + registrar_id, \ + args_def_fn, \ + meta_kernel_fn, \ + cpp_dtype, \ + ...) \ + static const ::phi::KernelRegistrar PT_CONCATENATE( \ + __reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \ + reg_type, \ + #kernel_name, \ + #backend, \ + DATALAYOUT(layout), \ + ::paddle::experimental::CppTypeToDataType::Type(), \ + ::phi::KernelArgsParseFunctor)>::Parse, \ + args_def_fn, \ + PT_KERNEL(meta_kernel_fn), \ + PT_VARIADIC_KERNEL(meta_kernel_fn)); \ + PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_14(reg_type, \ + kernel_name, \ + backend, \ + context, \ + layout, \ + PT_ID, \ + args_def_fn, \ + meta_kernel_fn, \ __VA_ARGS__)) - /** PT_REGISTER_GENERAL_KERNEL * * Basic Kernel register marco, used to register a instantiated kernel function * with one template argument. */ -#define PT_REGISTER_GENERAL_KERNEL( \ - kernel_name, backend, layout, kernel_fn, dtype) \ +#define PT_REGISTER_GENERAL_KERNEL( \ + kernel_name, backend, layout, kernel_fn, dtype) \ + _PT_REGISTER_GENERAL_KERNEL( \ + ::phi::RegType::BUILTIN, kernel_name, backend, layout, kernel_fn, dtype) + +#define _PT_REGISTER_GENERAL_KERNEL( \ + reg_type, kernel_name, backend, layout, kernel_fn, dtype) \ PT_STATIC_ASSERT_GLOBAL_NAMESPACE( \ pt_register_no_t_kernel_ns_check_##kernel_name##_##backend##_##layout, \ "PT_REGISTER_NO_TEMPLATE_KERNEL must be called in global namespace."); \ - _PT_REGISTER_GENERAL_KERNEL(kernel_name, backend, layout, kernel_fn, dtype) + __PT_REGISTER_GENERAL_KERNEL( \ + reg_type, kernel_name, backend, layout, kernel_fn, dtype) #ifndef _WIN32 -#define _PT_REGISTER_GENERAL_KERNEL( \ - kernel_name, backend, layout, kernel_fn, dtype) \ +#define __PT_REGISTER_GENERAL_KERNEL( \ + reg_type, kernel_name, backend, layout, kernel_fn, dtype) \ template decltype(kernel_fn) kernel_fn; \ static void __PT_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout( \ const ::phi::KernelKey& kernel_key, ::phi::Kernel* kernel); \ static const ::phi::KernelRegistrar \ __reg_pt_kernel_##kernel_name##_##backend##_##layout( \ + reg_type, \ #kernel_name, \ - BACKEND(backend), \ + #backend, \ DATALAYOUT(layout), \ ::phi::KernelArgsParseFunctor::Parse, \ &__PT_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout, \ @@ -787,14 +950,15 @@ struct KernelRegistrar { void __PT_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout( \ const ::phi::KernelKey& kernel_key, ::phi::Kernel* kernel) #else -#define _PT_REGISTER_GENERAL_KERNEL( \ - kernel_name, backend, layout, kernel_fn, dtype) \ +#define __PT_REGISTER_GENERAL_KERNEL( \ + reg_type, kernel_name, backend, layout, kernel_fn, dtype) \ static void __PT_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout( \ const ::phi::KernelKey& kernel_key, ::phi::Kernel* kernel); \ static const ::phi::KernelRegistrar \ __reg_pt_kernel_##kernel_name##_##backend##_##layout( \ + reg_type, \ #kernel_name, \ - BACKEND(backend), \ + #backend, \ DATALAYOUT(layout), \ ::phi::KernelArgsParseFunctor::Parse, \ &__PT_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout, \ @@ -821,4 +985,33 @@ struct KernelRegistrar { __declare_kernel_symbol_for_##kernel_name##_##backend##_##layout = \ TouchKernelSymbolFor_##kernel_name##_##backend##_##layout() +/** PD_REGISTER_KERNEL + * + * Used to register kernels for built-in backends. + * Support CPU GPU XPU. + */ +#define PD_REGISTER_KERNEL(kernel_name, backend, layout, meta_kernel_fn, ...) \ + _PT_REGISTER_KERNEL(::phi::RegType::PLUGIN, \ + kernel_name, \ + backend, \ + ::phi::backend##Context, \ + layout, \ + meta_kernel_fn, \ + __VA_ARGS__) + +/** PD_REGISTER_CUSTOM_KERNEL + * + * Used to register kernels for plug-in backends. + * Support user-defined backend such as 'Ascend910'. + */ +#define PD_REGISTER_CUSTOM_KERNEL( \ + kernel_name, backend, layout, meta_kernel_fn, ...) \ + _PT_REGISTER_KERNEL(::phi::RegType::PLUGIN, \ + kernel_name, \ + backend, \ + ::phi::CustomContext, \ + layout, \ + meta_kernel_fn, \ + __VA_ARGS__) + } // namespace phi diff --git a/paddle/phi/core/kernel_utils.h b/paddle/phi/core/kernel_utils.h index 8c7d096eab0..862f61b2040 100644 --- a/paddle/phi/core/kernel_utils.h +++ b/paddle/phi/core/kernel_utils.h @@ -15,6 +15,7 @@ #pragma once #include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/backends/custom/custom_context.h" #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/backends/xpu/xpu_context.h" #include "paddle/phi/common/scalar.h" @@ -22,7 +23,9 @@ #include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/enforce.h" #include "paddle/phi/core/kernel_context.h" +#ifndef PADDLE_WITH_CUSTOM_KERNEL #include "paddle/phi/core/selected_rows.h" +#endif #include "paddle/phi/core/sparse_coo_tensor.h" #include "paddle/phi/core/sparse_csr_tensor.h" #include "paddle/phi/core/type_defs.h" @@ -210,13 +213,18 @@ struct KernelImpl { #ifdef PADDLE_WITH_XPU PT_SPECIALIZE_KernelCallHelper_FOR_DEVICE_CONTEXT(XPUContext); #endif +#ifdef PADDLE_WITH_CUSTOM_DEVICE + PT_SPECIALIZE_KernelCallHelper_FOR_DEVICE_CONTEXT(CustomContext); +#endif /* Input Helpers */ PT_SPECIALIZE_KernelCallHelper_FOR_INPUT(DenseTensor); PT_SPECIALIZE_KernelCallHelper_FOR_OPTIONAL_INPUT(DenseTensor); PT_SPECIALIZE_KernelCallHelper_FOR_MULTI_INPUT(DenseTensor); +#ifndef PADDLE_WITH_CUSTOM_KERNEL PT_SPECIALIZE_KernelCallHelper_FOR_INPUT(SelectedRows); +#endif PT_SPECIALIZE_KernelCallHelper_FOR_INPUT(SparseCooTensor); PT_SPECIALIZE_KernelCallHelper_FOR_OPTIONAL_INPUT(SparseCooTensor); @@ -250,7 +258,9 @@ struct KernelImpl { PT_SPECIALIZE_KernelCallHelper_FOR_OUTPUT(DenseTensor); PT_SPECIALIZE_KernelCallHelper_FOR_MULTI_OUTPUT(DenseTensor); +#ifndef PADDLE_WITH_CUSTOM_KERNEL PT_SPECIALIZE_KernelCallHelper_FOR_OUTPUT(SelectedRows); +#endif PT_SPECIALIZE_KernelCallHelper_FOR_OUTPUT(SparseCooTensor); PT_SPECIALIZE_KernelCallHelper_FOR_MULTI_OUTPUT(SparseCooTensor); diff --git a/paddle/phi/core/lod_utils.h b/paddle/phi/core/lod_utils.h index 2b0be4d9342..a5f73b66fb9 100644 --- a/paddle/phi/core/lod_utils.h +++ b/paddle/phi/core/lod_utils.h @@ -15,10 +15,16 @@ #pragma once // See Note [ Why still include the fluid headers? ] +#ifndef PADDLE_WITH_CUSTOM_KERNEL #include "paddle/fluid/framework/mixed_vector.h" +#endif namespace phi { +#ifndef PADDLE_WITH_CUSTOM_KERNEL using LoD = std::vector>; +#else +using LoD = std::vector>; +#endif void AppendLoD(LoD* lod, const LoD& lod_length); diff --git a/paddle/phi/core/tensor_meta.h b/paddle/phi/core/tensor_meta.h index d5e5e2aa001..ede9b43b1f3 100644 --- a/paddle/phi/core/tensor_meta.h +++ b/paddle/phi/core/tensor_meta.h @@ -24,12 +24,18 @@ limitations under the License. */ // Note: mixed_vector include many header now, LoD will be // used on CUDA device? Can we use small_vector here? // @zhanlve: Rollback to original LoD for now +#ifndef PADDLE_WITH_CUSTOM_KERNEL #include "paddle/fluid/framework/mixed_vector.h" +#endif namespace phi { using DDim = phi::DDim; +#ifndef PADDLE_WITH_CUSTOM_KERNEL using LoD = std::vector>; +#else +using LoD = std::vector>; +#endif /// \brief The meta data of dense tensor. Take the structure type /// and use all default operations. /// diff --git a/paddle/phi/core/tensor_utils.h b/paddle/phi/core/tensor_utils.h index 04db7c0877a..676a590ecbc 100644 --- a/paddle/phi/core/tensor_utils.h +++ b/paddle/phi/core/tensor_utils.h @@ -31,25 +31,25 @@ class DenseTensorUtils { size_t bytes = tensor.numel() * SizeOf(tensor.dtype()); PADDLE_ENFORCE_GE(tensor.capacity(), bytes, - paddle::platform::errors::InvalidArgument( + phi::errors::InvalidArgument( "The memory size %d should be enough to meet the " "volume required by metadata %d.", tensor.capacity(), bytes)); - PADDLE_ENFORCE_GE(begin_idx, - 0, - paddle::platform::errors::OutOfRange( - "The start row index must be greater than 0." - "But received the start index is d%.", - begin_idx)); - PADDLE_ENFORCE_LE(end_idx, - tensor.dims()[0], - paddle::platform::errors::OutOfRange( - "The end row index is out of bound.")); + PADDLE_ENFORCE_GE( + begin_idx, + 0, + phi::errors::OutOfRange("The start row index must be greater than 0." + "But received the start index is d%.", + begin_idx)); + PADDLE_ENFORCE_LE( + end_idx, + tensor.dims()[0], + phi::errors::OutOfRange("The end row index is out of bound.")); PADDLE_ENFORCE_LT( begin_idx, end_idx, - paddle::platform::errors::InvalidArgument( + phi::errors::InvalidArgument( "The start row index must be less than the end row index." "But received the start index = %d, the end index = %d.", begin_idx, diff --git a/paddle/phi/tests/common/test_backend.cc b/paddle/phi/tests/common/test_backend.cc index 941c00d9fea..d74a35c9eae 100644 --- a/paddle/phi/tests/common/test_backend.cc +++ b/paddle/phi/tests/common/test_backend.cc @@ -52,5 +52,19 @@ TEST(Backend, OStream) { } } +TEST(Backend, StringToBackend) { + namespace pexp = paddle::experimental; + EXPECT_EQ(phi::Backend::UNDEFINED, pexp::StringToBackend("Undefined")); + EXPECT_EQ(phi::Backend::CPU, pexp::StringToBackend("CPU")); + EXPECT_EQ(phi::Backend::GPU, pexp::StringToBackend("GPU")); + EXPECT_EQ(phi::Backend::XPU, pexp::StringToBackend("XPU")); + EXPECT_EQ(phi::Backend::NPU, pexp::StringToBackend("NPU")); + EXPECT_EQ(phi::Backend::MKLDNN, pexp::StringToBackend("MKLDNN")); + EXPECT_EQ(phi::Backend::CUDNN, pexp::StringToBackend("CUDNN")); + EXPECT_EQ(static_cast( + static_cast(phi::Backend::NUM_BACKENDS) + 1), + pexp::StringToBackend("CustomBackend")); +} + } // namespace tests } // namespace phi diff --git a/paddle/phi/tests/core/CMakeLists.txt b/paddle/phi/tests/core/CMakeLists.txt index 971d9112eea..576ab7ffe6a 100644 --- a/paddle/phi/tests/core/CMakeLists.txt +++ b/paddle/phi/tests/core/CMakeLists.txt @@ -1,3 +1,4 @@ +cc_test(test_custom_kernel SRCS test_custom_kernel.cc DEPS pten_custom_kernel) cc_test(test_dense_tensor SRCS test_dense_tensor.cc DEPS dense_tensor) cc_test(test_intrusive_ptr SRCS test_intrusive_ptr.cc) cc_test(test_type_info SRCS test_type_info.cc) diff --git a/paddle/fluid/framework/custom_kernel_test.cc b/paddle/phi/tests/core/test_custom_kernel.cc similarity index 70% rename from paddle/fluid/framework/custom_kernel_test.cc rename to paddle/phi/tests/core/test_custom_kernel.cc index fb3cc0a35f0..b0957d80aa9 100644 --- a/paddle/fluid/framework/custom_kernel_test.cc +++ b/paddle/phi/tests/core/test_custom_kernel.cc @@ -17,24 +17,21 @@ limitations under the License. */ #define _LINUX #endif -#include "paddle/fluid/framework/custom_kernel.h" - -#include -#include -#include "paddle/extension.h" +#ifdef _LINUX #include "paddle/fluid/framework/lod_tensor.h" -#include "paddle/fluid/framework/op_kernel_info_helper.h" -#include "paddle/fluid/memory/allocation/allocator_facade.h" #include "paddle/fluid/platform/device_context.h" #include "paddle/phi/api/lib/utils/allocator.h" #include "paddle/phi/api/lib/utils/storage.h" -#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/common/scalar.h" +#include "paddle/phi/common/scalar_array.h" #include "paddle/phi/core/kernel_context.h" #include "paddle/phi/core/kernel_factory.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/core/meta_tensor.h" #include "paddle/phi/infermeta/binary.h" -#include "paddle/utils/small_vector.h" -#ifdef _LINUX +#include + // user kernel function namespace custom_kernel { @@ -43,17 +40,23 @@ namespace custom_kernel { // attribute 11: fake_attributes // output 2: one Tensor* and one std::vector template -void FakeDot(const Context& dev_ctx, const paddle::Tensor& x, - const paddle::Tensor& y, - const std::vector& fake_input_vec, - bool fake_attr_bool, int fake_attr_int, float fake_attr_float, - double fake_attr_double, int64_t fake_attr_int64, - phi::dtype::float16 fake_attr_f16, phi::DataType fake_attr_dtype, +void FakeDot(const Context& dev_ctx, + const phi::DenseTensor& x, + const phi::DenseTensor& y, + const std::vector& fake_input_vec, + bool fake_attr_bool, + int fake_attr_int, + float fake_attr_float, + double fake_attr_double, + int64_t fake_attr_int64, + phi::dtype::float16 fake_attr_f16, + phi::DataType fake_attr_dtype, const phi::Scalar& fake_attr_scalar, const phi::ScalarArray& fake_attr_scalar_array, const std::vector& fake_attr_int64_vec, - const std::vector& fake_attr_int_vec, paddle::Tensor* out, - std::vector fake_out_vec) { + const std::vector& fake_attr_int_vec, + phi::DenseTensor* out, + std::vector fake_out_vec) { // print param info std::cout << "fake_input_vec.size: " << fake_input_vec.size() << std::endl; std::cout << "fake_attr_bool: " << fake_attr_bool << std::endl; @@ -83,10 +86,10 @@ void FakeDot(const Context& dev_ctx, const paddle::Tensor& x, auto const *x_ptr = x.data(), *x_ptr_ = &x_ptr[0]; auto const *y_ptr = y.data(), *y_ptr_ = &y_ptr[0]; - auto* z = out->mutable_data(paddle::PlaceType::kCPU); - auto shape = x.shape(); + T* z = dev_ctx.template Alloc(out); + auto&& d = x.dims(); auto const N = x.numel(); - auto const B = shape[shape.size() - 1]; + auto const B = d[d.size() - 1]; for (int j = 0; j < N / B; j++) { T ss = 0; for (int i = 0; i < B; i++) ss += (*x_ptr_++) * (*y_ptr_++); @@ -95,8 +98,19 @@ void FakeDot(const Context& dev_ctx, const paddle::Tensor& x, } } // namespace custom_kernel -PD_REGISTER_KERNEL(fake_dot, CPU, ALL_LAYOUT, custom_kernel::FakeDot, float, - double, int, int64_t, int8_t, uint8_t) {} +PD_REGISTER_KERNEL(fake_dot, + CPU, + ALL_LAYOUT, + custom_kernel::FakeDot, + float, + double, + int, + int64_t, + int8_t, + uint8_t) {} + +namespace phi { +namespace tests { // Upper code will store dot kernels info into OpKernelInfoMap TEST(CustomKernel, custom_kernel_dot) { @@ -105,33 +119,38 @@ TEST(CustomKernel, custom_kernel_dot) { phi::DataLayout layout = phi::DataLayout::ALL_LAYOUT; // 1.custom kernel info parsed and store - EXPECT_TRUE(paddle::OpKernelInfoMap::Instance().GetMap().find(op_name) != - paddle::OpKernelInfoMap::Instance().GetMap().end()); + EXPECT_TRUE(phi::CustomKernelMap::Instance().GetMap().find(op_name) != + phi::CustomKernelMap::Instance().GetMap().end()); + auto& custom_kernels = phi::CustomKernelMap::Instance().Kernels(); // 2.info check - EXPECT_EQ( - 6, static_cast(paddle::OpKernelInfoMap::Instance()[op_name].size())); - // index 0 - EXPECT_TRUE(paddle::OpKernelInfoMap::Instance()[op_name][0].GetBackend() == - backend); - EXPECT_TRUE(paddle::OpKernelInfoMap::Instance()[op_name][0].GetDataLayout() == - layout); - EXPECT_TRUE(paddle::OpKernelInfoMap::Instance()[op_name][0].GetDataType() == - phi::DataType::FLOAT32); - // index 5 - EXPECT_TRUE(paddle::OpKernelInfoMap::Instance()[op_name][5].GetBackend() == - backend); - EXPECT_TRUE(paddle::OpKernelInfoMap::Instance()[op_name][5].GetDataLayout() == - layout); - EXPECT_TRUE(paddle::OpKernelInfoMap::Instance()[op_name][5].GetDataType() == - phi::DataType::UINT8); + EXPECT_EQ(6, static_cast(custom_kernels[op_name].size())); + auto& custom_fake_dot_kernels = custom_kernels[op_name]; + EXPECT_TRUE(custom_fake_dot_kernels.find( + phi::KernelKey(backend, layout, phi::DataType::FLOAT32)) != + custom_fake_dot_kernels.end()); + EXPECT_TRUE(custom_fake_dot_kernels.find( + phi::KernelKey(backend, layout, phi::DataType::FLOAT64)) != + custom_fake_dot_kernels.end()); + EXPECT_TRUE(custom_fake_dot_kernels.find( + phi::KernelKey(backend, layout, phi::DataType::INT32)) != + custom_fake_dot_kernels.end()); + EXPECT_TRUE(custom_fake_dot_kernels.find( + phi::KernelKey(backend, layout, phi::DataType::INT64)) != + custom_fake_dot_kernels.end()); + EXPECT_TRUE(custom_fake_dot_kernels.find( + phi::KernelKey(backend, layout, phi::DataType::INT8)) != + custom_fake_dot_kernels.end()); + EXPECT_TRUE(custom_fake_dot_kernels.find( + phi::KernelKey(backend, layout, phi::DataType::UINT8)) != + custom_fake_dot_kernels.end()); // 3.before register auto& kernel_factory_instance = phi::KernelFactory::Instance(); auto& kernels = phi::KernelFactory::Instance().kernels(); EXPECT_TRUE(!kernel_factory_instance.HasCompatiblePtenKernel(op_name)); - // mock fake_dot is supported by pten for HasCompatiblePtenKernel check while + // mock fake_dot is supported by phi for HasCompatiblePtenKernel check while // registering auto& fake_dot_kernels = kernels[op_name]; @@ -155,8 +174,7 @@ TEST(CustomKernel, custom_kernel_dot) { fake_dot_kernels.end()); // register - paddle::framework::RegisterKernelWithMetaInfoMap( - paddle::OpKernelInfoMap::Instance()); + phi::RegisterCustomKernels(phi::CustomKernelMap::Instance()); EXPECT_TRUE(fake_dot_kernels.find( phi::KernelKey(backend, layout, phi::DataType::FLOAT32)) != @@ -186,15 +204,15 @@ TEST(CustomKernel, custom_kernel_dot) { paddle::platform::CPUPlace()); auto dense_x = std::make_shared( alloc.get(), - phi::DenseTensorMeta(phi::DataType::UINT8, phi::make_ddim({2, 3}), - phi::DataLayout::NCHW)); + phi::DenseTensorMeta( + phi::DataType::UINT8, phi::make_ddim({2, 3}), phi::DataLayout::NCHW)); auto* dense_x_data = dense_x->mutable_data(paddle::platform::CPUPlace()); auto dense_y = std::make_shared( alloc.get(), - phi::DenseTensorMeta(phi::DataType::UINT8, phi::make_ddim({2, 3}), - phi::DataLayout::NCHW)); + phi::DenseTensorMeta( + phi::DataType::UINT8, phi::make_ddim({2, 3}), phi::DataLayout::NCHW)); auto* dense_y_data = dense_y->mutable_data(paddle::platform::CPUPlace()); @@ -288,38 +306,7 @@ TEST(CustomKernel, custom_kernel_dot) { ASSERT_EQ(expect_result[1], actual_result1); } -// test OpKernelInfoHelper -TEST(OpKernelInfoHelper, op_kernel_info_help_getters) { - using OpKernelInfoHelper = paddle::framework::OpKernelInfoHelper; - std::string op_name = "fake_dot"; - phi::Backend backend = phi::Backend::CPU; - phi::DataLayout layout = phi::DataLayout::ANY; - phi::DataType dtype = phi::DataType::FLOAT32; - - auto op_kernel_info = paddle::OpKernelInfoMap::Instance()[op_name][0]; - - EXPECT_EQ(op_name, OpKernelInfoHelper::GetOpName(op_kernel_info)); - EXPECT_EQ(backend, OpKernelInfoHelper::GetBackend(op_kernel_info)); - EXPECT_EQ(layout, OpKernelInfoHelper::GetDataLayout(op_kernel_info)); - EXPECT_EQ(dtype, OpKernelInfoHelper::GetDataType(op_kernel_info)); - - EXPECT_EQ(phi::KernelKey(backend, layout, dtype), - OpKernelInfoHelper::GetKernelKey(op_kernel_info)); - - paddle::CustomKernelFunc kernel_fn = - PD_PT_KERNEL(custom_kernel::FakeDot); - EXPECT_EQ(kernel_fn, OpKernelInfoHelper::GetKernelFn(op_kernel_info)); - - void* variadic_func = - PD_PT_VARIADIC_KERNEL(custom_kernel::FakeDot); - EXPECT_EQ(variadic_func, - OpKernelInfoHelper::GetVariadicKernelFn(op_kernel_info)); - - auto& input_defs = OpKernelInfoHelper::GetInputDefs(op_kernel_info); - auto& output_defs = OpKernelInfoHelper::GetOutputDefs(op_kernel_info); - auto& attribute_defs = OpKernelInfoHelper::GetAttributeDefs(op_kernel_info); - EXPECT_EQ(3, static_cast(input_defs.size())); - EXPECT_EQ(2, static_cast(output_defs.size())); - EXPECT_EQ(11, static_cast(attribute_defs.size())); -} +} // namespace tests +} // namespace phi + #endif diff --git a/python/paddle/fluid/tests/custom_kernel/custom_kernel_dot.cc b/python/paddle/fluid/tests/custom_kernel/custom_kernel_dot.cc index 3ae30c2f305..68393cba57e 100644 --- a/python/paddle/fluid/tests/custom_kernel/custom_kernel_dot.cc +++ b/python/paddle/fluid/tests/custom_kernel/custom_kernel_dot.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/extension.h" +#include "paddle/phi/core/kernel_registry.h" namespace paddle { @@ -21,19 +21,19 @@ namespace custom_kernel { // Here we use dot for test // This test will fail when this kernel is supported in framework template -void Dot(const Context& dev_ctx, - const paddle::Tensor& x, - const paddle::Tensor& y, - paddle::Tensor* out) { +void DotKernel(const Context& dev_ctx, + const phi::DenseTensor& x, + const phi::DenseTensor& y, + phi::DenseTensor* out) { auto const *x_ptr = x.data(), *x_ptr_ = &x_ptr[0]; auto const *y_ptr = y.data(), *y_ptr_ = &y_ptr[0]; - auto* z = out->mutable_data(paddle::PlaceType::kCPU); + T* z = dev_ctx.template Alloc(out); // Loop over the total N elements of both operands while sum-reducing every // B pairs along the way where B is the dimension of the least ordered axis - auto shape = x.shape(); + auto&& d = x.dims(); auto const N = x.numel(); - auto const B = shape[shape.size() - 1]; + auto const B = d[d.size() - 1]; for (int j = 0; j < N / B; j++) { T ss = 0; @@ -45,6 +45,7 @@ void Dot(const Context& dev_ctx, } // namespace custom_kernel } // namespace paddle -PD_REGISTER_KERNEL(dot, CPU, ALL_LAYOUT, paddle::custom_kernel::Dot, int8_t) { +PD_REGISTER_KERNEL( + dot, CPU, ALL_LAYOUT, paddle::custom_kernel::DotKernel, int8_t) { kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::INT8); } diff --git a/python/paddle/fluid/tests/custom_kernel/custom_kernel_dot_setup.py b/python/paddle/fluid/tests/custom_kernel/custom_kernel_dot_setup.py index 5e3bd2f8ed9..3cef228d14d 100644 --- a/python/paddle/fluid/tests/custom_kernel/custom_kernel_dot_setup.py +++ b/python/paddle/fluid/tests/custom_kernel/custom_kernel_dot_setup.py @@ -1,11 +1,11 @@ # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -# +# # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at -# +# # http://www.apache.org/licenses/LICENSE-2.0 -# +# # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -16,9 +16,28 @@ import os from paddle.fluid import core from distutils.sysconfig import get_python_lib from distutils.core import setup, Extension +from setuptools.command.build_ext import build_ext + + +# refer: https://note.qidong.name/2018/03/setup-warning-strict-prototypes +# Avoid a gcc warning below: +# cc1plus: warning: command line option ‘-Wstrict-prototypes’ is valid +# for C/ObjC but not for C++ +class BuildExt(build_ext): + def build_extensions(self): + if '-Wstrict-prototypes' in self.compiler.compiler_so: + self.compiler.compiler_so.remove('-Wstrict-prototypes') + super(BuildExt, self).build_extensions() + # cc flags -paddle_extra_compile_args = ['-std=c++14', '-shared', '-fPIC'] +paddle_extra_compile_args = [ + '-std=c++14', + '-shared', + '-fPIC', + '-Wno-parentheses', + '-DPADDLE_WITH_CUSTOM_KERNEL', +] if core.is_compiled_with_npu(): paddle_extra_compile_args += ['-D_GLIBCXX_USE_CXX11_ABI=0'] @@ -27,6 +46,14 @@ site_packages_path = get_python_lib() paddle_custom_kernel_include = [ os.path.join(site_packages_path, 'paddle', 'include'), ] +# include path third_party +compile_third_party_path = os.path.join(os.environ['PADDLE_ROOT'], + 'build/third_party') +paddle_custom_kernel_include += [ + os.path.join(compile_third_party_path, 'boost/src/extern_boost'), # boost + os.path.join(compile_third_party_path, 'install/gflags/include'), # gflags + os.path.join(compile_third_party_path, 'install/glog/include'), # glog +] # libs path paddle_custom_kernel_library_dir = [ @@ -50,4 +77,5 @@ setup( name='custom_kernel_dot', version='1.0', description='custom kernel fot compiling', + cmdclass={'build_ext': BuildExt}, ext_modules=[custom_kernel_dot_module]) diff --git a/python/setup.py.in b/python/setup.py.in index 7b3909d40a0..f39429387db 100755 --- a/python/setup.py.in +++ b/python/setup.py.in @@ -577,9 +577,9 @@ headers = ( list(find_files('*.h', '@PADDLE_SOURCE_DIR@/paddle/phi/common')) + # pten common headers # pten level api headers (low level api) list(find_files('*.h', '@PADDLE_SOURCE_DIR@/paddle/phi/core', recursive=True)) + # pten core headers + list(find_files('*.h', '@PADDLE_SOURCE_DIR@/paddle/phi/backends', recursive=True)) + # pten backends headers # utila api headers - ['@PADDLE_SOURCE_DIR@/paddle/utils/any.h'] + - ['@PADDLE_SOURCE_DIR@/paddle/utils/small_vector.h'] + + list(find_files('*.h', '@PADDLE_SOURCE_DIR@/paddle/utils', recursive=True)) + # paddle utils headers ['@PADDLE_SOURCE_DIR@/paddle/fluid/platform/device/device_ext.h']) if '${WITH_MKLDNN}' == 'ON': -- GitLab