未验证 提交 e92e3aab 编写于 作者: Y YuanRisheng 提交者: GitHub

[PHI]Unify Fluid and PHI kernel (#49328)

* unify_kernel

* fix compile bugs

* modify macro name

* perfect code according comment

* fix compile bugs

* fix compile bugs

* fix ci bugs

* fix ci bug

* fix ci bugs

* fix ci bugs

* modify code according comment

* rm conv_fusion_op
上级 766a4ca9
...@@ -28,11 +28,7 @@ endfunction() ...@@ -28,11 +28,7 @@ endfunction()
function(find_phi_register FILENAME ADD_PATH PATTERN) function(find_phi_register FILENAME ADD_PATH PATTERN)
# set op_name to OUTPUT # set op_name to OUTPUT
set(options "")
set(oneValueArgs "")
set(multiValueArgs "")
file(READ ${FILENAME} CONTENT) file(READ ${FILENAME} CONTENT)
string( string(
REGEX REGEX
MATCH MATCH
...@@ -402,6 +398,7 @@ function(op_library TARGET) ...@@ -402,6 +398,7 @@ function(op_library TARGET)
set(op_name "") set(op_name "")
# Add PHI Kernel Registry Message # Add PHI Kernel Registry Message
find_phi_register(${cc_src} ${pybind_file} "PD_REGISTER_KERNEL") find_phi_register(${cc_src} ${pybind_file} "PD_REGISTER_KERNEL")
find_phi_register(${cc_src} ${pybind_file} "PD_REGISTER_STRUCT_KERNEL")
find_phi_register(${cc_src} ${pybind_file} "PD_REGISTER_GENERAL_KERNEL") find_phi_register(${cc_src} ${pybind_file} "PD_REGISTER_GENERAL_KERNEL")
find_register(${cc_src} "REGISTER_OPERATOR" op_name) find_register(${cc_src} "REGISTER_OPERATOR" op_name)
if(NOT ${op_name} EQUAL "") if(NOT ${op_name} EQUAL "")
...@@ -453,6 +450,7 @@ function(op_library TARGET) ...@@ -453,6 +450,7 @@ function(op_library TARGET)
set(op_name "") set(op_name "")
# Add PHI Kernel Registry Message # Add PHI Kernel Registry Message
find_phi_register(${cu_src} ${pybind_file} "PD_REGISTER_KERNEL") find_phi_register(${cu_src} ${pybind_file} "PD_REGISTER_KERNEL")
find_phi_register(${cu_src} ${pybind_file} "PD_REGISTER_STRUCT_KERNEL")
find_phi_register(${cu_src} ${pybind_file} "PD_REGISTER_GENERAL_KERNEL") find_phi_register(${cu_src} ${pybind_file} "PD_REGISTER_GENERAL_KERNEL")
find_register(${cu_src} "REGISTER_OP_CUDA_KERNEL" op_name) find_register(${cu_src} "REGISTER_OP_CUDA_KERNEL" op_name)
if(NOT ${op_name} EQUAL "") if(NOT ${op_name} EQUAL "")
......
...@@ -827,7 +827,9 @@ bool BuildOpFuncList(const platform::Place& place, ...@@ -827,7 +827,9 @@ bool BuildOpFuncList(const platform::Place& place,
} }
// step 5. run kernel // step 5. run kernel
if (run_phi_kernel) { if (run_phi_kernel &&
op_func_node.phi_kernel_->GetKernelRegisteredType() ==
phi::KernelRegisteredType::FUNCTION) {
phi::KernelContext phi_kernel_context; phi::KernelContext phi_kernel_context;
op_with_kernel->BuildPhiKernelContext( op_with_kernel->BuildPhiKernelContext(
runtime_context, dev_ctx, &phi_kernel_context); runtime_context, dev_ctx, &phi_kernel_context);
...@@ -838,6 +840,12 @@ bool BuildOpFuncList(const platform::Place& place, ...@@ -838,6 +840,12 @@ bool BuildOpFuncList(const platform::Place& place,
op_with_kernel->PhiKernelSignature(), op_with_kernel->PhiKernelSignature(),
&phi_kernel_context); &phi_kernel_context);
} }
} else if (run_phi_kernel &&
op_func_node.phi_kernel_->GetKernelRegisteredType() ==
phi::KernelRegisteredType::STRUCTURE) {
ExecutionContext execution_context(
*op_with_kernel, *runtime_scope, *dev_ctx, runtime_context);
(*op_func_node.phi_kernel_)(&execution_context);
} else { } else {
// the place of exec_ctx maybe has changed. // the place of exec_ctx maybe has changed.
if (!skip_run) { if (!skip_run) {
......
...@@ -34,6 +34,7 @@ limitations under the License. */ ...@@ -34,6 +34,7 @@ limitations under the License. */
#include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/shape_inference.h" #include "paddle/fluid/framework/shape_inference.h"
#include "paddle/phi/core/kernel_registry.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -484,5 +485,36 @@ struct OpKernelRegistrarFunctorEx<PlaceType, ...@@ -484,5 +485,36 @@ struct OpKernelRegistrarFunctorEx<PlaceType,
USE_OP_KERNEL(op_type) USE_OP_KERNEL(op_type)
// clang-format on // clang-format on
template <typename StructureKernel>
struct StructKernelImpl {
static void Compute(phi::KernelContext* ctx) {
auto exe_ctx = static_cast<paddle::framework::ExecutionContext*>(ctx);
StructureKernel().Compute(*exe_ctx);
}
};
#define PHI_STRUCTURE_KERNEL(...) \
::paddle::framework::StructKernelImpl<__VA_ARGS__>::Compute
#define PHI_STRUCTURE_VARIADIC_KERNEL(...) nullptr
#define STRUCTURE_ARG_PARSE_FUNCTOR(...) nullptr
#define STRUCTURE_KERNEL_INSTANTIATION( \
meta_kernel_structure, cpp_dtype, context) \
template class meta_kernel_structure<cpp_dtype, context>;
#define PD_REGISTER_STRUCT_KERNEL( \
kernel_name, backend, layout, meta_kernel_structure, ...) \
_PD_REGISTER_KERNEL(::phi::RegType::INNER, \
kernel_name, \
backend, \
::phi::backend##Context, \
layout, \
meta_kernel_structure, \
STRUCTURE_KERNEL_INSTANTIATION, \
STRUCTURE_ARG_PARSE_FUNCTOR, \
PHI_STRUCTURE_KERNEL, \
PHI_STRUCTURE_VARIADIC_KERNEL, \
__VA_ARGS__)
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -1689,15 +1689,18 @@ void OperatorWithKernel::RunImpl(const Scope& scope, ...@@ -1689,15 +1689,18 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
std::string phi_kernel_name; std::string phi_kernel_name;
if (phi::KernelFactory::Instance().HasCompatiblePhiKernel(type_)) { if (phi::KernelFactory::Instance().HasCompatiblePhiKernel(type_)) {
if (kernel_signature_ == nullptr || phi_kernel_ == nullptr) { if (kernel_signature_ == nullptr || phi_kernel_ == nullptr) {
kernel_signature_.reset(new phi::KernelSignature( if (phi::KernelFactory::Instance().HasStructuredKernel(type_)) {
std::move(GetExpectedPhiKernelArgs(exe_ctx)))); kernel_signature_.reset(new phi::KernelSignature(type_.c_str()));
VLOG(6) << *kernel_signature_.get(); } else {
kernel_signature_.reset(new phi::KernelSignature(
std::move(GetExpectedPhiKernelArgs(exe_ctx))));
}
VLOG(6) << *kernel_signature_.get();
phi_kernel_name = kernel_signature_->name;
kernel_type_.reset( kernel_type_.reset(
new OpKernelType(std::move(InnerGetExpectedKernelType(exe_ctx)))); new OpKernelType(std::move(InnerGetExpectedKernelType(exe_ctx))));
dev_ctx = pool.Get(kernel_type_->place_); dev_ctx = pool.Get(kernel_type_->place_);
phi_kernel_name = kernel_signature_->name;
// NOTE(Liu-xiandong): The register kernel used KP have library_type[KP], // NOTE(Liu-xiandong): The register kernel used KP have library_type[KP],
// But the default library_type is Plain, so we need to modify the // But the default library_type is Plain, so we need to modify the
// library_type here, otherwise it can't work. // library_type here, otherwise it can't work.
...@@ -1753,7 +1756,6 @@ void OperatorWithKernel::RunImpl(const Scope& scope, ...@@ -1753,7 +1756,6 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
} }
} else { } else {
phi_kernel_name = kernel_signature_->name; phi_kernel_name = kernel_signature_->name;
// NOTE(jiahongyu): The registered MKLDNN kernel have library_type = // NOTE(jiahongyu): The registered MKLDNN kernel have library_type =
// LibraryType::kMKLDNN and data_layout_ = DataLayout::ONEDNN. But the default // LibraryType::kMKLDNN and data_layout_ = DataLayout::ONEDNN. But the default
// values are kPlain, so we need to modify the library_type and data_layout_ // values are kPlain, so we need to modify the library_type and data_layout_
...@@ -1939,7 +1941,8 @@ void OperatorWithKernel::RunImpl(const Scope& scope, ...@@ -1939,7 +1941,8 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
platform::TracerEventType::OperatorInner, platform::TracerEventType::OperatorInner,
1, 1,
platform::EventRole::kInnerOp); platform::EventRole::kInnerOp);
if (run_phi_kernel_) { if (run_phi_kernel_ && phi_kernel_->GetKernelRegisteredType() ==
phi::KernelRegisteredType::FUNCTION) {
phi::KernelContext phi_kernel_context; phi::KernelContext phi_kernel_context;
if (enable_cache_runtime_context_ && !need_prepare_phi_data_ && if (enable_cache_runtime_context_ && !need_prepare_phi_data_ &&
!need_prepare_data_) { !need_prepare_data_) {
...@@ -1977,6 +1980,11 @@ void OperatorWithKernel::RunImpl(const Scope& scope, ...@@ -1977,6 +1980,11 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
BuildPhiKernelContext(*runtime_ctx, dev_ctx, &phi_kernel_context); BuildPhiKernelContext(*runtime_ctx, dev_ctx, &phi_kernel_context);
(*phi_kernel_)(&phi_kernel_context); (*phi_kernel_)(&phi_kernel_context);
} }
} else if (run_phi_kernel_ && phi_kernel_->GetKernelRegisteredType() ==
phi::KernelRegisteredType::STRUCTURE) {
ExecutionContext execution_context(
*this, exec_scope, *dev_ctx, *runtime_ctx);
(*phi_kernel_)(&execution_context);
} else { } else {
(*kernel_func_)( (*kernel_func_)(
ExecutionContext(*this, exec_scope, *dev_ctx, *runtime_ctx)); ExecutionContext(*this, exec_scope, *dev_ctx, *runtime_ctx));
...@@ -2147,14 +2155,18 @@ OpKernelType OperatorWithKernel::InnerGetExpectedKernelType( ...@@ -2147,14 +2155,18 @@ OpKernelType OperatorWithKernel::InnerGetExpectedKernelType(
phi::KernelKey OperatorWithKernel::ChoosePhiKernel( phi::KernelKey OperatorWithKernel::ChoosePhiKernel(
const ExecutionContext& ctx) const { const ExecutionContext& ctx) const {
kernel_signature_.reset( std::string phi_kernel_name;
new phi::KernelSignature(std::move(GetExpectedPhiKernelArgs(ctx)))); if (phi::KernelFactory::Instance().HasStructuredKernel(type_)) {
kernel_signature_.reset(new phi::KernelSignature(type_.c_str()));
} else {
kernel_signature_.reset(
new phi::KernelSignature(std::move(GetExpectedPhiKernelArgs(ctx))));
}
VLOG(6) << *kernel_signature_.get(); VLOG(6) << *kernel_signature_.get();
phi_kernel_name = kernel_signature_->name;
kernel_type_.reset( kernel_type_.reset(
new OpKernelType(std::move(InnerGetExpectedKernelType(ctx)))); new OpKernelType(std::move(InnerGetExpectedKernelType(ctx))));
auto phi_kernel_name = kernel_signature_->name;
auto phi_kernel_key = TransOpKernelTypeToPhiKernelKey(*kernel_type_.get()); auto phi_kernel_key = TransOpKernelTypeToPhiKernelKey(*kernel_type_.get());
phi_kernel_.reset(new phi::Kernel(phi::KernelFactory::Instance().SelectKernel( phi_kernel_.reset(new phi::Kernel(phi::KernelFactory::Instance().SelectKernel(
phi_kernel_name, phi_kernel_key))); phi_kernel_name, phi_kernel_key)));
...@@ -2616,7 +2628,8 @@ Scope* OperatorWithKernel::PrepareData( ...@@ -2616,7 +2628,8 @@ Scope* OperatorWithKernel::PrepareData(
} }
}; };
if (run_phi_kernel_) { if (run_phi_kernel_ && phi_kernel_->GetKernelRegisteredType() ==
phi::KernelRegisteredType::FUNCTION) {
const auto& input_names = kernel_signature_->input_names; const auto& input_names = kernel_signature_->input_names;
const auto& input_defs = phi_kernel_->args_def().input_defs(); const auto& input_defs = phi_kernel_->args_def().input_defs();
PADDLE_ENFORCE_EQ(input_names.size(), PADDLE_ENFORCE_EQ(input_names.size(),
......
...@@ -41,6 +41,7 @@ limitations under the License. */ ...@@ -41,6 +41,7 @@ limitations under the License. */
#include "paddle/phi/core/compat/arg_map_context.h" #include "paddle/phi/core/compat/arg_map_context.h"
#include "paddle/phi/core/compat/op_utils.h" #include "paddle/phi/core/compat/op_utils.h"
#include "paddle/phi/core/kernel_context.h"
#include "paddle/phi/core/kernel_factory.h" #include "paddle/phi/core/kernel_factory.h"
#include "paddle/utils/flat_hash_map.h" #include "paddle/utils/flat_hash_map.h"
...@@ -290,7 +291,7 @@ class OperatorBase { ...@@ -290,7 +291,7 @@ class OperatorBase {
const platform::Place& place) const = 0; const platform::Place& place) const = 0;
}; };
class ExecutionContext { class ExecutionContext : public phi::KernelContext {
public: public:
ExecutionContext(const OperatorBase& op, ExecutionContext(const OperatorBase& op,
const Scope& scope, const Scope& scope,
......
...@@ -273,17 +273,23 @@ PreparedOp PrepareImpl( ...@@ -273,17 +273,23 @@ PreparedOp PrepareImpl(
kernel_signature = (*arg_map_fn)( kernel_signature = (*arg_map_fn)(
framework::ExecutionArgumentMappingContext(dygraph_exe_ctx)); framework::ExecutionArgumentMappingContext(dygraph_exe_ctx));
} else { } else {
default_kernel_signature = if (phi::KernelFactory::Instance().HasStructuredKernel(op.Type())) {
default_phi_kernel_sig_map.GetNullable(op.Type());
if (default_kernel_signature) {
has_phi_kernel = true; has_phi_kernel = true;
kernel_signature = *default_kernel_signature; kernel_signature = phi::KernelSignature(op.Type().c_str());
} else {
default_kernel_signature =
default_phi_kernel_sig_map.GetNullable(op.Type());
if (default_kernel_signature) {
has_phi_kernel = true;
kernel_signature = *default_kernel_signature;
}
} }
} }
if (has_phi_kernel) { if (has_phi_kernel) {
VLOG(6) << kernel_signature; VLOG(6) << kernel_signature;
phi_kernel_name = kernel_signature.name; phi_kernel_name = kernel_signature.name;
// NOTE(Liu-xiandong): The register kernel used KP have library_type[KP], // NOTE(Liu-xiandong): The register kernel used KP have library_type[KP],
// But the default library_type is Plain, so we need to modify the // But the default library_type is Plain, so we need to modify the
// library_type here, otherwise it can't work. // library_type here, otherwise it can't work.
...@@ -648,6 +654,7 @@ static void PreparedOpRunPtImpl( ...@@ -648,6 +654,7 @@ static void PreparedOpRunPtImpl(
const phi::KernelSignature* default_kernel_signature, const phi::KernelSignature* default_kernel_signature,
const phi::KernelSignature& kernel_signature, const phi::KernelSignature& kernel_signature,
const phi::Kernel& phi_kernel, const phi::Kernel& phi_kernel,
const framework::RuntimeContext& ctx,
platform::DeviceContext* dev_ctx, platform::DeviceContext* dev_ctx,
const NameVarMap<VarType>& ins, const NameVarMap<VarType>& ins,
const NameVarMap<VarType>& outs, const NameVarMap<VarType>& outs,
...@@ -678,19 +685,25 @@ static void PreparedOpRunPtImpl( ...@@ -678,19 +685,25 @@ static void PreparedOpRunPtImpl(
1, 1,
platform::EventRole::kInnerOp); platform::EventRole::kInnerOp);
PreparePhiData<VarType>(phi_kernel, kernel_signature, ins); if (phi_kernel.GetKernelRegisteredType() ==
phi::KernelRegisteredType::FUNCTION) {
phi::KernelContext phi_kernel_context; PreparePhiData<VarType>(phi_kernel, kernel_signature, ins);
BuildDygraphPhiKernelContext<VarType>(kernel_signature, phi::KernelContext phi_kernel_context;
phi_kernel, BuildDygraphPhiKernelContext<VarType>(kernel_signature,
ins, phi_kernel,
outs, ins,
attrs, outs,
default_attrs, attrs,
dev_ctx, default_attrs,
&phi_kernel_context); dev_ctx,
&phi_kernel_context);
phi_kernel(&phi_kernel_context); phi_kernel(&phi_kernel_context);
} else {
DygraphExecutionContext<VarType> exe_ctx(
op, empty_scope, *dev_ctx, ctx, ins, outs, attrs, default_attrs);
phi_kernel(&exe_ctx);
}
} }
if (FLAGS_check_nan_inf) { if (FLAGS_check_nan_inf) {
...@@ -722,6 +735,7 @@ void PreparedOp::Run(const NameVarMap<VarBase>& ins, ...@@ -722,6 +735,7 @@ void PreparedOp::Run(const NameVarMap<VarBase>& ins,
default_kernel_signature_, default_kernel_signature_,
kernel_signature_, kernel_signature_,
phi_kernel_, phi_kernel_,
ctx_,
dev_ctx_, dev_ctx_,
ins, ins,
outs, outs,
...@@ -753,6 +767,7 @@ void PreparedOp::Run(const NameVarMap<VariableWrapper>& ins, ...@@ -753,6 +767,7 @@ void PreparedOp::Run(const NameVarMap<VariableWrapper>& ins,
default_kernel_signature_, default_kernel_signature_,
kernel_signature_, kernel_signature_,
phi_kernel_, phi_kernel_,
ctx_,
dev_ctx_, dev_ctx_,
ins, ins,
outs, outs,
...@@ -784,6 +799,7 @@ void PreparedOp::Run(const NameVarMap<egr::EagerVariable>& ins, ...@@ -784,6 +799,7 @@ void PreparedOp::Run(const NameVarMap<egr::EagerVariable>& ins,
default_kernel_signature_, default_kernel_signature_,
kernel_signature_, kernel_signature_,
phi_kernel_, phi_kernel_,
ctx_,
dev_ctx_, dev_ctx_,
ins, ins,
outs, outs,
......
...@@ -530,8 +530,12 @@ phi::KernelSignature Tracer::GetExpectedKernelSignature( ...@@ -530,8 +530,12 @@ phi::KernelSignature Tracer::GetExpectedKernelSignature(
"This op type:`%s` is not a OperatorWithKernel, only " "This op type:`%s` is not a OperatorWithKernel, only "
"OperatorWithKernel can get KernelSignature", "OperatorWithKernel can get KernelSignature",
type)); type));
return phi::KernelSignature( if (phi::KernelFactory::Instance().HasStructuredKernel(type)) {
std::move(opbase_with_kernel->GetExpectedPhiKernelArgs(dygraph_exe_ctx))); return phi::KernelSignature(op->Type().c_str());
} else {
return phi::KernelSignature(std::move(
opbase_with_kernel->GetExpectedPhiKernelArgs(dygraph_exe_ctx)));
}
} }
} // namespace imperative } // namespace imperative
......
...@@ -34,7 +34,7 @@ limitations under the License. */ ...@@ -34,7 +34,7 @@ limitations under the License. */
phi::RegType::INNER, \ phi::RegType::INNER, \
#kernel_name, \ #kernel_name, \
dev_type, \ dev_type, \
DATALAYOUT(layout), \ DATA_LAYOUT(layout), \
::phi::KernelArgsParseFunctor<decltype(&kernel_fn)>::Parse, \ ::phi::KernelArgsParseFunctor<decltype(&kernel_fn)>::Parse, \
[](const phi::KernelKey& kernel_key, phi::Kernel* kernel) {}, \ [](const phi::KernelKey& kernel_key, phi::Kernel* kernel) {}, \
PHI_KERNEL(kernel_fn), \ PHI_KERNEL(kernel_fn), \
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <array>
#include "paddle/fluid/framework/conv_search_cache.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/conv_op.h"
#include "paddle/fluid/platform/device/gpu/gpu_dnn.h"
#include "paddle/phi/kernels/funcs/padding.h"
#include "paddle/phi/kernels/gpudnn/conv_gpudnn_info.h"
DECLARE_int64(cudnn_exhaustive_search_times);
namespace paddle {
namespace operators {
#if PADDLE_WITH_HIP || CUDNN_VERSION >= 7100
using ScopedTensorDescriptor = platform::ScopedTensorDescriptor;
using ScopedFilterDescriptor = platform::ScopedFilterDescriptor;
using ScopedConvolutionDescriptor = platform::ScopedConvolutionDescriptor;
using ScopedActivationDescriptor = platform::ScopedActivationDescriptor;
using DataLayout = platform::DataLayout;
using framework::AlgorithmsCache;
using framework::ConvSearchCache;
using framework::SearchFuseResult;
template <typename T>
using ScalingParamType = typename platform::CudnnDataType<T>::ScalingParamType;
template <typename T>
class CUDNNConvFusionOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto& dev_ctx = ctx.template device_context<phi::GPUContext>();
auto* input = ctx.Input<phi::DenseTensor>("Input");
auto* filter = ctx.Input<phi::DenseTensor>("Filter");
auto* bias = ctx.Input<phi::DenseTensor>("Bias");
auto* residual = ctx.Input<phi::DenseTensor>("ResidualData");
auto* output = ctx.Output<phi::DenseTensor>("Output");
dev_ctx.template Alloc<T>(output, output->numel() * sizeof(T));
std::vector<int> strides = ctx.Attr<std::vector<int>>("strides");
std::vector<int> paddings = ctx.Attr<std::vector<int>>("paddings");
std::vector<int> dilations = ctx.Attr<std::vector<int>>("dilations");
const std::string activation = ctx.Attr<std::string>("activation");
std::string data_format = ctx.Attr<std::string>("data_format");
PADDLE_ENFORCE_NE(
data_format,
"NHWC",
platform::errors::PermissionDenied(
"Operator(Conv2DFusion) in cuDNN only supports data format of "
"channel first (NCHW) now. But received: data_format = '%s'.",
data_format));
int groups = ctx.Attr<int>("groups");
int64_t user_workspace_size =
static_cast<size_t>(ctx.Attr<int>("workspace_size_MB"));
bool exhaustive_search =
FLAGS_cudnn_exhaustive_search || ctx.Attr<bool>("exhaustive_search");
const T* filter_data = filter->data<T>();
const T* bias_data = bias->data<T>();
const std::string padding_algorithm =
ctx.Attr<std::string>("padding_algorithm");
phi::DenseTensor transformed_input_channel(input->dtype());
phi::DenseTensor transformed_output(output->dtype());
transformed_input_channel = *input;
transformed_output = *output;
T* output_data = transformed_output.data<T>();
const T* residual_data = residual ? residual->data<T>() : output_data;
// update padding and dilation
auto in_dims = transformed_input_channel.dims();
auto filter_dims = filter->dims();
framework::DDim in_data_dims = phi::slice_ddim(in_dims, 2, in_dims.size());
framework::DDim filter_data_dims =
phi::slice_ddim(filter_dims, 2, filter_dims.size());
std::vector<int> ksize = phi::vectorize<int>(filter_data_dims);
UpdatePaddingAndDilation(
&paddings, &dilations, padding_algorithm, in_data_dims, strides, ksize);
int data_dim = strides.size(); // 2d or 3d
bool is_sys_pad = phi::funcs::IsSymmetricPadding(paddings, data_dim);
phi::DenseTensor transformed_input;
std::vector<int> padding_common(data_dim, 0);
if (!is_sys_pad) {
std::vector<int> padding_diff(data_dim);
std::vector<int> new_input_shape_vec(data_dim + 2);
new_input_shape_vec[0] = transformed_input_channel.dims()[0];
new_input_shape_vec[1] = transformed_input_channel.dims()[1];
std::vector<int> input_pad(transformed_input_channel.dims().size() * 2,
0);
for (size_t i = 0; i < data_dim; ++i) {
padding_diff[i] = std::abs(paddings[2 * i] - paddings[2 * i + 1]);
padding_common[i] = std::min(paddings[2 * i], paddings[2 * i + 1]);
new_input_shape_vec[i + 2] =
transformed_input_channel.dims()[i + 2] + padding_diff[i];
input_pad[2 * i + 4] = paddings[2 * i] - padding_common[i];
input_pad[2 * i + 4 + 1] = paddings[2 * i + 1] - padding_common[i];
}
framework::DDim new_input_shape(phi::make_ddim(new_input_shape_vec));
transformed_input.Resize(new_input_shape);
auto& dev_ctx = ctx.template device_context<phi::GPUContext>();
transformed_input =
ctx.AllocateTmpTensor<T, phi::GPUContext>(new_input_shape, dev_ctx);
const int rank = transformed_input_channel.dims().size();
T pad_value(0.0);
switch (rank) {
case 4: {
phi::funcs::PadFunction<phi::GPUContext, T, 4>(
dev_ctx,
input_pad,
transformed_input_channel,
pad_value,
&transformed_input);
} break;
case 5: {
phi::funcs::PadFunction<phi::GPUContext, T, 5>(
dev_ctx,
input_pad,
transformed_input_channel,
pad_value,
&transformed_input);
} break;
default:
PADDLE_THROW(platform::errors::PermissionDenied(
"Operator Conv2DFusion expects Input to be a 4-D or 5-D "
"phi::DenseTensor. "
"But received the actual dimension = %d, shape = [%s].",
rank,
transformed_input_channel.dims()));
}
} else {
transformed_input = transformed_input_channel;
if (paddings.size() == data_dim) {
for (size_t i = 0; i < data_dim; ++i) {
padding_common[i] = paddings[i];
}
} else {
for (size_t i = 0; i < data_dim; ++i) {
padding_common[i] = paddings[2 * i];
}
}
}
const T* input_data = transformed_input.data<T>();
// ------------------- cudnn descriptors ---------------------
ScopedTensorDescriptor input_desc;
ScopedTensorDescriptor output_desc;
ScopedFilterDescriptor filter_desc;
ScopedTensorDescriptor bias_desc;
ScopedConvolutionDescriptor conv_desc;
ScopedActivationDescriptor act_desc;
DataLayout layout = DataLayout::kNCHW;
if (input->dims().size() == 5) {
layout = DataLayout::kNCDHW;
}
#ifdef PADDLE_WITH_HIP
miopenConvolutionDescriptor_t cudnn_conv_desc =
conv_desc.descriptor<T>(padding_common, strides, dilations);
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::miopenSetConvolutionGroupCount(cudnn_conv_desc,
groups));
// Now only support NCHW
std::vector<int> bias_dim = {
1, static_cast<int>(transformed_output.dims()[1]), 1, 1};
miopenTensorDescriptor_t cudnn_input_desc = input_desc.descriptor<T>(
layout, phi::vectorize<int>(transformed_input.dims()));
miopenTensorDescriptor_t cudnn_output_desc = output_desc.descriptor<T>(
layout, phi::vectorize<int>(transformed_output.dims()));
miopenTensorDescriptor_t cudnn_filter_desc =
filter_desc.descriptor<T>(layout, phi::vectorize<int>(filter->dims()));
miopenTensorDescriptor_t cudnn_bias_desc =
bias_desc.descriptor<T>(layout, bias_dim);
miopenActivationDescriptor_t cudnn_act_desc =
act_desc.descriptor<T>(activation);
miopenConvFwdAlgorithm_t algo;
auto handle = dev_ctx.cudnn_handle();
auto workspace_handle = dev_ctx.cudnn_workspace_handle();
auto x_dims = phi::vectorize(transformed_input.dims());
auto f_dims = phi::vectorize(filter->dims());
size_t workspace_size = 0;
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::miopenConvolutionForwardGetWorkSpaceSize(
handle,
cudnn_filter_desc,
cudnn_input_desc,
cudnn_conv_desc,
cudnn_output_desc,
&workspace_size));
int find_count;
miopenConvAlgoPerf_t find_result;
auto cudnn_find_func = [&](void* cudnn_workspace_ptr) {
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::miopenFindConvolutionForwardAlgorithm(
handle,
cudnn_input_desc,
input_data,
cudnn_filter_desc,
filter_data,
cudnn_conv_desc,
cudnn_output_desc,
output_data,
phi::kNUM_CUDNN_FWD_ALGS,
&find_count,
&find_result,
cudnn_workspace_ptr,
workspace_size,
false));
};
workspace_handle.RunFuncSync(cudnn_find_func, workspace_size);
algo = find_result.fwd_algo;
VLOG(3) << "cuDNN forward algo " << algo;
{
ScalingParamType<T> alpha = 1.0f, beta = 0.0f;
auto cudnn_func = [&](void* cudnn_workspace) {
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::miopenConvolutionForward(handle,
&alpha,
cudnn_input_desc,
input_data,
cudnn_filter_desc,
filter_data,
cudnn_conv_desc,
algo,
&beta,
cudnn_output_desc,
output_data,
cudnn_workspace,
workspace_size));
};
workspace_handle.RunFunc(cudnn_func, workspace_size);
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::miopenConvolutionForwardBias(handle,
&alpha,
cudnn_bias_desc,
bias_data,
&beta,
cudnn_output_desc,
output_data));
if (activation != "identity") {
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::miopenActivationForward(handle,
cudnn_act_desc,
&alpha,
cudnn_output_desc,
output_data,
&beta,
cudnn_output_desc,
output_data));
}
if (residual) {
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::miopenOpTensor(handle,
miopenTensorOpAdd,
&alpha,
cudnn_output_desc,
output_data,
&alpha,
cudnn_output_desc,
residual_data,
&beta,
cudnn_output_desc,
output_data));
}
}
#else // PADDLE_WITH_HIP
cudnnConvolutionDescriptor_t cudnn_conv_desc =
conv_desc.descriptor<T>(padding_common, strides, dilations);
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cudnnSetConvolutionGroupCount(
cudnn_conv_desc, groups));
cudnnTensorDescriptor_t cudnn_input_desc = input_desc.descriptor<T>(
layout, phi::vectorize<int>(transformed_input.dims()));
cudnnTensorDescriptor_t cudnn_output_desc = output_desc.descriptor<T>(
layout, phi::vectorize<int>(transformed_output.dims()));
cudnnFilterDescriptor_t cudnn_filter_desc =
filter_desc.descriptor<T>(layout, phi::vectorize<int>(filter->dims()));
// Now only support NCHW
std::vector<int> bias_dim = {
1, static_cast<int>(transformed_output.dims()[1]), 1, 1};
cudnnTensorDescriptor_t cudnn_bias_desc =
bias_desc.descriptor<T>(layout, bias_dim);
cudnnActivationDescriptor_t cudnn_act_desc =
act_desc.descriptor<T>(activation);
// ------------------- cudnn conv workspace ---------------------
size_t workspace_size_in_bytes; // final workspace to allocate.
size_t workspace_size_limit = 0;
if (FLAGS_conv_workspace_size_limit > 0 || user_workspace_size > 0) {
int64_t max_user_size =
std::min(static_cast<int64_t>(FLAGS_conv_workspace_size_limit),
user_workspace_size);
workspace_size_limit = max_user_size * 1024 * 1024;
}
// ------------------- cudnn conv algorithm ---------------------
cudnnConvolutionFwdAlgo_t algo;
auto handle = dev_ctx.cudnn_handle();
auto workspace_handle = dev_ctx.cudnn_workspace_handle();
auto dtype = platform::CudnnDataType<T>::type;
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cudnnSetConvolutionMathType(
cudnn_conv_desc, CUDNN_DEFAULT_MATH));
if (dtype == CUDNN_DATA_HALF) {
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cudnnSetConvolutionMathType(
cudnn_conv_desc, CUDNN_TENSOR_OP_MATH));
}
#if CUDA_VERSION >= 11000 && CUDNN_VERSION >= 8000
if (!platform::allow_tf32_cudnn) {
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cudnnSetConvolutionMathType(
cudnn_conv_desc, CUDNN_FMA_MATH));
}
#endif // CUDA_VERSION >= 11000 && CUDNN_VERSION >= 8000
auto x_dims = phi::vectorize(transformed_input.dims());
auto f_dims = phi::vectorize(filter->dims());
if (!exhaustive_search) {
#if CUDNN_VERSION >= 8000
int perf_count;
int best_algo_idx = 0;
size_t tmp_size = 0;
std::unique_ptr<cudnnConvolutionFwdAlgoPerf_t[]> perf_results(
new cudnnConvolutionFwdAlgoPerf_t[phi::kNUM_CUDNN_FWD_ALGS]);
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cudnnGetConvolutionForwardAlgorithm_v7(
handle,
cudnn_input_desc,
cudnn_filter_desc,
cudnn_conv_desc,
cudnn_output_desc,
phi::kNUM_CUDNN_FWD_ALGS,
&perf_count,
perf_results.get()));
algo = (perf_results.get())[best_algo_idx].algo;
#else
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cudnnGetConvolutionForwardAlgorithm(
handle,
cudnn_input_desc,
cudnn_filter_desc,
cudnn_conv_desc,
cudnn_output_desc,
CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT,
workspace_size_limit,
&algo));
#endif
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cudnnGetConvolutionForwardWorkspaceSize(
handle,
cudnn_input_desc,
cudnn_filter_desc,
cudnn_conv_desc,
cudnn_output_desc,
algo,
&workspace_size_in_bytes));
if (workspace_size_in_bytes > workspace_size_limit)
workspace_size_limit = workspace_size_in_bytes;
VLOG(3) << "cuDNN forward algo " << algo;
} else {
std::function<SearchFuseResult<cudnnConvolutionFwdAlgo_t>()> search_func =
[&]() -> SearchFuseResult<cudnnConvolutionFwdAlgo_t> {
int returned_algo_count;
SearchFuseResult<cudnnConvolutionFwdAlgo_t> fwd_result;
std::array<cudnnConvolutionFwdAlgoPerf_t, phi::kNUM_CUDNN_FWD_ALGS>
fwd_perf_stat;
auto cudnn_find_func = [&](void* cudnn_workspace) {
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cudnnFindConvolutionForwardAlgorithmEx(
handle,
cudnn_input_desc,
input_data,
cudnn_filter_desc,
filter_data,
cudnn_conv_desc,
cudnn_output_desc,
output_data,
phi::kNUM_CUDNN_FWD_ALGS,
&returned_algo_count,
fwd_perf_stat.data(),
cudnn_workspace,
workspace_size_limit));
};
workspace_handle.RunFuncSync(cudnn_find_func, workspace_size_limit);
VLOG(3) << "Perf result: (algo: stat, time, memory)";
for (int i = 0; i < returned_algo_count; ++i) {
const auto& stat = fwd_perf_stat[i];
VLOG(3) << stat.algo << ": " << stat.status << " " << stat.time << " "
<< stat.memory;
}
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cudnnGetConvolutionForwardWorkspaceSize(
handle,
cudnn_input_desc,
cudnn_filter_desc,
cudnn_conv_desc,
cudnn_output_desc,
fwd_perf_stat[0].algo,
&workspace_size_in_bytes));
// PADDLE_ENFORCE_LE(
// workspace_size_in_bytes,
// workspace_size_limit,
// platform::errors::InvalidArgument(
// "The actual workspace size to be allocated for cuDNN is
// expected " "to be less than the limit. But received: the
// actual workspace " "size = %d, limit = %d.",
// workspace_size_in_bytes,
// workspace_size_limit));
fwd_result.algo = fwd_perf_stat[0].algo;
fwd_result.workspace_size = workspace_size_in_bytes;
return fwd_result;
};
AlgorithmsCache<SearchFuseResult<cudnnConvolutionFwdAlgo_t>>& algo_cache =
*(framework::ConvSearchCache::Instance().GetConvFusion());
int search_times = ctx.Attr<int>("search_times");
SearchFuseResult<cudnnConvolutionFwdAlgo_t> algo_result;
search_times = std::max(
static_cast<int>(FLAGS_cudnn_exhaustive_search_times), search_times);
// TODO(dangqingqing): Unify this if-else.
if (search_times > 0) {
// The searched algo will be cached by `search_times` times for
// different input dimension. For other dimensions, select the algo
// of closest area.
algo_result = algo_cache.GetAlgorithm(
x_dims[2] * x_dims[3], search_times, 0, search_func);
algo = algo_result.algo;
workspace_size_in_bytes = algo_result.workspace_size;
} else {
algo_result = algo_cache.GetAlgorithm(x_dims,
f_dims,
strides,
paddings,
dilations,
0,
dtype,
search_func);
algo = algo_result.algo;
workspace_size_in_bytes = algo_result.workspace_size;
}
VLOG(3) << "choose algo " << algo;
}
if ((activation == "identity") && (!residual)) {
// Only the CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM algo is
// enabled with CUDNN_ACTIVATION_IDENTITY in cuDNN lib.
// But test in some case, the speed is slower, change to use
// cudnnConvolutionForward and cudnnAddTensor
// ------------- cudnn conv forward and bias add ---------------------
ScalingParamType<T> alpha = 1.0f, beta = 0.0f;
auto cudnn_func = [&](void* cudnn_workspace) {
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cudnnConvolutionForward(handle,
&alpha,
cudnn_input_desc,
input_data,
cudnn_filter_desc,
filter_data,
cudnn_conv_desc,
algo,
cudnn_workspace,
workspace_size_in_bytes,
&beta,
cudnn_output_desc,
output_data));
};
workspace_handle.RunFunc(cudnn_func, workspace_size_in_bytes);
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cudnnAddTensor(handle,
&alpha,
cudnn_bias_desc,
bias_data,
&alpha,
cudnn_output_desc,
output_data));
} else {
if (activation == "identity") {
algo = CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM;
}
// ------------------- cudnn conv+bias+act forward --------------------
ScalingParamType<T> alpha1 = 1.0f;
ScalingParamType<T> alpha2 = residual ? 1.0f : 0.0f;
auto cudnn_func = [&](void* cudnn_workspace) {
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cudnnConvolutionBiasActivationForward(
handle,
&alpha1,
cudnn_input_desc,
input_data,
cudnn_filter_desc,
filter_data,
cudnn_conv_desc,
algo,
cudnn_workspace,
workspace_size_in_bytes,
&alpha2,
cudnn_output_desc,
residual_data,
cudnn_bias_desc,
bias_data,
cudnn_act_desc,
cudnn_output_desc,
output_data));
};
workspace_handle.RunFunc(cudnn_func, workspace_size_in_bytes);
}
#endif
std::vector<int> channels = ctx.Attr<std::vector<int>>("split_channels");
if (channels.size()) {
auto outs = ctx.MultiOutput<phi::DenseTensor>("Outputs");
if (x_dims[0] == 1) {
// share data with Output
phi::DenseTensor t;
t.ShareDataWith(*output);
auto y_dims = output->dims();
t.Resize({y_dims[1], y_dims[2], y_dims[3]});
int s = 0;
for (size_t i = 0; i < channels.size(); ++i) {
int e = s + channels[i];
outs[i]->ShareDataWith(t.Slice(s, e));
outs[i]->Resize({x_dims[0], channels[i], y_dims[2], y_dims[3]});
s = e;
}
} else {
// TODO(qingiqng): do copy when batch size large than 1
PADDLE_THROW(platform::errors::Unimplemented(
"Input with batch size greater than 1 is unsupported. The received "
"batch size is %d, Input's shape is [%s].",
x_dims[0],
phi::make_ddim(x_dims)));
}
}
}
};
#endif
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
#if CUDNN_VERSION >= 7100
REGISTER_OP_CUDA_KERNEL(
conv2d_fusion,
ops::CUDNNConvFusionOpKernel<float>,
ops::CUDNNConvFusionOpKernel<double>,
ops::CUDNNConvFusionOpKernel<paddle::platform::float16>);
#endif
#ifdef PADDLE_WITH_HIP
REGISTER_OP_CUDA_KERNEL(conv2d_fusion, ops::CUDNNConvFusionOpKernel<float>);
#endif
...@@ -240,12 +240,15 @@ REGISTER_OPERATOR(rank_loss, ...@@ -240,12 +240,15 @@ REGISTER_OPERATOR(rank_loss,
ops::RankLossGradMaker<paddle::framework::OpDesc>, ops::RankLossGradMaker<paddle::framework::OpDesc>,
ops::RankLossGradMaker<paddle::imperative::OpBase>); ops::RankLossGradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(rank_loss_grad, ops::RankLossGradOp); REGISTER_OPERATOR(rank_loss_grad, ops::RankLossGradOp);
REGISTER_OP_CPU_KERNEL(rank_loss, ops::RankLossKernel<phi::CPUContext, float>);
REGISTER_OP_CPU_KERNEL(rank_loss_grad, PD_REGISTER_STRUCT_KERNEL(
ops::RankLossGradKernel<phi::CPUContext, float>); rank_loss, CPU, ALL_LAYOUT, ops::RankLossKernel, float) {}
PD_REGISTER_STRUCT_KERNEL(
REGISTER_OP_CUDA_KERNEL( rank_loss_grad, CPU, ALL_LAYOUT, ops::RankLossGradKernel, float) {}
rank_loss, paddle::operators::RankLossKernel<phi::GPUContext, float>);
REGISTER_OP_CUDA_KERNEL( #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
rank_loss_grad, PD_REGISTER_STRUCT_KERNEL(
paddle::operators::RankLossGradKernel<phi::GPUContext, float>); rank_loss, GPU, ALL_LAYOUT, ops::RankLossKernel, float) {}
PD_REGISTER_STRUCT_KERNEL(
rank_loss_grad, GPU, ALL_LAYOUT, ops::RankLossGradKernel, float) {}
#endif
...@@ -21,7 +21,7 @@ limitations under the License. */ ...@@ -21,7 +21,7 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
template <typename DeviceContext, typename T> template <typename T, typename DeviceContext>
class RankLossKernel : public framework::OpKernel<T> { class RankLossKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const { void Compute(const framework::ExecutionContext& ctx) const {
...@@ -42,7 +42,7 @@ class RankLossKernel : public framework::OpKernel<T> { ...@@ -42,7 +42,7 @@ class RankLossKernel : public framework::OpKernel<T> {
} }
}; };
template <typename DeviceContext, typename T> template <typename T, typename DeviceContext>
class RankLossGradKernel : public framework::OpKernel<T> { class RankLossGradKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const { void Compute(const framework::ExecutionContext& ctx) const {
......
...@@ -56,6 +56,8 @@ struct KernelSignature { ...@@ -56,6 +56,8 @@ struct KernelSignature {
attr_names(attrs), attr_names(attrs),
output_names(outputs) {} output_names(outputs) {}
explicit KernelSignature(const char* kernel_name) : name(kernel_name) {}
// TODO(chenweihang): add assign constructor to solve windows compile // TODO(chenweihang): add assign constructor to solve windows compile
// problem, remove it later // problem, remove it later
KernelSignature(const KernelSignature& other) KernelSignature(const KernelSignature& other)
......
...@@ -62,6 +62,21 @@ bool KernelFactory::HasCompatiblePhiKernel(const std::string& op_type) const { ...@@ -62,6 +62,21 @@ bool KernelFactory::HasCompatiblePhiKernel(const std::string& op_type) const {
return false; return false;
} }
bool KernelFactory::HasStructuredKernel(const std::string& op_type) const {
auto phi_kernel_name = phi::OpUtilsMap::Instance().GetBaseKernelName(op_type);
auto kernel_iter = kernels_.find(phi_kernel_name);
if (deprecated_op_names.find(op_type) == deprecated_op_names.end() &&
kernel_iter != kernels_.end()) {
return std::any_of(kernel_iter->second.begin(),
kernel_iter->second.end(),
[](phi::KernelKeyMap::const_reference kernel_pair) {
return kernel_pair.second.GetKernelRegisteredType() ==
KernelRegisteredType::STRUCTURE;
});
}
return false;
}
const Kernel& KernelFactory::SelectKernel(const std::string& kernel_name, const Kernel& KernelFactory::SelectKernel(const std::string& kernel_name,
const KernelKey& kernel_key) const { const KernelKey& kernel_key) const {
auto iter = kernels_.find(kernel_name); auto iter = kernels_.find(kernel_name);
......
...@@ -238,13 +238,21 @@ class KernelArgsDef { ...@@ -238,13 +238,21 @@ class KernelArgsDef {
{}}; {}};
}; };
enum class KernelRegisteredType { FUNCTION, STRUCTURE };
class Kernel { class Kernel {
public: public:
// for map element construct // for map element construct
Kernel() = default; Kernel() = default;
explicit Kernel(KernelFn fn, void* variadic_fn) explicit Kernel(KernelFn fn, void* variadic_fn)
: fn_(fn), variadic_fn_(variadic_fn) {} : fn_(fn), variadic_fn_(variadic_fn) {
if (variadic_fn == nullptr) {
kernel_registered_type_ = KernelRegisteredType::STRUCTURE;
} else {
kernel_registered_type_ = KernelRegisteredType::FUNCTION;
}
}
void operator()(KernelContext* ctx) const { fn_(ctx); } void operator()(KernelContext* ctx) const { fn_(ctx); }
...@@ -272,10 +280,15 @@ class Kernel { ...@@ -272,10 +280,15 @@ class Kernel {
bool IsValid() const { return fn_ != nullptr; } bool IsValid() const { return fn_ != nullptr; }
KernelRegisteredType GetKernelRegisteredType() const {
return kernel_registered_type_;
}
private: private:
KernelFn fn_{nullptr}; KernelFn fn_{nullptr};
void* variadic_fn_ = nullptr; void* variadic_fn_ = nullptr;
KernelArgsDef args_def_; KernelArgsDef args_def_;
KernelRegisteredType kernel_registered_type_ = KernelRegisteredType::FUNCTION;
}; };
using KernelKeyMap = paddle::flat_hash_map<KernelKey, Kernel, KernelKey::Hash>; using KernelKeyMap = paddle::flat_hash_map<KernelKey, Kernel, KernelKey::Hash>;
...@@ -304,6 +317,8 @@ class KernelFactory { ...@@ -304,6 +317,8 @@ class KernelFactory {
bool HasCompatiblePhiKernel(const std::string& op_type) const; bool HasCompatiblePhiKernel(const std::string& op_type) const;
bool HasStructuredKernel(const std::string& op_type) const;
KernelResult SelectKernelOrThrowError(const std::string& kernel_name, KernelResult SelectKernelOrThrowError(const std::string& kernel_name,
const KernelKey& kernel_key) const; const KernelKey& kernel_key) const;
......
...@@ -32,7 +32,7 @@ ...@@ -32,7 +32,7 @@
namespace phi { namespace phi {
#define BACKEND(arg__) phi::Backend::arg__ #define BACKEND(arg__) phi::Backend::arg__
#define DATALAYOUT(arg__) phi::DataLayout::arg__ #define DATA_LAYOUT(arg__) phi::DataLayout::arg__
#define DATATYPE(arg__) phi::DataType::arg__ #define DATATYPE(arg__) phi::DataType::arg__
template <typename Func> template <typename Func>
...@@ -348,7 +348,9 @@ struct KernelRegistrar { ...@@ -348,7 +348,9 @@ struct KernelRegistrar {
KernelKey kernel_key( KernelKey kernel_key(
paddle::experimental::StringToBackend(backend_cstr), layout, dtype); paddle::experimental::StringToBackend(backend_cstr), layout, dtype);
Kernel kernel(kernel_fn, variadic_kernel_fn); Kernel kernel(kernel_fn, variadic_kernel_fn);
args_parse_fn(kernel_key, kernel.mutable_args_def()); if (kernel.GetKernelRegisteredType() == KernelRegisteredType::FUNCTION) {
args_parse_fn(kernel_key, kernel.mutable_args_def());
}
args_def_fn(kernel_key, &kernel); args_def_fn(kernel_key, &kernel);
if (reg_type == RegType::INNER) { if (reg_type == RegType::INNER) {
KernelFactory::Instance().kernels()[kernel_name][kernel_key] = kernel; KernelFactory::Instance().kernels()[kernel_name][kernel_key] = kernel;
...@@ -380,6 +382,16 @@ struct KernelRegistrar { ...@@ -380,6 +382,16 @@ struct KernelRegistrar {
#define _PD_ARG_N(args) _PD_ARG_N_EXPAND args #define _PD_ARG_N(args) _PD_ARG_N_EXPAND args
#define _PD_RESQ_N() 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0 #define _PD_RESQ_N() 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0
// The macro for passing KernelArgsParseFunctor's function
#define ARG_PARSE_FUNCTOR(meta_kernel_fn, cpp_dtype, context) \
::phi::KernelArgsParseFunctor< \
decltype(&meta_kernel_fn<cpp_dtype, context>)>::Parse
// The macro for instantiating function kernel
#define FUNCTION_KERNEL_INSTANTIATION(meta_kernel_fn, cpp_dtype, context) \
template decltype(meta_kernel_fn<cpp_dtype, context>) \
meta_kernel_fn<cpp_dtype, context>;
/** PD_REGISTER_KERNEL /** PD_REGISTER_KERNEL
* *
* The most frequently used kernel registration macro, used for kernel * The most frequently used kernel registration macro, used for kernel
...@@ -396,10 +408,23 @@ struct KernelRegistrar { ...@@ -396,10 +408,23 @@ struct KernelRegistrar {
::phi::backend##Context, \ ::phi::backend##Context, \
layout, \ layout, \
meta_kernel_fn, \ meta_kernel_fn, \
FUNCTION_KERNEL_INSTANTIATION, \
ARG_PARSE_FUNCTOR, \
PHI_KERNEL, \
PHI_VARIADIC_KERNEL, \
__VA_ARGS__) __VA_ARGS__)
#define _PD_REGISTER_KERNEL( \ #define _PD_REGISTER_KERNEL(reg_type, \
reg_type, kernel_name, backend, context, layout, meta_kernel_fn, ...) \ kernel_name, \
backend, \
context, \
layout, \
meta_kernel_fn, \
kernel_instantiation_macro, \
arg_parse_functor_macro, \
kernel_unfold_macro, \
variadic_kernel_unfold_marco, \
...) \
PD_STATIC_ASSERT_GLOBAL_NAMESPACE( \ PD_STATIC_ASSERT_GLOBAL_NAMESPACE( \
PD_REGISTER_tp_kernel_ns_check_##kernel_name##_##backend##_##layout, \ PD_REGISTER_tp_kernel_ns_check_##kernel_name##_##backend##_##layout, \
"PD_REGISTER_KERNEL must be called in global namespace."); \ "PD_REGISTER_KERNEL must be called in global namespace."); \
...@@ -409,12 +434,29 @@ struct KernelRegistrar { ...@@ -409,12 +434,29 @@ struct KernelRegistrar {
context, \ context, \
layout, \ layout, \
meta_kernel_fn, \ meta_kernel_fn, \
kernel_instantiation_macro, \
arg_parse_functor_macro, \
kernel_unfold_macro, \
variadic_kernel_unfold_marco, \
__VA_ARGS__)) __VA_ARGS__))
#ifndef _WIN32 #ifndef _WIN32
#define _PD_REGISTER_2TA_KERNEL( \ #define _PD_REGISTER_2TA_KERNEL(reg_type, \
reg_type, kernel_name, backend, context, layout, meta_kernel_fn, ...) \ kernel_name, \
PD_KERNEL_INSTANTIATION(meta_kernel_fn, backend, context, __VA_ARGS__); \ backend, \
context, \
layout, \
meta_kernel_fn, \
kernel_instantiation_macro, \
arg_parse_functor_macro, \
kernel_unfold_macro, \
variadic_kernel_unfold_marco, \
...) \
PD_KERNEL_INSTANTIATION(meta_kernel_fn, \
backend, \
context, \
kernel_instantiation_macro, \
__VA_ARGS__); \
static void __PD_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout( \ static void __PD_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout( \
const ::phi::KernelKey& kernel_key, ::phi::Kernel* kernel); \ const ::phi::KernelKey& kernel_key, ::phi::Kernel* kernel); \
PD_KERNEL_REGISTRAR_INIT( \ PD_KERNEL_REGISTRAR_INIT( \
...@@ -425,6 +467,9 @@ struct KernelRegistrar { ...@@ -425,6 +467,9 @@ struct KernelRegistrar {
layout, \ layout, \
&__PD_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout, \ &__PD_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout, \
meta_kernel_fn, \ meta_kernel_fn, \
arg_parse_functor_macro, \
kernel_unfold_macro, \
variadic_kernel_unfold_marco, \
__VA_ARGS__); \ __VA_ARGS__); \
void __PD_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout( \ void __PD_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout( \
const ::phi::KernelKey& kernel_key, ::phi::Kernel* kernel) const ::phi::KernelKey& kernel_key, ::phi::Kernel* kernel)
...@@ -441,8 +486,17 @@ struct KernelRegistrar { ...@@ -441,8 +486,17 @@ struct KernelRegistrar {
* *
* And msvc can work without template instantiation * And msvc can work without template instantiation
*/ */
#define _PD_REGISTER_2TA_KERNEL( \ #define _PD_REGISTER_2TA_KERNEL(reg_type, \
reg_type, kernel_name, backend, context, layout, meta_kernel_fn, ...) \ kernel_name, \
backend, \
context, \
layout, \
meta_kernel_fn, \
kernel_instantiation_macro, \
arg_parse_functor_macro, \
kernel_unfold_macro, \
variadic_kernel_unfold_marco, \
...) \
static void __PD_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout( \ static void __PD_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout( \
const ::phi::KernelKey& kernel_key, ::phi::Kernel* kernel); \ const ::phi::KernelKey& kernel_key, ::phi::Kernel* kernel); \
PD_EXPAND(PD_KERNEL_REGISTRAR_INIT( \ PD_EXPAND(PD_KERNEL_REGISTRAR_INIT( \
...@@ -453,124 +507,222 @@ struct KernelRegistrar { ...@@ -453,124 +507,222 @@ struct KernelRegistrar {
layout, \ layout, \
&__PD_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout, \ &__PD_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout, \
meta_kernel_fn, \ meta_kernel_fn, \
arg_parse_functor_macro, \
kernel_unfold_macro, \
variadic_kernel_unfold_marco, \
__VA_ARGS__)); \ __VA_ARGS__)); \
void __PD_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout( \ void __PD_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout( \
const ::phi::KernelKey& kernel_key, ::phi::Kernel* kernel) const ::phi::KernelKey& kernel_key, ::phi::Kernel* kernel)
#endif #endif
#define PD_KERNEL_INSTANTIATION(meta_kernel_fn, backend, context, ...) \ #define PD_KERNEL_INSTANTIATION( \
_PD_KERNEL_INSTANTIATION( \ meta_kernel_fn, backend, context, kernel_instantiation_macro, ...) \
PD_NARGS(__VA_ARGS__), meta_kernel_fn, backend, context, __VA_ARGS__) _PD_KERNEL_INSTANTIATION(PD_NARGS(__VA_ARGS__), \
meta_kernel_fn, \
backend, \
context, \
kernel_instantiation_macro, \
__VA_ARGS__)
#define _PD_KERNEL_INSTANTIATION(N, meta_kernel_fn, backend, context, ...) \ #define _PD_KERNEL_INSTANTIATION( \
PD_CONCATENATE(_PD_KERNEL_INSTANTIATION_, N) \ N, meta_kernel_fn, backend, context, kernel_instantiation_macro, ...) \
(meta_kernel_fn, backend, context, __VA_ARGS__) PD_CONCATENATE(_PD_KERNEL_INSTANTIATION_, N) \
(meta_kernel_fn, backend, context, kernel_instantiation_macro, __VA_ARGS__)
#define _PD_KERNEL_INSTANTIATION_1( \ #define _PD_KERNEL_INSTANTIATION_1( \
meta_kernel_fn, backend, context, cpp_dtype) \ meta_kernel_fn, backend, context, kernel_instantiation_macro, cpp_dtype) \
template decltype(meta_kernel_fn<cpp_dtype, context>) \ kernel_instantiation_macro(meta_kernel_fn, cpp_dtype, context)
meta_kernel_fn<cpp_dtype, context> #define _PD_KERNEL_INSTANTIATION_2(meta_kernel_fn, \
#define _PD_KERNEL_INSTANTIATION_2( \ backend, \
meta_kernel_fn, backend, context, cpp_dtype, ...) \ context, \
template decltype(meta_kernel_fn<cpp_dtype, context>) \ kernel_instantiation_macro, \
meta_kernel_fn<cpp_dtype, context>; \ cpp_dtype, \
PD_EXPAND(_PD_KERNEL_INSTANTIATION_1( \ ...) \
meta_kernel_fn, backend, context, __VA_ARGS__)) kernel_instantiation_macro(meta_kernel_fn, cpp_dtype, context) \
#define _PD_KERNEL_INSTANTIATION_3( \ PD_EXPAND(_PD_KERNEL_INSTANTIATION_1(meta_kernel_fn, \
meta_kernel_fn, backend, context, cpp_dtype, ...) \ backend, \
template decltype(meta_kernel_fn<cpp_dtype, context>) \ context, \
meta_kernel_fn<cpp_dtype, context>; \ kernel_instantiation_macro, \
PD_EXPAND(_PD_KERNEL_INSTANTIATION_2( \ __VA_ARGS__))
meta_kernel_fn, backend, context, __VA_ARGS__)) #define _PD_KERNEL_INSTANTIATION_3(meta_kernel_fn, \
#define _PD_KERNEL_INSTANTIATION_4( \ backend, \
meta_kernel_fn, backend, context, cpp_dtype, ...) \ context, \
template decltype(meta_kernel_fn<cpp_dtype, context>) \ kernel_instantiation_macro, \
meta_kernel_fn<cpp_dtype, context>; \ cpp_dtype, \
PD_EXPAND(_PD_KERNEL_INSTANTIATION_3( \ ...) \
meta_kernel_fn, backend, context, __VA_ARGS__)) kernel_instantiation_macro(meta_kernel_fn, cpp_dtype, context) \
#define _PD_KERNEL_INSTANTIATION_5( \ PD_EXPAND(_PD_KERNEL_INSTANTIATION_2(meta_kernel_fn, \
meta_kernel_fn, backend, context, cpp_dtype, ...) \ backend, \
template decltype(meta_kernel_fn<cpp_dtype, context>) \ context, \
meta_kernel_fn<cpp_dtype, context>; \ kernel_instantiation_macro, \
PD_EXPAND(_PD_KERNEL_INSTANTIATION_4( \ __VA_ARGS__))
meta_kernel_fn, backend, context, __VA_ARGS__)) #define _PD_KERNEL_INSTANTIATION_4(meta_kernel_fn, \
#define _PD_KERNEL_INSTANTIATION_6( \ backend, \
meta_kernel_fn, backend, context, cpp_dtype, ...) \ context, \
template decltype(meta_kernel_fn<cpp_dtype, context>) \ kernel_instantiation_macro, \
meta_kernel_fn<cpp_dtype, context>; \ cpp_dtype, \
PD_EXPAND(_PD_KERNEL_INSTANTIATION_5( \ ...) \
meta_kernel_fn, backend, context, __VA_ARGS__)) kernel_instantiation_macro(meta_kernel_fn, cpp_dtype, context) \
#define _PD_KERNEL_INSTANTIATION_7( \ PD_EXPAND(_PD_KERNEL_INSTANTIATION_3(meta_kernel_fn, \
meta_kernel_fn, backend, context, cpp_dtype, ...) \ backend, \
template decltype(meta_kernel_fn<cpp_dtype, context>) \ context, \
meta_kernel_fn<cpp_dtype, context>; \ kernel_instantiation_macro, \
PD_EXPAND(_PD_KERNEL_INSTANTIATION_6( \ __VA_ARGS__))
meta_kernel_fn, backend, context, __VA_ARGS__)) #define _PD_KERNEL_INSTANTIATION_5(meta_kernel_fn, \
#define _PD_KERNEL_INSTANTIATION_8( \ backend, \
meta_kernel_fn, backend, context, cpp_dtype, ...) \ context, \
template decltype(meta_kernel_fn<cpp_dtype, context>) \ kernel_instantiation_macro, \
meta_kernel_fn<cpp_dtype, context>; \ cpp_dtype, \
PD_EXPAND(_PD_KERNEL_INSTANTIATION_7( \ ...) \
meta_kernel_fn, backend, context, __VA_ARGS__)) kernel_instantiation_macro(meta_kernel_fn, cpp_dtype, context) \
#define _PD_KERNEL_INSTANTIATION_9( \ PD_EXPAND(_PD_KERNEL_INSTANTIATION_4(meta_kernel_fn, \
meta_kernel_fn, backend, context, cpp_dtype, ...) \ backend, \
template decltype(meta_kernel_fn<cpp_dtype, context>) \ context, \
meta_kernel_fn<cpp_dtype, context>; \ kernel_instantiation_macro, \
PD_EXPAND(_PD_KERNEL_INSTANTIATION_8( \ __VA_ARGS__))
meta_kernel_fn, backend, context, __VA_ARGS__)) #define _PD_KERNEL_INSTANTIATION_6(meta_kernel_fn, \
#define _PD_KERNEL_INSTANTIATION_10( \ backend, \
meta_kernel_fn, backend, context, cpp_dtype, ...) \ context, \
template decltype(meta_kernel_fn<cpp_dtype, context>) \ kernel_instantiation_macro, \
meta_kernel_fn<cpp_dtype, context>; \ cpp_dtype, \
PD_EXPAND(_PD_KERNEL_INSTANTIATION_9( \ ...) \
meta_kernel_fn, backend, context, __VA_ARGS__)) kernel_instantiation_macro(meta_kernel_fn, cpp_dtype, context) \
#define _PD_KERNEL_INSTANTIATION_11( \ PD_EXPAND(_PD_KERNEL_INSTANTIATION_5(meta_kernel_fn, \
meta_kernel_fn, backend, context, cpp_dtype, ...) \ backend, \
template decltype(meta_kernel_fn<cpp_dtype, context>) \ context, \
meta_kernel_fn<cpp_dtype, context>; \ kernel_instantiation_macro, \
PD_EXPAND(_PD_KERNEL_INSTANTIATION_10( \ __VA_ARGS__))
meta_kernel_fn, backend, context, __VA_ARGS__)) #define _PD_KERNEL_INSTANTIATION_7(meta_kernel_fn, \
#define _PD_KERNEL_INSTANTIATION_12( \ backend, \
meta_kernel_fn, backend, context, cpp_dtype, ...) \ context, \
template decltype(meta_kernel_fn<cpp_dtype, context>) \ kernel_instantiation_macro, \
meta_kernel_fn<cpp_dtype, context>; \ cpp_dtype, \
PD_EXPAND(_PD_KERNEL_INSTANTIATION_11( \ ...) \
meta_kernel_fn, backend, context, __VA_ARGS__)) kernel_instantiation_macro(meta_kernel_fn, cpp_dtype, context) \
#define _PD_KERNEL_INSTANTIATION_13( \ PD_EXPAND(_PD_KERNEL_INSTANTIATION_6(meta_kernel_fn, \
meta_kernel_fn, backend, context, cpp_dtype, ...) \ backend, \
template decltype(meta_kernel_fn<cpp_dtype, context>) \ context, \
meta_kernel_fn<cpp_dtype, context>; \ kernel_instantiation_macro, \
PD_EXPAND(_PD_KERNEL_INSTANTIATION_12( \ __VA_ARGS__))
meta_kernel_fn, backend, context, __VA_ARGS__)) #define _PD_KERNEL_INSTANTIATION_8(meta_kernel_fn, \
#define _PD_KERNEL_INSTANTIATION_14( \ backend, \
meta_kernel_fn, backend, context, cpp_dtype, ...) \ context, \
template decltype(meta_kernel_fn<cpp_dtype, context>) \ kernel_instantiation_macro, \
meta_kernel_fn<cpp_dtype, context>; \ cpp_dtype, \
PD_EXPAND(_PD_KERNEL_INSTANTIATION_13( \ ...) \
meta_kernel_fn, backend, context, __VA_ARGS__)) kernel_instantiation_macro(meta_kernel_fn, cpp_dtype, context) \
#define _PD_KERNEL_INSTANTIATION_15( \ PD_EXPAND(_PD_KERNEL_INSTANTIATION_7(meta_kernel_fn, \
meta_kernel_fn, backend, context, cpp_dtype, ...) \ backend, \
template decltype(meta_kernel_fn<cpp_dtype, context>) \ context, \
meta_kernel_fn<cpp_dtype, context>; \ kernel_instantiation_macro, \
PD_EXPAND(_PD_KERNEL_INSTANTIATION_14( \ __VA_ARGS__))
meta_kernel_fn, backend, context, __VA_ARGS__)) #define _PD_KERNEL_INSTANTIATION_9(meta_kernel_fn, \
backend, \
context, \
kernel_instantiation_macro, \
cpp_dtype, \
...) \
kernel_instantiation_macro(meta_kernel_fn, cpp_dtype, context) \
PD_EXPAND(_PD_KERNEL_INSTANTIATION_8(meta_kernel_fn, \
backend, \
context, \
kernel_instantiation_macro, \
__VA_ARGS__))
#define _PD_KERNEL_INSTANTIATION_10(meta_kernel_fn, \
backend, \
context, \
kernel_instantiation_macro, \
cpp_dtype, \
...) \
kernel_instantiation_macro(meta_kernel_fn, cpp_dtype, context) \
PD_EXPAND(_PD_KERNEL_INSTANTIATION_9(meta_kernel_fn, \
backend, \
context, \
kernel_instantiation_macro, \
__VA_ARGS__))
#define _PD_KERNEL_INSTANTIATION_11(meta_kernel_fn, \
backend, \
context, \
kernel_instantiation_macro, \
cpp_dtype, \
...) \
kernel_instantiation_macro(meta_kernel_fn, cpp_dtype, context) \
PD_EXPAND(_PD_KERNEL_INSTANTIATION_10(meta_kernel_fn, \
backend, \
context, \
kernel_instantiation_macro, \
__VA_ARGS__))
#define _PD_KERNEL_INSTANTIATION_12(meta_kernel_fn, \
backend, \
context, \
kernel_instantiation_macro, \
cpp_dtype, \
...) \
kernel_instantiation_macro(meta_kernel_fn, cpp_dtype, context) \
PD_EXPAND(_PD_KERNEL_INSTANTIATION_11(meta_kernel_fn, \
backend, \
context, \
kernel_instantiation_macro, \
__VA_ARGS__))
#define _PD_KERNEL_INSTANTIATION_13(meta_kernel_fn, \
backend, \
context, \
kernel_instantiation_macro, \
cpp_dtype, \
...) \
kernel_instantiation_macro(meta_kernel_fn, cpp_dtype, context) \
PD_EXPAND(_PD_KERNEL_INSTANTIATION_12(meta_kernel_fn, \
backend, \
context, \
kernel_instantiation_macro, \
__VA_ARGS__))
#define _PD_KERNEL_INSTANTIATION_14(meta_kernel_fn, \
backend, \
context, \
kernel_instantiation_macro, \
cpp_dtype, \
...) \
kernel_instantiation_macro(meta_kernel_fn, cpp_dtype, context) \
PD_EXPAND(_PD_KERNEL_INSTANTIATION_13(meta_kernel_fn, \
backend, \
context, \
kernel_instantiation_macro, \
__VA_ARGS__))
#define _PD_KERNEL_INSTANTIATION_15(meta_kernel_fn, \
backend, \
context, \
kernel_instantiation_macro, \
cpp_dtype, \
...) \
kernel_instantiation_macro(meta_kernel_fn, cpp_dtype, context) \
PD_EXPAND(_PD_KERNEL_INSTANTIATION_14(meta_kernel_fn, \
backend, \
context, \
kernel_instantiation_macro, \
__VA_ARGS__))
#define PD_KERNEL_REGISTRAR_INIT(reg_type, \ #define PD_KERNEL_REGISTRAR_INIT(reg_type, \
kernel_name, \ kernel_name, \
backend, \ backend, \
context, \ context, \
layout, \ layout, \
args_def_fn, \ args_def_fn, \
meta_kernel_fn, \ meta_kernel_fn, \
...) \ arg_parse_functor_macro, \
PD_EXPAND(_PD_KERNEL_REGISTRAR_INIT(PD_NARGS(__VA_ARGS__), \ kernel_unfold_macro, \
reg_type, \ variadic_kernel_unfold_marco, \
kernel_name, \ ...) \
backend, \ PD_EXPAND(_PD_KERNEL_REGISTRAR_INIT(PD_NARGS(__VA_ARGS__), \
context, \ reg_type, \
layout, \ kernel_name, \
args_def_fn, \ backend, \
meta_kernel_fn, \ context, \
layout, \
args_def_fn, \
meta_kernel_fn, \
arg_parse_functor_macro, \
kernel_unfold_macro, \
variadic_kernel_unfold_marco, \
__VA_ARGS__)) __VA_ARGS__))
// clang-format off // clang-format off
...@@ -585,6 +737,9 @@ struct KernelRegistrar { ...@@ -585,6 +737,9 @@ struct KernelRegistrar {
layout, \ layout, \
args_def_fn, \ args_def_fn, \
meta_kernel_fn, \ meta_kernel_fn, \
arg_parse_functor_macro, \
kernel_unfold_macro, \
variadic_kernel_unfold_marco, \
...) \ ...) \
PD_EXPAND(PD_CONCATENATE(_PD_KERNEL_REGISTRAR_INIT_, N) ( \ PD_EXPAND(PD_CONCATENATE(_PD_KERNEL_REGISTRAR_INIT_, N) ( \
reg_type, \ reg_type, \
...@@ -595,11 +750,14 @@ struct KernelRegistrar { ...@@ -595,11 +750,14 @@ struct KernelRegistrar {
PD_ID, \ PD_ID, \
args_def_fn, \ args_def_fn, \
meta_kernel_fn, \ meta_kernel_fn, \
arg_parse_functor_macro, \
kernel_unfold_macro, \
variadic_kernel_unfold_marco, \
__VA_ARGS__)) __VA_ARGS__))
// clang-format on // clang-format on
#define _PD_KERNEL_REGISTRAR_INIT_1(reg_type, \ #define _PD_CREATE_REGISTRAR_OBJECT(reg_type, \
kernel_name, \ kernel_name, \
backend, \ backend, \
context, \ context, \
...@@ -607,453 +765,564 @@ struct KernelRegistrar { ...@@ -607,453 +765,564 @@ struct KernelRegistrar {
registrar_id, \ registrar_id, \
args_def_fn, \ args_def_fn, \
meta_kernel_fn, \ meta_kernel_fn, \
arg_parse_functor_macro, \
kernel_unfold_macro, \
variadic_kernel_unfold_marco, \
cpp_dtype) \ cpp_dtype) \
static const ::phi::KernelRegistrar PD_CONCATENATE( \ static const ::phi::KernelRegistrar PD_CONCATENATE( \
__reg_phi_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \ __reg_phi_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \
reg_type, \ reg_type, \
#kernel_name, \ #kernel_name, \
#backend, \ #backend, \
DATALAYOUT(layout), \ DATA_LAYOUT(layout), \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \ ::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \
::phi::KernelArgsParseFunctor< \ arg_parse_functor_macro(meta_kernel_fn, cpp_dtype, context), \
decltype(&meta_kernel_fn<cpp_dtype, context>)>::Parse, \
args_def_fn, \ args_def_fn, \
PHI_KERNEL(meta_kernel_fn<cpp_dtype, context>), \ kernel_unfold_macro(meta_kernel_fn<cpp_dtype, context>), \
PHI_VARIADIC_KERNEL(meta_kernel_fn<cpp_dtype, context>)); \ variadic_kernel_unfold_marco(meta_kernel_fn<cpp_dtype, context>));
#define _PD_KERNEL_REGISTRAR_INIT_1(reg_type, \
kernel_name, \
backend, \
context, \
layout, \
registrar_id, \
args_def_fn, \
meta_kernel_fn, \
arg_parse_functor_macro, \
kernel_unfold_macro, \
variadic_kernel_unfold_marco, \
cpp_dtype) \
_PD_CREATE_REGISTRAR_OBJECT(reg_type, \
kernel_name, \
backend, \
context, \
layout, \
registrar_id, \
args_def_fn, \
meta_kernel_fn, \
arg_parse_functor_macro, \
kernel_unfold_macro, \
variadic_kernel_unfold_marco, \
cpp_dtype) \
int TouchKernelSymbolFor_##kernel_name##_##backend##_##layout() { return 0; } int TouchKernelSymbolFor_##kernel_name##_##backend##_##layout() { return 0; }
#define _PD_KERNEL_REGISTRAR_INIT_2(reg_type, \ #define _PD_KERNEL_REGISTRAR_INIT_2(reg_type, \
kernel_name, \ kernel_name, \
backend, \ backend, \
context, \ context, \
layout, \ layout, \
registrar_id, \ registrar_id, \
args_def_fn, \ args_def_fn, \
meta_kernel_fn, \ meta_kernel_fn, \
cpp_dtype, \ arg_parse_functor_macro, \
...) \ kernel_unfold_macro, \
static const ::phi::KernelRegistrar PD_CONCATENATE( \ variadic_kernel_unfold_marco, \
__reg_phi_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \ cpp_dtype, \
reg_type, \ ...) \
#kernel_name, \ _PD_CREATE_REGISTRAR_OBJECT(reg_type, \
#backend, \ kernel_name, \
DATALAYOUT(layout), \ backend, \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \ context, \
::phi::KernelArgsParseFunctor< \ layout, \
decltype(&meta_kernel_fn<cpp_dtype, context>)>::Parse, \ registrar_id, \
args_def_fn, \ args_def_fn, \
PHI_KERNEL(meta_kernel_fn<cpp_dtype, context>), \ meta_kernel_fn, \
PHI_VARIADIC_KERNEL(meta_kernel_fn<cpp_dtype, context>)); \ arg_parse_functor_macro, \
PD_EXPAND(_PD_KERNEL_REGISTRAR_INIT_1(reg_type, \ kernel_unfold_macro, \
kernel_name, \ variadic_kernel_unfold_marco, \
backend, \ cpp_dtype) \
context, \ PD_EXPAND(_PD_KERNEL_REGISTRAR_INIT_1(reg_type, \
layout, \ kernel_name, \
PD_ID, \ backend, \
args_def_fn, \ context, \
meta_kernel_fn, \ layout, \
PD_ID, \
args_def_fn, \
meta_kernel_fn, \
arg_parse_functor_macro, \
kernel_unfold_macro, \
variadic_kernel_unfold_marco, \
__VA_ARGS__)) __VA_ARGS__))
#define _PD_KERNEL_REGISTRAR_INIT_3(reg_type, \ #define _PD_KERNEL_REGISTRAR_INIT_3(reg_type, \
kernel_name, \ kernel_name, \
backend, \ backend, \
context, \ context, \
layout, \ layout, \
registrar_id, \ registrar_id, \
args_def_fn, \ args_def_fn, \
meta_kernel_fn, \ meta_kernel_fn, \
cpp_dtype, \ arg_parse_functor_macro, \
...) \ kernel_unfold_macro, \
static const ::phi::KernelRegistrar PD_CONCATENATE( \ variadic_kernel_unfold_marco, \
__reg_phi_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \ cpp_dtype, \
reg_type, \ ...) \
#kernel_name, \ _PD_CREATE_REGISTRAR_OBJECT(reg_type, \
#backend, \ kernel_name, \
DATALAYOUT(layout), \ backend, \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \ context, \
::phi::KernelArgsParseFunctor< \ layout, \
decltype(&meta_kernel_fn<cpp_dtype, context>)>::Parse, \ registrar_id, \
args_def_fn, \ args_def_fn, \
PHI_KERNEL(meta_kernel_fn<cpp_dtype, context>), \ meta_kernel_fn, \
PHI_VARIADIC_KERNEL(meta_kernel_fn<cpp_dtype, context>)); \ arg_parse_functor_macro, \
PD_EXPAND(_PD_KERNEL_REGISTRAR_INIT_2(reg_type, \ kernel_unfold_macro, \
kernel_name, \ variadic_kernel_unfold_marco, \
backend, \ cpp_dtype) \
context, \ PD_EXPAND(_PD_KERNEL_REGISTRAR_INIT_2(reg_type, \
layout, \ kernel_name, \
PD_ID, \ backend, \
args_def_fn, \ context, \
meta_kernel_fn, \ layout, \
PD_ID, \
args_def_fn, \
meta_kernel_fn, \
arg_parse_functor_macro, \
kernel_unfold_macro, \
variadic_kernel_unfold_marco, \
__VA_ARGS__)) __VA_ARGS__))
#define _PD_KERNEL_REGISTRAR_INIT_4(reg_type, \ #define _PD_KERNEL_REGISTRAR_INIT_4(reg_type, \
kernel_name, \ kernel_name, \
backend, \ backend, \
context, \ context, \
layout, \ layout, \
registrar_id, \ registrar_id, \
args_def_fn, \ args_def_fn, \
meta_kernel_fn, \ meta_kernel_fn, \
cpp_dtype, \ arg_parse_functor_macro, \
...) \ kernel_unfold_macro, \
static const ::phi::KernelRegistrar PD_CONCATENATE( \ variadic_kernel_unfold_marco, \
__reg_phi_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \ cpp_dtype, \
reg_type, \ ...) \
#kernel_name, \ _PD_CREATE_REGISTRAR_OBJECT(reg_type, \
#backend, \ kernel_name, \
DATALAYOUT(layout), \ backend, \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \ context, \
::phi::KernelArgsParseFunctor< \ layout, \
decltype(&meta_kernel_fn<cpp_dtype, context>)>::Parse, \ registrar_id, \
args_def_fn, \ args_def_fn, \
PHI_KERNEL(meta_kernel_fn<cpp_dtype, context>), \ meta_kernel_fn, \
PHI_VARIADIC_KERNEL(meta_kernel_fn<cpp_dtype, context>)); \ arg_parse_functor_macro, \
PD_EXPAND(_PD_KERNEL_REGISTRAR_INIT_3(reg_type, \ kernel_unfold_macro, \
kernel_name, \ variadic_kernel_unfold_marco, \
backend, \ cpp_dtype) \
context, \ PD_EXPAND(_PD_KERNEL_REGISTRAR_INIT_3(reg_type, \
layout, \ kernel_name, \
PD_ID, \ backend, \
args_def_fn, \ context, \
meta_kernel_fn, \ layout, \
PD_ID, \
args_def_fn, \
meta_kernel_fn, \
arg_parse_functor_macro, \
kernel_unfold_macro, \
variadic_kernel_unfold_marco, \
__VA_ARGS__)) __VA_ARGS__))
#define _PD_KERNEL_REGISTRAR_INIT_5(reg_type, \ #define _PD_KERNEL_REGISTRAR_INIT_5(reg_type, \
kernel_name, \ kernel_name, \
backend, \ backend, \
context, \ context, \
layout, \ layout, \
registrar_id, \ registrar_id, \
args_def_fn, \ args_def_fn, \
meta_kernel_fn, \ meta_kernel_fn, \
cpp_dtype, \ arg_parse_functor_macro, \
...) \ kernel_unfold_macro, \
static const ::phi::KernelRegistrar PD_CONCATENATE( \ variadic_kernel_unfold_marco, \
__reg_phi_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \ cpp_dtype, \
reg_type, \ ...) \
#kernel_name, \ _PD_CREATE_REGISTRAR_OBJECT(reg_type, \
#backend, \ kernel_name, \
DATALAYOUT(layout), \ backend, \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \ context, \
::phi::KernelArgsParseFunctor< \ layout, \
decltype(&meta_kernel_fn<cpp_dtype, context>)>::Parse, \ registrar_id, \
args_def_fn, \ args_def_fn, \
PHI_KERNEL(meta_kernel_fn<cpp_dtype, context>), \ meta_kernel_fn, \
PHI_VARIADIC_KERNEL(meta_kernel_fn<cpp_dtype, context>)); \ arg_parse_functor_macro, \
PD_EXPAND(_PD_KERNEL_REGISTRAR_INIT_4(reg_type, \ kernel_unfold_macro, \
kernel_name, \ variadic_kernel_unfold_marco, \
backend, \ cpp_dtype) \
context, \ PD_EXPAND(_PD_KERNEL_REGISTRAR_INIT_4(reg_type, \
layout, \ kernel_name, \
PD_ID, \ backend, \
args_def_fn, \ context, \
meta_kernel_fn, \ layout, \
PD_ID, \
args_def_fn, \
meta_kernel_fn, \
arg_parse_functor_macro, \
kernel_unfold_macro, \
variadic_kernel_unfold_marco, \
__VA_ARGS__)) __VA_ARGS__))
#define _PD_KERNEL_REGISTRAR_INIT_6(reg_type, \ #define _PD_KERNEL_REGISTRAR_INIT_6(reg_type, \
kernel_name, \ kernel_name, \
backend, \ backend, \
context, \ context, \
layout, \ layout, \
registrar_id, \ registrar_id, \
args_def_fn, \ args_def_fn, \
meta_kernel_fn, \ meta_kernel_fn, \
cpp_dtype, \ arg_parse_functor_macro, \
...) \ kernel_unfold_macro, \
static const ::phi::KernelRegistrar PD_CONCATENATE( \ variadic_kernel_unfold_marco, \
__reg_phi_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \ cpp_dtype, \
reg_type, \ ...) \
#kernel_name, \ _PD_CREATE_REGISTRAR_OBJECT(reg_type, \
#backend, \ kernel_name, \
DATALAYOUT(layout), \ backend, \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \ context, \
::phi::KernelArgsParseFunctor< \ layout, \
decltype(&meta_kernel_fn<cpp_dtype, context>)>::Parse, \ registrar_id, \
args_def_fn, \ args_def_fn, \
PHI_KERNEL(meta_kernel_fn<cpp_dtype, context>), \ meta_kernel_fn, \
PHI_VARIADIC_KERNEL(meta_kernel_fn<cpp_dtype, context>)); \ arg_parse_functor_macro, \
PD_EXPAND(_PD_KERNEL_REGISTRAR_INIT_5(reg_type, \ kernel_unfold_macro, \
kernel_name, \ variadic_kernel_unfold_marco, \
backend, \ cpp_dtype) \
context, \ PD_EXPAND(_PD_KERNEL_REGISTRAR_INIT_5(reg_type, \
layout, \ kernel_name, \
PD_ID, \ backend, \
args_def_fn, \ context, \
meta_kernel_fn, \ layout, \
PD_ID, \
args_def_fn, \
meta_kernel_fn, \
arg_parse_functor_macro, \
kernel_unfold_macro, \
variadic_kernel_unfold_marco, \
__VA_ARGS__)) __VA_ARGS__))
#define _PD_KERNEL_REGISTRAR_INIT_7(reg_type, \ #define _PD_KERNEL_REGISTRAR_INIT_7(reg_type, \
kernel_name, \ kernel_name, \
backend, \ backend, \
context, \ context, \
layout, \ layout, \
registrar_id, \ registrar_id, \
args_def_fn, \ args_def_fn, \
meta_kernel_fn, \ meta_kernel_fn, \
cpp_dtype, \ arg_parse_functor_macro, \
...) \ kernel_unfold_macro, \
static const ::phi::KernelRegistrar PD_CONCATENATE( \ variadic_kernel_unfold_marco, \
__reg_phi_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \ cpp_dtype, \
reg_type, \ ...) \
#kernel_name, \ _PD_CREATE_REGISTRAR_OBJECT(reg_type, \
#backend, \ kernel_name, \
DATALAYOUT(layout), \ backend, \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \ context, \
::phi::KernelArgsParseFunctor< \ layout, \
decltype(&meta_kernel_fn<cpp_dtype, context>)>::Parse, \ registrar_id, \
args_def_fn, \ args_def_fn, \
PHI_KERNEL(meta_kernel_fn<cpp_dtype, context>), \ meta_kernel_fn, \
PHI_VARIADIC_KERNEL(meta_kernel_fn<cpp_dtype, context>)); \ arg_parse_functor_macro, \
PD_EXPAND(_PD_KERNEL_REGISTRAR_INIT_6(reg_type, \ kernel_unfold_macro, \
kernel_name, \ variadic_kernel_unfold_marco, \
backend, \ cpp_dtype) \
context, \ PD_EXPAND(_PD_KERNEL_REGISTRAR_INIT_6(reg_type, \
layout, \ kernel_name, \
PD_ID, \ backend, \
args_def_fn, \ context, \
meta_kernel_fn, \ layout, \
PD_ID, \
args_def_fn, \
meta_kernel_fn, \
arg_parse_functor_macro, \
kernel_unfold_macro, \
variadic_kernel_unfold_marco, \
__VA_ARGS__)) __VA_ARGS__))
#define _PD_KERNEL_REGISTRAR_INIT_8(reg_type, \ #define _PD_KERNEL_REGISTRAR_INIT_8(reg_type, \
kernel_name, \ kernel_name, \
backend, \ backend, \
context, \ context, \
layout, \ layout, \
registrar_id, \ registrar_id, \
args_def_fn, \ args_def_fn, \
meta_kernel_fn, \ meta_kernel_fn, \
cpp_dtype, \ arg_parse_functor_macro, \
...) \ kernel_unfold_macro, \
static const ::phi::KernelRegistrar PD_CONCATENATE( \ variadic_kernel_unfold_marco, \
__reg_phi_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \ cpp_dtype, \
reg_type, \ ...) \
#kernel_name, \ _PD_CREATE_REGISTRAR_OBJECT(reg_type, \
#backend, \ kernel_name, \
DATALAYOUT(layout), \ backend, \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \ context, \
::phi::KernelArgsParseFunctor< \ layout, \
decltype(&meta_kernel_fn<cpp_dtype, context>)>::Parse, \ registrar_id, \
args_def_fn, \ args_def_fn, \
PHI_KERNEL(meta_kernel_fn<cpp_dtype, context>), \ meta_kernel_fn, \
PHI_VARIADIC_KERNEL(meta_kernel_fn<cpp_dtype, context>)); \ arg_parse_functor_macro, \
PD_EXPAND(_PD_KERNEL_REGISTRAR_INIT_7(reg_type, \ kernel_unfold_macro, \
kernel_name, \ variadic_kernel_unfold_marco, \
backend, \ cpp_dtype) \
context, \ PD_EXPAND(_PD_KERNEL_REGISTRAR_INIT_7(reg_type, \
layout, \ kernel_name, \
PD_ID, \ backend, \
args_def_fn, \ context, \
meta_kernel_fn, \ layout, \
PD_ID, \
args_def_fn, \
meta_kernel_fn, \
arg_parse_functor_macro, \
kernel_unfold_macro, \
variadic_kernel_unfold_marco, \
__VA_ARGS__)) __VA_ARGS__))
#define _PD_KERNEL_REGISTRAR_INIT_9(reg_type, \ #define _PD_KERNEL_REGISTRAR_INIT_9(reg_type, \
kernel_name, \ kernel_name, \
backend, \ backend, \
context, \ context, \
layout, \ layout, \
registrar_id, \ registrar_id, \
args_def_fn, \ args_def_fn, \
meta_kernel_fn, \ meta_kernel_fn, \
cpp_dtype, \ arg_parse_functor_macro, \
...) \ kernel_unfold_macro, \
static const ::phi::KernelRegistrar PD_CONCATENATE( \ variadic_kernel_unfold_marco, \
__reg_phi_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \ cpp_dtype, \
reg_type, \ ...) \
#kernel_name, \ _PD_CREATE_REGISTRAR_OBJECT(reg_type, \
#backend, \ kernel_name, \
DATALAYOUT(layout), \ backend, \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \ context, \
::phi::KernelArgsParseFunctor< \ layout, \
decltype(&meta_kernel_fn<cpp_dtype, context>)>::Parse, \ registrar_id, \
args_def_fn, \ args_def_fn, \
PHI_KERNEL(meta_kernel_fn<cpp_dtype, context>), \ meta_kernel_fn, \
PHI_VARIADIC_KERNEL(meta_kernel_fn<cpp_dtype, context>)); \ arg_parse_functor_macro, \
PD_EXPAND(_PD_KERNEL_REGISTRAR_INIT_8(reg_type, \ kernel_unfold_macro, \
kernel_name, \ variadic_kernel_unfold_marco, \
backend, \ cpp_dtype) \
context, \ PD_EXPAND(_PD_KERNEL_REGISTRAR_INIT_8(reg_type, \
layout, \ kernel_name, \
PD_ID, \ backend, \
args_def_fn, \ context, \
meta_kernel_fn, \ layout, \
PD_ID, \
args_def_fn, \
meta_kernel_fn, \
arg_parse_functor_macro, \
kernel_unfold_macro, \
variadic_kernel_unfold_marco, \
__VA_ARGS__)) __VA_ARGS__))
#define _PD_KERNEL_REGISTRAR_INIT_10(reg_type, \ #define _PD_KERNEL_REGISTRAR_INIT_10(reg_type, \
kernel_name, \ kernel_name, \
backend, \ backend, \
context, \ context, \
layout, \ layout, \
registrar_id, \ registrar_id, \
args_def_fn, \ args_def_fn, \
meta_kernel_fn, \ meta_kernel_fn, \
cpp_dtype, \ arg_parse_functor_macro, \
...) \ kernel_unfold_macro, \
static const ::phi::KernelRegistrar PD_CONCATENATE( \ variadic_kernel_unfold_marco, \
__reg_phi_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \ cpp_dtype, \
reg_type, \ ...) \
#kernel_name, \ _PD_CREATE_REGISTRAR_OBJECT(reg_type, \
#backend, \ kernel_name, \
DATALAYOUT(layout), \ backend, \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \ context, \
::phi::KernelArgsParseFunctor< \ layout, \
decltype(&meta_kernel_fn<cpp_dtype, context>)>::Parse, \ registrar_id, \
args_def_fn, \ args_def_fn, \
PHI_KERNEL(meta_kernel_fn<cpp_dtype, context>), \ meta_kernel_fn, \
PHI_VARIADIC_KERNEL(meta_kernel_fn<cpp_dtype, context>)); \ arg_parse_functor_macro, \
PD_EXPAND(_PD_KERNEL_REGISTRAR_INIT_9(reg_type, \ kernel_unfold_macro, \
kernel_name, \ variadic_kernel_unfold_marco, \
backend, \ cpp_dtype) \
context, \ PD_EXPAND(_PD_KERNEL_REGISTRAR_INIT_9(reg_type, \
layout, \ kernel_name, \
PD_ID, \ backend, \
args_def_fn, \ context, \
meta_kernel_fn, \ layout, \
PD_ID, \
args_def_fn, \
meta_kernel_fn, \
arg_parse_functor_macro, \
kernel_unfold_macro, \
variadic_kernel_unfold_marco, \
__VA_ARGS__)) __VA_ARGS__))
#define _PD_KERNEL_REGISTRAR_INIT_11(reg_type, \ #define _PD_KERNEL_REGISTRAR_INIT_11(reg_type, \
kernel_name, \ kernel_name, \
backend, \ backend, \
context, \ context, \
layout, \ layout, \
registrar_id, \ registrar_id, \
args_def_fn, \ args_def_fn, \
meta_kernel_fn, \ meta_kernel_fn, \
cpp_dtype, \ arg_parse_functor_macro, \
...) \ kernel_unfold_macro, \
static const ::phi::KernelRegistrar PD_CONCATENATE( \ variadic_kernel_unfold_marco, \
__reg_phi_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \ cpp_dtype, \
reg_type, \ ...) \
#kernel_name, \ _PD_CREATE_REGISTRAR_OBJECT(reg_type, \
#backend, \ kernel_name, \
DATALAYOUT(layout), \ backend, \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \ context, \
::phi::KernelArgsParseFunctor< \ layout, \
decltype(&meta_kernel_fn<cpp_dtype, context>)>::Parse, \ registrar_id, \
args_def_fn, \ args_def_fn, \
PHI_KERNEL(meta_kernel_fn<cpp_dtype, context>), \ meta_kernel_fn, \
PHI_VARIADIC_KERNEL(meta_kernel_fn<cpp_dtype, context>)); \ arg_parse_functor_macro, \
PD_EXPAND(_PD_KERNEL_REGISTRAR_INIT_10(reg_type, \ kernel_unfold_macro, \
kernel_name, \ variadic_kernel_unfold_marco, \
backend, \ cpp_dtype) \
context, \ PD_EXPAND(_PD_KERNEL_REGISTRAR_INIT_10(reg_type, \
layout, \ kernel_name, \
PD_ID, \ backend, \
args_def_fn, \ context, \
meta_kernel_fn, \ layout, \
PD_ID, \
args_def_fn, \
meta_kernel_fn, \
arg_parse_functor_macro, \
kernel_unfold_macro, \
variadic_kernel_unfold_marco, \
__VA_ARGS__)) __VA_ARGS__))
#define _PD_KERNEL_REGISTRAR_INIT_12(reg_type, \ #define _PD_KERNEL_REGISTRAR_INIT_12(reg_type, \
kernel_name, \ kernel_name, \
backend, \ backend, \
context, \ context, \
layout, \ layout, \
registrar_id, \ registrar_id, \
args_def_fn, \ args_def_fn, \
meta_kernel_fn, \ meta_kernel_fn, \
cpp_dtype, \ arg_parse_functor_macro, \
...) \ kernel_unfold_macro, \
static const ::phi::KernelRegistrar PD_CONCATENATE( \ variadic_kernel_unfold_marco, \
__reg_phi_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \ cpp_dtype, \
reg_type, \ ...) \
#kernel_name, \ _PD_CREATE_REGISTRAR_OBJECT(reg_type, \
#backend, \ kernel_name, \
DATALAYOUT(layout), \ backend, \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \ context, \
::phi::KernelArgsParseFunctor< \ layout, \
decltype(&meta_kernel_fn<cpp_dtype, context>)>::Parse, \ registrar_id, \
args_def_fn, \ args_def_fn, \
PHI_KERNEL(meta_kernel_fn<cpp_dtype, context>), \ meta_kernel_fn, \
PHI_VARIADIC_KERNEL(meta_kernel_fn<cpp_dtype, context>)); \ arg_parse_functor_macro, \
PD_EXPAND(_PD_KERNEL_REGISTRAR_INIT_11(reg_type, \ kernel_unfold_macro, \
kernel_name, \ variadic_kernel_unfold_marco, \
backend, \ cpp_dtype) \
context, \ PD_EXPAND(_PD_KERNEL_REGISTRAR_INIT_11(reg_type, \
layout, \ kernel_name, \
PD_ID, \ backend, \
args_def_fn, \ context, \
meta_kernel_fn, \ layout, \
PD_ID, \
args_def_fn, \
meta_kernel_fn, \
arg_parse_functor_macro, \
kernel_unfold_macro, \
variadic_kernel_unfold_marco, \
__VA_ARGS__)) __VA_ARGS__))
#define _PD_KERNEL_REGISTRAR_INIT_13(reg_type, \ #define _PD_KERNEL_REGISTRAR_INIT_13(reg_type, \
kernel_name, \ kernel_name, \
backend, \ backend, \
context, \ context, \
layout, \ layout, \
registrar_id, \ registrar_id, \
args_def_fn, \ args_def_fn, \
meta_kernel_fn, \ meta_kernel_fn, \
cpp_dtype, \ arg_parse_functor_macro, \
...) \ kernel_unfold_macro, \
static const ::phi::KernelRegistrar PD_CONCATENATE( \ variadic_kernel_unfold_marco, \
__reg_phi_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \ cpp_dtype, \
reg_type, \ ...) \
#kernel_name, \ _PD_CREATE_REGISTRAR_OBJECT(reg_type, \
#backend, \ kernel_name, \
DATALAYOUT(layout), \ backend, \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \ context, \
::phi::KernelArgsParseFunctor< \ layout, \
decltype(&meta_kernel_fn<cpp_dtype, context>)>::Parse, \ registrar_id, \
args_def_fn, \ args_def_fn, \
PHI_KERNEL(meta_kernel_fn<cpp_dtype, context>), \ meta_kernel_fn, \
PHI_VARIADIC_KERNEL(meta_kernel_fn<cpp_dtype, context>)); \ arg_parse_functor_macro, \
PD_EXPAND(_PD_KERNEL_REGISTRAR_INIT_12(reg_type, \ kernel_unfold_macro, \
kernel_name, \ variadic_kernel_unfold_marco, \
backend, \ cpp_dtype) \
context, \ PD_EXPAND(_PD_KERNEL_REGISTRAR_INIT_12(reg_type, \
layout, \ kernel_name, \
PD_ID, \ backend, \
args_def_fn, \ context, \
meta_kernel_fn, \ layout, \
PD_ID, \
args_def_fn, \
meta_kernel_fn, \
arg_parse_functor_macro, \
kernel_unfold_macro, \
variadic_kernel_unfold_marco, \
__VA_ARGS__)) __VA_ARGS__))
#define _PD_KERNEL_REGISTRAR_INIT_14(reg_type, \ #define _PD_KERNEL_REGISTRAR_INIT_14(reg_type, \
kernel_name, \ kernel_name, \
backend, \ backend, \
context, \ context, \
layout, \ layout, \
registrar_id, \ registrar_id, \
args_def_fn, \ args_def_fn, \
meta_kernel_fn, \ meta_kernel_fn, \
cpp_dtype, \ arg_parse_functor_macro, \
...) \ kernel_unfold_macro, \
static const ::phi::KernelRegistrar PD_CONCATENATE( \ variadic_kernel_unfold_marco, \
__reg_phi_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \ cpp_dtype, \
reg_type, \ ...) \
#kernel_name, \ _PD_CREATE_REGISTRAR_OBJECT(reg_type, \
#backend, \ kernel_name, \
DATALAYOUT(layout), \ backend, \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \ context, \
::phi::KernelArgsParseFunctor< \ layout, \
decltype(&meta_kernel_fn<cpp_dtype, context>)>::Parse, \ registrar_id, \
args_def_fn, \ args_def_fn, \
PHI_KERNEL(meta_kernel_fn<cpp_dtype, context>), \ meta_kernel_fn, \
PHI_VARIADIC_KERNEL(meta_kernel_fn<cpp_dtype, context>)); \ arg_parse_functor_macro, \
PD_EXPAND(_PD_KERNEL_REGISTRAR_INIT_13(reg_type, \ kernel_unfold_macro, \
kernel_name, \ variadic_kernel_unfold_marco, \
backend, \ cpp_dtype) \
context, \ PD_EXPAND(_PD_KERNEL_REGISTRAR_INIT_13(reg_type, \
layout, \ kernel_name, \
PD_ID, \ backend, \
args_def_fn, \ context, \
meta_kernel_fn, \ layout, \
PD_ID, \
args_def_fn, \
meta_kernel_fn, \
arg_parse_functor_macro, \
kernel_unfold_macro, \
variadic_kernel_unfold_marco, \
__VA_ARGS__)) __VA_ARGS__))
#define _PD_KERNEL_REGISTRAR_INIT_15(reg_type, \ #define _PD_KERNEL_REGISTRAR_INIT_15(reg_type, \
kernel_name, \ kernel_name, \
backend, \ backend, \
context, \ context, \
layout, \ layout, \
registrar_id, \ registrar_id, \
args_def_fn, \ args_def_fn, \
meta_kernel_fn, \ meta_kernel_fn, \
cpp_dtype, \ arg_parse_functor_macro, \
...) \ kernel_unfold_macro, \
static const ::phi::KernelRegistrar PD_CONCATENATE( \ variadic_kernel_unfold_marco, \
__reg_phi_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \ cpp_dtype, \
reg_type, \ ...) \
#kernel_name, \ _PD_CREATE_REGISTRAR_OBJECT(reg_type, \
#backend, \ kernel_name, \
DATALAYOUT(layout), \ backend, \
::paddle::experimental::CppTypeToDataType<cpp_dtype>::Type(), \ context, \
::phi::KernelArgsParseFunctor< \ layout, \
decltype(&meta_kernel_fn<cpp_dtype, context>)>::Parse, \ registrar_id, \
args_def_fn, \ args_def_fn, \
PHI_KERNEL(meta_kernel_fn<cpp_dtype, context>), \ meta_kernel_fn, \
PHI_VARIADIC_KERNEL(meta_kernel_fn<cpp_dtype, context>)); \ arg_parse_functor_macro, \
PD_EXPAND(_PD_KERNEL_REGISTRAR_INIT_14(reg_type, \ kernel_unfold_macro, \
kernel_name, \ variadic_kernel_unfold_marco, \
backend, \ cpp_dtype) \
context, \ PD_EXPAND(_PD_KERNEL_REGISTRAR_INIT_14(reg_type, \
layout, \ kernel_name, \
PD_ID, \ backend, \
args_def_fn, \ context, \
meta_kernel_fn, \ layout, \
PD_ID, \
args_def_fn, \
meta_kernel_fn, \
arg_parse_functor_macro, \
kernel_unfold_macro, \
variadic_kernel_unfold_marco, \
__VA_ARGS__)) __VA_ARGS__))
/** PD_REGISTER_GENERAL_KERNEL /** PD_REGISTER_GENERAL_KERNEL
* *
...@@ -1085,7 +1354,7 @@ struct KernelRegistrar { ...@@ -1085,7 +1354,7 @@ struct KernelRegistrar {
reg_type, \ reg_type, \
#kernel_name, \ #kernel_name, \
#backend, \ #backend, \
DATALAYOUT(layout), \ DATA_LAYOUT(layout), \
::phi::KernelArgsParseFunctor<decltype(&kernel_fn)>::Parse, \ ::phi::KernelArgsParseFunctor<decltype(&kernel_fn)>::Parse, \
&__PD_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout, \ &__PD_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout, \
PHI_KERNEL(kernel_fn), \ PHI_KERNEL(kernel_fn), \
...@@ -1105,7 +1374,7 @@ struct KernelRegistrar { ...@@ -1105,7 +1374,7 @@ struct KernelRegistrar {
reg_type, \ reg_type, \
#kernel_name, \ #kernel_name, \
#backend, \ #backend, \
DATALAYOUT(layout), \ DATA_LAYOUT(layout), \
::phi::KernelArgsParseFunctor<decltype(&kernel_fn)>::Parse, \ ::phi::KernelArgsParseFunctor<decltype(&kernel_fn)>::Parse, \
&__PD_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout, \ &__PD_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout, \
PHI_KERNEL(kernel_fn), \ PHI_KERNEL(kernel_fn), \
...@@ -1144,6 +1413,10 @@ struct KernelRegistrar { ...@@ -1144,6 +1413,10 @@ struct KernelRegistrar {
::phi::backend##Context, \ ::phi::backend##Context, \
layout, \ layout, \
meta_kernel_fn, \ meta_kernel_fn, \
FUNCTION_KERNEL_INSTANTIATION, \
ARG_PARSE_FUNCTOR, \
PHI_KERNEL, \
PHI_VARIADIC_KERNEL, \
__VA_ARGS__) __VA_ARGS__)
/** PD_REGISTER_PLUGIN_KERNEL /** PD_REGISTER_PLUGIN_KERNEL
...@@ -1159,6 +1432,10 @@ struct KernelRegistrar { ...@@ -1159,6 +1432,10 @@ struct KernelRegistrar {
::phi::CustomContext, \ ::phi::CustomContext, \
layout, \ layout, \
meta_kernel_fn, \ meta_kernel_fn, \
FUNCTION_KERNEL_INSTANTIATION, \
ARG_PARSE_FUNCTOR, \
PHI_KERNEL, \
PHI_VARIADIC_KERNEL, \
__VA_ARGS__) __VA_ARGS__)
} // namespace phi } // namespace phi
...@@ -35,4 +35,6 @@ KernelSignature SaveCombineOpArgumentMapping( ...@@ -35,4 +35,6 @@ KernelSignature SaveCombineOpArgumentMapping(
} // namespace phi } // namespace phi
PD_REGISTER_BASE_KERNEL_NAME(save_combine, save_combine_tensor);
PD_REGISTER_ARG_MAPPING_FN(save_combine, phi::SaveCombineOpArgumentMapping); PD_REGISTER_ARG_MAPPING_FN(save_combine, phi::SaveCombineOpArgumentMapping);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册