未验证 提交 68631ed4 编写于 作者: A Aganlengzi 提交者: GitHub

[PluggableDevice]custom kernel to phi core structs (#39690)

* [PluggableDevice]custom kernel to pten core structs

* mod extension.h for custom op

* compatible python for CI

* support custom context

* refactor to pten

* fix windows and ut
上级 9c51eee1
...@@ -437,8 +437,7 @@ message(STATUS "branch: ${PADDLE_BRANCH}") ...@@ -437,8 +437,7 @@ message(STATUS "branch: ${PADDLE_BRANCH}")
configure_file(commit.h.in commit.h) configure_file(commit.h.in commit.h)
cc_library(custom_operator SRCS custom_operator.cc DEPS tensor attribute framework_proto op_registry operator dynamic_loader string_helper pten_tensor op_meta_info pten_api) cc_library(custom_operator SRCS custom_operator.cc DEPS tensor attribute framework_proto op_registry operator dynamic_loader string_helper pten_tensor op_meta_info pten_api)
cc_library(custom_kernel SRCS custom_kernel.cc DEPS cc_library(custom_kernel SRCS custom_kernel.cc DEPS op_registry pten_custom_kernel pten_tensor_raw)
tensor attribute framework_proto op_registry operator dynamic_loader string_helper pten_tensor op_kernel_info pten_api)
#cc_binary(test_executor SRCS test_executor.cc DEPS executor op_registry ${GLOB_OP_LIB} ${GLOB_OPERATOR_DEPS} ) #cc_binary(test_executor SRCS test_executor.cc DEPS executor op_registry ${GLOB_OP_LIB} ${GLOB_OPERATOR_DEPS} )
#cc_binary(new_executor SRCS new_exec_test.cc DEPS operator op_registry executor ${GLOB_OP_LIB} ${GLOB_OPERATOR_DEPS} profiler) #cc_binary(new_executor SRCS new_exec_test.cc DEPS operator op_registry executor ${GLOB_OP_LIB} ${GLOB_OPERATOR_DEPS} profiler)
...@@ -459,4 +458,3 @@ else() ...@@ -459,4 +458,3 @@ else()
cc_library(fluid_convert_utils SRCS convert_utils.cc DEPS data_type place) cc_library(fluid_convert_utils SRCS convert_utils.cc DEPS data_type place)
endif() endif()
cc_test(convert_utils_test SRCS convert_utils_test.cc DEPS fluid_convert_utils) cc_test(convert_utils_test SRCS convert_utils_test.cc DEPS fluid_convert_utils)
cc_test(custom_kernel_test SRCS custom_kernel_test.cc DEPS custom_kernel pten_tensor)
...@@ -18,355 +18,24 @@ limitations under the License. */ ...@@ -18,355 +18,24 @@ limitations under the License. */
#endif #endif
#include "paddle/fluid/framework/custom_kernel.h" #include "paddle/fluid/framework/custom_kernel.h"
#include <dirent.h> #include "paddle/phi/core/custom_kernel.h"
#include <algorithm>
#include <regex>
#include "paddle/fluid/framework/op_kernel_info_helper.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/phi/api/ext/op_kernel_info.h"
#include "paddle/phi/core/compat/convert_utils.h"
#include "paddle/phi/core/kernel_context.h"
#include "paddle/phi/core/kernel_registry.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
// set phi::Kernel args_def_ from op_kernel_info
// because we can not set directly to phi::Kernel without exposing
// phi::KernelArgsDef when parsing custom user function
static void ParseArgs(const OpKernelInfo& op_kernel_info,
phi::KernelArgsDef* args_def) {
auto& input_defs = OpKernelInfoHelper::GetInputDefs(op_kernel_info);
auto& output_defs = OpKernelInfoHelper::GetOutputDefs(op_kernel_info);
auto& attribute_defs = OpKernelInfoHelper::GetAttributeDefs(op_kernel_info);
for (auto& input : input_defs) {
auto type_index =
input.is_vector
? std::type_index(typeid(const std::vector<phi::DenseTensor>&))
: std::type_index(typeid(const phi::DenseTensor&));
args_def->AppendInput(input.backend, input.layout, input.dtype, type_index);
}
for (auto& output : output_defs) {
auto type_index =
output.is_vector
? std::type_index(typeid(const std::vector<phi::DenseTensor>&))
: std::type_index(typeid(const phi::DenseTensor&));
args_def->AppendOutput(output.backend, output.layout, output.dtype,
type_index);
}
for (auto& attr : attribute_defs) {
args_def->AppendAttribute(attr.type_index);
}
}
// custom pten kernel call function define
static void RunKernelFunc(phi::KernelContext* ctx,
const OpKernelInfo& op_kernel_info) {
VLOG(3) << "[CUSTOM KERNEL] RunKernelFunc begin...";
// input and output size is not params' num
// but actual Tensors' size
size_t input_size = ctx->InputsSize();
size_t output_size = ctx->OutputsSize();
size_t attr_size = ctx->AttrsSize();
// parameters' num of unified user kernel function
auto& input_defs = OpKernelInfoHelper::GetInputDefs(op_kernel_info);
auto& output_defs = OpKernelInfoHelper::GetOutputDefs(op_kernel_info);
auto& attribute_defs = OpKernelInfoHelper::GetAttributeDefs(op_kernel_info);
PADDLE_ENFORCE_GE(input_size, input_defs.size(),
platform::errors::InvalidArgument(
"the size of ctx inputs size (%d) must be larger than "
"the size of kernel input_defs (%d).",
input_size, input_defs.size()));
PADDLE_ENFORCE_GE(output_size, output_defs.size(),
platform::errors::InvalidArgument(
"the size of ctx outputs size (%d) must be larger than "
"the size of kernel output_defs (%d).",
output_size, output_defs.size()));
PADDLE_ENFORCE_EQ(attr_size, attribute_defs.size(),
platform::errors::InvalidArgument(
"the size of ctx attribute size (%d) must be equal to "
"to the size of kernel attribute_defs (%d).",
attr_size, attribute_defs.size()));
VLOG(3) << "[CUSTOM KERNEL] Input num: " << input_defs.size()
<< "[tensor size:" << input_size << "]"
<< " Attribute num: " << attribute_defs.size()
<< " Output num: " << output_defs.size()
<< "[tensor size:" << output_size << "].";
// Inputs mapping
std::vector<paddle::experimental::Tensor> custom_ins;
std::vector<std::vector<paddle::experimental::Tensor>> custom_vec_ins;
for (size_t in_idx = 0; in_idx < input_defs.size(); ++in_idx) {
VLOG(3) << "Mapping Input[" << in_idx << "]";
const std::pair<int, int> range = ctx->InputRangeAt(in_idx);
// is_vector tells if this Input is Tensor or std::vector<Tensor>
if (!input_defs.at(in_idx).is_vector) {
paddle::experimental::Tensor custom_t;
auto& ctx_tensor = ctx->InputAt<phi::DenseTensor>(range.first);
custom_t.set_impl(std::make_shared<phi::DenseTensor>(ctx_tensor));
custom_ins.emplace_back(custom_t);
} else {
std::vector<paddle::experimental::Tensor> custom_vec_in;
auto ctx_tensor_vec =
ctx->MoveInputsBetween<phi::DenseTensor>(range.first, range.second);
for (auto& ctx_tensor : ctx_tensor_vec) {
paddle::experimental::Tensor custom_t;
custom_t.set_impl(std::make_shared<phi::DenseTensor>(ctx_tensor));
custom_vec_in.emplace_back(custom_t);
}
custom_vec_ins.emplace_back(custom_vec_in);
}
VLOG(3) << "Mapped Input[" << in_idx << "] with range[" << range.first
<< "," << range.second << ").";
}
// Attributes mapping
std::vector<paddle::any> custom_attrs;
for (size_t attr_idx = 0; attr_idx < attribute_defs.size(); ++attr_idx) {
VLOG(3) << "Mapping Attribute[" << attr_idx << "]";
if (attribute_defs[attr_idx].type_index == std::type_index(typeid(bool))) {
bool arg = ctx->AttrAt<bool>(attr_idx);
custom_attrs.emplace_back(arg);
} else if (attribute_defs[attr_idx].type_index ==
std::type_index(typeid(int))) {
int arg = ctx->AttrAt<int>(attr_idx);
custom_attrs.emplace_back(arg);
} else if (attribute_defs[attr_idx].type_index ==
std::type_index(typeid(float))) {
float arg = ctx->AttrAt<float>(attr_idx);
custom_attrs.emplace_back(arg);
} else if (attribute_defs[attr_idx].type_index ==
std::type_index(typeid(double))) {
double arg = ctx->AttrAt<double>(attr_idx);
custom_attrs.emplace_back(arg);
} else if (attribute_defs[attr_idx].type_index ==
std::type_index(typeid(int64_t))) {
int64_t arg = ctx->AttrAt<int64_t>(attr_idx);
custom_attrs.emplace_back(arg);
} else if (attribute_defs[attr_idx].type_index ==
std::type_index(typeid(phi::dtype::float16))) {
phi::dtype::float16 arg = ctx->AttrAt<phi::dtype::float16>(attr_idx);
custom_attrs.emplace_back(arg);
} else if (attribute_defs[attr_idx].type_index ==
std::type_index(typeid(DataType))) {
DataType arg = ctx->AttrAt<DataType>(attr_idx);
custom_attrs.emplace_back(arg);
} else if (attribute_defs[attr_idx].type_index ==
std::type_index(typeid(const Scalar&))) {
const Scalar& arg = ctx->AttrAt<const Scalar&>(attr_idx);
custom_attrs.emplace_back(arg);
} else if (attribute_defs[attr_idx].type_index ==
std::type_index(typeid(const std::vector<int64_t>&))) {
const std::vector<int64_t>& arg =
ctx->AttrAt<const std::vector<int64_t>&>(attr_idx);
custom_attrs.emplace_back(arg);
} else if (attribute_defs[attr_idx].type_index ==
std::type_index(typeid(const ScalarArray&))) {
const ScalarArray& arg = ctx->AttrAt<const ScalarArray&>(attr_idx);
custom_attrs.emplace_back(arg);
} else if (attribute_defs[attr_idx].type_index ==
std::type_index(typeid(const std::vector<int>&))) {
const std::vector<int>& arg =
ctx->AttrAt<const std::vector<int>&>(attr_idx);
custom_attrs.emplace_back(arg);
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"Unsupported attribute attribute_defs[%d].type_index", attr_idx));
}
VLOG(3) << "Mapped Attribute[" << attr_idx << "]";
}
// Outputs mapping
std::vector<paddle::experimental::Tensor*> custom_outs;
std::vector<std::vector<paddle::experimental::Tensor*>> custom_vec_outs;
std::vector<std::shared_ptr<phi::DenseTensor>> custom_outs_ptr;
std::vector<std::vector<std::shared_ptr<phi::DenseTensor>>>
custom_vec_outs_ptr;
for (size_t out_idx = 0; out_idx < output_defs.size(); ++out_idx) {
VLOG(3) << "Mapping Output[" << out_idx << "]";
const std::pair<int, int> range = ctx->OutputRangeAt(out_idx);
// is_vector tells if this Output is Tensor or std::vector<Tensor>
if (!output_defs.at(out_idx).is_vector) {
auto* ctx_tensor = ctx->MutableOutputAt<phi::DenseTensor>(range.first);
auto* custom_t = new paddle::experimental::Tensor();
auto custom_t_ptr = std::make_shared<phi::DenseTensor>(*ctx_tensor);
custom_t->set_impl(custom_t_ptr);
custom_outs.emplace_back(custom_t);
custom_outs_ptr.emplace_back(custom_t_ptr);
} else {
std::vector<paddle::experimental::Tensor*> custom_vec_out;
std::vector<std::shared_ptr<phi::DenseTensor>> custom_vec_out_ptr;
auto ctx_tensor_vec = ctx->MutableOutputBetween<phi::DenseTensor>(
range.first, range.second);
for (auto ctx_tensor : ctx_tensor_vec) {
auto* custom_t = new paddle::experimental::Tensor();
auto custom_t_ptr = std::make_shared<phi::DenseTensor>(*ctx_tensor);
custom_t->set_impl(custom_t_ptr);
custom_vec_out.emplace_back(custom_t);
custom_vec_out_ptr.emplace_back(custom_t_ptr);
}
custom_vec_outs.emplace_back(custom_vec_out);
custom_vec_outs_ptr.emplace_back(custom_vec_out_ptr);
}
VLOG(3) << "Mapped Output[" << out_idx << "] with range[" << range.first
<< "," << range.second << ").";
}
// DeviceContext
// In pten, the first paramter XXContext is decided when registering
// through template param, but custom kernel function use unified
// DeviceContext as first parameter of user_kernel_fn, we use backend
// from OpKernelInfo to decide XXContext. In temporary simple
// DeviceContext, we just set necessary info to dev_ctx(such as stream
// in NPUContext), more related work should be done when
// phi::DeviceContext is exposed to outer.
DeviceContext dev_ctx;
auto& backend = OpKernelInfoHelper::GetBackend(op_kernel_info);
if (backend == phi::Backend::CPU) {
// do nothing
} else {
#ifdef PADDLE_WITH_CUSTOM_DEVICE
size_t device_type_id_ = static_cast<size_t>(backend) -
static_cast<size_t>(phi::Backend::ALL_BACKEND);
std::string device_type = phi::GetGlobalDeviceType(device_type_id_);
if (!device_type.empty()) {
auto custom_ctx =
ctx->GetDeviceContext<paddle::platform::CustomDeviceContext>();
dev_ctx.set_stream(custom_ctx.stream());
return;
}
#endif
LOG(ERROR) << "[CUSTOM KERNEL] Unsupported kernel backend: " << backend
<< " with compiled Paddle.";
return;
}
auto& user_kernel_fn = OpKernelInfoHelper::GetKernelFn(op_kernel_info);
// call user function
user_kernel_fn(dev_ctx, custom_ins, custom_vec_ins, custom_attrs,
&custom_outs, &custom_vec_outs);
VLOG(3) << "[CUSTOM KERNEL] finished call user kernel function.";
// NOTE: Map back the output tensors with stored shared_ptrs.
for (int out_idx = output_defs.size() - 1; out_idx >= 0; --out_idx) {
VLOG(3) << "Mapping Back Output[" << out_idx << "]";
const std::pair<int, int> range = ctx->OutputRangeAt(out_idx);
// is_vector tells if this Output is Tensor or std::vector<Tensor>
if (!output_defs.at(out_idx).is_vector) {
auto* ctx_tensor = ctx->MutableOutputAt<phi::DenseTensor>(range.first);
*ctx_tensor = *(custom_outs_ptr.back().get());
custom_outs_ptr.pop_back();
} else {
auto ctx_tensor_vec = ctx->MutableOutputBetween<phi::DenseTensor>(
range.first, range.second);
auto custom_vec_ptr_out = custom_vec_outs_ptr.back();
for (int idx = ctx_tensor_vec.size() - 1; idx >= 0; --idx) {
*(ctx_tensor_vec[idx]) = *(custom_vec_ptr_out.back().get());
custom_vec_ptr_out.pop_back();
}
custom_vec_outs_ptr.pop_back();
}
VLOG(3) << "Mapped Output[" << out_idx << "] with range[" << range.first
<< "," << range.second << "].";
}
// delete newed paddle::Tensor for outputs while calling user kernel function
for (size_t i = 0; i < custom_outs.size(); ++i) {
delete custom_outs[i];
}
for (size_t i = 0; i < custom_vec_outs.size(); ++i) {
for (size_t j = 0; j < custom_vec_outs[i].size(); ++j) {
delete custom_vec_outs[i][j];
}
}
}
void RegisterKernelWithMetaInfo(
const std::vector<OpKernelInfo>& op_kernel_infos) {
for (size_t i = 0; i < op_kernel_infos.size(); ++i) {
auto& kernel_info = op_kernel_infos[i];
auto op_type = OpKernelInfoHelper::GetOpName(kernel_info);
auto kernel_key = OpKernelInfoHelper::GetKernelKey(kernel_info);
VLOG(3) << "[CUSTOM KERNEL] registering [" << op_type << "]" << kernel_key;
// 1.Check whether this kernel is valid for a specific operator
PADDLE_ENFORCE_EQ(
phi::KernelFactory::Instance().HasCompatiblePtenKernel(op_type), true,
platform::errors::InvalidArgument(
"[CUSTOM KERNEL] %s is not ready for custom kernel registering.",
op_type));
// 2.Check whether kernel_key has been already registed
PADDLE_ENFORCE_EQ(
phi::KernelFactory::Instance().kernels()[op_type].find(kernel_key),
phi::KernelFactory::Instance().kernels()[op_type].end(),
platform::errors::InvalidArgument(
"[CUSTOM KERNEL] The operator <%s>'s kernel: %s has been "
"already existed in Paddle, please contribute PR if need "
"to optimize the kernel code. Custom kernel do NOT support "
"to replace existing kernel in Paddle.",
op_type, kernel_key));
// phi::KernelFn
phi::KernelFn kernel_fn = [kernel_info](phi::KernelContext* ctx) {
VLOG(3) << "[CUSTOM KERNEL] run custom PTEN kernel func in lambda.";
RunKernelFunc(ctx, kernel_info);
};
// variadic_kernel_fn
void* variadic_kernel_fn =
OpKernelInfoHelper::GetVariadicKernelFn(kernel_info);
phi::Kernel kernel(kernel_fn, variadic_kernel_fn);
// args info
ParseArgs(kernel_info, kernel.mutable_args_def());
// register custom kernel to phi::KernelFactory
phi::KernelFactory::Instance().kernels()[op_type][kernel_key] = kernel;
VLOG(3) << "[CUSTOM KERNEL] Successed in registering operator <" << op_type
<< ">'s kernel " << kernel_key << " to Paddle. "
<< "It will be used like native ones.";
}
}
void RegisterKernelWithMetaInfoMap(
const paddle::OpKernelInfoMap& op_kernel_info_map) {
auto& kernel_info_map = op_kernel_info_map.GetMap();
VLOG(3) << "[CUSTOM KERNEL] size of op_kernel_info_map: "
<< kernel_info_map.size();
// pair: {op_type, OpKernelInfo}
for (auto& pair : kernel_info_map) {
VLOG(3) << "[CUSTOM KERNEL] pair first -> op name: " << pair.first;
RegisterKernelWithMetaInfo(pair.second);
}
}
void LoadCustomKernelLib(const std::string& dso_lib_path, void* dso_handle) { void LoadCustomKernelLib(const std::string& dso_lib_path, void* dso_handle) {
#ifdef _LINUX #ifdef _LINUX
typedef OpKernelInfoMap& get_op_kernel_info_map_t(); typedef phi::CustomKernelMap& get_custom_kernel_map_t();
auto* func = reinterpret_cast<get_op_kernel_info_map_t*>( auto* func = reinterpret_cast<get_custom_kernel_map_t*>(
dlsym(dso_handle, "PD_GetOpKernelInfoMap")); dlsym(dso_handle, "PD_GetCustomKernelMap"));
if (func == nullptr) { if (func == nullptr) {
LOG(WARNING) << "Skipped lib [" << dso_lib_path << "]: fail to find " LOG(WARNING) << "Skipped lib [" << dso_lib_path << "]: fail to find "
<< "PD_GetOpKernelInfoMap symbol in this lib."; << "PD_GetCustomKernelMap symbol in this lib.";
return; return;
} }
auto& op_kernel_info_map = func(); auto& custom_kernel_map = func();
RegisterKernelWithMetaInfoMap(op_kernel_info_map); phi::RegisterCustomKernels(custom_kernel_map);
LOG(INFO) << "Successed in loading custom kernels in lib: " << dso_lib_path; LOG(INFO) << "Successed in loading custom kernels in lib: " << dso_lib_path;
#else #else
VLOG(3) << "Unsupported: Custom kernel is only implemented on Linux."; VLOG(3) << "Unsupported: Custom kernel is only implemented on Linux.";
......
...@@ -14,22 +14,13 @@ limitations under the License. */ ...@@ -14,22 +14,13 @@ limitations under the License. */
#pragma once #pragma once
#include "paddle/phi/api/ext/op_kernel_info.h" #include <string>
namespace paddle { namespace paddle {
namespace framework { namespace framework {
// Load custom kernel lib and register
void LoadCustomKernelLib(const std::string& dso_lib_path, void* dso_handle); void LoadCustomKernelLib(const std::string& dso_lib_path, void* dso_handle);
// Load custom kernel api: register kernel after user compiled
void LoadOpKernelInfoAndRegister(const std::string& dso_name);
// Register custom kernel api: register kernel directly
void RegisterKernelWithMetaInfoMap(
const paddle::OpKernelInfoMap& op_kernel_info_map);
// Interface for selective register custom kernel.
void RegisterKernelWithMetaInfo(
const std::vector<OpKernelInfo>& op_kernel_infos);
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -146,6 +146,10 @@ if(WITH_ASCEND_CL) ...@@ -146,6 +146,10 @@ if(WITH_ASCEND_CL)
target_link_libraries(device_context npu_resource_pool) target_link_libraries(device_context npu_resource_pool)
endif() endif()
if(WITH_CUSTOM_DEVICE)
target_link_libraries(device_context custom_context)
endif()
cc_test(init_test SRCS init_test.cc DEPS device_context) cc_test(init_test SRCS init_test.cc DEPS device_context)
# Manage all device event library # Manage all device event library
......
...@@ -41,7 +41,6 @@ limitations under the License. */ ...@@ -41,7 +41,6 @@ limitations under the License. */
#include "paddle/phi/api/ext/dispatch.h" #include "paddle/phi/api/ext/dispatch.h"
#include "paddle/phi/api/ext/dll_decl.h" #include "paddle/phi/api/ext/dll_decl.h"
#include "paddle/phi/api/ext/exception.h" #include "paddle/phi/api/ext/exception.h"
#include "paddle/phi/api/ext/op_kernel_info.h"
#include "paddle/phi/api/ext/op_meta_info.h" #include "paddle/phi/api/ext/op_meta_info.h"
#include "paddle/phi/api/ext/place.h" #include "paddle/phi/api/ext/place.h"
#include "paddle/phi/api/ext/tensor_compat.h" #include "paddle/phi/api/ext/tensor_compat.h"
...@@ -90,7 +90,6 @@ cc_library(manual_api SRCS manual_api.cc DEPS pten_tensor_raw pten kernel_dispat ...@@ -90,7 +90,6 @@ cc_library(manual_api SRCS manual_api.cc DEPS pten_tensor_raw pten kernel_dispat
cc_library(pten_tensor SRCS tensor_method.cc DEPS pten_tensor_raw pten_function_api) cc_library(pten_tensor SRCS tensor_method.cc DEPS pten_tensor_raw pten_function_api)
cc_library(op_meta_info SRCS op_meta_info.cc DEPS pten_tensor) cc_library(op_meta_info SRCS op_meta_info.cc DEPS pten_tensor)
cc_library(op_kernel_info SRCS op_kernel_info.cc DEPS pten_tensor_raw)
cc_library(sparse_api SRCS sparse_api.cc DEPS pten_tensor pten kernel_dispatch pten_data_transform) cc_library(sparse_api SRCS sparse_api.cc DEPS pten_tensor pten kernel_dispatch pten_data_transform)
cc_library(pten_function_api SRCS ${api_source_file} DEPS pten_tensor pten kernel_dispatch pten_data_transform) cc_library(pten_function_api SRCS ${api_source_file} DEPS pten_tensor pten kernel_dispatch pten_data_transform)
......
...@@ -21,3 +21,7 @@ endif() ...@@ -21,3 +21,7 @@ endif()
if(WITH_GPU) if(WITH_GPU)
add_dependencies(pten_context gpu_context) add_dependencies(pten_context gpu_context)
endif() endif()
if(WITH_CUSTOM_DEVICE)
add_dependencies(pten_context custom_context)
endif()
...@@ -21,12 +21,15 @@ limitations under the License. */ ...@@ -21,12 +21,15 @@ limitations under the License. */
// path replacement after implementing pten DeviceContext // path replacement after implementing pten DeviceContext
#include "paddle/phi/backends/cpu/cpu_context.h" #include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/backends/custom/custom_context.h"
#include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/backends/xpu/xpu_context.h" #include "paddle/phi/backends/xpu/xpu_context.h"
#ifndef PADDLE_WITH_CUSTOM_KERNEL
// TODO(wilber): DeviceContextPool nees include fluid file. // TODO(wilber): DeviceContextPool nees include fluid file.
#include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/device_context.h"
namespace phi { namespace phi {
using DeviceContextPool = paddle::platform::DeviceContextPool; using DeviceContextPool = paddle::platform::DeviceContextPool;
} // namespace phi } // namespace phi
#endif
...@@ -32,8 +32,8 @@ struct CustomContext::Impl { ...@@ -32,8 +32,8 @@ struct CustomContext::Impl {
const Place& GetPlace() const { return place_; } const Place& GetPlace() const { return place_; }
C_Stream stream() const { void* stream() const {
return reinterpret_cast<C_Stream>(stream_->raw_stream()); return reinterpret_cast<void*>(stream_->raw_stream());
} }
void Wait() const { stream_->Wait(); } void Wait() const { stream_->Wait(); }
...@@ -47,7 +47,7 @@ void CustomContext::Init() { impl_->Init(); } ...@@ -47,7 +47,7 @@ void CustomContext::Init() { impl_->Init(); }
const Place& CustomContext::GetPlace() const { return impl_->GetPlace(); } const Place& CustomContext::GetPlace() const { return impl_->GetPlace(); }
C_Stream CustomContext::stream() const { return impl_->stream(); } void* CustomContext::stream() const { return impl_->stream(); }
void CustomContext::Wait() const { return impl_->Wait(); } void CustomContext::Wait() const { return impl_->Wait(); }
......
...@@ -15,7 +15,6 @@ limitations under the License. */ ...@@ -15,7 +15,6 @@ limitations under the License. */
#pragma once #pragma once
#include <memory> #include <memory>
#include "paddle/fluid/platform/device/device_ext.h"
#include "paddle/phi/common/place.h" #include "paddle/phi/common/place.h"
#include "paddle/phi/core/device_context.h" #include "paddle/phi/core/device_context.h"
...@@ -30,7 +29,7 @@ class CustomContext : public DeviceContext { ...@@ -30,7 +29,7 @@ class CustomContext : public DeviceContext {
const Place& GetPlace() const override; const Place& GetPlace() const override;
/*! \brief Return stream in the device context. */ /*! \brief Return stream in the device context. */
C_Stream stream() const; void* stream() const;
// Wait for all operations completion in the stream. // Wait for all operations completion in the stream.
void Wait() const override; void Wait() const override;
......
...@@ -130,6 +130,32 @@ inline std::ostream& operator<<(std::ostream& os, Backend backend) { ...@@ -130,6 +130,32 @@ inline std::ostream& operator<<(std::ostream& os, Backend backend) {
return os; return os;
} }
inline Backend StringToBackend(const char* backend_cstr) {
std::string s(backend_cstr);
if (s == std::string("Undefined")) {
return Backend::UNDEFINED;
}
for (size_t i = 0; i < s.size(); ++i) {
s[i] = toupper(s[i]);
}
if (s == std::string("CPU")) {
return Backend::CPU;
} else if (s == std::string("GPU")) {
return Backend::GPU;
} else if (s == std::string("XPU")) {
return Backend::XPU;
} else if (s == std::string("NPU")) {
return Backend::NPU;
} else if (s == std::string("MKLDNN")) {
return Backend::MKLDNN;
} else if (s == std::string("CUDNN")) {
return Backend::CUDNN;
} else {
return static_cast<Backend>(static_cast<size_t>(Backend::NUM_BACKENDS) +
phi::GetOrRegisterGlobalDeviceTypeId(s));
}
}
} // namespace experimental } // namespace experimental
} // namespace paddle } // namespace paddle
......
...@@ -25,6 +25,8 @@ cc_library(meta_tensor SRCS meta_tensor.cc DEPS tensor_base tensor_meta dense_te ...@@ -25,6 +25,8 @@ cc_library(meta_tensor SRCS meta_tensor.cc DEPS tensor_base tensor_meta dense_te
cc_library(infermeta_utils SRCS infermeta_utils.cc DEPS meta_tensor) cc_library(infermeta_utils SRCS infermeta_utils.cc DEPS meta_tensor)
cc_library(selected_rows SRCS selected_rows_impl.cc DEPS dense_tensor mixed_vector pten_enforce ddim) cc_library(selected_rows SRCS selected_rows_impl.cc DEPS dense_tensor mixed_vector pten_enforce ddim)
cc_library(pten_custom_kernel SRCS custom_kernel.cc DEPS kernel_factory convert_utils)
# Will remove once we implemented MKLDNN_Tensor # Will remove once we implemented MKLDNN_Tensor
if(WITH_MKLDNN) if(WITH_MKLDNN)
add_dependencies(dense_tensor mkldnn) add_dependencies(dense_tensor mkldnn)
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/custom_kernel.h"
namespace phi {
void RegisterCustomKernels(const CustomKernelMap& custom_kernel_map) {
auto& kernel_info_map = custom_kernel_map.GetMap();
VLOG(3) << "Size of custom_kernel_map: " << kernel_info_map.size();
for (auto& pair : kernel_info_map) {
PADDLE_ENFORCE_EQ(
KernelFactory::Instance().HasCompatiblePtenKernel(pair.first),
true,
phi::errors::InvalidArgument(
"The kernel %s is not ready for custom kernel registering.",
pair.first));
for (auto& info_pair : pair.second) {
auto& kernels = KernelFactory::Instance().kernels();
PADDLE_ENFORCE_EQ(
kernels[pair.first].find(info_pair.first),
kernels[pair.first].end(),
phi::errors::InvalidArgument(
"The operator <%s>'s kernel: %s has been already existed "
"in Paddle, please contribute PR if it is necessary "
"to optimize the kernel code. Custom kernel does NOT support "
"to replace existing kernel in Paddle.",
pair.first,
info_pair.first));
kernels[pair.first][info_pair.first] = info_pair.second;
VLOG(3) << "Successed in registering operator <" << pair.first
<< ">'s kernel: " << info_pair.first
<< " to Paddle. It will be used like native ones.";
}
}
}
} // namespace phi
#ifdef __cplusplus
extern "C" {
#endif
// C-API to get global CustomKernelMap.
phi::CustomKernelMap& PD_GetCustomKernelMap() {
return phi::CustomKernelMap::Instance();
}
#ifdef __cplusplus
} // end extern "C"
#endif
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/kernel_factory.h"
#include "paddle/phi/core/macros.h"
namespace phi {
/**
* Note:
* Used to store kernels' info before registered to KernelFactory.
*/
class CustomKernelMap {
public:
static CustomKernelMap& Instance() {
static CustomKernelMap g_custom_kernel_info_map;
return g_custom_kernel_info_map;
}
KernelNameMap& Kernels() { return kernels_; }
const KernelNameMap& GetMap() const { return kernels_; }
private:
CustomKernelMap() = default;
DISABLE_COPY_AND_ASSIGN(CustomKernelMap);
KernelNameMap kernels_;
};
/**
* Note:
* Used to register custom kernels to KernelFactory.
*/
void RegisterCustomKernels(const CustomKernelMap& custom_kernel_map);
} // namespace phi
...@@ -171,6 +171,9 @@ class DenseTensor : public TensorBase, ...@@ -171,6 +171,9 @@ class DenseTensor : public TensorBase,
DenseTensorMeta meta_; DenseTensorMeta meta_;
std::shared_ptr<phi::Allocation> holder_; std::shared_ptr<phi::Allocation> holder_;
#ifndef PADDLE_WITH_CUSTOM_KERNEL
#include "paddle/phi/core/dense_tensor.inl" #include "paddle/phi/core/dense_tensor.inl"
#endif
}; };
} // namespace phi } // namespace phi
...@@ -22,6 +22,7 @@ ...@@ -22,6 +22,7 @@
#include "paddle/phi/core/tensor_base.h" #include "paddle/phi/core/tensor_base.h"
#include "paddle/phi/core/tensor_utils.h" #include "paddle/phi/core/tensor_utils.h"
#include "paddle/utils/any.h" #include "paddle/utils/any.h"
#include "paddle/utils/optional.h"
#include "paddle/utils/small_vector.h" #include "paddle/utils/small_vector.h"
namespace phi { namespace phi {
......
...@@ -21,6 +21,7 @@ ...@@ -21,6 +21,7 @@
#include <typeinfo> #include <typeinfo>
#include <vector> #include <vector>
#include "paddle/phi/core/custom_kernel.h"
#include "paddle/phi/core/kernel_factory.h" #include "paddle/phi/core/kernel_factory.h"
#include "paddle/phi/core/kernel_utils.h" #include "paddle/phi/core/kernel_utils.h"
#include "paddle/phi/core/macros.h" #include "paddle/phi/core/macros.h"
...@@ -62,6 +63,9 @@ struct KernelArgsParseFunctor<Return_ (*)(Args_...)> { ...@@ -62,6 +63,9 @@ struct KernelArgsParseFunctor<Return_ (*)(Args_...)> {
#elif defined(PADDLE_WITH_XPU) #elif defined(PADDLE_WITH_XPU)
|| ||
arg_type == std::type_index(typeid(const XPUContext&))) { arg_type == std::type_index(typeid(const XPUContext&))) {
#elif defined(PADDLE_WITH_CUSTOM_DEVICE)
||
arg_type == std::type_index(typeid(const CustomContext&))) {
#else #else
) { ) {
#endif #endif
...@@ -83,11 +87,13 @@ struct KernelArgsParseFunctor<Return_ (*)(Args_...)> { ...@@ -83,11 +87,13 @@ struct KernelArgsParseFunctor<Return_ (*)(Args_...)> {
default_tensor_layout, default_tensor_layout,
default_key.dtype(), default_key.dtype(),
arg_type); arg_type);
#ifndef PADDLE_WITH_CUSTOM_KERNEL
} else if (arg_type == std::type_index(typeid(const SelectedRows&))) { } else if (arg_type == std::type_index(typeid(const SelectedRows&))) {
args_def->AppendInput(default_key.backend(), args_def->AppendInput(default_key.backend(),
default_tensor_layout, default_tensor_layout,
default_key.dtype(), default_key.dtype(),
arg_type); arg_type);
#endif
} else if (arg_type == std::type_index(typeid(DenseTensor*))) { } else if (arg_type == std::type_index(typeid(DenseTensor*))) {
args_def->AppendOutput(default_key.backend(), args_def->AppendOutput(default_key.backend(),
default_tensor_layout, default_tensor_layout,
...@@ -99,11 +105,13 @@ struct KernelArgsParseFunctor<Return_ (*)(Args_...)> { ...@@ -99,11 +105,13 @@ struct KernelArgsParseFunctor<Return_ (*)(Args_...)> {
default_tensor_layout, default_tensor_layout,
default_key.dtype(), default_key.dtype(),
arg_type); arg_type);
#ifndef PADDLE_WITH_CUSTOM_KERNEL
} else if (arg_type == std::type_index(typeid(SelectedRows*))) { } else if (arg_type == std::type_index(typeid(SelectedRows*))) {
args_def->AppendOutput(default_key.backend(), args_def->AppendOutput(default_key.backend(),
default_tensor_layout, default_tensor_layout,
default_key.dtype(), default_key.dtype(),
arg_type); arg_type);
#endif
} else { } else {
// Attribute deal with // Attribute deal with
// TODO(chenweihang): now here allow any types of attribute, maybe // TODO(chenweihang): now here allow any types of attribute, maybe
...@@ -121,20 +129,28 @@ struct KernelArgsParseFunctor<Return_ (*)(Args_...)> { ...@@ -121,20 +129,28 @@ struct KernelArgsParseFunctor<Return_ (*)(Args_...)> {
} }
}; };
// NOTE: used for making a difference between kernels compiled with phi or not.
enum class RegType : uint8_t {
BUILTIN = 0, // compiled with phi
PLUGIN, // separate compiled and registered
};
// TODO(chenweihang): Polish the kernel selection logic, support the selection // TODO(chenweihang): Polish the kernel selection logic, support the selection
// of ALL_DTYPE kernel, and simplify the constructor // of ALL_DTYPE kernel, and simplify the constructor
struct KernelRegistrar { struct KernelRegistrar {
public: public:
KernelRegistrar(const char* kernel_name_cstr, KernelRegistrar(RegType reg_type,
Backend backend, const char* kernel_name_cstr,
const char* backend_cstr,
DataLayout layout, DataLayout layout,
DataType dtype, DataType dtype,
KernelArgsParseFn args_parse_fn, KernelArgsParseFn args_parse_fn,
KernelArgsDefFn args_def_fn, KernelArgsDefFn args_def_fn,
KernelFn kernel_fn, KernelFn kernel_fn,
void* variadic_kernel_fn) { void* variadic_kernel_fn) {
ConstructKernel(kernel_name_cstr, ConstructKernel(reg_type,
backend, kernel_name_cstr,
backend_cstr,
layout, layout,
dtype, dtype,
args_parse_fn, args_parse_fn,
...@@ -143,8 +159,9 @@ struct KernelRegistrar { ...@@ -143,8 +159,9 @@ struct KernelRegistrar {
variadic_kernel_fn); variadic_kernel_fn);
} }
KernelRegistrar(const char* kernel_name_cstr, KernelRegistrar(RegType reg_type,
Backend backend, const char* kernel_name_cstr,
const char* backend_cstr,
DataLayout layout, DataLayout layout,
KernelArgsParseFn args_parse_fn, KernelArgsParseFn args_parse_fn,
KernelArgsDefFn args_def_fn, KernelArgsDefFn args_def_fn,
...@@ -160,8 +177,9 @@ struct KernelRegistrar { ...@@ -160,8 +177,9 @@ struct KernelRegistrar {
dtype == static_cast<size_t>(DataType::UINT16)) { dtype == static_cast<size_t>(DataType::UINT16)) {
continue; continue;
} }
ConstructKernel(kernel_name_cstr, ConstructKernel(reg_type,
backend, kernel_name_cstr,
backend_cstr,
layout, layout,
static_cast<DataType>(dtype), static_cast<DataType>(dtype),
args_parse_fn, args_parse_fn,
...@@ -172,8 +190,9 @@ struct KernelRegistrar { ...@@ -172,8 +190,9 @@ struct KernelRegistrar {
} }
private: private:
void ConstructKernel(const char* kernel_name_cstr, void ConstructKernel(RegType reg_type,
Backend backend, const char* kernel_name_cstr,
const char* backend_cstr,
DataLayout layout, DataLayout layout,
DataType dtype, DataType dtype,
KernelArgsParseFn args_parse_fn, KernelArgsParseFn args_parse_fn,
...@@ -181,11 +200,16 @@ struct KernelRegistrar { ...@@ -181,11 +200,16 @@ struct KernelRegistrar {
KernelFn kernel_fn, KernelFn kernel_fn,
void* variadic_kernel_fn) { void* variadic_kernel_fn) {
std::string kernel_name(kernel_name_cstr); std::string kernel_name(kernel_name_cstr);
KernelKey kernel_key(backend, layout, dtype); KernelKey kernel_key(
paddle::experimental::StringToBackend(backend_cstr), layout, dtype);
Kernel kernel(kernel_fn, variadic_kernel_fn); Kernel kernel(kernel_fn, variadic_kernel_fn);
args_parse_fn(kernel_key, kernel.mutable_args_def()); args_parse_fn(kernel_key, kernel.mutable_args_def());
args_def_fn(kernel_key, &kernel); args_def_fn(kernel_key, &kernel);
KernelFactory::Instance().kernels()[kernel_name][kernel_key] = kernel; if (reg_type == RegType::BUILTIN) {
KernelFactory::Instance().kernels()[kernel_name][kernel_key] = kernel;
} else {
CustomKernelMap::Instance().Kernels()[kernel_name][kernel_key] = kernel;
}
} }
}; };
...@@ -220,21 +244,38 @@ struct KernelRegistrar { ...@@ -220,21 +244,38 @@ struct KernelRegistrar {
* Note: `2TA` means `2 template argument` * Note: `2TA` means `2 template argument`
*/ */
#define PT_REGISTER_KERNEL(kernel_name, backend, layout, meta_kernel_fn, ...) \ #define PT_REGISTER_KERNEL(kernel_name, backend, layout, meta_kernel_fn, ...) \
PT_STATIC_ASSERT_GLOBAL_NAMESPACE( \ _PT_REGISTER_KERNEL(::phi::RegType::BUILTIN, \
pt_register_tp_kernel_ns_check_##kernel_name##_##backend##_##layout, \ kernel_name, \
"PT_REGISTER_KERNEL must be called in global namespace."); \ backend, \
PT_EXPAND(_PT_REGISTER_2TA_KERNEL( \ ::phi::backend##Context, \
kernel_name, backend, layout, meta_kernel_fn, __VA_ARGS__)) layout, \
meta_kernel_fn, \
__VA_ARGS__)
#define _PT_REGISTER_KERNEL( \
reg_type, kernel_name, backend, context, layout, meta_kernel_fn, ...) \
PT_STATIC_ASSERT_GLOBAL_NAMESPACE( \
pt_register_tp_kernel_ns_check_##kernel_name##_##backend##_##layout, \
"PT_REGISTER_KERNEL must be called in global namespace."); \
PT_EXPAND(_PT_REGISTER_2TA_KERNEL(reg_type, \
kernel_name, \
backend, \
context, \
layout, \
meta_kernel_fn, \
__VA_ARGS__))
#ifndef _WIN32 #ifndef _WIN32
#define _PT_REGISTER_2TA_KERNEL( \ #define _PT_REGISTER_2TA_KERNEL( \
kernel_name, backend, layout, meta_kernel_fn, ...) \ reg_type, kernel_name, backend, context, layout, meta_kernel_fn, ...) \
PT_KERNEL_INSTANTIATION(meta_kernel_fn, backend, __VA_ARGS__); \ PT_KERNEL_INSTANTIATION(meta_kernel_fn, backend, context, __VA_ARGS__); \
static void __PT_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout( \ static void __PT_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout( \
const ::phi::KernelKey& kernel_key, ::phi::Kernel* kernel); \ const ::phi::KernelKey& kernel_key, ::phi::Kernel* kernel); \
PT_KERNEL_REGISTRAR_INIT( \ PT_KERNEL_REGISTRAR_INIT( \
reg_type, \
kernel_name, \ kernel_name, \
backend, \ backend, \
context, \
layout, \ layout, \
&__PT_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout, \ &__PT_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout, \
meta_kernel_fn, \ meta_kernel_fn, \
...@@ -255,12 +296,14 @@ struct KernelRegistrar { ...@@ -255,12 +296,14 @@ struct KernelRegistrar {
* And msvc can work without template instantiation * And msvc can work without template instantiation
*/ */
#define _PT_REGISTER_2TA_KERNEL( \ #define _PT_REGISTER_2TA_KERNEL( \
kernel_name, backend, layout, meta_kernel_fn, ...) \ reg_type, kernel_name, backend, context, layout, meta_kernel_fn, ...) \
static void __PT_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout( \ static void __PT_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout( \
const ::phi::KernelKey& kernel_key, ::phi::Kernel* kernel); \ const ::phi::KernelKey& kernel_key, ::phi::Kernel* kernel); \
PT_EXPAND(PT_KERNEL_REGISTRAR_INIT( \ PT_EXPAND(PT_KERNEL_REGISTRAR_INIT( \
reg_type, \
kernel_name, \ kernel_name, \
backend, \ backend, \
context, \
layout, \ layout, \
&__PT_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout, \ &__PT_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout, \
meta_kernel_fn, \ meta_kernel_fn, \
...@@ -269,82 +312,119 @@ struct KernelRegistrar { ...@@ -269,82 +312,119 @@ struct KernelRegistrar {
const ::phi::KernelKey& kernel_key, ::phi::Kernel* kernel) const ::phi::KernelKey& kernel_key, ::phi::Kernel* kernel)
#endif #endif
#define PT_KERNEL_INSTANTIATION(meta_kernel_fn, backend, ...) \ #define PT_KERNEL_INSTANTIATION(meta_kernel_fn, backend, context, ...) \
_PT_KERNEL_INSTANTIATION( \ _PT_KERNEL_INSTANTIATION( \
PT_NARGS(__VA_ARGS__), meta_kernel_fn, backend, __VA_ARGS__) PT_NARGS(__VA_ARGS__), meta_kernel_fn, backend, context, __VA_ARGS__)
#define _PT_KERNEL_INSTANTIATION(N, meta_kernel_fn, backend, ...) \ #define _PT_KERNEL_INSTANTIATION(N, meta_kernel_fn, backend, context, ...) \
PT_CONCATENATE(_PT_KERNEL_INSTANTIATION_, N) \ PT_CONCATENATE(_PT_KERNEL_INSTANTIATION_, N) \
(meta_kernel_fn, backend, __VA_ARGS__) (meta_kernel_fn, backend, context, __VA_ARGS__)
#define _PT_KERNEL_INSTANTIATION_1(meta_kernel_fn, backend, cpp_dtype) \ #define _PT_KERNEL_INSTANTIATION_1( \
template decltype(meta_kernel_fn<cpp_dtype, ::phi::backend##Context>) \ meta_kernel_fn, backend, context, cpp_dtype) \
meta_kernel_fn<cpp_dtype, ::phi::backend##Context> template decltype( \
#define _PT_KERNEL_INSTANTIATION_2(meta_kernel_fn, backend, cpp_dtype, ...) \ meta_kernel_fn<cpp_dtype, context>) meta_kernel_fn<cpp_dtype, context>
template decltype(meta_kernel_fn<cpp_dtype, ::phi::backend##Context>) \ #define _PT_KERNEL_INSTANTIATION_2( \
meta_kernel_fn<cpp_dtype, ::phi::backend##Context>; \ meta_kernel_fn, backend, context, cpp_dtype, ...) \
PT_EXPAND(_PT_KERNEL_INSTANTIATION_1(meta_kernel_fn, backend, __VA_ARGS__)) template decltype( \
#define _PT_KERNEL_INSTANTIATION_3(meta_kernel_fn, backend, cpp_dtype, ...) \ meta_kernel_fn<cpp_dtype, context>) meta_kernel_fn<cpp_dtype, context>; \
template decltype(meta_kernel_fn<cpp_dtype, ::phi::backend##Context>) \ PT_EXPAND(_PT_KERNEL_INSTANTIATION_1( \
meta_kernel_fn<cpp_dtype, ::phi::backend##Context>; \ meta_kernel_fn, backend, context, __VA_ARGS__))
PT_EXPAND(_PT_KERNEL_INSTANTIATION_2(meta_kernel_fn, backend, __VA_ARGS__)) #define _PT_KERNEL_INSTANTIATION_3( \
#define _PT_KERNEL_INSTANTIATION_4(meta_kernel_fn, backend, cpp_dtype, ...) \ meta_kernel_fn, backend, context, cpp_dtype, ...) \
template decltype(meta_kernel_fn<cpp_dtype, ::phi::backend##Context>) \ template decltype( \
meta_kernel_fn<cpp_dtype, ::phi::backend##Context>; \ meta_kernel_fn<cpp_dtype, context>) meta_kernel_fn<cpp_dtype, context>; \
PT_EXPAND(_PT_KERNEL_INSTANTIATION_3(meta_kernel_fn, backend, __VA_ARGS__)) PT_EXPAND(_PT_KERNEL_INSTANTIATION_2( \
#define _PT_KERNEL_INSTANTIATION_5(meta_kernel_fn, backend, cpp_dtype, ...) \ meta_kernel_fn, backend, context, __VA_ARGS__))
template decltype(meta_kernel_fn<cpp_dtype, ::phi::backend##Context>) \ #define _PT_KERNEL_INSTANTIATION_4( \
meta_kernel_fn<cpp_dtype, ::phi::backend##Context>; \ meta_kernel_fn, backend, context, cpp_dtype, ...) \
PT_EXPAND(_PT_KERNEL_INSTANTIATION_4(meta_kernel_fn, backend, __VA_ARGS__)) template decltype( \
#define _PT_KERNEL_INSTANTIATION_6(meta_kernel_fn, backend, cpp_dtype, ...) \ meta_kernel_fn<cpp_dtype, context>) meta_kernel_fn<cpp_dtype, context>; \
template decltype(meta_kernel_fn<cpp_dtype, ::phi::backend##Context>) \ PT_EXPAND(_PT_KERNEL_INSTANTIATION_3( \
meta_kernel_fn<cpp_dtype, ::phi::backend##Context>; \ meta_kernel_fn, backend, context, __VA_ARGS__))
PT_EXPAND(_PT_KERNEL_INSTANTIATION_5(meta_kernel_fn, backend, __VA_ARGS__)) #define _PT_KERNEL_INSTANTIATION_5( \
#define _PT_KERNEL_INSTANTIATION_7(meta_kernel_fn, backend, cpp_dtype, ...) \ meta_kernel_fn, backend, context, cpp_dtype, ...) \
template decltype(meta_kernel_fn<cpp_dtype, ::phi::backend##Context>) \ template decltype( \
meta_kernel_fn<cpp_dtype, ::phi::backend##Context>; \ meta_kernel_fn<cpp_dtype, context>) meta_kernel_fn<cpp_dtype, context>; \
PT_EXPAND(_PT_KERNEL_INSTANTIATION_6(meta_kernel_fn, backend, __VA_ARGS__)) PT_EXPAND(_PT_KERNEL_INSTANTIATION_4( \
#define _PT_KERNEL_INSTANTIATION_8(meta_kernel_fn, backend, cpp_dtype, ...) \ meta_kernel_fn, backend, context, __VA_ARGS__))
template decltype(meta_kernel_fn<cpp_dtype, ::phi::backend##Context>) \ #define _PT_KERNEL_INSTANTIATION_6( \
meta_kernel_fn<cpp_dtype, ::phi::backend##Context>; \ meta_kernel_fn, backend, context, cpp_dtype, ...) \
PT_EXPAND(_PT_KERNEL_INSTANTIATION_7(meta_kernel_fn, backend, __VA_ARGS__)) template decltype( \
#define _PT_KERNEL_INSTANTIATION_9(meta_kernel_fn, backend, cpp_dtype, ...) \ meta_kernel_fn<cpp_dtype, context>) meta_kernel_fn<cpp_dtype, context>; \
template decltype(meta_kernel_fn<cpp_dtype, ::phi::backend##Context>) \ PT_EXPAND(_PT_KERNEL_INSTANTIATION_5( \
meta_kernel_fn<cpp_dtype, ::phi::backend##Context>; \ meta_kernel_fn, backend, context, __VA_ARGS__))
PT_EXPAND(_PT_KERNEL_INSTANTIATION_8(meta_kernel_fn, backend, __VA_ARGS__)) #define _PT_KERNEL_INSTANTIATION_7( \
#define _PT_KERNEL_INSTANTIATION_10(meta_kernel_fn, backend, cpp_dtype, ...) \ meta_kernel_fn, backend, context, cpp_dtype, ...) \
template decltype(meta_kernel_fn<cpp_dtype, ::phi::backend##Context>) \ template decltype( \
meta_kernel_fn<cpp_dtype, ::phi::backend##Context>; \ meta_kernel_fn<cpp_dtype, context>) meta_kernel_fn<cpp_dtype, context>; \
PT_EXPAND(_PT_KERNEL_INSTANTIATION_9(meta_kernel_fn, backend, __VA_ARGS__)) PT_EXPAND(_PT_KERNEL_INSTANTIATION_6( \
#define _PT_KERNEL_INSTANTIATION_11(meta_kernel_fn, backend, cpp_dtype, ...) \ meta_kernel_fn, backend, context, __VA_ARGS__))
template decltype(meta_kernel_fn<cpp_dtype, ::phi::backend##Context>) \ #define _PT_KERNEL_INSTANTIATION_8( \
meta_kernel_fn<cpp_dtype, ::phi::backend##Context>; \ meta_kernel_fn, backend, context, cpp_dtype, ...) \
PT_EXPAND(_PT_KERNEL_INSTANTIATION_10(meta_kernel_fn, backend, __VA_ARGS__)) template decltype( \
#define _PT_KERNEL_INSTANTIATION_12(meta_kernel_fn, backend, cpp_dtype, ...) \ meta_kernel_fn<cpp_dtype, context>) meta_kernel_fn<cpp_dtype, context>; \
template decltype(meta_kernel_fn<cpp_dtype, ::phi::backend##Context>) \ PT_EXPAND(_PT_KERNEL_INSTANTIATION_7( \
meta_kernel_fn<cpp_dtype, ::phi::backend##Context>; \ meta_kernel_fn, backend, context, __VA_ARGS__))
PT_EXPAND(_PT_KERNEL_INSTANTIATION_11(meta_kernel_fn, backend, __VA_ARGS__)) #define _PT_KERNEL_INSTANTIATION_9( \
#define _PT_KERNEL_INSTANTIATION_13(meta_kernel_fn, backend, cpp_dtype, ...) \ meta_kernel_fn, backend, context, cpp_dtype, ...) \
template decltype(meta_kernel_fn<cpp_dtype, ::phi::backend##Context>) \ template decltype( \
meta_kernel_fn<cpp_dtype, ::phi::backend##Context>; \ meta_kernel_fn<cpp_dtype, context>) meta_kernel_fn<cpp_dtype, context>; \
PT_EXPAND(_PT_KERNEL_INSTANTIATION_12(meta_kernel_fn, backend, __VA_ARGS__)) PT_EXPAND(_PT_KERNEL_INSTANTIATION_8( \
#define _PT_KERNEL_INSTANTIATION_14(meta_kernel_fn, backend, cpp_dtype, ...) \ meta_kernel_fn, backend, context, __VA_ARGS__))
template decltype(meta_kernel_fn<cpp_dtype, ::phi::backend##Context>) \ #define _PT_KERNEL_INSTANTIATION_10( \
meta_kernel_fn<cpp_dtype, ::phi::backend##Context>; \ meta_kernel_fn, backend, context, cpp_dtype, ...) \
PT_EXPAND(_PT_KERNEL_INSTANTIATION_13(meta_kernel_fn, backend, __VA_ARGS__)) template decltype( \
#define _PT_KERNEL_INSTANTIATION_15(meta_kernel_fn, backend, cpp_dtype, ...) \ meta_kernel_fn<cpp_dtype, context>) meta_kernel_fn<cpp_dtype, context>; \
template decltype(meta_kernel_fn<cpp_dtype, ::phi::backend##Context>) \ PT_EXPAND(_PT_KERNEL_INSTANTIATION_9( \
meta_kernel_fn<cpp_dtype, ::phi::backend##Context>; \ meta_kernel_fn, backend, context, __VA_ARGS__))
PT_EXPAND(_PT_KERNEL_INSTANTIATION_14(meta_kernel_fn, backend, __VA_ARGS__)) #define _PT_KERNEL_INSTANTIATION_11( \
meta_kernel_fn, backend, context, cpp_dtype, ...) \
#define PT_KERNEL_REGISTRAR_INIT( \ template decltype( \
kernel_name, backend, layout, args_def_fn, meta_kernel_fn, ...) \ meta_kernel_fn<cpp_dtype, context>) meta_kernel_fn<cpp_dtype, context>; \
PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT(PT_NARGS(__VA_ARGS__), \ PT_EXPAND(_PT_KERNEL_INSTANTIATION_10( \
kernel_name, \ meta_kernel_fn, backend, context, __VA_ARGS__))
backend, \ #define _PT_KERNEL_INSTANTIATION_12( \
layout, \ meta_kernel_fn, backend, context, cpp_dtype, ...) \
args_def_fn, \ template decltype( \
meta_kernel_fn, \ meta_kernel_fn<cpp_dtype, context>) meta_kernel_fn<cpp_dtype, context>; \
PT_EXPAND(_PT_KERNEL_INSTANTIATION_11( \
meta_kernel_fn, backend, context, __VA_ARGS__))
#define _PT_KERNEL_INSTANTIATION_13( \
meta_kernel_fn, backend, context, cpp_dtype, ...) \
template decltype( \
meta_kernel_fn<cpp_dtype, context>) meta_kernel_fn<cpp_dtype, context>; \
PT_EXPAND(_PT_KERNEL_INSTANTIATION_12( \
meta_kernel_fn, backend, context, __VA_ARGS__))
#define _PT_KERNEL_INSTANTIATION_14( \
meta_kernel_fn, backend, context, cpp_dtype, ...) \
template decltype( \
meta_kernel_fn<cpp_dtype, context>) meta_kernel_fn<cpp_dtype, context>; \
PT_EXPAND(_PT_KERNEL_INSTANTIATION_13( \
meta_kernel_fn, backend, context, __VA_ARGS__))
#define _PT_KERNEL_INSTANTIATION_15( \
meta_kernel_fn, backend, context, cpp_dtype, ...) \
template decltype( \
meta_kernel_fn<cpp_dtype, context>) meta_kernel_fn<cpp_dtype, context>; \
PT_EXPAND(_PT_KERNEL_INSTANTIATION_14( \
meta_kernel_fn, backend, context, __VA_ARGS__))
#define PT_KERNEL_REGISTRAR_INIT(reg_type, \
kernel_name, \
backend, \
context, \
layout, \
args_def_fn, \
meta_kernel_fn, \
...) \
PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT(PT_NARGS(__VA_ARGS__), \
reg_type, \
kernel_name, \
backend, \
context, \
layout, \
args_def_fn, \
meta_kernel_fn, \
__VA_ARGS__)) __VA_ARGS__))
// clang-format off // clang-format off
...@@ -352,15 +432,19 @@ struct KernelRegistrar { ...@@ -352,15 +432,19 @@ struct KernelRegistrar {
/* The =pre-commit always treats this macro into the wrong format, /* The =pre-commit always treats this macro into the wrong format,
and multi-line macros cannot be skipped with NOLINT.*/ and multi-line macros cannot be skipped with NOLINT.*/
#define _PT_KERNEL_REGISTRAR_INIT(N, \ #define _PT_KERNEL_REGISTRAR_INIT(N, \
reg_type, \
kernel_name, \ kernel_name, \
backend, \ backend, \
context, \
layout, \ layout, \
args_def_fn, \ args_def_fn, \
meta_kernel_fn, \ meta_kernel_fn, \
...) \ ...) \
PT_EXPAND(PT_CONCATENATE(_PT_KERNEL_REGISTRAR_INIT_, N) ( \ PT_EXPAND(PT_CONCATENATE(_PT_KERNEL_REGISTRAR_INIT_, N) ( \
reg_type, \
kernel_name, \ kernel_name, \
backend, \ backend, \
context, \
layout, \ layout, \
PT_ID, \ PT_ID, \
args_def_fn, \ args_def_fn, \
...@@ -369,413 +453,492 @@ struct KernelRegistrar { ...@@ -369,413 +453,492 @@ struct KernelRegistrar {
// clang-format on // clang-format on
#define _PT_KERNEL_REGISTRAR_INIT_1(kernel_name, \ #define _PT_KERNEL_REGISTRAR_INIT_1(reg_type, \
backend, \ kernel_name, \
layout, \ backend, \
registrar_id, \ context, \
args_def_fn, \ layout, \
meta_kernel_fn, \ registrar_id, \
cpp_dtype) \ args_def_fn, \
static const ::phi::KernelRegistrar PT_CONCATENATE( \ meta_kernel_fn, \
__reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \ cpp_dtype) \
#kernel_name, \ static const ::phi::KernelRegistrar PT_CONCATENATE( \
BACKEND(backend), \ __reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \
DATALAYOUT(layout), \ reg_type, \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \ #kernel_name, \
::phi::KernelArgsParseFunctor<decltype( \ #backend, \
&meta_kernel_fn<cpp_dtype, ::phi::backend##Context>)>::Parse, \ DATALAYOUT(layout), \
args_def_fn, \ ::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \
PT_KERNEL(meta_kernel_fn<cpp_dtype, ::phi::backend##Context>), \ ::phi::KernelArgsParseFunctor<decltype( \
PT_VARIADIC_KERNEL(meta_kernel_fn<cpp_dtype, ::phi::backend##Context>)); \ &meta_kernel_fn<cpp_dtype, context>)>::Parse, \
args_def_fn, \
PT_KERNEL(meta_kernel_fn<cpp_dtype, context>), \
PT_VARIADIC_KERNEL(meta_kernel_fn<cpp_dtype, context>)); \
int TouchKernelSymbolFor_##kernel_name##_##backend##_##layout() { return 0; } int TouchKernelSymbolFor_##kernel_name##_##backend##_##layout() { return 0; }
#define _PT_KERNEL_REGISTRAR_INIT_2(kernel_name, \ #define _PT_KERNEL_REGISTRAR_INIT_2(reg_type, \
backend, \ kernel_name, \
layout, \ backend, \
registrar_id, \ context, \
args_def_fn, \ layout, \
meta_kernel_fn, \ registrar_id, \
cpp_dtype, \ args_def_fn, \
...) \ meta_kernel_fn, \
static const ::phi::KernelRegistrar PT_CONCATENATE( \ cpp_dtype, \
__reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \ ...) \
#kernel_name, \ static const ::phi::KernelRegistrar PT_CONCATENATE( \
BACKEND(backend), \ __reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \
DATALAYOUT(layout), \ reg_type, \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \ #kernel_name, \
::phi::KernelArgsParseFunctor<decltype( \ #backend, \
&meta_kernel_fn<cpp_dtype, ::phi::backend##Context>)>::Parse, \ DATALAYOUT(layout), \
args_def_fn, \ ::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \
PT_KERNEL(meta_kernel_fn<cpp_dtype, ::phi::backend##Context>), \ ::phi::KernelArgsParseFunctor<decltype( \
PT_VARIADIC_KERNEL(meta_kernel_fn<cpp_dtype, ::phi::backend##Context>)); \ &meta_kernel_fn<cpp_dtype, context>)>::Parse, \
PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_1(kernel_name, \ args_def_fn, \
backend, \ PT_KERNEL(meta_kernel_fn<cpp_dtype, context>), \
layout, \ PT_VARIADIC_KERNEL(meta_kernel_fn<cpp_dtype, context>)); \
PT_ID, \ PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_1(reg_type, \
args_def_fn, \ kernel_name, \
meta_kernel_fn, \ backend, \
context, \
layout, \
PT_ID, \
args_def_fn, \
meta_kernel_fn, \
__VA_ARGS__)) __VA_ARGS__))
#define _PT_KERNEL_REGISTRAR_INIT_3(kernel_name, \ #define _PT_KERNEL_REGISTRAR_INIT_3(reg_type, \
backend, \ kernel_name, \
layout, \ backend, \
registrar_id, \ context, \
args_def_fn, \ layout, \
meta_kernel_fn, \ registrar_id, \
cpp_dtype, \ args_def_fn, \
...) \ meta_kernel_fn, \
static const ::phi::KernelRegistrar PT_CONCATENATE( \ cpp_dtype, \
__reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \ ...) \
#kernel_name, \ static const ::phi::KernelRegistrar PT_CONCATENATE( \
BACKEND(backend), \ __reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \
DATALAYOUT(layout), \ reg_type, \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \ #kernel_name, \
::phi::KernelArgsParseFunctor<decltype( \ #backend, \
&meta_kernel_fn<cpp_dtype, ::phi::backend##Context>)>::Parse, \ DATALAYOUT(layout), \
args_def_fn, \ ::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \
PT_KERNEL(meta_kernel_fn<cpp_dtype, ::phi::backend##Context>), \ ::phi::KernelArgsParseFunctor<decltype( \
PT_VARIADIC_KERNEL(meta_kernel_fn<cpp_dtype, ::phi::backend##Context>)); \ &meta_kernel_fn<cpp_dtype, context>)>::Parse, \
PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_2(kernel_name, \ args_def_fn, \
backend, \ PT_KERNEL(meta_kernel_fn<cpp_dtype, context>), \
layout, \ PT_VARIADIC_KERNEL(meta_kernel_fn<cpp_dtype, context>)); \
PT_ID, \ PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_2(reg_type, \
args_def_fn, \ kernel_name, \
meta_kernel_fn, \ backend, \
context, \
layout, \
PT_ID, \
args_def_fn, \
meta_kernel_fn, \
__VA_ARGS__)) __VA_ARGS__))
#define _PT_KERNEL_REGISTRAR_INIT_4(kernel_name, \ #define _PT_KERNEL_REGISTRAR_INIT_4(reg_type, \
backend, \ kernel_name, \
layout, \ backend, \
registrar_id, \ context, \
args_def_fn, \ layout, \
meta_kernel_fn, \ registrar_id, \
cpp_dtype, \ args_def_fn, \
...) \ meta_kernel_fn, \
static const ::phi::KernelRegistrar PT_CONCATENATE( \ cpp_dtype, \
__reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \ ...) \
#kernel_name, \ static const ::phi::KernelRegistrar PT_CONCATENATE( \
BACKEND(backend), \ __reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \
DATALAYOUT(layout), \ reg_type, \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \ #kernel_name, \
::phi::KernelArgsParseFunctor<decltype( \ #backend, \
&meta_kernel_fn<cpp_dtype, ::phi::backend##Context>)>::Parse, \ DATALAYOUT(layout), \
args_def_fn, \ ::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \
PT_KERNEL(meta_kernel_fn<cpp_dtype, ::phi::backend##Context>), \ ::phi::KernelArgsParseFunctor<decltype( \
PT_VARIADIC_KERNEL(meta_kernel_fn<cpp_dtype, ::phi::backend##Context>)); \ &meta_kernel_fn<cpp_dtype, context>)>::Parse, \
PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_3(kernel_name, \ args_def_fn, \
backend, \ PT_KERNEL(meta_kernel_fn<cpp_dtype, context>), \
layout, \ PT_VARIADIC_KERNEL(meta_kernel_fn<cpp_dtype, context>)); \
PT_ID, \ PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_3(reg_type, \
args_def_fn, \ kernel_name, \
meta_kernel_fn, \ backend, \
context, \
layout, \
PT_ID, \
args_def_fn, \
meta_kernel_fn, \
__VA_ARGS__)) __VA_ARGS__))
#define _PT_KERNEL_REGISTRAR_INIT_5(kernel_name, \ #define _PT_KERNEL_REGISTRAR_INIT_5(reg_type, \
backend, \ kernel_name, \
layout, \ backend, \
registrar_id, \ context, \
args_def_fn, \ layout, \
meta_kernel_fn, \ registrar_id, \
cpp_dtype, \ args_def_fn, \
...) \ meta_kernel_fn, \
static const ::phi::KernelRegistrar PT_CONCATENATE( \ cpp_dtype, \
__reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \ ...) \
#kernel_name, \ static const ::phi::KernelRegistrar PT_CONCATENATE( \
BACKEND(backend), \ __reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \
DATALAYOUT(layout), \ reg_type, \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \ #kernel_name, \
::phi::KernelArgsParseFunctor<decltype( \ #backend, \
&meta_kernel_fn<cpp_dtype, ::phi::backend##Context>)>::Parse, \ DATALAYOUT(layout), \
args_def_fn, \ ::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \
PT_KERNEL(meta_kernel_fn<cpp_dtype, ::phi::backend##Context>), \ ::phi::KernelArgsParseFunctor<decltype( \
PT_VARIADIC_KERNEL(meta_kernel_fn<cpp_dtype, ::phi::backend##Context>)); \ &meta_kernel_fn<cpp_dtype, context>)>::Parse, \
PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_4(kernel_name, \ args_def_fn, \
backend, \ PT_KERNEL(meta_kernel_fn<cpp_dtype, context>), \
layout, \ PT_VARIADIC_KERNEL(meta_kernel_fn<cpp_dtype, context>)); \
PT_ID, \ PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_4(reg_type, \
args_def_fn, \ kernel_name, \
meta_kernel_fn, \ backend, \
context, \
layout, \
PT_ID, \
args_def_fn, \
meta_kernel_fn, \
__VA_ARGS__)) __VA_ARGS__))
#define _PT_KERNEL_REGISTRAR_INIT_6(kernel_name, \ #define _PT_KERNEL_REGISTRAR_INIT_6(reg_type, \
backend, \ kernel_name, \
layout, \ backend, \
registrar_id, \ context, \
args_def_fn, \ layout, \
meta_kernel_fn, \ registrar_id, \
cpp_dtype, \ args_def_fn, \
...) \ meta_kernel_fn, \
static const ::phi::KernelRegistrar PT_CONCATENATE( \ cpp_dtype, \
__reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \ ...) \
#kernel_name, \ static const ::phi::KernelRegistrar PT_CONCATENATE( \
BACKEND(backend), \ __reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \
DATALAYOUT(layout), \ reg_type, \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \ #kernel_name, \
::phi::KernelArgsParseFunctor<decltype( \ #backend, \
&meta_kernel_fn<cpp_dtype, ::phi::backend##Context>)>::Parse, \ DATALAYOUT(layout), \
args_def_fn, \ ::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \
PT_KERNEL(meta_kernel_fn<cpp_dtype, ::phi::backend##Context>), \ ::phi::KernelArgsParseFunctor<decltype( \
PT_VARIADIC_KERNEL(meta_kernel_fn<cpp_dtype, ::phi::backend##Context>)); \ &meta_kernel_fn<cpp_dtype, context>)>::Parse, \
PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_5(kernel_name, \ args_def_fn, \
backend, \ PT_KERNEL(meta_kernel_fn<cpp_dtype, context>), \
layout, \ PT_VARIADIC_KERNEL(meta_kernel_fn<cpp_dtype, context>)); \
PT_ID, \ PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_5(reg_type, \
args_def_fn, \ kernel_name, \
meta_kernel_fn, \ backend, \
context, \
layout, \
PT_ID, \
args_def_fn, \
meta_kernel_fn, \
__VA_ARGS__)) __VA_ARGS__))
#define _PT_KERNEL_REGISTRAR_INIT_7(kernel_name, \ #define _PT_KERNEL_REGISTRAR_INIT_7(reg_type, \
backend, \ kernel_name, \
layout, \ backend, \
registrar_id, \ context, \
args_def_fn, \ layout, \
meta_kernel_fn, \ registrar_id, \
cpp_dtype, \ args_def_fn, \
...) \ meta_kernel_fn, \
static const ::phi::KernelRegistrar PT_CONCATENATE( \ cpp_dtype, \
__reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \ ...) \
#kernel_name, \ static const ::phi::KernelRegistrar PT_CONCATENATE( \
BACKEND(backend), \ __reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \
DATALAYOUT(layout), \ reg_type, \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \ #kernel_name, \
::phi::KernelArgsParseFunctor<decltype( \ #backend, \
&meta_kernel_fn<cpp_dtype, ::phi::backend##Context>)>::Parse, \ DATALAYOUT(layout), \
args_def_fn, \ ::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \
PT_KERNEL(meta_kernel_fn<cpp_dtype, ::phi::backend##Context>), \ ::phi::KernelArgsParseFunctor<decltype( \
PT_VARIADIC_KERNEL(meta_kernel_fn<cpp_dtype, ::phi::backend##Context>)); \ &meta_kernel_fn<cpp_dtype, context>)>::Parse, \
PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_6(kernel_name, \ args_def_fn, \
backend, \ PT_KERNEL(meta_kernel_fn<cpp_dtype, context>), \
layout, \ PT_VARIADIC_KERNEL(meta_kernel_fn<cpp_dtype, context>)); \
PT_ID, \ PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_6(reg_type, \
args_def_fn, \ kernel_name, \
meta_kernel_fn, \ backend, \
context, \
layout, \
PT_ID, \
args_def_fn, \
meta_kernel_fn, \
__VA_ARGS__)) __VA_ARGS__))
#define _PT_KERNEL_REGISTRAR_INIT_8(kernel_name, \ #define _PT_KERNEL_REGISTRAR_INIT_8(reg_type, \
backend, \ kernel_name, \
layout, \ backend, \
registrar_id, \ context, \
args_def_fn, \ layout, \
meta_kernel_fn, \ registrar_id, \
cpp_dtype, \ args_def_fn, \
...) \ meta_kernel_fn, \
static const ::phi::KernelRegistrar PT_CONCATENATE( \ cpp_dtype, \
__reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \ ...) \
#kernel_name, \ static const ::phi::KernelRegistrar PT_CONCATENATE( \
BACKEND(backend), \ __reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \
DATALAYOUT(layout), \ reg_type, \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \ #kernel_name, \
::phi::KernelArgsParseFunctor<decltype( \ #backend, \
&meta_kernel_fn<cpp_dtype, ::phi::backend##Context>)>::Parse, \ DATALAYOUT(layout), \
args_def_fn, \ ::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \
PT_KERNEL(meta_kernel_fn<cpp_dtype, ::phi::backend##Context>), \ ::phi::KernelArgsParseFunctor<decltype( \
PT_VARIADIC_KERNEL(meta_kernel_fn<cpp_dtype, ::phi::backend##Context>)); \ &meta_kernel_fn<cpp_dtype, context>)>::Parse, \
PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_7(kernel_name, \ args_def_fn, \
backend, \ PT_KERNEL(meta_kernel_fn<cpp_dtype, context>), \
layout, \ PT_VARIADIC_KERNEL(meta_kernel_fn<cpp_dtype, context>)); \
PT_ID, \ PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_7(reg_type, \
args_def_fn, \ kernel_name, \
meta_kernel_fn, \ backend, \
context, \
layout, \
PT_ID, \
args_def_fn, \
meta_kernel_fn, \
__VA_ARGS__)) __VA_ARGS__))
#define _PT_KERNEL_REGISTRAR_INIT_9(kernel_name, \ #define _PT_KERNEL_REGISTRAR_INIT_9(reg_type, \
backend, \ kernel_name, \
layout, \ backend, \
registrar_id, \ context, \
args_def_fn, \ layout, \
meta_kernel_fn, \ registrar_id, \
cpp_dtype, \ args_def_fn, \
...) \ meta_kernel_fn, \
static const ::phi::KernelRegistrar PT_CONCATENATE( \ cpp_dtype, \
__reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \ ...) \
#kernel_name, \ static const ::phi::KernelRegistrar PT_CONCATENATE( \
BACKEND(backend), \ __reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \
DATALAYOUT(layout), \ reg_type, \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \ #kernel_name, \
::phi::KernelArgsParseFunctor<decltype( \ #backend, \
&meta_kernel_fn<cpp_dtype, ::phi::backend##Context>)>::Parse, \ DATALAYOUT(layout), \
args_def_fn, \ ::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \
PT_KERNEL(meta_kernel_fn<cpp_dtype, ::phi::backend##Context>), \ ::phi::KernelArgsParseFunctor<decltype( \
PT_VARIADIC_KERNEL(meta_kernel_fn<cpp_dtype, ::phi::backend##Context>)); \ &meta_kernel_fn<cpp_dtype, context>)>::Parse, \
PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_8(kernel_name, \ args_def_fn, \
backend, \ PT_KERNEL(meta_kernel_fn<cpp_dtype, context>), \
layout, \ PT_VARIADIC_KERNEL(meta_kernel_fn<cpp_dtype, context>)); \
PT_ID, \ PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_8(reg_type, \
args_def_fn, \ kernel_name, \
meta_kernel_fn, \ backend, \
context, \
layout, \
PT_ID, \
args_def_fn, \
meta_kernel_fn, \
__VA_ARGS__)) __VA_ARGS__))
#define _PT_KERNEL_REGISTRAR_INIT_10(kernel_name, \ #define _PT_KERNEL_REGISTRAR_INIT_10(reg_type, \
backend, \ kernel_name, \
layout, \ backend, \
registrar_id, \ context, \
args_def_fn, \ layout, \
meta_kernel_fn, \ registrar_id, \
cpp_dtype, \ args_def_fn, \
...) \ meta_kernel_fn, \
static const ::phi::KernelRegistrar PT_CONCATENATE( \ cpp_dtype, \
__reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \ ...) \
#kernel_name, \ static const ::phi::KernelRegistrar PT_CONCATENATE( \
BACKEND(backend), \ __reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \
DATALAYOUT(layout), \ reg_type, \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \ #kernel_name, \
::phi::KernelArgsParseFunctor<decltype( \ #backend, \
&meta_kernel_fn<cpp_dtype, ::phi::backend##Context>)>::Parse, \ DATALAYOUT(layout), \
args_def_fn, \ ::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \
PT_KERNEL(meta_kernel_fn<cpp_dtype, ::phi::backend##Context>), \ ::phi::KernelArgsParseFunctor<decltype( \
PT_VARIADIC_KERNEL(meta_kernel_fn<cpp_dtype, ::phi::backend##Context>)); \ &meta_kernel_fn<cpp_dtype, context>)>::Parse, \
PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_9(kernel_name, \ args_def_fn, \
backend, \ PT_KERNEL(meta_kernel_fn<cpp_dtype, context>), \
layout, \ PT_VARIADIC_KERNEL(meta_kernel_fn<cpp_dtype, context>)); \
PT_ID, \ PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_9(reg_type, \
args_def_fn, \ kernel_name, \
meta_kernel_fn, \ backend, \
context, \
layout, \
PT_ID, \
args_def_fn, \
meta_kernel_fn, \
__VA_ARGS__)) __VA_ARGS__))
#define _PT_KERNEL_REGISTRAR_INIT_11(kernel_name, \ #define _PT_KERNEL_REGISTRAR_INIT_11(reg_type, \
backend, \ kernel_name, \
layout, \ backend, \
registrar_id, \ context, \
args_def_fn, \ layout, \
meta_kernel_fn, \ registrar_id, \
cpp_dtype, \ args_def_fn, \
...) \ meta_kernel_fn, \
static const ::phi::KernelRegistrar PT_CONCATENATE( \ cpp_dtype, \
__reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \ ...) \
#kernel_name, \ static const ::phi::KernelRegistrar PT_CONCATENATE( \
BACKEND(backend), \ __reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \
DATALAYOUT(layout), \ reg_type, \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \ #kernel_name, \
::phi::KernelArgsParseFunctor<decltype( \ #backend, \
&meta_kernel_fn<cpp_dtype, ::phi::backend##Context>)>::Parse, \ DATALAYOUT(layout), \
args_def_fn, \ ::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \
PT_KERNEL(meta_kernel_fn<cpp_dtype, ::phi::backend##Context>), \ ::phi::KernelArgsParseFunctor<decltype( \
PT_VARIADIC_KERNEL(meta_kernel_fn<cpp_dtype, ::phi::backend##Context>)); \ &meta_kernel_fn<cpp_dtype, context>)>::Parse, \
PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_10(kernel_name, \ args_def_fn, \
backend, \ PT_KERNEL(meta_kernel_fn<cpp_dtype, context>), \
layout, \ PT_VARIADIC_KERNEL(meta_kernel_fn<cpp_dtype, context>)); \
PT_ID, \ PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_10(reg_type, \
args_def_fn, \ kernel_name, \
meta_kernel_fn, \ backend, \
context, \
layout, \
PT_ID, \
args_def_fn, \
meta_kernel_fn, \
__VA_ARGS__)) __VA_ARGS__))
#define _PT_KERNEL_REGISTRAR_INIT_12(kernel_name, \ #define _PT_KERNEL_REGISTRAR_INIT_12(reg_type, \
backend, \ kernel_name, \
layout, \ backend, \
registrar_id, \ context, \
args_def_fn, \ layout, \
meta_kernel_fn, \ registrar_id, \
cpp_dtype, \ args_def_fn, \
...) \ meta_kernel_fn, \
static const ::phi::KernelRegistrar PT_CONCATENATE( \ cpp_dtype, \
__reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \ ...) \
#kernel_name, \ static const ::phi::KernelRegistrar PT_CONCATENATE( \
BACKEND(backend), \ __reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \
DATALAYOUT(layout), \ reg_type, \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \ #kernel_name, \
::phi::KernelArgsParseFunctor<decltype( \ #backend, \
&meta_kernel_fn<cpp_dtype, ::phi::backend##Context>)>::Parse, \ DATALAYOUT(layout), \
args_def_fn, \ ::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \
PT_KERNEL(meta_kernel_fn<cpp_dtype, ::phi::backend##Context>), \ ::phi::KernelArgsParseFunctor<decltype( \
PT_VARIADIC_KERNEL(meta_kernel_fn<cpp_dtype, ::phi::backend##Context>)); \ &meta_kernel_fn<cpp_dtype, context>)>::Parse, \
PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_11(kernel_name, \ args_def_fn, \
backend, \ PT_KERNEL(meta_kernel_fn<cpp_dtype, context>), \
layout, \ PT_VARIADIC_KERNEL(meta_kernel_fn<cpp_dtype, context>)); \
PT_ID, \ PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_11(reg_type, \
args_def_fn, \ kernel_name, \
meta_kernel_fn, \ backend, \
context, \
layout, \
PT_ID, \
args_def_fn, \
meta_kernel_fn, \
__VA_ARGS__)) __VA_ARGS__))
#define _PT_KERNEL_REGISTRAR_INIT_13(kernel_name, \ #define _PT_KERNEL_REGISTRAR_INIT_13(reg_type, \
backend, \ kernel_name, \
layout, \ backend, \
registrar_id, \ context, \
args_def_fn, \ layout, \
meta_kernel_fn, \ registrar_id, \
cpp_dtype, \ args_def_fn, \
...) \ meta_kernel_fn, \
static const ::phi::KernelRegistrar PT_CONCATENATE( \ cpp_dtype, \
__reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \ ...) \
#kernel_name, \ static const ::phi::KernelRegistrar PT_CONCATENATE( \
BACKEND(backend), \ __reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \
DATALAYOUT(layout), \ reg_type, \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \ #kernel_name, \
::phi::KernelArgsParseFunctor<decltype( \ #backend, \
&meta_kernel_fn<cpp_dtype, ::phi::backend##Context>)>::Parse, \ DATALAYOUT(layout), \
args_def_fn, \ ::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \
PT_KERNEL(meta_kernel_fn<cpp_dtype, ::phi::backend##Context>), \ ::phi::KernelArgsParseFunctor<decltype( \
PT_VARIADIC_KERNEL(meta_kernel_fn<cpp_dtype, ::phi::backend##Context>)); \ &meta_kernel_fn<cpp_dtype, context>)>::Parse, \
PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_12(kernel_name, \ args_def_fn, \
backend, \ PT_KERNEL(meta_kernel_fn<cpp_dtype, context>), \
layout, \ PT_VARIADIC_KERNEL(meta_kernel_fn<cpp_dtype, context>)); \
PT_ID, \ PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_12(reg_type, \
args_def_fn, \ kernel_name, \
meta_kernel_fn, \ backend, \
context, \
layout, \
PT_ID, \
args_def_fn, \
meta_kernel_fn, \
__VA_ARGS__)) __VA_ARGS__))
#define _PT_KERNEL_REGISTRAR_INIT_14(kernel_name, \ #define _PT_KERNEL_REGISTRAR_INIT_14(reg_type, \
backend, \ kernel_name, \
layout, \ backend, \
registrar_id, \ context, \
args_def_fn, \ layout, \
meta_kernel_fn, \ registrar_id, \
cpp_dtype, \ args_def_fn, \
...) \ meta_kernel_fn, \
static const ::phi::KernelRegistrar PT_CONCATENATE( \ cpp_dtype, \
__reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \ ...) \
#kernel_name, \ static const ::phi::KernelRegistrar PT_CONCATENATE( \
BACKEND(backend), \ __reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \
DATALAYOUT(layout), \ reg_type, \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \ #kernel_name, \
::phi::KernelArgsParseFunctor<decltype( \ #backend, \
&meta_kernel_fn<cpp_dtype, ::phi::backend##Context>)>::Parse, \ DATALAYOUT(layout), \
args_def_fn, \ ::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \
PT_KERNEL(meta_kernel_fn<cpp_dtype, ::phi::backend##Context>), \ ::phi::KernelArgsParseFunctor<decltype( \
PT_VARIADIC_KERNEL(meta_kernel_fn<cpp_dtype, ::phi::backend##Context>)); \ &meta_kernel_fn<cpp_dtype, context>)>::Parse, \
PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_13(kernel_name, \ args_def_fn, \
backend, \ PT_KERNEL(meta_kernel_fn<cpp_dtype, context>), \
layout, \ PT_VARIADIC_KERNEL(meta_kernel_fn<cpp_dtype, context>)); \
PT_ID, \ PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_13(reg_type, \
args_def_fn, \ kernel_name, \
meta_kernel_fn, \ backend, \
context, \
layout, \
PT_ID, \
args_def_fn, \
meta_kernel_fn, \
__VA_ARGS__)) __VA_ARGS__))
#define _PT_KERNEL_REGISTRAR_INIT_15(kernel_name, \ #define _PT_KERNEL_REGISTRAR_INIT_15(reg_type, \
backend, \ kernel_name, \
layout, \ backend, \
registrar_id, \ context, \
args_def_fn, \ layout, \
meta_kernel_fn, \ registrar_id, \
cpp_dtype, \ args_def_fn, \
...) \ meta_kernel_fn, \
static const ::phi::KernelRegistrar PT_CONCATENATE( \ cpp_dtype, \
__reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \ ...) \
#kernel_name, \ static const ::phi::KernelRegistrar PT_CONCATENATE( \
BACKEND(backend), \ __reg_pt_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \
DATALAYOUT(layout), \ reg_type, \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \ #kernel_name, \
::phi::KernelArgsParseFunctor<decltype( \ #backend, \
&meta_kernel_fn<cpp_dtype, ::phi::backend##Context>)>::Parse, \ DATALAYOUT(layout), \
args_def_fn, \ ::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \
PT_KERNEL(meta_kernel_fn<cpp_dtype, ::phi::backend##Context>), \ ::phi::KernelArgsParseFunctor<decltype( \
PT_VARIADIC_KERNEL(meta_kernel_fn<cpp_dtype, ::phi::backend##Context>)); \ &meta_kernel_fn<cpp_dtype, context>)>::Parse, \
PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_14(kernel_name, \ args_def_fn, \
backend, \ PT_KERNEL(meta_kernel_fn<cpp_dtype, context>), \
layout, \ PT_VARIADIC_KERNEL(meta_kernel_fn<cpp_dtype, context>)); \
PT_ID, \ PT_EXPAND(_PT_KERNEL_REGISTRAR_INIT_14(reg_type, \
args_def_fn, \ kernel_name, \
meta_kernel_fn, \ backend, \
context, \
layout, \
PT_ID, \
args_def_fn, \
meta_kernel_fn, \
__VA_ARGS__)) __VA_ARGS__))
/** PT_REGISTER_GENERAL_KERNEL /** PT_REGISTER_GENERAL_KERNEL
* *
* Basic Kernel register marco, used to register a instantiated kernel function * Basic Kernel register marco, used to register a instantiated kernel function
* with one template argument. * with one template argument.
*/ */
#define PT_REGISTER_GENERAL_KERNEL( \ #define PT_REGISTER_GENERAL_KERNEL( \
kernel_name, backend, layout, kernel_fn, dtype) \ kernel_name, backend, layout, kernel_fn, dtype) \
_PT_REGISTER_GENERAL_KERNEL( \
::phi::RegType::BUILTIN, kernel_name, backend, layout, kernel_fn, dtype)
#define _PT_REGISTER_GENERAL_KERNEL( \
reg_type, kernel_name, backend, layout, kernel_fn, dtype) \
PT_STATIC_ASSERT_GLOBAL_NAMESPACE( \ PT_STATIC_ASSERT_GLOBAL_NAMESPACE( \
pt_register_no_t_kernel_ns_check_##kernel_name##_##backend##_##layout, \ pt_register_no_t_kernel_ns_check_##kernel_name##_##backend##_##layout, \
"PT_REGISTER_NO_TEMPLATE_KERNEL must be called in global namespace."); \ "PT_REGISTER_NO_TEMPLATE_KERNEL must be called in global namespace."); \
_PT_REGISTER_GENERAL_KERNEL(kernel_name, backend, layout, kernel_fn, dtype) __PT_REGISTER_GENERAL_KERNEL( \
reg_type, kernel_name, backend, layout, kernel_fn, dtype)
#ifndef _WIN32 #ifndef _WIN32
#define _PT_REGISTER_GENERAL_KERNEL( \ #define __PT_REGISTER_GENERAL_KERNEL( \
kernel_name, backend, layout, kernel_fn, dtype) \ reg_type, kernel_name, backend, layout, kernel_fn, dtype) \
template decltype(kernel_fn) kernel_fn; \ template decltype(kernel_fn) kernel_fn; \
static void __PT_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout( \ static void __PT_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout( \
const ::phi::KernelKey& kernel_key, ::phi::Kernel* kernel); \ const ::phi::KernelKey& kernel_key, ::phi::Kernel* kernel); \
static const ::phi::KernelRegistrar \ static const ::phi::KernelRegistrar \
__reg_pt_kernel_##kernel_name##_##backend##_##layout( \ __reg_pt_kernel_##kernel_name##_##backend##_##layout( \
reg_type, \
#kernel_name, \ #kernel_name, \
BACKEND(backend), \ #backend, \
DATALAYOUT(layout), \ DATALAYOUT(layout), \
::phi::KernelArgsParseFunctor<decltype(&kernel_fn)>::Parse, \ ::phi::KernelArgsParseFunctor<decltype(&kernel_fn)>::Parse, \
&__PT_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout, \ &__PT_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout, \
...@@ -787,14 +950,15 @@ struct KernelRegistrar { ...@@ -787,14 +950,15 @@ struct KernelRegistrar {
void __PT_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout( \ void __PT_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout( \
const ::phi::KernelKey& kernel_key, ::phi::Kernel* kernel) const ::phi::KernelKey& kernel_key, ::phi::Kernel* kernel)
#else #else
#define _PT_REGISTER_GENERAL_KERNEL( \ #define __PT_REGISTER_GENERAL_KERNEL( \
kernel_name, backend, layout, kernel_fn, dtype) \ reg_type, kernel_name, backend, layout, kernel_fn, dtype) \
static void __PT_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout( \ static void __PT_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout( \
const ::phi::KernelKey& kernel_key, ::phi::Kernel* kernel); \ const ::phi::KernelKey& kernel_key, ::phi::Kernel* kernel); \
static const ::phi::KernelRegistrar \ static const ::phi::KernelRegistrar \
__reg_pt_kernel_##kernel_name##_##backend##_##layout( \ __reg_pt_kernel_##kernel_name##_##backend##_##layout( \
reg_type, \
#kernel_name, \ #kernel_name, \
BACKEND(backend), \ #backend, \
DATALAYOUT(layout), \ DATALAYOUT(layout), \
::phi::KernelArgsParseFunctor<decltype(&kernel_fn)>::Parse, \ ::phi::KernelArgsParseFunctor<decltype(&kernel_fn)>::Parse, \
&__PT_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout, \ &__PT_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout, \
...@@ -821,4 +985,33 @@ struct KernelRegistrar { ...@@ -821,4 +985,33 @@ struct KernelRegistrar {
__declare_kernel_symbol_for_##kernel_name##_##backend##_##layout = \ __declare_kernel_symbol_for_##kernel_name##_##backend##_##layout = \
TouchKernelSymbolFor_##kernel_name##_##backend##_##layout() TouchKernelSymbolFor_##kernel_name##_##backend##_##layout()
/** PD_REGISTER_KERNEL
*
* Used to register kernels for built-in backends.
* Support CPU GPU XPU.
*/
#define PD_REGISTER_KERNEL(kernel_name, backend, layout, meta_kernel_fn, ...) \
_PT_REGISTER_KERNEL(::phi::RegType::PLUGIN, \
kernel_name, \
backend, \
::phi::backend##Context, \
layout, \
meta_kernel_fn, \
__VA_ARGS__)
/** PD_REGISTER_CUSTOM_KERNEL
*
* Used to register kernels for plug-in backends.
* Support user-defined backend such as 'Ascend910'.
*/
#define PD_REGISTER_CUSTOM_KERNEL( \
kernel_name, backend, layout, meta_kernel_fn, ...) \
_PT_REGISTER_KERNEL(::phi::RegType::PLUGIN, \
kernel_name, \
backend, \
::phi::CustomContext, \
layout, \
meta_kernel_fn, \
__VA_ARGS__)
} // namespace phi } // namespace phi
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
#pragma once #pragma once
#include "paddle/phi/backends/cpu/cpu_context.h" #include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/backends/custom/custom_context.h"
#include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/backends/xpu/xpu_context.h" #include "paddle/phi/backends/xpu/xpu_context.h"
#include "paddle/phi/common/scalar.h" #include "paddle/phi/common/scalar.h"
...@@ -22,7 +23,9 @@ ...@@ -22,7 +23,9 @@
#include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/enforce.h" #include "paddle/phi/core/enforce.h"
#include "paddle/phi/core/kernel_context.h" #include "paddle/phi/core/kernel_context.h"
#ifndef PADDLE_WITH_CUSTOM_KERNEL
#include "paddle/phi/core/selected_rows.h" #include "paddle/phi/core/selected_rows.h"
#endif
#include "paddle/phi/core/sparse_coo_tensor.h" #include "paddle/phi/core/sparse_coo_tensor.h"
#include "paddle/phi/core/sparse_csr_tensor.h" #include "paddle/phi/core/sparse_csr_tensor.h"
#include "paddle/phi/core/type_defs.h" #include "paddle/phi/core/type_defs.h"
...@@ -210,13 +213,18 @@ struct KernelImpl<Return (*)(DevCtx, Args...), kernel_fn> { ...@@ -210,13 +213,18 @@ struct KernelImpl<Return (*)(DevCtx, Args...), kernel_fn> {
#ifdef PADDLE_WITH_XPU #ifdef PADDLE_WITH_XPU
PT_SPECIALIZE_KernelCallHelper_FOR_DEVICE_CONTEXT(XPUContext); PT_SPECIALIZE_KernelCallHelper_FOR_DEVICE_CONTEXT(XPUContext);
#endif #endif
#ifdef PADDLE_WITH_CUSTOM_DEVICE
PT_SPECIALIZE_KernelCallHelper_FOR_DEVICE_CONTEXT(CustomContext);
#endif
/* Input Helpers */ /* Input Helpers */
PT_SPECIALIZE_KernelCallHelper_FOR_INPUT(DenseTensor); PT_SPECIALIZE_KernelCallHelper_FOR_INPUT(DenseTensor);
PT_SPECIALIZE_KernelCallHelper_FOR_OPTIONAL_INPUT(DenseTensor); PT_SPECIALIZE_KernelCallHelper_FOR_OPTIONAL_INPUT(DenseTensor);
PT_SPECIALIZE_KernelCallHelper_FOR_MULTI_INPUT(DenseTensor); PT_SPECIALIZE_KernelCallHelper_FOR_MULTI_INPUT(DenseTensor);
#ifndef PADDLE_WITH_CUSTOM_KERNEL
PT_SPECIALIZE_KernelCallHelper_FOR_INPUT(SelectedRows); PT_SPECIALIZE_KernelCallHelper_FOR_INPUT(SelectedRows);
#endif
PT_SPECIALIZE_KernelCallHelper_FOR_INPUT(SparseCooTensor); PT_SPECIALIZE_KernelCallHelper_FOR_INPUT(SparseCooTensor);
PT_SPECIALIZE_KernelCallHelper_FOR_OPTIONAL_INPUT(SparseCooTensor); PT_SPECIALIZE_KernelCallHelper_FOR_OPTIONAL_INPUT(SparseCooTensor);
...@@ -250,7 +258,9 @@ struct KernelImpl<Return (*)(DevCtx, Args...), kernel_fn> { ...@@ -250,7 +258,9 @@ struct KernelImpl<Return (*)(DevCtx, Args...), kernel_fn> {
PT_SPECIALIZE_KernelCallHelper_FOR_OUTPUT(DenseTensor); PT_SPECIALIZE_KernelCallHelper_FOR_OUTPUT(DenseTensor);
PT_SPECIALIZE_KernelCallHelper_FOR_MULTI_OUTPUT(DenseTensor); PT_SPECIALIZE_KernelCallHelper_FOR_MULTI_OUTPUT(DenseTensor);
#ifndef PADDLE_WITH_CUSTOM_KERNEL
PT_SPECIALIZE_KernelCallHelper_FOR_OUTPUT(SelectedRows); PT_SPECIALIZE_KernelCallHelper_FOR_OUTPUT(SelectedRows);
#endif
PT_SPECIALIZE_KernelCallHelper_FOR_OUTPUT(SparseCooTensor); PT_SPECIALIZE_KernelCallHelper_FOR_OUTPUT(SparseCooTensor);
PT_SPECIALIZE_KernelCallHelper_FOR_MULTI_OUTPUT(SparseCooTensor); PT_SPECIALIZE_KernelCallHelper_FOR_MULTI_OUTPUT(SparseCooTensor);
......
...@@ -15,10 +15,16 @@ ...@@ -15,10 +15,16 @@
#pragma once #pragma once
// See Note [ Why still include the fluid headers? ] // See Note [ Why still include the fluid headers? ]
#ifndef PADDLE_WITH_CUSTOM_KERNEL
#include "paddle/fluid/framework/mixed_vector.h" #include "paddle/fluid/framework/mixed_vector.h"
#endif
namespace phi { namespace phi {
#ifndef PADDLE_WITH_CUSTOM_KERNEL
using LoD = std::vector<paddle::framework::Vector<size_t>>; using LoD = std::vector<paddle::framework::Vector<size_t>>;
#else
using LoD = std::vector<std::vector<size_t>>;
#endif
void AppendLoD(LoD* lod, const LoD& lod_length); void AppendLoD(LoD* lod, const LoD& lod_length);
......
...@@ -24,12 +24,18 @@ limitations under the License. */ ...@@ -24,12 +24,18 @@ limitations under the License. */
// Note: mixed_vector include many header now, LoD will be // Note: mixed_vector include many header now, LoD will be
// used on CUDA device? Can we use small_vector here? // used on CUDA device? Can we use small_vector here?
// @zhanlve: Rollback to original LoD for now // @zhanlve: Rollback to original LoD for now
#ifndef PADDLE_WITH_CUSTOM_KERNEL
#include "paddle/fluid/framework/mixed_vector.h" #include "paddle/fluid/framework/mixed_vector.h"
#endif
namespace phi { namespace phi {
using DDim = phi::DDim; using DDim = phi::DDim;
#ifndef PADDLE_WITH_CUSTOM_KERNEL
using LoD = std::vector<paddle::framework::Vector<size_t>>; using LoD = std::vector<paddle::framework::Vector<size_t>>;
#else
using LoD = std::vector<std::vector<size_t>>;
#endif
/// \brief The meta data of dense tensor. Take the structure type /// \brief The meta data of dense tensor. Take the structure type
/// and use all default operations. /// and use all default operations.
/// ///
......
...@@ -31,25 +31,25 @@ class DenseTensorUtils { ...@@ -31,25 +31,25 @@ class DenseTensorUtils {
size_t bytes = tensor.numel() * SizeOf(tensor.dtype()); size_t bytes = tensor.numel() * SizeOf(tensor.dtype());
PADDLE_ENFORCE_GE(tensor.capacity(), PADDLE_ENFORCE_GE(tensor.capacity(),
bytes, bytes,
paddle::platform::errors::InvalidArgument( phi::errors::InvalidArgument(
"The memory size %d should be enough to meet the " "The memory size %d should be enough to meet the "
"volume required by metadata %d.", "volume required by metadata %d.",
tensor.capacity(), tensor.capacity(),
bytes)); bytes));
PADDLE_ENFORCE_GE(begin_idx, PADDLE_ENFORCE_GE(
0, begin_idx,
paddle::platform::errors::OutOfRange( 0,
"The start row index must be greater than 0." phi::errors::OutOfRange("The start row index must be greater than 0."
"But received the start index is d%.", "But received the start index is d%.",
begin_idx)); begin_idx));
PADDLE_ENFORCE_LE(end_idx, PADDLE_ENFORCE_LE(
tensor.dims()[0], end_idx,
paddle::platform::errors::OutOfRange( tensor.dims()[0],
"The end row index is out of bound.")); phi::errors::OutOfRange("The end row index is out of bound."));
PADDLE_ENFORCE_LT( PADDLE_ENFORCE_LT(
begin_idx, begin_idx,
end_idx, end_idx,
paddle::platform::errors::InvalidArgument( phi::errors::InvalidArgument(
"The start row index must be less than the end row index." "The start row index must be less than the end row index."
"But received the start index = %d, the end index = %d.", "But received the start index = %d, the end index = %d.",
begin_idx, begin_idx,
......
...@@ -52,5 +52,19 @@ TEST(Backend, OStream) { ...@@ -52,5 +52,19 @@ TEST(Backend, OStream) {
} }
} }
TEST(Backend, StringToBackend) {
namespace pexp = paddle::experimental;
EXPECT_EQ(phi::Backend::UNDEFINED, pexp::StringToBackend("Undefined"));
EXPECT_EQ(phi::Backend::CPU, pexp::StringToBackend("CPU"));
EXPECT_EQ(phi::Backend::GPU, pexp::StringToBackend("GPU"));
EXPECT_EQ(phi::Backend::XPU, pexp::StringToBackend("XPU"));
EXPECT_EQ(phi::Backend::NPU, pexp::StringToBackend("NPU"));
EXPECT_EQ(phi::Backend::MKLDNN, pexp::StringToBackend("MKLDNN"));
EXPECT_EQ(phi::Backend::CUDNN, pexp::StringToBackend("CUDNN"));
EXPECT_EQ(static_cast<phi::Backend>(
static_cast<size_t>(phi::Backend::NUM_BACKENDS) + 1),
pexp::StringToBackend("CustomBackend"));
}
} // namespace tests } // namespace tests
} // namespace phi } // namespace phi
cc_test(test_custom_kernel SRCS test_custom_kernel.cc DEPS pten_custom_kernel)
cc_test(test_dense_tensor SRCS test_dense_tensor.cc DEPS dense_tensor) cc_test(test_dense_tensor SRCS test_dense_tensor.cc DEPS dense_tensor)
cc_test(test_intrusive_ptr SRCS test_intrusive_ptr.cc) cc_test(test_intrusive_ptr SRCS test_intrusive_ptr.cc)
cc_test(test_type_info SRCS test_type_info.cc) cc_test(test_type_info SRCS test_type_info.cc)
......
...@@ -17,24 +17,21 @@ limitations under the License. */ ...@@ -17,24 +17,21 @@ limitations under the License. */
#define _LINUX #define _LINUX
#endif #endif
#include "paddle/fluid/framework/custom_kernel.h" #ifdef _LINUX
#include <glog/logging.h>
#include <gtest/gtest.h>
#include "paddle/extension.h"
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_kernel_info_helper.h"
#include "paddle/fluid/memory/allocation/allocator_facade.h"
#include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/device_context.h"
#include "paddle/phi/api/lib/utils/allocator.h" #include "paddle/phi/api/lib/utils/allocator.h"
#include "paddle/phi/api/lib/utils/storage.h" #include "paddle/phi/api/lib/utils/storage.h"
#include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/common/scalar.h"
#include "paddle/phi/common/scalar_array.h"
#include "paddle/phi/core/kernel_context.h" #include "paddle/phi/core/kernel_context.h"
#include "paddle/phi/core/kernel_factory.h" #include "paddle/phi/core/kernel_factory.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/meta_tensor.h"
#include "paddle/phi/infermeta/binary.h" #include "paddle/phi/infermeta/binary.h"
#include "paddle/utils/small_vector.h"
#ifdef _LINUX #include <gtest/gtest.h>
// user kernel function // user kernel function
namespace custom_kernel { namespace custom_kernel {
...@@ -43,17 +40,23 @@ namespace custom_kernel { ...@@ -43,17 +40,23 @@ namespace custom_kernel {
// attribute 11: fake_attributes // attribute 11: fake_attributes
// output 2: one Tensor* and one std::vector<Tensor*> // output 2: one Tensor* and one std::vector<Tensor*>
template <typename T, typename Context> template <typename T, typename Context>
void FakeDot(const Context& dev_ctx, const paddle::Tensor& x, void FakeDot(const Context& dev_ctx,
const paddle::Tensor& y, const phi::DenseTensor& x,
const std::vector<paddle::Tensor>& fake_input_vec, const phi::DenseTensor& y,
bool fake_attr_bool, int fake_attr_int, float fake_attr_float, const std::vector<phi::DenseTensor>& fake_input_vec,
double fake_attr_double, int64_t fake_attr_int64, bool fake_attr_bool,
phi::dtype::float16 fake_attr_f16, phi::DataType fake_attr_dtype, int fake_attr_int,
float fake_attr_float,
double fake_attr_double,
int64_t fake_attr_int64,
phi::dtype::float16 fake_attr_f16,
phi::DataType fake_attr_dtype,
const phi::Scalar& fake_attr_scalar, const phi::Scalar& fake_attr_scalar,
const phi::ScalarArray& fake_attr_scalar_array, const phi::ScalarArray& fake_attr_scalar_array,
const std::vector<int64_t>& fake_attr_int64_vec, const std::vector<int64_t>& fake_attr_int64_vec,
const std::vector<int>& fake_attr_int_vec, paddle::Tensor* out, const std::vector<int>& fake_attr_int_vec,
std::vector<paddle::Tensor*> fake_out_vec) { phi::DenseTensor* out,
std::vector<phi::DenseTensor*> fake_out_vec) {
// print param info // print param info
std::cout << "fake_input_vec.size: " << fake_input_vec.size() << std::endl; std::cout << "fake_input_vec.size: " << fake_input_vec.size() << std::endl;
std::cout << "fake_attr_bool: " << fake_attr_bool << std::endl; std::cout << "fake_attr_bool: " << fake_attr_bool << std::endl;
...@@ -83,10 +86,10 @@ void FakeDot(const Context& dev_ctx, const paddle::Tensor& x, ...@@ -83,10 +86,10 @@ void FakeDot(const Context& dev_ctx, const paddle::Tensor& x,
auto const *x_ptr = x.data<T>(), *x_ptr_ = &x_ptr[0]; auto const *x_ptr = x.data<T>(), *x_ptr_ = &x_ptr[0];
auto const *y_ptr = y.data<T>(), *y_ptr_ = &y_ptr[0]; auto const *y_ptr = y.data<T>(), *y_ptr_ = &y_ptr[0];
auto* z = out->mutable_data<T>(paddle::PlaceType::kCPU); T* z = dev_ctx.template Alloc<T>(out);
auto shape = x.shape(); auto&& d = x.dims();
auto const N = x.numel(); auto const N = x.numel();
auto const B = shape[shape.size() - 1]; auto const B = d[d.size() - 1];
for (int j = 0; j < N / B; j++) { for (int j = 0; j < N / B; j++) {
T ss = 0; T ss = 0;
for (int i = 0; i < B; i++) ss += (*x_ptr_++) * (*y_ptr_++); for (int i = 0; i < B; i++) ss += (*x_ptr_++) * (*y_ptr_++);
...@@ -95,8 +98,19 @@ void FakeDot(const Context& dev_ctx, const paddle::Tensor& x, ...@@ -95,8 +98,19 @@ void FakeDot(const Context& dev_ctx, const paddle::Tensor& x,
} }
} // namespace custom_kernel } // namespace custom_kernel
PD_REGISTER_KERNEL(fake_dot, CPU, ALL_LAYOUT, custom_kernel::FakeDot, float, PD_REGISTER_KERNEL(fake_dot,
double, int, int64_t, int8_t, uint8_t) {} CPU,
ALL_LAYOUT,
custom_kernel::FakeDot,
float,
double,
int,
int64_t,
int8_t,
uint8_t) {}
namespace phi {
namespace tests {
// Upper code will store dot kernels info into OpKernelInfoMap // Upper code will store dot kernels info into OpKernelInfoMap
TEST(CustomKernel, custom_kernel_dot) { TEST(CustomKernel, custom_kernel_dot) {
...@@ -105,33 +119,38 @@ TEST(CustomKernel, custom_kernel_dot) { ...@@ -105,33 +119,38 @@ TEST(CustomKernel, custom_kernel_dot) {
phi::DataLayout layout = phi::DataLayout::ALL_LAYOUT; phi::DataLayout layout = phi::DataLayout::ALL_LAYOUT;
// 1.custom kernel info parsed and store // 1.custom kernel info parsed and store
EXPECT_TRUE(paddle::OpKernelInfoMap::Instance().GetMap().find(op_name) != EXPECT_TRUE(phi::CustomKernelMap::Instance().GetMap().find(op_name) !=
paddle::OpKernelInfoMap::Instance().GetMap().end()); phi::CustomKernelMap::Instance().GetMap().end());
auto& custom_kernels = phi::CustomKernelMap::Instance().Kernels();
// 2.info check // 2.info check
EXPECT_EQ( EXPECT_EQ(6, static_cast<int>(custom_kernels[op_name].size()));
6, static_cast<int>(paddle::OpKernelInfoMap::Instance()[op_name].size())); auto& custom_fake_dot_kernels = custom_kernels[op_name];
// index 0 EXPECT_TRUE(custom_fake_dot_kernels.find(
EXPECT_TRUE(paddle::OpKernelInfoMap::Instance()[op_name][0].GetBackend() == phi::KernelKey(backend, layout, phi::DataType::FLOAT32)) !=
backend); custom_fake_dot_kernels.end());
EXPECT_TRUE(paddle::OpKernelInfoMap::Instance()[op_name][0].GetDataLayout() == EXPECT_TRUE(custom_fake_dot_kernels.find(
layout); phi::KernelKey(backend, layout, phi::DataType::FLOAT64)) !=
EXPECT_TRUE(paddle::OpKernelInfoMap::Instance()[op_name][0].GetDataType() == custom_fake_dot_kernels.end());
phi::DataType::FLOAT32); EXPECT_TRUE(custom_fake_dot_kernels.find(
// index 5 phi::KernelKey(backend, layout, phi::DataType::INT32)) !=
EXPECT_TRUE(paddle::OpKernelInfoMap::Instance()[op_name][5].GetBackend() == custom_fake_dot_kernels.end());
backend); EXPECT_TRUE(custom_fake_dot_kernels.find(
EXPECT_TRUE(paddle::OpKernelInfoMap::Instance()[op_name][5].GetDataLayout() == phi::KernelKey(backend, layout, phi::DataType::INT64)) !=
layout); custom_fake_dot_kernels.end());
EXPECT_TRUE(paddle::OpKernelInfoMap::Instance()[op_name][5].GetDataType() == EXPECT_TRUE(custom_fake_dot_kernels.find(
phi::DataType::UINT8); phi::KernelKey(backend, layout, phi::DataType::INT8)) !=
custom_fake_dot_kernels.end());
EXPECT_TRUE(custom_fake_dot_kernels.find(
phi::KernelKey(backend, layout, phi::DataType::UINT8)) !=
custom_fake_dot_kernels.end());
// 3.before register // 3.before register
auto& kernel_factory_instance = phi::KernelFactory::Instance(); auto& kernel_factory_instance = phi::KernelFactory::Instance();
auto& kernels = phi::KernelFactory::Instance().kernels(); auto& kernels = phi::KernelFactory::Instance().kernels();
EXPECT_TRUE(!kernel_factory_instance.HasCompatiblePtenKernel(op_name)); EXPECT_TRUE(!kernel_factory_instance.HasCompatiblePtenKernel(op_name));
// mock fake_dot is supported by pten for HasCompatiblePtenKernel check while // mock fake_dot is supported by phi for HasCompatiblePtenKernel check while
// registering // registering
auto& fake_dot_kernels = kernels[op_name]; auto& fake_dot_kernels = kernels[op_name];
...@@ -155,8 +174,7 @@ TEST(CustomKernel, custom_kernel_dot) { ...@@ -155,8 +174,7 @@ TEST(CustomKernel, custom_kernel_dot) {
fake_dot_kernels.end()); fake_dot_kernels.end());
// register // register
paddle::framework::RegisterKernelWithMetaInfoMap( phi::RegisterCustomKernels(phi::CustomKernelMap::Instance());
paddle::OpKernelInfoMap::Instance());
EXPECT_TRUE(fake_dot_kernels.find( EXPECT_TRUE(fake_dot_kernels.find(
phi::KernelKey(backend, layout, phi::DataType::FLOAT32)) != phi::KernelKey(backend, layout, phi::DataType::FLOAT32)) !=
...@@ -186,15 +204,15 @@ TEST(CustomKernel, custom_kernel_dot) { ...@@ -186,15 +204,15 @@ TEST(CustomKernel, custom_kernel_dot) {
paddle::platform::CPUPlace()); paddle::platform::CPUPlace());
auto dense_x = std::make_shared<phi::DenseTensor>( auto dense_x = std::make_shared<phi::DenseTensor>(
alloc.get(), alloc.get(),
phi::DenseTensorMeta(phi::DataType::UINT8, phi::make_ddim({2, 3}), phi::DenseTensorMeta(
phi::DataLayout::NCHW)); phi::DataType::UINT8, phi::make_ddim({2, 3}), phi::DataLayout::NCHW));
auto* dense_x_data = auto* dense_x_data =
dense_x->mutable_data<uint8_t>(paddle::platform::CPUPlace()); dense_x->mutable_data<uint8_t>(paddle::platform::CPUPlace());
auto dense_y = std::make_shared<phi::DenseTensor>( auto dense_y = std::make_shared<phi::DenseTensor>(
alloc.get(), alloc.get(),
phi::DenseTensorMeta(phi::DataType::UINT8, phi::make_ddim({2, 3}), phi::DenseTensorMeta(
phi::DataLayout::NCHW)); phi::DataType::UINT8, phi::make_ddim({2, 3}), phi::DataLayout::NCHW));
auto* dense_y_data = auto* dense_y_data =
dense_y->mutable_data<uint8_t>(paddle::platform::CPUPlace()); dense_y->mutable_data<uint8_t>(paddle::platform::CPUPlace());
...@@ -288,38 +306,7 @@ TEST(CustomKernel, custom_kernel_dot) { ...@@ -288,38 +306,7 @@ TEST(CustomKernel, custom_kernel_dot) {
ASSERT_EQ(expect_result[1], actual_result1); ASSERT_EQ(expect_result[1], actual_result1);
} }
// test OpKernelInfoHelper } // namespace tests
TEST(OpKernelInfoHelper, op_kernel_info_help_getters) { } // namespace phi
using OpKernelInfoHelper = paddle::framework::OpKernelInfoHelper;
std::string op_name = "fake_dot";
phi::Backend backend = phi::Backend::CPU;
phi::DataLayout layout = phi::DataLayout::ANY;
phi::DataType dtype = phi::DataType::FLOAT32;
auto op_kernel_info = paddle::OpKernelInfoMap::Instance()[op_name][0];
EXPECT_EQ(op_name, OpKernelInfoHelper::GetOpName(op_kernel_info));
EXPECT_EQ(backend, OpKernelInfoHelper::GetBackend(op_kernel_info));
EXPECT_EQ(layout, OpKernelInfoHelper::GetDataLayout(op_kernel_info));
EXPECT_EQ(dtype, OpKernelInfoHelper::GetDataType(op_kernel_info));
EXPECT_EQ(phi::KernelKey(backend, layout, dtype),
OpKernelInfoHelper::GetKernelKey(op_kernel_info));
paddle::CustomKernelFunc kernel_fn =
PD_PT_KERNEL(custom_kernel::FakeDot<float, paddle::CPUContext>);
EXPECT_EQ(kernel_fn, OpKernelInfoHelper::GetKernelFn(op_kernel_info));
void* variadic_func =
PD_PT_VARIADIC_KERNEL(custom_kernel::FakeDot<float, paddle::CPUContext>);
EXPECT_EQ(variadic_func,
OpKernelInfoHelper::GetVariadicKernelFn(op_kernel_info));
auto& input_defs = OpKernelInfoHelper::GetInputDefs(op_kernel_info);
auto& output_defs = OpKernelInfoHelper::GetOutputDefs(op_kernel_info);
auto& attribute_defs = OpKernelInfoHelper::GetAttributeDefs(op_kernel_info);
EXPECT_EQ(3, static_cast<int>(input_defs.size()));
EXPECT_EQ(2, static_cast<int>(output_defs.size()));
EXPECT_EQ(11, static_cast<int>(attribute_defs.size()));
}
#endif #endif
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "paddle/extension.h" #include "paddle/phi/core/kernel_registry.h"
namespace paddle { namespace paddle {
...@@ -21,19 +21,19 @@ namespace custom_kernel { ...@@ -21,19 +21,19 @@ namespace custom_kernel {
// Here we use dot <CPU, ANY, INT8> for test // Here we use dot <CPU, ANY, INT8> for test
// This test will fail when this kernel is supported in framework // This test will fail when this kernel is supported in framework
template <typename T, typename Context> template <typename T, typename Context>
void Dot(const Context& dev_ctx, void DotKernel(const Context& dev_ctx,
const paddle::Tensor& x, const phi::DenseTensor& x,
const paddle::Tensor& y, const phi::DenseTensor& y,
paddle::Tensor* out) { phi::DenseTensor* out) {
auto const *x_ptr = x.data<T>(), *x_ptr_ = &x_ptr[0]; auto const *x_ptr = x.data<T>(), *x_ptr_ = &x_ptr[0];
auto const *y_ptr = y.data<T>(), *y_ptr_ = &y_ptr[0]; auto const *y_ptr = y.data<T>(), *y_ptr_ = &y_ptr[0];
auto* z = out->mutable_data<T>(paddle::PlaceType::kCPU); T* z = dev_ctx.template Alloc<T>(out);
// Loop over the total N elements of both operands while sum-reducing every // Loop over the total N elements of both operands while sum-reducing every
// B pairs along the way where B is the dimension of the least ordered axis // B pairs along the way where B is the dimension of the least ordered axis
auto shape = x.shape(); auto&& d = x.dims();
auto const N = x.numel(); auto const N = x.numel();
auto const B = shape[shape.size() - 1]; auto const B = d[d.size() - 1];
for (int j = 0; j < N / B; j++) { for (int j = 0; j < N / B; j++) {
T ss = 0; T ss = 0;
...@@ -45,6 +45,7 @@ void Dot(const Context& dev_ctx, ...@@ -45,6 +45,7 @@ void Dot(const Context& dev_ctx,
} // namespace custom_kernel } // namespace custom_kernel
} // namespace paddle } // namespace paddle
PD_REGISTER_KERNEL(dot, CPU, ALL_LAYOUT, paddle::custom_kernel::Dot, int8_t) { PD_REGISTER_KERNEL(
dot, CPU, ALL_LAYOUT, paddle::custom_kernel::DotKernel, int8_t) {
kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::INT8); kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::INT8);
} }
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
# You may obtain a copy of the License at # You may obtain a copy of the License at
# #
# http://www.apache.org/licenses/LICENSE-2.0 # http://www.apache.org/licenses/LICENSE-2.0
# #
# Unless required by applicable law or agreed to in writing, software # Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, # distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
...@@ -16,9 +16,28 @@ import os ...@@ -16,9 +16,28 @@ import os
from paddle.fluid import core from paddle.fluid import core
from distutils.sysconfig import get_python_lib from distutils.sysconfig import get_python_lib
from distutils.core import setup, Extension from distutils.core import setup, Extension
from setuptools.command.build_ext import build_ext
# refer: https://note.qidong.name/2018/03/setup-warning-strict-prototypes
# Avoid a gcc warning below:
# cc1plus: warning: command line option ‘-Wstrict-prototypes’ is valid
# for C/ObjC but not for C++
class BuildExt(build_ext):
def build_extensions(self):
if '-Wstrict-prototypes' in self.compiler.compiler_so:
self.compiler.compiler_so.remove('-Wstrict-prototypes')
super(BuildExt, self).build_extensions()
# cc flags # cc flags
paddle_extra_compile_args = ['-std=c++14', '-shared', '-fPIC'] paddle_extra_compile_args = [
'-std=c++14',
'-shared',
'-fPIC',
'-Wno-parentheses',
'-DPADDLE_WITH_CUSTOM_KERNEL',
]
if core.is_compiled_with_npu(): if core.is_compiled_with_npu():
paddle_extra_compile_args += ['-D_GLIBCXX_USE_CXX11_ABI=0'] paddle_extra_compile_args += ['-D_GLIBCXX_USE_CXX11_ABI=0']
...@@ -27,6 +46,14 @@ site_packages_path = get_python_lib() ...@@ -27,6 +46,14 @@ site_packages_path = get_python_lib()
paddle_custom_kernel_include = [ paddle_custom_kernel_include = [
os.path.join(site_packages_path, 'paddle', 'include'), os.path.join(site_packages_path, 'paddle', 'include'),
] ]
# include path third_party
compile_third_party_path = os.path.join(os.environ['PADDLE_ROOT'],
'build/third_party')
paddle_custom_kernel_include += [
os.path.join(compile_third_party_path, 'boost/src/extern_boost'), # boost
os.path.join(compile_third_party_path, 'install/gflags/include'), # gflags
os.path.join(compile_third_party_path, 'install/glog/include'), # glog
]
# libs path # libs path
paddle_custom_kernel_library_dir = [ paddle_custom_kernel_library_dir = [
...@@ -50,4 +77,5 @@ setup( ...@@ -50,4 +77,5 @@ setup(
name='custom_kernel_dot', name='custom_kernel_dot',
version='1.0', version='1.0',
description='custom kernel fot compiling', description='custom kernel fot compiling',
cmdclass={'build_ext': BuildExt},
ext_modules=[custom_kernel_dot_module]) ext_modules=[custom_kernel_dot_module])
...@@ -577,9 +577,9 @@ headers = ( ...@@ -577,9 +577,9 @@ headers = (
list(find_files('*.h', '@PADDLE_SOURCE_DIR@/paddle/phi/common')) + # pten common headers list(find_files('*.h', '@PADDLE_SOURCE_DIR@/paddle/phi/common')) + # pten common headers
# pten level api headers (low level api) # pten level api headers (low level api)
list(find_files('*.h', '@PADDLE_SOURCE_DIR@/paddle/phi/core', recursive=True)) + # pten core headers list(find_files('*.h', '@PADDLE_SOURCE_DIR@/paddle/phi/core', recursive=True)) + # pten core headers
list(find_files('*.h', '@PADDLE_SOURCE_DIR@/paddle/phi/backends', recursive=True)) + # pten backends headers
# utila api headers # utila api headers
['@PADDLE_SOURCE_DIR@/paddle/utils/any.h'] + list(find_files('*.h', '@PADDLE_SOURCE_DIR@/paddle/utils', recursive=True)) + # paddle utils headers
['@PADDLE_SOURCE_DIR@/paddle/utils/small_vector.h'] +
['@PADDLE_SOURCE_DIR@/paddle/fluid/platform/device/device_ext.h']) ['@PADDLE_SOURCE_DIR@/paddle/fluid/platform/device/device_ext.h'])
if '${WITH_MKLDNN}' == 'ON': if '${WITH_MKLDNN}' == 'ON':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册