From e92e3aabe81533f57215552a62ea5fedf2ab445c Mon Sep 17 00:00:00 2001 From: YuanRisheng Date: Wed, 8 Feb 2023 10:32:45 +0800 Subject: [PATCH] [PHI]Unify Fluid and PHI kernel (#49328) * unify_kernel * fix compile bugs * modify macro name * perfect code according comment * fix compile bugs * fix compile bugs * fix ci bugs * fix ci bug * fix ci bugs * fix ci bugs * modify code according comment * rm conv_fusion_op --- cmake/operators.cmake | 6 +- .../interpreter/interpreter_util.cc | 10 +- paddle/fluid/framework/op_registry.h | 32 + paddle/fluid/framework/operator.cc | 37 +- paddle/fluid/framework/operator.h | 3 +- paddle/fluid/imperative/prepared_operator.cc | 48 +- paddle/fluid/imperative/tracer.cc | 8 +- .../custom_device_common_op_registry.cc | 2 +- .../fluid/operators/fused/conv_fusion_op.cu | 575 ------- paddle/fluid/operators/rank_loss_op.cc | 21 +- paddle/fluid/operators/rank_loss_op.h | 4 +- paddle/phi/core/compat/arg_map_context.h | 2 + paddle/phi/core/kernel_factory.cc | 15 + paddle/phi/core/kernel_factory.h | 17 +- paddle/phi/core/kernel_registry.h | 1371 ++++++++++------- paddle/phi/ops/compat/save_combine_sig.cc | 2 + 16 files changed, 982 insertions(+), 1171 deletions(-) delete mode 100644 paddle/fluid/operators/fused/conv_fusion_op.cu diff --git a/cmake/operators.cmake b/cmake/operators.cmake index fcdaa8da1f..0b5e886922 100644 --- a/cmake/operators.cmake +++ b/cmake/operators.cmake @@ -28,11 +28,7 @@ endfunction() function(find_phi_register FILENAME ADD_PATH PATTERN) # set op_name to OUTPUT - set(options "") - set(oneValueArgs "") - set(multiValueArgs "") file(READ ${FILENAME} CONTENT) - string( REGEX MATCH @@ -402,6 +398,7 @@ function(op_library TARGET) set(op_name "") # Add PHI Kernel Registry Message find_phi_register(${cc_src} ${pybind_file} "PD_REGISTER_KERNEL") + find_phi_register(${cc_src} ${pybind_file} "PD_REGISTER_STRUCT_KERNEL") find_phi_register(${cc_src} ${pybind_file} "PD_REGISTER_GENERAL_KERNEL") find_register(${cc_src} "REGISTER_OPERATOR" op_name) if(NOT ${op_name} EQUAL "") @@ -453,6 +450,7 @@ function(op_library TARGET) set(op_name "") # Add PHI Kernel Registry Message find_phi_register(${cu_src} ${pybind_file} "PD_REGISTER_KERNEL") + find_phi_register(${cu_src} ${pybind_file} "PD_REGISTER_STRUCT_KERNEL") find_phi_register(${cu_src} ${pybind_file} "PD_REGISTER_GENERAL_KERNEL") find_register(${cu_src} "REGISTER_OP_CUDA_KERNEL" op_name) if(NOT ${op_name} EQUAL "") diff --git a/paddle/fluid/framework/new_executor/interpreter/interpreter_util.cc b/paddle/fluid/framework/new_executor/interpreter/interpreter_util.cc index 507f0302c5..79c763da55 100644 --- a/paddle/fluid/framework/new_executor/interpreter/interpreter_util.cc +++ b/paddle/fluid/framework/new_executor/interpreter/interpreter_util.cc @@ -827,7 +827,9 @@ bool BuildOpFuncList(const platform::Place& place, } // step 5. run kernel - if (run_phi_kernel) { + if (run_phi_kernel && + op_func_node.phi_kernel_->GetKernelRegisteredType() == + phi::KernelRegisteredType::FUNCTION) { phi::KernelContext phi_kernel_context; op_with_kernel->BuildPhiKernelContext( runtime_context, dev_ctx, &phi_kernel_context); @@ -838,6 +840,12 @@ bool BuildOpFuncList(const platform::Place& place, op_with_kernel->PhiKernelSignature(), &phi_kernel_context); } + } else if (run_phi_kernel && + op_func_node.phi_kernel_->GetKernelRegisteredType() == + phi::KernelRegisteredType::STRUCTURE) { + ExecutionContext execution_context( + *op_with_kernel, *runtime_scope, *dev_ctx, runtime_context); + (*op_func_node.phi_kernel_)(&execution_context); } else { // the place of exec_ctx maybe has changed. if (!skip_run) { diff --git a/paddle/fluid/framework/op_registry.h b/paddle/fluid/framework/op_registry.h index cfb9892e55..0fe1c2abea 100644 --- a/paddle/fluid/framework/op_registry.h +++ b/paddle/fluid/framework/op_registry.h @@ -34,6 +34,7 @@ limitations under the License. */ #include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/shape_inference.h" +#include "paddle/phi/core/kernel_registry.h" namespace paddle { namespace framework { @@ -484,5 +485,36 @@ struct OpKernelRegistrarFunctorEx +struct StructKernelImpl { + static void Compute(phi::KernelContext* ctx) { + auto exe_ctx = static_cast(ctx); + StructureKernel().Compute(*exe_ctx); + } +}; + +#define PHI_STRUCTURE_KERNEL(...) \ + ::paddle::framework::StructKernelImpl<__VA_ARGS__>::Compute +#define PHI_STRUCTURE_VARIADIC_KERNEL(...) nullptr +#define STRUCTURE_ARG_PARSE_FUNCTOR(...) nullptr + +#define STRUCTURE_KERNEL_INSTANTIATION( \ + meta_kernel_structure, cpp_dtype, context) \ + template class meta_kernel_structure; + +#define PD_REGISTER_STRUCT_KERNEL( \ + kernel_name, backend, layout, meta_kernel_structure, ...) \ + _PD_REGISTER_KERNEL(::phi::RegType::INNER, \ + kernel_name, \ + backend, \ + ::phi::backend##Context, \ + layout, \ + meta_kernel_structure, \ + STRUCTURE_KERNEL_INSTANTIATION, \ + STRUCTURE_ARG_PARSE_FUNCTOR, \ + PHI_STRUCTURE_KERNEL, \ + PHI_STRUCTURE_VARIADIC_KERNEL, \ + __VA_ARGS__) + } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/operator.cc b/paddle/fluid/framework/operator.cc index fe863381b5..76fe54dc62 100644 --- a/paddle/fluid/framework/operator.cc +++ b/paddle/fluid/framework/operator.cc @@ -1689,15 +1689,18 @@ void OperatorWithKernel::RunImpl(const Scope& scope, std::string phi_kernel_name; if (phi::KernelFactory::Instance().HasCompatiblePhiKernel(type_)) { if (kernel_signature_ == nullptr || phi_kernel_ == nullptr) { - kernel_signature_.reset(new phi::KernelSignature( - std::move(GetExpectedPhiKernelArgs(exe_ctx)))); - VLOG(6) << *kernel_signature_.get(); + if (phi::KernelFactory::Instance().HasStructuredKernel(type_)) { + kernel_signature_.reset(new phi::KernelSignature(type_.c_str())); + } else { + kernel_signature_.reset(new phi::KernelSignature( + std::move(GetExpectedPhiKernelArgs(exe_ctx)))); + } + VLOG(6) << *kernel_signature_.get(); + phi_kernel_name = kernel_signature_->name; kernel_type_.reset( new OpKernelType(std::move(InnerGetExpectedKernelType(exe_ctx)))); dev_ctx = pool.Get(kernel_type_->place_); - - phi_kernel_name = kernel_signature_->name; // NOTE(Liu-xiandong): The register kernel used KP have library_type[KP], // But the default library_type is Plain, so we need to modify the // library_type here, otherwise it can't work. @@ -1753,7 +1756,6 @@ void OperatorWithKernel::RunImpl(const Scope& scope, } } else { phi_kernel_name = kernel_signature_->name; - // NOTE(jiahongyu): The registered MKLDNN kernel have library_type = // LibraryType::kMKLDNN and data_layout_ = DataLayout::ONEDNN. But the default // values are kPlain, so we need to modify the library_type and data_layout_ @@ -1939,7 +1941,8 @@ void OperatorWithKernel::RunImpl(const Scope& scope, platform::TracerEventType::OperatorInner, 1, platform::EventRole::kInnerOp); - if (run_phi_kernel_) { + if (run_phi_kernel_ && phi_kernel_->GetKernelRegisteredType() == + phi::KernelRegisteredType::FUNCTION) { phi::KernelContext phi_kernel_context; if (enable_cache_runtime_context_ && !need_prepare_phi_data_ && !need_prepare_data_) { @@ -1977,6 +1980,11 @@ void OperatorWithKernel::RunImpl(const Scope& scope, BuildPhiKernelContext(*runtime_ctx, dev_ctx, &phi_kernel_context); (*phi_kernel_)(&phi_kernel_context); } + } else if (run_phi_kernel_ && phi_kernel_->GetKernelRegisteredType() == + phi::KernelRegisteredType::STRUCTURE) { + ExecutionContext execution_context( + *this, exec_scope, *dev_ctx, *runtime_ctx); + (*phi_kernel_)(&execution_context); } else { (*kernel_func_)( ExecutionContext(*this, exec_scope, *dev_ctx, *runtime_ctx)); @@ -2147,14 +2155,18 @@ OpKernelType OperatorWithKernel::InnerGetExpectedKernelType( phi::KernelKey OperatorWithKernel::ChoosePhiKernel( const ExecutionContext& ctx) const { - kernel_signature_.reset( - new phi::KernelSignature(std::move(GetExpectedPhiKernelArgs(ctx)))); + std::string phi_kernel_name; + if (phi::KernelFactory::Instance().HasStructuredKernel(type_)) { + kernel_signature_.reset(new phi::KernelSignature(type_.c_str())); + } else { + kernel_signature_.reset( + new phi::KernelSignature(std::move(GetExpectedPhiKernelArgs(ctx)))); + } VLOG(6) << *kernel_signature_.get(); - + phi_kernel_name = kernel_signature_->name; kernel_type_.reset( new OpKernelType(std::move(InnerGetExpectedKernelType(ctx)))); - auto phi_kernel_name = kernel_signature_->name; auto phi_kernel_key = TransOpKernelTypeToPhiKernelKey(*kernel_type_.get()); phi_kernel_.reset(new phi::Kernel(phi::KernelFactory::Instance().SelectKernel( phi_kernel_name, phi_kernel_key))); @@ -2616,7 +2628,8 @@ Scope* OperatorWithKernel::PrepareData( } }; - if (run_phi_kernel_) { + if (run_phi_kernel_ && phi_kernel_->GetKernelRegisteredType() == + phi::KernelRegisteredType::FUNCTION) { const auto& input_names = kernel_signature_->input_names; const auto& input_defs = phi_kernel_->args_def().input_defs(); PADDLE_ENFORCE_EQ(input_names.size(), diff --git a/paddle/fluid/framework/operator.h b/paddle/fluid/framework/operator.h index 955f30f340..f2de07db96 100644 --- a/paddle/fluid/framework/operator.h +++ b/paddle/fluid/framework/operator.h @@ -41,6 +41,7 @@ limitations under the License. */ #include "paddle/phi/core/compat/arg_map_context.h" #include "paddle/phi/core/compat/op_utils.h" +#include "paddle/phi/core/kernel_context.h" #include "paddle/phi/core/kernel_factory.h" #include "paddle/utils/flat_hash_map.h" @@ -290,7 +291,7 @@ class OperatorBase { const platform::Place& place) const = 0; }; -class ExecutionContext { +class ExecutionContext : public phi::KernelContext { public: ExecutionContext(const OperatorBase& op, const Scope& scope, diff --git a/paddle/fluid/imperative/prepared_operator.cc b/paddle/fluid/imperative/prepared_operator.cc index 32a4515624..df315ba97e 100644 --- a/paddle/fluid/imperative/prepared_operator.cc +++ b/paddle/fluid/imperative/prepared_operator.cc @@ -273,17 +273,23 @@ PreparedOp PrepareImpl( kernel_signature = (*arg_map_fn)( framework::ExecutionArgumentMappingContext(dygraph_exe_ctx)); } else { - default_kernel_signature = - default_phi_kernel_sig_map.GetNullable(op.Type()); - if (default_kernel_signature) { + if (phi::KernelFactory::Instance().HasStructuredKernel(op.Type())) { has_phi_kernel = true; - kernel_signature = *default_kernel_signature; + kernel_signature = phi::KernelSignature(op.Type().c_str()); + } else { + default_kernel_signature = + default_phi_kernel_sig_map.GetNullable(op.Type()); + if (default_kernel_signature) { + has_phi_kernel = true; + kernel_signature = *default_kernel_signature; + } } } if (has_phi_kernel) { VLOG(6) << kernel_signature; phi_kernel_name = kernel_signature.name; + // NOTE(Liu-xiandong): The register kernel used KP have library_type[KP], // But the default library_type is Plain, so we need to modify the // library_type here, otherwise it can't work. @@ -648,6 +654,7 @@ static void PreparedOpRunPtImpl( const phi::KernelSignature* default_kernel_signature, const phi::KernelSignature& kernel_signature, const phi::Kernel& phi_kernel, + const framework::RuntimeContext& ctx, platform::DeviceContext* dev_ctx, const NameVarMap& ins, const NameVarMap& outs, @@ -678,19 +685,25 @@ static void PreparedOpRunPtImpl( 1, platform::EventRole::kInnerOp); - PreparePhiData(phi_kernel, kernel_signature, ins); - - phi::KernelContext phi_kernel_context; - BuildDygraphPhiKernelContext(kernel_signature, - phi_kernel, - ins, - outs, - attrs, - default_attrs, - dev_ctx, - &phi_kernel_context); + if (phi_kernel.GetKernelRegisteredType() == + phi::KernelRegisteredType::FUNCTION) { + PreparePhiData(phi_kernel, kernel_signature, ins); + phi::KernelContext phi_kernel_context; + BuildDygraphPhiKernelContext(kernel_signature, + phi_kernel, + ins, + outs, + attrs, + default_attrs, + dev_ctx, + &phi_kernel_context); - phi_kernel(&phi_kernel_context); + phi_kernel(&phi_kernel_context); + } else { + DygraphExecutionContext exe_ctx( + op, empty_scope, *dev_ctx, ctx, ins, outs, attrs, default_attrs); + phi_kernel(&exe_ctx); + } } if (FLAGS_check_nan_inf) { @@ -722,6 +735,7 @@ void PreparedOp::Run(const NameVarMap& ins, default_kernel_signature_, kernel_signature_, phi_kernel_, + ctx_, dev_ctx_, ins, outs, @@ -753,6 +767,7 @@ void PreparedOp::Run(const NameVarMap& ins, default_kernel_signature_, kernel_signature_, phi_kernel_, + ctx_, dev_ctx_, ins, outs, @@ -784,6 +799,7 @@ void PreparedOp::Run(const NameVarMap& ins, default_kernel_signature_, kernel_signature_, phi_kernel_, + ctx_, dev_ctx_, ins, outs, diff --git a/paddle/fluid/imperative/tracer.cc b/paddle/fluid/imperative/tracer.cc index 08f73c51fe..bfe6cea6e0 100644 --- a/paddle/fluid/imperative/tracer.cc +++ b/paddle/fluid/imperative/tracer.cc @@ -530,8 +530,12 @@ phi::KernelSignature Tracer::GetExpectedKernelSignature( "This op type:`%s` is not a OperatorWithKernel, only " "OperatorWithKernel can get KernelSignature", type)); - return phi::KernelSignature( - std::move(opbase_with_kernel->GetExpectedPhiKernelArgs(dygraph_exe_ctx))); + if (phi::KernelFactory::Instance().HasStructuredKernel(type)) { + return phi::KernelSignature(op->Type().c_str()); + } else { + return phi::KernelSignature(std::move( + opbase_with_kernel->GetExpectedPhiKernelArgs(dygraph_exe_ctx))); + } } } // namespace imperative diff --git a/paddle/fluid/operators/custom_device_common_op_registry.cc b/paddle/fluid/operators/custom_device_common_op_registry.cc index bbb75d4183..b921fd8216 100644 --- a/paddle/fluid/operators/custom_device_common_op_registry.cc +++ b/paddle/fluid/operators/custom_device_common_op_registry.cc @@ -34,7 +34,7 @@ limitations under the License. */ phi::RegType::INNER, \ #kernel_name, \ dev_type, \ - DATALAYOUT(layout), \ + DATA_LAYOUT(layout), \ ::phi::KernelArgsParseFunctor::Parse, \ [](const phi::KernelKey& kernel_key, phi::Kernel* kernel) {}, \ PHI_KERNEL(kernel_fn), \ diff --git a/paddle/fluid/operators/fused/conv_fusion_op.cu b/paddle/fluid/operators/fused/conv_fusion_op.cu deleted file mode 100644 index dee0c1837a..0000000000 --- a/paddle/fluid/operators/fused/conv_fusion_op.cu +++ /dev/null @@ -1,575 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include - -#include "paddle/fluid/framework/conv_search_cache.h" -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/operators/conv_op.h" -#include "paddle/fluid/platform/device/gpu/gpu_dnn.h" -#include "paddle/phi/kernels/funcs/padding.h" -#include "paddle/phi/kernels/gpudnn/conv_gpudnn_info.h" - -DECLARE_int64(cudnn_exhaustive_search_times); - -namespace paddle { -namespace operators { - -#if PADDLE_WITH_HIP || CUDNN_VERSION >= 7100 -using ScopedTensorDescriptor = platform::ScopedTensorDescriptor; -using ScopedFilterDescriptor = platform::ScopedFilterDescriptor; -using ScopedConvolutionDescriptor = platform::ScopedConvolutionDescriptor; -using ScopedActivationDescriptor = platform::ScopedActivationDescriptor; -using DataLayout = platform::DataLayout; -using framework::AlgorithmsCache; -using framework::ConvSearchCache; -using framework::SearchFuseResult; - -template -using ScalingParamType = typename platform::CudnnDataType::ScalingParamType; - -template -class CUDNNConvFusionOpKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - auto& dev_ctx = ctx.template device_context(); - auto* input = ctx.Input("Input"); - auto* filter = ctx.Input("Filter"); - auto* bias = ctx.Input("Bias"); - auto* residual = ctx.Input("ResidualData"); - auto* output = ctx.Output("Output"); - dev_ctx.template Alloc(output, output->numel() * sizeof(T)); - - std::vector strides = ctx.Attr>("strides"); - std::vector paddings = ctx.Attr>("paddings"); - std::vector dilations = ctx.Attr>("dilations"); - const std::string activation = ctx.Attr("activation"); - std::string data_format = ctx.Attr("data_format"); - PADDLE_ENFORCE_NE( - data_format, - "NHWC", - platform::errors::PermissionDenied( - "Operator(Conv2DFusion) in cuDNN only supports data format of " - "channel first (NCHW) now. But received: data_format = '%s'.", - data_format)); - - int groups = ctx.Attr("groups"); - int64_t user_workspace_size = - static_cast(ctx.Attr("workspace_size_MB")); - bool exhaustive_search = - FLAGS_cudnn_exhaustive_search || ctx.Attr("exhaustive_search"); - - const T* filter_data = filter->data(); - const T* bias_data = bias->data(); - - const std::string padding_algorithm = - ctx.Attr("padding_algorithm"); - - phi::DenseTensor transformed_input_channel(input->dtype()); - phi::DenseTensor transformed_output(output->dtype()); - transformed_input_channel = *input; - transformed_output = *output; - T* output_data = transformed_output.data(); - - const T* residual_data = residual ? residual->data() : output_data; - - // update padding and dilation - auto in_dims = transformed_input_channel.dims(); - auto filter_dims = filter->dims(); - framework::DDim in_data_dims = phi::slice_ddim(in_dims, 2, in_dims.size()); - - framework::DDim filter_data_dims = - phi::slice_ddim(filter_dims, 2, filter_dims.size()); - std::vector ksize = phi::vectorize(filter_data_dims); - UpdatePaddingAndDilation( - &paddings, &dilations, padding_algorithm, in_data_dims, strides, ksize); - - int data_dim = strides.size(); // 2d or 3d - bool is_sys_pad = phi::funcs::IsSymmetricPadding(paddings, data_dim); - - phi::DenseTensor transformed_input; - std::vector padding_common(data_dim, 0); - if (!is_sys_pad) { - std::vector padding_diff(data_dim); - std::vector new_input_shape_vec(data_dim + 2); - new_input_shape_vec[0] = transformed_input_channel.dims()[0]; - new_input_shape_vec[1] = transformed_input_channel.dims()[1]; - - std::vector input_pad(transformed_input_channel.dims().size() * 2, - 0); - for (size_t i = 0; i < data_dim; ++i) { - padding_diff[i] = std::abs(paddings[2 * i] - paddings[2 * i + 1]); - padding_common[i] = std::min(paddings[2 * i], paddings[2 * i + 1]); - new_input_shape_vec[i + 2] = - transformed_input_channel.dims()[i + 2] + padding_diff[i]; - input_pad[2 * i + 4] = paddings[2 * i] - padding_common[i]; - input_pad[2 * i + 4 + 1] = paddings[2 * i + 1] - padding_common[i]; - } - framework::DDim new_input_shape(phi::make_ddim(new_input_shape_vec)); - transformed_input.Resize(new_input_shape); - auto& dev_ctx = ctx.template device_context(); - - transformed_input = - ctx.AllocateTmpTensor(new_input_shape, dev_ctx); - const int rank = transformed_input_channel.dims().size(); - T pad_value(0.0); - switch (rank) { - case 4: { - phi::funcs::PadFunction( - dev_ctx, - input_pad, - transformed_input_channel, - pad_value, - &transformed_input); - } break; - case 5: { - phi::funcs::PadFunction( - dev_ctx, - input_pad, - transformed_input_channel, - pad_value, - &transformed_input); - } break; - default: - PADDLE_THROW(platform::errors::PermissionDenied( - "Operator Conv2DFusion expects Input to be a 4-D or 5-D " - "phi::DenseTensor. " - "But received the actual dimension = %d, shape = [%s].", - rank, - transformed_input_channel.dims())); - } - - } else { - transformed_input = transformed_input_channel; - if (paddings.size() == data_dim) { - for (size_t i = 0; i < data_dim; ++i) { - padding_common[i] = paddings[i]; - } - } else { - for (size_t i = 0; i < data_dim; ++i) { - padding_common[i] = paddings[2 * i]; - } - } - } - - const T* input_data = transformed_input.data(); - - // ------------------- cudnn descriptors --------------------- - ScopedTensorDescriptor input_desc; - ScopedTensorDescriptor output_desc; - ScopedFilterDescriptor filter_desc; - ScopedTensorDescriptor bias_desc; - ScopedConvolutionDescriptor conv_desc; - ScopedActivationDescriptor act_desc; - DataLayout layout = DataLayout::kNCHW; - if (input->dims().size() == 5) { - layout = DataLayout::kNCDHW; - } -#ifdef PADDLE_WITH_HIP - miopenConvolutionDescriptor_t cudnn_conv_desc = - conv_desc.descriptor(padding_common, strides, dilations); - PADDLE_ENFORCE_GPU_SUCCESS( - platform::dynload::miopenSetConvolutionGroupCount(cudnn_conv_desc, - groups)); - // Now only support NCHW - std::vector bias_dim = { - 1, static_cast(transformed_output.dims()[1]), 1, 1}; - miopenTensorDescriptor_t cudnn_input_desc = input_desc.descriptor( - layout, phi::vectorize(transformed_input.dims())); - miopenTensorDescriptor_t cudnn_output_desc = output_desc.descriptor( - layout, phi::vectorize(transformed_output.dims())); - miopenTensorDescriptor_t cudnn_filter_desc = - filter_desc.descriptor(layout, phi::vectorize(filter->dims())); - miopenTensorDescriptor_t cudnn_bias_desc = - bias_desc.descriptor(layout, bias_dim); - miopenActivationDescriptor_t cudnn_act_desc = - act_desc.descriptor(activation); - - miopenConvFwdAlgorithm_t algo; - auto handle = dev_ctx.cudnn_handle(); - auto workspace_handle = dev_ctx.cudnn_workspace_handle(); - - auto x_dims = phi::vectorize(transformed_input.dims()); - auto f_dims = phi::vectorize(filter->dims()); - - size_t workspace_size = 0; - PADDLE_ENFORCE_GPU_SUCCESS( - platform::dynload::miopenConvolutionForwardGetWorkSpaceSize( - handle, - cudnn_filter_desc, - cudnn_input_desc, - cudnn_conv_desc, - cudnn_output_desc, - &workspace_size)); - int find_count; - miopenConvAlgoPerf_t find_result; - auto cudnn_find_func = [&](void* cudnn_workspace_ptr) { - PADDLE_ENFORCE_GPU_SUCCESS( - platform::dynload::miopenFindConvolutionForwardAlgorithm( - handle, - cudnn_input_desc, - input_data, - cudnn_filter_desc, - filter_data, - cudnn_conv_desc, - cudnn_output_desc, - output_data, - phi::kNUM_CUDNN_FWD_ALGS, - &find_count, - &find_result, - cudnn_workspace_ptr, - workspace_size, - false)); - }; - workspace_handle.RunFuncSync(cudnn_find_func, workspace_size); - algo = find_result.fwd_algo; - VLOG(3) << "cuDNN forward algo " << algo; - - { - ScalingParamType alpha = 1.0f, beta = 0.0f; - auto cudnn_func = [&](void* cudnn_workspace) { - PADDLE_ENFORCE_GPU_SUCCESS( - platform::dynload::miopenConvolutionForward(handle, - &alpha, - cudnn_input_desc, - input_data, - cudnn_filter_desc, - filter_data, - cudnn_conv_desc, - algo, - &beta, - cudnn_output_desc, - output_data, - cudnn_workspace, - workspace_size)); - }; - workspace_handle.RunFunc(cudnn_func, workspace_size); - PADDLE_ENFORCE_GPU_SUCCESS( - platform::dynload::miopenConvolutionForwardBias(handle, - &alpha, - cudnn_bias_desc, - bias_data, - &beta, - cudnn_output_desc, - output_data)); - if (activation != "identity") { - PADDLE_ENFORCE_GPU_SUCCESS( - platform::dynload::miopenActivationForward(handle, - cudnn_act_desc, - &alpha, - cudnn_output_desc, - output_data, - &beta, - cudnn_output_desc, - output_data)); - } - if (residual) { - PADDLE_ENFORCE_GPU_SUCCESS( - platform::dynload::miopenOpTensor(handle, - miopenTensorOpAdd, - &alpha, - cudnn_output_desc, - output_data, - &alpha, - cudnn_output_desc, - residual_data, - &beta, - cudnn_output_desc, - output_data)); - } - } -#else // PADDLE_WITH_HIP - cudnnConvolutionDescriptor_t cudnn_conv_desc = - conv_desc.descriptor(padding_common, strides, dilations); - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cudnnSetConvolutionGroupCount( - cudnn_conv_desc, groups)); - - cudnnTensorDescriptor_t cudnn_input_desc = input_desc.descriptor( - layout, phi::vectorize(transformed_input.dims())); - cudnnTensorDescriptor_t cudnn_output_desc = output_desc.descriptor( - layout, phi::vectorize(transformed_output.dims())); - cudnnFilterDescriptor_t cudnn_filter_desc = - filter_desc.descriptor(layout, phi::vectorize(filter->dims())); - // Now only support NCHW - std::vector bias_dim = { - 1, static_cast(transformed_output.dims()[1]), 1, 1}; - cudnnTensorDescriptor_t cudnn_bias_desc = - bias_desc.descriptor(layout, bias_dim); - cudnnActivationDescriptor_t cudnn_act_desc = - act_desc.descriptor(activation); - - // ------------------- cudnn conv workspace --------------------- - size_t workspace_size_in_bytes; // final workspace to allocate. - size_t workspace_size_limit = 0; - if (FLAGS_conv_workspace_size_limit > 0 || user_workspace_size > 0) { - int64_t max_user_size = - std::min(static_cast(FLAGS_conv_workspace_size_limit), - user_workspace_size); - workspace_size_limit = max_user_size * 1024 * 1024; - } - - // ------------------- cudnn conv algorithm --------------------- - cudnnConvolutionFwdAlgo_t algo; - auto handle = dev_ctx.cudnn_handle(); - auto workspace_handle = dev_ctx.cudnn_workspace_handle(); - auto dtype = platform::CudnnDataType::type; - - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cudnnSetConvolutionMathType( - cudnn_conv_desc, CUDNN_DEFAULT_MATH)); - if (dtype == CUDNN_DATA_HALF) { - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cudnnSetConvolutionMathType( - cudnn_conv_desc, CUDNN_TENSOR_OP_MATH)); - } -#if CUDA_VERSION >= 11000 && CUDNN_VERSION >= 8000 - if (!platform::allow_tf32_cudnn) { - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cudnnSetConvolutionMathType( - cudnn_conv_desc, CUDNN_FMA_MATH)); - } -#endif // CUDA_VERSION >= 11000 && CUDNN_VERSION >= 8000 - - auto x_dims = phi::vectorize(transformed_input.dims()); - auto f_dims = phi::vectorize(filter->dims()); - if (!exhaustive_search) { -#if CUDNN_VERSION >= 8000 - int perf_count; - int best_algo_idx = 0; - size_t tmp_size = 0; - std::unique_ptr perf_results( - new cudnnConvolutionFwdAlgoPerf_t[phi::kNUM_CUDNN_FWD_ALGS]); - PADDLE_ENFORCE_GPU_SUCCESS( - platform::dynload::cudnnGetConvolutionForwardAlgorithm_v7( - handle, - cudnn_input_desc, - cudnn_filter_desc, - cudnn_conv_desc, - cudnn_output_desc, - phi::kNUM_CUDNN_FWD_ALGS, - &perf_count, - perf_results.get())); - algo = (perf_results.get())[best_algo_idx].algo; -#else - PADDLE_ENFORCE_GPU_SUCCESS( - platform::dynload::cudnnGetConvolutionForwardAlgorithm( - handle, - cudnn_input_desc, - cudnn_filter_desc, - cudnn_conv_desc, - cudnn_output_desc, - CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT, - workspace_size_limit, - &algo)); -#endif - PADDLE_ENFORCE_GPU_SUCCESS( - platform::dynload::cudnnGetConvolutionForwardWorkspaceSize( - handle, - cudnn_input_desc, - cudnn_filter_desc, - cudnn_conv_desc, - cudnn_output_desc, - algo, - &workspace_size_in_bytes)); - if (workspace_size_in_bytes > workspace_size_limit) - workspace_size_limit = workspace_size_in_bytes; - VLOG(3) << "cuDNN forward algo " << algo; - } else { - std::function()> search_func = - [&]() -> SearchFuseResult { - int returned_algo_count; - SearchFuseResult fwd_result; - std::array - fwd_perf_stat; - auto cudnn_find_func = [&](void* cudnn_workspace) { - PADDLE_ENFORCE_GPU_SUCCESS( - platform::dynload::cudnnFindConvolutionForwardAlgorithmEx( - handle, - cudnn_input_desc, - input_data, - cudnn_filter_desc, - filter_data, - cudnn_conv_desc, - cudnn_output_desc, - output_data, - phi::kNUM_CUDNN_FWD_ALGS, - &returned_algo_count, - fwd_perf_stat.data(), - cudnn_workspace, - workspace_size_limit)); - }; - workspace_handle.RunFuncSync(cudnn_find_func, workspace_size_limit); - VLOG(3) << "Perf result: (algo: stat, time, memory)"; - for (int i = 0; i < returned_algo_count; ++i) { - const auto& stat = fwd_perf_stat[i]; - VLOG(3) << stat.algo << ": " << stat.status << " " << stat.time << " " - << stat.memory; - } - - PADDLE_ENFORCE_GPU_SUCCESS( - platform::dynload::cudnnGetConvolutionForwardWorkspaceSize( - handle, - cudnn_input_desc, - cudnn_filter_desc, - cudnn_conv_desc, - cudnn_output_desc, - fwd_perf_stat[0].algo, - &workspace_size_in_bytes)); - // PADDLE_ENFORCE_LE( - // workspace_size_in_bytes, - // workspace_size_limit, - // platform::errors::InvalidArgument( - // "The actual workspace size to be allocated for cuDNN is - // expected " "to be less than the limit. But received: the - // actual workspace " "size = %d, limit = %d.", - // workspace_size_in_bytes, - // workspace_size_limit)); - - fwd_result.algo = fwd_perf_stat[0].algo; - fwd_result.workspace_size = workspace_size_in_bytes; - return fwd_result; - }; - AlgorithmsCache>& algo_cache = - *(framework::ConvSearchCache::Instance().GetConvFusion()); - int search_times = ctx.Attr("search_times"); - SearchFuseResult algo_result; - search_times = std::max( - static_cast(FLAGS_cudnn_exhaustive_search_times), search_times); - // TODO(dangqingqing): Unify this if-else. - if (search_times > 0) { - // The searched algo will be cached by `search_times` times for - // different input dimension. For other dimensions, select the algo - // of closest area. - algo_result = algo_cache.GetAlgorithm( - x_dims[2] * x_dims[3], search_times, 0, search_func); - algo = algo_result.algo; - workspace_size_in_bytes = algo_result.workspace_size; - } else { - algo_result = algo_cache.GetAlgorithm(x_dims, - f_dims, - strides, - paddings, - dilations, - 0, - dtype, - search_func); - algo = algo_result.algo; - workspace_size_in_bytes = algo_result.workspace_size; - } - VLOG(3) << "choose algo " << algo; - } - if ((activation == "identity") && (!residual)) { - // Only the CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM algo is - // enabled with CUDNN_ACTIVATION_IDENTITY in cuDNN lib. - // But test in some case, the speed is slower, change to use - // cudnnConvolutionForward and cudnnAddTensor - // ------------- cudnn conv forward and bias add --------------------- - ScalingParamType alpha = 1.0f, beta = 0.0f; - auto cudnn_func = [&](void* cudnn_workspace) { - PADDLE_ENFORCE_GPU_SUCCESS( - platform::dynload::cudnnConvolutionForward(handle, - &alpha, - cudnn_input_desc, - input_data, - cudnn_filter_desc, - filter_data, - cudnn_conv_desc, - algo, - cudnn_workspace, - workspace_size_in_bytes, - &beta, - cudnn_output_desc, - output_data)); - }; - workspace_handle.RunFunc(cudnn_func, workspace_size_in_bytes); - PADDLE_ENFORCE_GPU_SUCCESS( - platform::dynload::cudnnAddTensor(handle, - &alpha, - cudnn_bias_desc, - bias_data, - &alpha, - cudnn_output_desc, - output_data)); - } else { - if (activation == "identity") { - algo = CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM; - } - // ------------------- cudnn conv+bias+act forward -------------------- - ScalingParamType alpha1 = 1.0f; - ScalingParamType alpha2 = residual ? 1.0f : 0.0f; - auto cudnn_func = [&](void* cudnn_workspace) { - PADDLE_ENFORCE_GPU_SUCCESS( - platform::dynload::cudnnConvolutionBiasActivationForward( - handle, - &alpha1, - cudnn_input_desc, - input_data, - cudnn_filter_desc, - filter_data, - cudnn_conv_desc, - algo, - cudnn_workspace, - workspace_size_in_bytes, - &alpha2, - cudnn_output_desc, - residual_data, - cudnn_bias_desc, - bias_data, - cudnn_act_desc, - cudnn_output_desc, - output_data)); - }; - workspace_handle.RunFunc(cudnn_func, workspace_size_in_bytes); - } -#endif - std::vector channels = ctx.Attr>("split_channels"); - if (channels.size()) { - auto outs = ctx.MultiOutput("Outputs"); - if (x_dims[0] == 1) { - // share data with Output - phi::DenseTensor t; - t.ShareDataWith(*output); - auto y_dims = output->dims(); - t.Resize({y_dims[1], y_dims[2], y_dims[3]}); - int s = 0; - for (size_t i = 0; i < channels.size(); ++i) { - int e = s + channels[i]; - outs[i]->ShareDataWith(t.Slice(s, e)); - outs[i]->Resize({x_dims[0], channels[i], y_dims[2], y_dims[3]}); - s = e; - } - } else { - // TODO(qingiqng): do copy when batch size large than 1 - PADDLE_THROW(platform::errors::Unimplemented( - "Input with batch size greater than 1 is unsupported. The received " - "batch size is %d, Input's shape is [%s].", - x_dims[0], - phi::make_ddim(x_dims))); - } - } - } -}; -#endif - -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; -#if CUDNN_VERSION >= 7100 -REGISTER_OP_CUDA_KERNEL( - conv2d_fusion, - ops::CUDNNConvFusionOpKernel, - ops::CUDNNConvFusionOpKernel, - ops::CUDNNConvFusionOpKernel); -#endif -#ifdef PADDLE_WITH_HIP -REGISTER_OP_CUDA_KERNEL(conv2d_fusion, ops::CUDNNConvFusionOpKernel); -#endif diff --git a/paddle/fluid/operators/rank_loss_op.cc b/paddle/fluid/operators/rank_loss_op.cc index 2daf8c5d6b..ebdddfd41b 100644 --- a/paddle/fluid/operators/rank_loss_op.cc +++ b/paddle/fluid/operators/rank_loss_op.cc @@ -240,12 +240,15 @@ REGISTER_OPERATOR(rank_loss, ops::RankLossGradMaker, ops::RankLossGradMaker); REGISTER_OPERATOR(rank_loss_grad, ops::RankLossGradOp); -REGISTER_OP_CPU_KERNEL(rank_loss, ops::RankLossKernel); -REGISTER_OP_CPU_KERNEL(rank_loss_grad, - ops::RankLossGradKernel); - -REGISTER_OP_CUDA_KERNEL( - rank_loss, paddle::operators::RankLossKernel); -REGISTER_OP_CUDA_KERNEL( - rank_loss_grad, - paddle::operators::RankLossGradKernel); + +PD_REGISTER_STRUCT_KERNEL( + rank_loss, CPU, ALL_LAYOUT, ops::RankLossKernel, float) {} +PD_REGISTER_STRUCT_KERNEL( + rank_loss_grad, CPU, ALL_LAYOUT, ops::RankLossGradKernel, float) {} + +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) +PD_REGISTER_STRUCT_KERNEL( + rank_loss, GPU, ALL_LAYOUT, ops::RankLossKernel, float) {} +PD_REGISTER_STRUCT_KERNEL( + rank_loss_grad, GPU, ALL_LAYOUT, ops::RankLossGradKernel, float) {} +#endif diff --git a/paddle/fluid/operators/rank_loss_op.h b/paddle/fluid/operators/rank_loss_op.h index 4c81129c0e..03e0a09455 100644 --- a/paddle/fluid/operators/rank_loss_op.h +++ b/paddle/fluid/operators/rank_loss_op.h @@ -21,7 +21,7 @@ limitations under the License. */ namespace paddle { namespace operators { -template +template class RankLossKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const { @@ -42,7 +42,7 @@ class RankLossKernel : public framework::OpKernel { } }; -template +template class RankLossGradKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const { diff --git a/paddle/phi/core/compat/arg_map_context.h b/paddle/phi/core/compat/arg_map_context.h index 2dd3ba8b76..1c92a043d8 100644 --- a/paddle/phi/core/compat/arg_map_context.h +++ b/paddle/phi/core/compat/arg_map_context.h @@ -56,6 +56,8 @@ struct KernelSignature { attr_names(attrs), output_names(outputs) {} + explicit KernelSignature(const char* kernel_name) : name(kernel_name) {} + // TODO(chenweihang): add assign constructor to solve windows compile // problem, remove it later KernelSignature(const KernelSignature& other) diff --git a/paddle/phi/core/kernel_factory.cc b/paddle/phi/core/kernel_factory.cc index 7c15d60414..f44bfe6a2e 100644 --- a/paddle/phi/core/kernel_factory.cc +++ b/paddle/phi/core/kernel_factory.cc @@ -62,6 +62,21 @@ bool KernelFactory::HasCompatiblePhiKernel(const std::string& op_type) const { return false; } +bool KernelFactory::HasStructuredKernel(const std::string& op_type) const { + auto phi_kernel_name = phi::OpUtilsMap::Instance().GetBaseKernelName(op_type); + auto kernel_iter = kernels_.find(phi_kernel_name); + if (deprecated_op_names.find(op_type) == deprecated_op_names.end() && + kernel_iter != kernels_.end()) { + return std::any_of(kernel_iter->second.begin(), + kernel_iter->second.end(), + [](phi::KernelKeyMap::const_reference kernel_pair) { + return kernel_pair.second.GetKernelRegisteredType() == + KernelRegisteredType::STRUCTURE; + }); + } + return false; +} + const Kernel& KernelFactory::SelectKernel(const std::string& kernel_name, const KernelKey& kernel_key) const { auto iter = kernels_.find(kernel_name); diff --git a/paddle/phi/core/kernel_factory.h b/paddle/phi/core/kernel_factory.h index 8b8eb8fd0d..0c85f8d49c 100644 --- a/paddle/phi/core/kernel_factory.h +++ b/paddle/phi/core/kernel_factory.h @@ -238,13 +238,21 @@ class KernelArgsDef { {}}; }; +enum class KernelRegisteredType { FUNCTION, STRUCTURE }; + class Kernel { public: // for map element construct Kernel() = default; explicit Kernel(KernelFn fn, void* variadic_fn) - : fn_(fn), variadic_fn_(variadic_fn) {} + : fn_(fn), variadic_fn_(variadic_fn) { + if (variadic_fn == nullptr) { + kernel_registered_type_ = KernelRegisteredType::STRUCTURE; + } else { + kernel_registered_type_ = KernelRegisteredType::FUNCTION; + } + } void operator()(KernelContext* ctx) const { fn_(ctx); } @@ -272,10 +280,15 @@ class Kernel { bool IsValid() const { return fn_ != nullptr; } + KernelRegisteredType GetKernelRegisteredType() const { + return kernel_registered_type_; + } + private: KernelFn fn_{nullptr}; void* variadic_fn_ = nullptr; KernelArgsDef args_def_; + KernelRegisteredType kernel_registered_type_ = KernelRegisteredType::FUNCTION; }; using KernelKeyMap = paddle::flat_hash_map; @@ -304,6 +317,8 @@ class KernelFactory { bool HasCompatiblePhiKernel(const std::string& op_type) const; + bool HasStructuredKernel(const std::string& op_type) const; + KernelResult SelectKernelOrThrowError(const std::string& kernel_name, const KernelKey& kernel_key) const; diff --git a/paddle/phi/core/kernel_registry.h b/paddle/phi/core/kernel_registry.h index dc6f657fee..bacfce613f 100644 --- a/paddle/phi/core/kernel_registry.h +++ b/paddle/phi/core/kernel_registry.h @@ -32,7 +32,7 @@ namespace phi { #define BACKEND(arg__) phi::Backend::arg__ -#define DATALAYOUT(arg__) phi::DataLayout::arg__ +#define DATA_LAYOUT(arg__) phi::DataLayout::arg__ #define DATATYPE(arg__) phi::DataType::arg__ template @@ -348,7 +348,9 @@ struct KernelRegistrar { KernelKey kernel_key( paddle::experimental::StringToBackend(backend_cstr), layout, dtype); Kernel kernel(kernel_fn, variadic_kernel_fn); - args_parse_fn(kernel_key, kernel.mutable_args_def()); + if (kernel.GetKernelRegisteredType() == KernelRegisteredType::FUNCTION) { + args_parse_fn(kernel_key, kernel.mutable_args_def()); + } args_def_fn(kernel_key, &kernel); if (reg_type == RegType::INNER) { KernelFactory::Instance().kernels()[kernel_name][kernel_key] = kernel; @@ -380,6 +382,16 @@ struct KernelRegistrar { #define _PD_ARG_N(args) _PD_ARG_N_EXPAND args #define _PD_RESQ_N() 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0 +// The macro for passing KernelArgsParseFunctor's function +#define ARG_PARSE_FUNCTOR(meta_kernel_fn, cpp_dtype, context) \ + ::phi::KernelArgsParseFunctor< \ + decltype(&meta_kernel_fn)>::Parse + +// The macro for instantiating function kernel +#define FUNCTION_KERNEL_INSTANTIATION(meta_kernel_fn, cpp_dtype, context) \ + template decltype(meta_kernel_fn) \ + meta_kernel_fn; + /** PD_REGISTER_KERNEL * * The most frequently used kernel registration macro, used for kernel @@ -396,10 +408,23 @@ struct KernelRegistrar { ::phi::backend##Context, \ layout, \ meta_kernel_fn, \ + FUNCTION_KERNEL_INSTANTIATION, \ + ARG_PARSE_FUNCTOR, \ + PHI_KERNEL, \ + PHI_VARIADIC_KERNEL, \ __VA_ARGS__) -#define _PD_REGISTER_KERNEL( \ - reg_type, kernel_name, backend, context, layout, meta_kernel_fn, ...) \ +#define _PD_REGISTER_KERNEL(reg_type, \ + kernel_name, \ + backend, \ + context, \ + layout, \ + meta_kernel_fn, \ + kernel_instantiation_macro, \ + arg_parse_functor_macro, \ + kernel_unfold_macro, \ + variadic_kernel_unfold_marco, \ + ...) \ PD_STATIC_ASSERT_GLOBAL_NAMESPACE( \ PD_REGISTER_tp_kernel_ns_check_##kernel_name##_##backend##_##layout, \ "PD_REGISTER_KERNEL must be called in global namespace."); \ @@ -409,12 +434,29 @@ struct KernelRegistrar { context, \ layout, \ meta_kernel_fn, \ + kernel_instantiation_macro, \ + arg_parse_functor_macro, \ + kernel_unfold_macro, \ + variadic_kernel_unfold_marco, \ __VA_ARGS__)) #ifndef _WIN32 -#define _PD_REGISTER_2TA_KERNEL( \ - reg_type, kernel_name, backend, context, layout, meta_kernel_fn, ...) \ - PD_KERNEL_INSTANTIATION(meta_kernel_fn, backend, context, __VA_ARGS__); \ +#define _PD_REGISTER_2TA_KERNEL(reg_type, \ + kernel_name, \ + backend, \ + context, \ + layout, \ + meta_kernel_fn, \ + kernel_instantiation_macro, \ + arg_parse_functor_macro, \ + kernel_unfold_macro, \ + variadic_kernel_unfold_marco, \ + ...) \ + PD_KERNEL_INSTANTIATION(meta_kernel_fn, \ + backend, \ + context, \ + kernel_instantiation_macro, \ + __VA_ARGS__); \ static void __PD_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout( \ const ::phi::KernelKey& kernel_key, ::phi::Kernel* kernel); \ PD_KERNEL_REGISTRAR_INIT( \ @@ -425,6 +467,9 @@ struct KernelRegistrar { layout, \ &__PD_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout, \ meta_kernel_fn, \ + arg_parse_functor_macro, \ + kernel_unfold_macro, \ + variadic_kernel_unfold_marco, \ __VA_ARGS__); \ void __PD_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout( \ const ::phi::KernelKey& kernel_key, ::phi::Kernel* kernel) @@ -441,8 +486,17 @@ struct KernelRegistrar { * * And msvc can work without template instantiation */ -#define _PD_REGISTER_2TA_KERNEL( \ - reg_type, kernel_name, backend, context, layout, meta_kernel_fn, ...) \ +#define _PD_REGISTER_2TA_KERNEL(reg_type, \ + kernel_name, \ + backend, \ + context, \ + layout, \ + meta_kernel_fn, \ + kernel_instantiation_macro, \ + arg_parse_functor_macro, \ + kernel_unfold_macro, \ + variadic_kernel_unfold_marco, \ + ...) \ static void __PD_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout( \ const ::phi::KernelKey& kernel_key, ::phi::Kernel* kernel); \ PD_EXPAND(PD_KERNEL_REGISTRAR_INIT( \ @@ -453,124 +507,222 @@ struct KernelRegistrar { layout, \ &__PD_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout, \ meta_kernel_fn, \ + arg_parse_functor_macro, \ + kernel_unfold_macro, \ + variadic_kernel_unfold_marco, \ __VA_ARGS__)); \ void __PD_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout( \ const ::phi::KernelKey& kernel_key, ::phi::Kernel* kernel) #endif -#define PD_KERNEL_INSTANTIATION(meta_kernel_fn, backend, context, ...) \ - _PD_KERNEL_INSTANTIATION( \ - PD_NARGS(__VA_ARGS__), meta_kernel_fn, backend, context, __VA_ARGS__) +#define PD_KERNEL_INSTANTIATION( \ + meta_kernel_fn, backend, context, kernel_instantiation_macro, ...) \ + _PD_KERNEL_INSTANTIATION(PD_NARGS(__VA_ARGS__), \ + meta_kernel_fn, \ + backend, \ + context, \ + kernel_instantiation_macro, \ + __VA_ARGS__) -#define _PD_KERNEL_INSTANTIATION(N, meta_kernel_fn, backend, context, ...) \ - PD_CONCATENATE(_PD_KERNEL_INSTANTIATION_, N) \ - (meta_kernel_fn, backend, context, __VA_ARGS__) +#define _PD_KERNEL_INSTANTIATION( \ + N, meta_kernel_fn, backend, context, kernel_instantiation_macro, ...) \ + PD_CONCATENATE(_PD_KERNEL_INSTANTIATION_, N) \ + (meta_kernel_fn, backend, context, kernel_instantiation_macro, __VA_ARGS__) -#define _PD_KERNEL_INSTANTIATION_1( \ - meta_kernel_fn, backend, context, cpp_dtype) \ - template decltype(meta_kernel_fn) \ - meta_kernel_fn -#define _PD_KERNEL_INSTANTIATION_2( \ - meta_kernel_fn, backend, context, cpp_dtype, ...) \ - template decltype(meta_kernel_fn) \ - meta_kernel_fn; \ - PD_EXPAND(_PD_KERNEL_INSTANTIATION_1( \ - meta_kernel_fn, backend, context, __VA_ARGS__)) -#define _PD_KERNEL_INSTANTIATION_3( \ - meta_kernel_fn, backend, context, cpp_dtype, ...) \ - template decltype(meta_kernel_fn) \ - meta_kernel_fn; \ - PD_EXPAND(_PD_KERNEL_INSTANTIATION_2( \ - meta_kernel_fn, backend, context, __VA_ARGS__)) -#define _PD_KERNEL_INSTANTIATION_4( \ - meta_kernel_fn, backend, context, cpp_dtype, ...) \ - template decltype(meta_kernel_fn) \ - meta_kernel_fn; \ - PD_EXPAND(_PD_KERNEL_INSTANTIATION_3( \ - meta_kernel_fn, backend, context, __VA_ARGS__)) -#define _PD_KERNEL_INSTANTIATION_5( \ - meta_kernel_fn, backend, context, cpp_dtype, ...) \ - template decltype(meta_kernel_fn) \ - meta_kernel_fn; \ - PD_EXPAND(_PD_KERNEL_INSTANTIATION_4( \ - meta_kernel_fn, backend, context, __VA_ARGS__)) -#define _PD_KERNEL_INSTANTIATION_6( \ - meta_kernel_fn, backend, context, cpp_dtype, ...) \ - template decltype(meta_kernel_fn) \ - meta_kernel_fn; \ - PD_EXPAND(_PD_KERNEL_INSTANTIATION_5( \ - meta_kernel_fn, backend, context, __VA_ARGS__)) -#define _PD_KERNEL_INSTANTIATION_7( \ - meta_kernel_fn, backend, context, cpp_dtype, ...) \ - template decltype(meta_kernel_fn) \ - meta_kernel_fn; \ - PD_EXPAND(_PD_KERNEL_INSTANTIATION_6( \ - meta_kernel_fn, backend, context, __VA_ARGS__)) -#define _PD_KERNEL_INSTANTIATION_8( \ - meta_kernel_fn, backend, context, cpp_dtype, ...) \ - template decltype(meta_kernel_fn) \ - meta_kernel_fn; \ - PD_EXPAND(_PD_KERNEL_INSTANTIATION_7( \ - meta_kernel_fn, backend, context, __VA_ARGS__)) -#define _PD_KERNEL_INSTANTIATION_9( \ - meta_kernel_fn, backend, context, cpp_dtype, ...) \ - template decltype(meta_kernel_fn) \ - meta_kernel_fn; \ - PD_EXPAND(_PD_KERNEL_INSTANTIATION_8( \ - meta_kernel_fn, backend, context, __VA_ARGS__)) -#define _PD_KERNEL_INSTANTIATION_10( \ - meta_kernel_fn, backend, context, cpp_dtype, ...) \ - template decltype(meta_kernel_fn) \ - meta_kernel_fn; \ - PD_EXPAND(_PD_KERNEL_INSTANTIATION_9( \ - meta_kernel_fn, backend, context, __VA_ARGS__)) -#define _PD_KERNEL_INSTANTIATION_11( \ - meta_kernel_fn, backend, context, cpp_dtype, ...) \ - template decltype(meta_kernel_fn) \ - meta_kernel_fn; \ - PD_EXPAND(_PD_KERNEL_INSTANTIATION_10( \ - meta_kernel_fn, backend, context, __VA_ARGS__)) -#define _PD_KERNEL_INSTANTIATION_12( \ - meta_kernel_fn, backend, context, cpp_dtype, ...) \ - template decltype(meta_kernel_fn) \ - meta_kernel_fn; \ - PD_EXPAND(_PD_KERNEL_INSTANTIATION_11( \ - meta_kernel_fn, backend, context, __VA_ARGS__)) -#define _PD_KERNEL_INSTANTIATION_13( \ - meta_kernel_fn, backend, context, cpp_dtype, ...) \ - template decltype(meta_kernel_fn) \ - meta_kernel_fn; \ - PD_EXPAND(_PD_KERNEL_INSTANTIATION_12( \ - meta_kernel_fn, backend, context, __VA_ARGS__)) -#define _PD_KERNEL_INSTANTIATION_14( \ - meta_kernel_fn, backend, context, cpp_dtype, ...) \ - template decltype(meta_kernel_fn) \ - meta_kernel_fn; \ - PD_EXPAND(_PD_KERNEL_INSTANTIATION_13( \ - meta_kernel_fn, backend, context, __VA_ARGS__)) -#define _PD_KERNEL_INSTANTIATION_15( \ - meta_kernel_fn, backend, context, cpp_dtype, ...) \ - template decltype(meta_kernel_fn) \ - meta_kernel_fn; \ - PD_EXPAND(_PD_KERNEL_INSTANTIATION_14( \ - meta_kernel_fn, backend, context, __VA_ARGS__)) +#define _PD_KERNEL_INSTANTIATION_1( \ + meta_kernel_fn, backend, context, kernel_instantiation_macro, cpp_dtype) \ + kernel_instantiation_macro(meta_kernel_fn, cpp_dtype, context) +#define _PD_KERNEL_INSTANTIATION_2(meta_kernel_fn, \ + backend, \ + context, \ + kernel_instantiation_macro, \ + cpp_dtype, \ + ...) \ + kernel_instantiation_macro(meta_kernel_fn, cpp_dtype, context) \ + PD_EXPAND(_PD_KERNEL_INSTANTIATION_1(meta_kernel_fn, \ + backend, \ + context, \ + kernel_instantiation_macro, \ + __VA_ARGS__)) +#define _PD_KERNEL_INSTANTIATION_3(meta_kernel_fn, \ + backend, \ + context, \ + kernel_instantiation_macro, \ + cpp_dtype, \ + ...) \ + kernel_instantiation_macro(meta_kernel_fn, cpp_dtype, context) \ + PD_EXPAND(_PD_KERNEL_INSTANTIATION_2(meta_kernel_fn, \ + backend, \ + context, \ + kernel_instantiation_macro, \ + __VA_ARGS__)) +#define _PD_KERNEL_INSTANTIATION_4(meta_kernel_fn, \ + backend, \ + context, \ + kernel_instantiation_macro, \ + cpp_dtype, \ + ...) \ + kernel_instantiation_macro(meta_kernel_fn, cpp_dtype, context) \ + PD_EXPAND(_PD_KERNEL_INSTANTIATION_3(meta_kernel_fn, \ + backend, \ + context, \ + kernel_instantiation_macro, \ + __VA_ARGS__)) +#define _PD_KERNEL_INSTANTIATION_5(meta_kernel_fn, \ + backend, \ + context, \ + kernel_instantiation_macro, \ + cpp_dtype, \ + ...) \ + kernel_instantiation_macro(meta_kernel_fn, cpp_dtype, context) \ + PD_EXPAND(_PD_KERNEL_INSTANTIATION_4(meta_kernel_fn, \ + backend, \ + context, \ + kernel_instantiation_macro, \ + __VA_ARGS__)) +#define _PD_KERNEL_INSTANTIATION_6(meta_kernel_fn, \ + backend, \ + context, \ + kernel_instantiation_macro, \ + cpp_dtype, \ + ...) \ + kernel_instantiation_macro(meta_kernel_fn, cpp_dtype, context) \ + PD_EXPAND(_PD_KERNEL_INSTANTIATION_5(meta_kernel_fn, \ + backend, \ + context, \ + kernel_instantiation_macro, \ + __VA_ARGS__)) +#define _PD_KERNEL_INSTANTIATION_7(meta_kernel_fn, \ + backend, \ + context, \ + kernel_instantiation_macro, \ + cpp_dtype, \ + ...) \ + kernel_instantiation_macro(meta_kernel_fn, cpp_dtype, context) \ + PD_EXPAND(_PD_KERNEL_INSTANTIATION_6(meta_kernel_fn, \ + backend, \ + context, \ + kernel_instantiation_macro, \ + __VA_ARGS__)) +#define _PD_KERNEL_INSTANTIATION_8(meta_kernel_fn, \ + backend, \ + context, \ + kernel_instantiation_macro, \ + cpp_dtype, \ + ...) \ + kernel_instantiation_macro(meta_kernel_fn, cpp_dtype, context) \ + PD_EXPAND(_PD_KERNEL_INSTANTIATION_7(meta_kernel_fn, \ + backend, \ + context, \ + kernel_instantiation_macro, \ + __VA_ARGS__)) +#define _PD_KERNEL_INSTANTIATION_9(meta_kernel_fn, \ + backend, \ + context, \ + kernel_instantiation_macro, \ + cpp_dtype, \ + ...) \ + kernel_instantiation_macro(meta_kernel_fn, cpp_dtype, context) \ + PD_EXPAND(_PD_KERNEL_INSTANTIATION_8(meta_kernel_fn, \ + backend, \ + context, \ + kernel_instantiation_macro, \ + __VA_ARGS__)) +#define _PD_KERNEL_INSTANTIATION_10(meta_kernel_fn, \ + backend, \ + context, \ + kernel_instantiation_macro, \ + cpp_dtype, \ + ...) \ + kernel_instantiation_macro(meta_kernel_fn, cpp_dtype, context) \ + PD_EXPAND(_PD_KERNEL_INSTANTIATION_9(meta_kernel_fn, \ + backend, \ + context, \ + kernel_instantiation_macro, \ + __VA_ARGS__)) +#define _PD_KERNEL_INSTANTIATION_11(meta_kernel_fn, \ + backend, \ + context, \ + kernel_instantiation_macro, \ + cpp_dtype, \ + ...) \ + kernel_instantiation_macro(meta_kernel_fn, cpp_dtype, context) \ + PD_EXPAND(_PD_KERNEL_INSTANTIATION_10(meta_kernel_fn, \ + backend, \ + context, \ + kernel_instantiation_macro, \ + __VA_ARGS__)) +#define _PD_KERNEL_INSTANTIATION_12(meta_kernel_fn, \ + backend, \ + context, \ + kernel_instantiation_macro, \ + cpp_dtype, \ + ...) \ + kernel_instantiation_macro(meta_kernel_fn, cpp_dtype, context) \ + PD_EXPAND(_PD_KERNEL_INSTANTIATION_11(meta_kernel_fn, \ + backend, \ + context, \ + kernel_instantiation_macro, \ + __VA_ARGS__)) +#define _PD_KERNEL_INSTANTIATION_13(meta_kernel_fn, \ + backend, \ + context, \ + kernel_instantiation_macro, \ + cpp_dtype, \ + ...) \ + kernel_instantiation_macro(meta_kernel_fn, cpp_dtype, context) \ + PD_EXPAND(_PD_KERNEL_INSTANTIATION_12(meta_kernel_fn, \ + backend, \ + context, \ + kernel_instantiation_macro, \ + __VA_ARGS__)) +#define _PD_KERNEL_INSTANTIATION_14(meta_kernel_fn, \ + backend, \ + context, \ + kernel_instantiation_macro, \ + cpp_dtype, \ + ...) \ + kernel_instantiation_macro(meta_kernel_fn, cpp_dtype, context) \ + PD_EXPAND(_PD_KERNEL_INSTANTIATION_13(meta_kernel_fn, \ + backend, \ + context, \ + kernel_instantiation_macro, \ + __VA_ARGS__)) +#define _PD_KERNEL_INSTANTIATION_15(meta_kernel_fn, \ + backend, \ + context, \ + kernel_instantiation_macro, \ + cpp_dtype, \ + ...) \ + kernel_instantiation_macro(meta_kernel_fn, cpp_dtype, context) \ + PD_EXPAND(_PD_KERNEL_INSTANTIATION_14(meta_kernel_fn, \ + backend, \ + context, \ + kernel_instantiation_macro, \ + __VA_ARGS__)) -#define PD_KERNEL_REGISTRAR_INIT(reg_type, \ - kernel_name, \ - backend, \ - context, \ - layout, \ - args_def_fn, \ - meta_kernel_fn, \ - ...) \ - PD_EXPAND(_PD_KERNEL_REGISTRAR_INIT(PD_NARGS(__VA_ARGS__), \ - reg_type, \ - kernel_name, \ - backend, \ - context, \ - layout, \ - args_def_fn, \ - meta_kernel_fn, \ +#define PD_KERNEL_REGISTRAR_INIT(reg_type, \ + kernel_name, \ + backend, \ + context, \ + layout, \ + args_def_fn, \ + meta_kernel_fn, \ + arg_parse_functor_macro, \ + kernel_unfold_macro, \ + variadic_kernel_unfold_marco, \ + ...) \ + PD_EXPAND(_PD_KERNEL_REGISTRAR_INIT(PD_NARGS(__VA_ARGS__), \ + reg_type, \ + kernel_name, \ + backend, \ + context, \ + layout, \ + args_def_fn, \ + meta_kernel_fn, \ + arg_parse_functor_macro, \ + kernel_unfold_macro, \ + variadic_kernel_unfold_marco, \ __VA_ARGS__)) // clang-format off @@ -585,6 +737,9 @@ struct KernelRegistrar { layout, \ args_def_fn, \ meta_kernel_fn, \ + arg_parse_functor_macro, \ + kernel_unfold_macro, \ + variadic_kernel_unfold_marco, \ ...) \ PD_EXPAND(PD_CONCATENATE(_PD_KERNEL_REGISTRAR_INIT_, N) ( \ reg_type, \ @@ -595,11 +750,14 @@ struct KernelRegistrar { PD_ID, \ args_def_fn, \ meta_kernel_fn, \ + arg_parse_functor_macro, \ + kernel_unfold_macro, \ + variadic_kernel_unfold_marco, \ __VA_ARGS__)) // clang-format on -#define _PD_KERNEL_REGISTRAR_INIT_1(reg_type, \ +#define _PD_CREATE_REGISTRAR_OBJECT(reg_type, \ kernel_name, \ backend, \ context, \ @@ -607,453 +765,564 @@ struct KernelRegistrar { registrar_id, \ args_def_fn, \ meta_kernel_fn, \ + arg_parse_functor_macro, \ + kernel_unfold_macro, \ + variadic_kernel_unfold_marco, \ cpp_dtype) \ static const ::phi::KernelRegistrar PD_CONCATENATE( \ __reg_phi_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \ reg_type, \ #kernel_name, \ #backend, \ - DATALAYOUT(layout), \ + DATA_LAYOUT(layout), \ ::paddle::experimental::CppTypeToDataType::Type(), \ - ::phi::KernelArgsParseFunctor< \ - decltype(&meta_kernel_fn)>::Parse, \ + arg_parse_functor_macro(meta_kernel_fn, cpp_dtype, context), \ args_def_fn, \ - PHI_KERNEL(meta_kernel_fn), \ - PHI_VARIADIC_KERNEL(meta_kernel_fn)); \ + kernel_unfold_macro(meta_kernel_fn), \ + variadic_kernel_unfold_marco(meta_kernel_fn)); + +#define _PD_KERNEL_REGISTRAR_INIT_1(reg_type, \ + kernel_name, \ + backend, \ + context, \ + layout, \ + registrar_id, \ + args_def_fn, \ + meta_kernel_fn, \ + arg_parse_functor_macro, \ + kernel_unfold_macro, \ + variadic_kernel_unfold_marco, \ + cpp_dtype) \ + _PD_CREATE_REGISTRAR_OBJECT(reg_type, \ + kernel_name, \ + backend, \ + context, \ + layout, \ + registrar_id, \ + args_def_fn, \ + meta_kernel_fn, \ + arg_parse_functor_macro, \ + kernel_unfold_macro, \ + variadic_kernel_unfold_marco, \ + cpp_dtype) \ int TouchKernelSymbolFor_##kernel_name##_##backend##_##layout() { return 0; } -#define _PD_KERNEL_REGISTRAR_INIT_2(reg_type, \ - kernel_name, \ - backend, \ - context, \ - layout, \ - registrar_id, \ - args_def_fn, \ - meta_kernel_fn, \ - cpp_dtype, \ - ...) \ - static const ::phi::KernelRegistrar PD_CONCATENATE( \ - __reg_phi_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \ - reg_type, \ - #kernel_name, \ - #backend, \ - DATALAYOUT(layout), \ - ::paddle::experimental::CppTypeToDataType::Type(), \ - ::phi::KernelArgsParseFunctor< \ - decltype(&meta_kernel_fn)>::Parse, \ - args_def_fn, \ - PHI_KERNEL(meta_kernel_fn), \ - PHI_VARIADIC_KERNEL(meta_kernel_fn)); \ - PD_EXPAND(_PD_KERNEL_REGISTRAR_INIT_1(reg_type, \ - kernel_name, \ - backend, \ - context, \ - layout, \ - PD_ID, \ - args_def_fn, \ - meta_kernel_fn, \ +#define _PD_KERNEL_REGISTRAR_INIT_2(reg_type, \ + kernel_name, \ + backend, \ + context, \ + layout, \ + registrar_id, \ + args_def_fn, \ + meta_kernel_fn, \ + arg_parse_functor_macro, \ + kernel_unfold_macro, \ + variadic_kernel_unfold_marco, \ + cpp_dtype, \ + ...) \ + _PD_CREATE_REGISTRAR_OBJECT(reg_type, \ + kernel_name, \ + backend, \ + context, \ + layout, \ + registrar_id, \ + args_def_fn, \ + meta_kernel_fn, \ + arg_parse_functor_macro, \ + kernel_unfold_macro, \ + variadic_kernel_unfold_marco, \ + cpp_dtype) \ + PD_EXPAND(_PD_KERNEL_REGISTRAR_INIT_1(reg_type, \ + kernel_name, \ + backend, \ + context, \ + layout, \ + PD_ID, \ + args_def_fn, \ + meta_kernel_fn, \ + arg_parse_functor_macro, \ + kernel_unfold_macro, \ + variadic_kernel_unfold_marco, \ __VA_ARGS__)) -#define _PD_KERNEL_REGISTRAR_INIT_3(reg_type, \ - kernel_name, \ - backend, \ - context, \ - layout, \ - registrar_id, \ - args_def_fn, \ - meta_kernel_fn, \ - cpp_dtype, \ - ...) \ - static const ::phi::KernelRegistrar PD_CONCATENATE( \ - __reg_phi_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \ - reg_type, \ - #kernel_name, \ - #backend, \ - DATALAYOUT(layout), \ - ::paddle::experimental::CppTypeToDataType::Type(), \ - ::phi::KernelArgsParseFunctor< \ - decltype(&meta_kernel_fn)>::Parse, \ - args_def_fn, \ - PHI_KERNEL(meta_kernel_fn), \ - PHI_VARIADIC_KERNEL(meta_kernel_fn)); \ - PD_EXPAND(_PD_KERNEL_REGISTRAR_INIT_2(reg_type, \ - kernel_name, \ - backend, \ - context, \ - layout, \ - PD_ID, \ - args_def_fn, \ - meta_kernel_fn, \ +#define _PD_KERNEL_REGISTRAR_INIT_3(reg_type, \ + kernel_name, \ + backend, \ + context, \ + layout, \ + registrar_id, \ + args_def_fn, \ + meta_kernel_fn, \ + arg_parse_functor_macro, \ + kernel_unfold_macro, \ + variadic_kernel_unfold_marco, \ + cpp_dtype, \ + ...) \ + _PD_CREATE_REGISTRAR_OBJECT(reg_type, \ + kernel_name, \ + backend, \ + context, \ + layout, \ + registrar_id, \ + args_def_fn, \ + meta_kernel_fn, \ + arg_parse_functor_macro, \ + kernel_unfold_macro, \ + variadic_kernel_unfold_marco, \ + cpp_dtype) \ + PD_EXPAND(_PD_KERNEL_REGISTRAR_INIT_2(reg_type, \ + kernel_name, \ + backend, \ + context, \ + layout, \ + PD_ID, \ + args_def_fn, \ + meta_kernel_fn, \ + arg_parse_functor_macro, \ + kernel_unfold_macro, \ + variadic_kernel_unfold_marco, \ __VA_ARGS__)) -#define _PD_KERNEL_REGISTRAR_INIT_4(reg_type, \ - kernel_name, \ - backend, \ - context, \ - layout, \ - registrar_id, \ - args_def_fn, \ - meta_kernel_fn, \ - cpp_dtype, \ - ...) \ - static const ::phi::KernelRegistrar PD_CONCATENATE( \ - __reg_phi_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \ - reg_type, \ - #kernel_name, \ - #backend, \ - DATALAYOUT(layout), \ - ::paddle::experimental::CppTypeToDataType::Type(), \ - ::phi::KernelArgsParseFunctor< \ - decltype(&meta_kernel_fn)>::Parse, \ - args_def_fn, \ - PHI_KERNEL(meta_kernel_fn), \ - PHI_VARIADIC_KERNEL(meta_kernel_fn)); \ - PD_EXPAND(_PD_KERNEL_REGISTRAR_INIT_3(reg_type, \ - kernel_name, \ - backend, \ - context, \ - layout, \ - PD_ID, \ - args_def_fn, \ - meta_kernel_fn, \ +#define _PD_KERNEL_REGISTRAR_INIT_4(reg_type, \ + kernel_name, \ + backend, \ + context, \ + layout, \ + registrar_id, \ + args_def_fn, \ + meta_kernel_fn, \ + arg_parse_functor_macro, \ + kernel_unfold_macro, \ + variadic_kernel_unfold_marco, \ + cpp_dtype, \ + ...) \ + _PD_CREATE_REGISTRAR_OBJECT(reg_type, \ + kernel_name, \ + backend, \ + context, \ + layout, \ + registrar_id, \ + args_def_fn, \ + meta_kernel_fn, \ + arg_parse_functor_macro, \ + kernel_unfold_macro, \ + variadic_kernel_unfold_marco, \ + cpp_dtype) \ + PD_EXPAND(_PD_KERNEL_REGISTRAR_INIT_3(reg_type, \ + kernel_name, \ + backend, \ + context, \ + layout, \ + PD_ID, \ + args_def_fn, \ + meta_kernel_fn, \ + arg_parse_functor_macro, \ + kernel_unfold_macro, \ + variadic_kernel_unfold_marco, \ __VA_ARGS__)) -#define _PD_KERNEL_REGISTRAR_INIT_5(reg_type, \ - kernel_name, \ - backend, \ - context, \ - layout, \ - registrar_id, \ - args_def_fn, \ - meta_kernel_fn, \ - cpp_dtype, \ - ...) \ - static const ::phi::KernelRegistrar PD_CONCATENATE( \ - __reg_phi_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \ - reg_type, \ - #kernel_name, \ - #backend, \ - DATALAYOUT(layout), \ - ::paddle::experimental::CppTypeToDataType::Type(), \ - ::phi::KernelArgsParseFunctor< \ - decltype(&meta_kernel_fn)>::Parse, \ - args_def_fn, \ - PHI_KERNEL(meta_kernel_fn), \ - PHI_VARIADIC_KERNEL(meta_kernel_fn)); \ - PD_EXPAND(_PD_KERNEL_REGISTRAR_INIT_4(reg_type, \ - kernel_name, \ - backend, \ - context, \ - layout, \ - PD_ID, \ - args_def_fn, \ - meta_kernel_fn, \ +#define _PD_KERNEL_REGISTRAR_INIT_5(reg_type, \ + kernel_name, \ + backend, \ + context, \ + layout, \ + registrar_id, \ + args_def_fn, \ + meta_kernel_fn, \ + arg_parse_functor_macro, \ + kernel_unfold_macro, \ + variadic_kernel_unfold_marco, \ + cpp_dtype, \ + ...) \ + _PD_CREATE_REGISTRAR_OBJECT(reg_type, \ + kernel_name, \ + backend, \ + context, \ + layout, \ + registrar_id, \ + args_def_fn, \ + meta_kernel_fn, \ + arg_parse_functor_macro, \ + kernel_unfold_macro, \ + variadic_kernel_unfold_marco, \ + cpp_dtype) \ + PD_EXPAND(_PD_KERNEL_REGISTRAR_INIT_4(reg_type, \ + kernel_name, \ + backend, \ + context, \ + layout, \ + PD_ID, \ + args_def_fn, \ + meta_kernel_fn, \ + arg_parse_functor_macro, \ + kernel_unfold_macro, \ + variadic_kernel_unfold_marco, \ __VA_ARGS__)) -#define _PD_KERNEL_REGISTRAR_INIT_6(reg_type, \ - kernel_name, \ - backend, \ - context, \ - layout, \ - registrar_id, \ - args_def_fn, \ - meta_kernel_fn, \ - cpp_dtype, \ - ...) \ - static const ::phi::KernelRegistrar PD_CONCATENATE( \ - __reg_phi_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \ - reg_type, \ - #kernel_name, \ - #backend, \ - DATALAYOUT(layout), \ - ::paddle::experimental::CppTypeToDataType::Type(), \ - ::phi::KernelArgsParseFunctor< \ - decltype(&meta_kernel_fn)>::Parse, \ - args_def_fn, \ - PHI_KERNEL(meta_kernel_fn), \ - PHI_VARIADIC_KERNEL(meta_kernel_fn)); \ - PD_EXPAND(_PD_KERNEL_REGISTRAR_INIT_5(reg_type, \ - kernel_name, \ - backend, \ - context, \ - layout, \ - PD_ID, \ - args_def_fn, \ - meta_kernel_fn, \ +#define _PD_KERNEL_REGISTRAR_INIT_6(reg_type, \ + kernel_name, \ + backend, \ + context, \ + layout, \ + registrar_id, \ + args_def_fn, \ + meta_kernel_fn, \ + arg_parse_functor_macro, \ + kernel_unfold_macro, \ + variadic_kernel_unfold_marco, \ + cpp_dtype, \ + ...) \ + _PD_CREATE_REGISTRAR_OBJECT(reg_type, \ + kernel_name, \ + backend, \ + context, \ + layout, \ + registrar_id, \ + args_def_fn, \ + meta_kernel_fn, \ + arg_parse_functor_macro, \ + kernel_unfold_macro, \ + variadic_kernel_unfold_marco, \ + cpp_dtype) \ + PD_EXPAND(_PD_KERNEL_REGISTRAR_INIT_5(reg_type, \ + kernel_name, \ + backend, \ + context, \ + layout, \ + PD_ID, \ + args_def_fn, \ + meta_kernel_fn, \ + arg_parse_functor_macro, \ + kernel_unfold_macro, \ + variadic_kernel_unfold_marco, \ __VA_ARGS__)) -#define _PD_KERNEL_REGISTRAR_INIT_7(reg_type, \ - kernel_name, \ - backend, \ - context, \ - layout, \ - registrar_id, \ - args_def_fn, \ - meta_kernel_fn, \ - cpp_dtype, \ - ...) \ - static const ::phi::KernelRegistrar PD_CONCATENATE( \ - __reg_phi_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \ - reg_type, \ - #kernel_name, \ - #backend, \ - DATALAYOUT(layout), \ - ::paddle::experimental::CppTypeToDataType::Type(), \ - ::phi::KernelArgsParseFunctor< \ - decltype(&meta_kernel_fn)>::Parse, \ - args_def_fn, \ - PHI_KERNEL(meta_kernel_fn), \ - PHI_VARIADIC_KERNEL(meta_kernel_fn)); \ - PD_EXPAND(_PD_KERNEL_REGISTRAR_INIT_6(reg_type, \ - kernel_name, \ - backend, \ - context, \ - layout, \ - PD_ID, \ - args_def_fn, \ - meta_kernel_fn, \ +#define _PD_KERNEL_REGISTRAR_INIT_7(reg_type, \ + kernel_name, \ + backend, \ + context, \ + layout, \ + registrar_id, \ + args_def_fn, \ + meta_kernel_fn, \ + arg_parse_functor_macro, \ + kernel_unfold_macro, \ + variadic_kernel_unfold_marco, \ + cpp_dtype, \ + ...) \ + _PD_CREATE_REGISTRAR_OBJECT(reg_type, \ + kernel_name, \ + backend, \ + context, \ + layout, \ + registrar_id, \ + args_def_fn, \ + meta_kernel_fn, \ + arg_parse_functor_macro, \ + kernel_unfold_macro, \ + variadic_kernel_unfold_marco, \ + cpp_dtype) \ + PD_EXPAND(_PD_KERNEL_REGISTRAR_INIT_6(reg_type, \ + kernel_name, \ + backend, \ + context, \ + layout, \ + PD_ID, \ + args_def_fn, \ + meta_kernel_fn, \ + arg_parse_functor_macro, \ + kernel_unfold_macro, \ + variadic_kernel_unfold_marco, \ __VA_ARGS__)) -#define _PD_KERNEL_REGISTRAR_INIT_8(reg_type, \ - kernel_name, \ - backend, \ - context, \ - layout, \ - registrar_id, \ - args_def_fn, \ - meta_kernel_fn, \ - cpp_dtype, \ - ...) \ - static const ::phi::KernelRegistrar PD_CONCATENATE( \ - __reg_phi_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \ - reg_type, \ - #kernel_name, \ - #backend, \ - DATALAYOUT(layout), \ - ::paddle::experimental::CppTypeToDataType::Type(), \ - ::phi::KernelArgsParseFunctor< \ - decltype(&meta_kernel_fn)>::Parse, \ - args_def_fn, \ - PHI_KERNEL(meta_kernel_fn), \ - PHI_VARIADIC_KERNEL(meta_kernel_fn)); \ - PD_EXPAND(_PD_KERNEL_REGISTRAR_INIT_7(reg_type, \ - kernel_name, \ - backend, \ - context, \ - layout, \ - PD_ID, \ - args_def_fn, \ - meta_kernel_fn, \ +#define _PD_KERNEL_REGISTRAR_INIT_8(reg_type, \ + kernel_name, \ + backend, \ + context, \ + layout, \ + registrar_id, \ + args_def_fn, \ + meta_kernel_fn, \ + arg_parse_functor_macro, \ + kernel_unfold_macro, \ + variadic_kernel_unfold_marco, \ + cpp_dtype, \ + ...) \ + _PD_CREATE_REGISTRAR_OBJECT(reg_type, \ + kernel_name, \ + backend, \ + context, \ + layout, \ + registrar_id, \ + args_def_fn, \ + meta_kernel_fn, \ + arg_parse_functor_macro, \ + kernel_unfold_macro, \ + variadic_kernel_unfold_marco, \ + cpp_dtype) \ + PD_EXPAND(_PD_KERNEL_REGISTRAR_INIT_7(reg_type, \ + kernel_name, \ + backend, \ + context, \ + layout, \ + PD_ID, \ + args_def_fn, \ + meta_kernel_fn, \ + arg_parse_functor_macro, \ + kernel_unfold_macro, \ + variadic_kernel_unfold_marco, \ __VA_ARGS__)) -#define _PD_KERNEL_REGISTRAR_INIT_9(reg_type, \ - kernel_name, \ - backend, \ - context, \ - layout, \ - registrar_id, \ - args_def_fn, \ - meta_kernel_fn, \ - cpp_dtype, \ - ...) \ - static const ::phi::KernelRegistrar PD_CONCATENATE( \ - __reg_phi_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \ - reg_type, \ - #kernel_name, \ - #backend, \ - DATALAYOUT(layout), \ - ::paddle::experimental::CppTypeToDataType::Type(), \ - ::phi::KernelArgsParseFunctor< \ - decltype(&meta_kernel_fn)>::Parse, \ - args_def_fn, \ - PHI_KERNEL(meta_kernel_fn), \ - PHI_VARIADIC_KERNEL(meta_kernel_fn)); \ - PD_EXPAND(_PD_KERNEL_REGISTRAR_INIT_8(reg_type, \ - kernel_name, \ - backend, \ - context, \ - layout, \ - PD_ID, \ - args_def_fn, \ - meta_kernel_fn, \ +#define _PD_KERNEL_REGISTRAR_INIT_9(reg_type, \ + kernel_name, \ + backend, \ + context, \ + layout, \ + registrar_id, \ + args_def_fn, \ + meta_kernel_fn, \ + arg_parse_functor_macro, \ + kernel_unfold_macro, \ + variadic_kernel_unfold_marco, \ + cpp_dtype, \ + ...) \ + _PD_CREATE_REGISTRAR_OBJECT(reg_type, \ + kernel_name, \ + backend, \ + context, \ + layout, \ + registrar_id, \ + args_def_fn, \ + meta_kernel_fn, \ + arg_parse_functor_macro, \ + kernel_unfold_macro, \ + variadic_kernel_unfold_marco, \ + cpp_dtype) \ + PD_EXPAND(_PD_KERNEL_REGISTRAR_INIT_8(reg_type, \ + kernel_name, \ + backend, \ + context, \ + layout, \ + PD_ID, \ + args_def_fn, \ + meta_kernel_fn, \ + arg_parse_functor_macro, \ + kernel_unfold_macro, \ + variadic_kernel_unfold_marco, \ __VA_ARGS__)) -#define _PD_KERNEL_REGISTRAR_INIT_10(reg_type, \ - kernel_name, \ - backend, \ - context, \ - layout, \ - registrar_id, \ - args_def_fn, \ - meta_kernel_fn, \ - cpp_dtype, \ - ...) \ - static const ::phi::KernelRegistrar PD_CONCATENATE( \ - __reg_phi_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \ - reg_type, \ - #kernel_name, \ - #backend, \ - DATALAYOUT(layout), \ - ::paddle::experimental::CppTypeToDataType::Type(), \ - ::phi::KernelArgsParseFunctor< \ - decltype(&meta_kernel_fn)>::Parse, \ - args_def_fn, \ - PHI_KERNEL(meta_kernel_fn), \ - PHI_VARIADIC_KERNEL(meta_kernel_fn)); \ - PD_EXPAND(_PD_KERNEL_REGISTRAR_INIT_9(reg_type, \ - kernel_name, \ - backend, \ - context, \ - layout, \ - PD_ID, \ - args_def_fn, \ - meta_kernel_fn, \ +#define _PD_KERNEL_REGISTRAR_INIT_10(reg_type, \ + kernel_name, \ + backend, \ + context, \ + layout, \ + registrar_id, \ + args_def_fn, \ + meta_kernel_fn, \ + arg_parse_functor_macro, \ + kernel_unfold_macro, \ + variadic_kernel_unfold_marco, \ + cpp_dtype, \ + ...) \ + _PD_CREATE_REGISTRAR_OBJECT(reg_type, \ + kernel_name, \ + backend, \ + context, \ + layout, \ + registrar_id, \ + args_def_fn, \ + meta_kernel_fn, \ + arg_parse_functor_macro, \ + kernel_unfold_macro, \ + variadic_kernel_unfold_marco, \ + cpp_dtype) \ + PD_EXPAND(_PD_KERNEL_REGISTRAR_INIT_9(reg_type, \ + kernel_name, \ + backend, \ + context, \ + layout, \ + PD_ID, \ + args_def_fn, \ + meta_kernel_fn, \ + arg_parse_functor_macro, \ + kernel_unfold_macro, \ + variadic_kernel_unfold_marco, \ __VA_ARGS__)) -#define _PD_KERNEL_REGISTRAR_INIT_11(reg_type, \ - kernel_name, \ - backend, \ - context, \ - layout, \ - registrar_id, \ - args_def_fn, \ - meta_kernel_fn, \ - cpp_dtype, \ - ...) \ - static const ::phi::KernelRegistrar PD_CONCATENATE( \ - __reg_phi_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \ - reg_type, \ - #kernel_name, \ - #backend, \ - DATALAYOUT(layout), \ - ::paddle::experimental::CppTypeToDataType::Type(), \ - ::phi::KernelArgsParseFunctor< \ - decltype(&meta_kernel_fn)>::Parse, \ - args_def_fn, \ - PHI_KERNEL(meta_kernel_fn), \ - PHI_VARIADIC_KERNEL(meta_kernel_fn)); \ - PD_EXPAND(_PD_KERNEL_REGISTRAR_INIT_10(reg_type, \ - kernel_name, \ - backend, \ - context, \ - layout, \ - PD_ID, \ - args_def_fn, \ - meta_kernel_fn, \ +#define _PD_KERNEL_REGISTRAR_INIT_11(reg_type, \ + kernel_name, \ + backend, \ + context, \ + layout, \ + registrar_id, \ + args_def_fn, \ + meta_kernel_fn, \ + arg_parse_functor_macro, \ + kernel_unfold_macro, \ + variadic_kernel_unfold_marco, \ + cpp_dtype, \ + ...) \ + _PD_CREATE_REGISTRAR_OBJECT(reg_type, \ + kernel_name, \ + backend, \ + context, \ + layout, \ + registrar_id, \ + args_def_fn, \ + meta_kernel_fn, \ + arg_parse_functor_macro, \ + kernel_unfold_macro, \ + variadic_kernel_unfold_marco, \ + cpp_dtype) \ + PD_EXPAND(_PD_KERNEL_REGISTRAR_INIT_10(reg_type, \ + kernel_name, \ + backend, \ + context, \ + layout, \ + PD_ID, \ + args_def_fn, \ + meta_kernel_fn, \ + arg_parse_functor_macro, \ + kernel_unfold_macro, \ + variadic_kernel_unfold_marco, \ __VA_ARGS__)) -#define _PD_KERNEL_REGISTRAR_INIT_12(reg_type, \ - kernel_name, \ - backend, \ - context, \ - layout, \ - registrar_id, \ - args_def_fn, \ - meta_kernel_fn, \ - cpp_dtype, \ - ...) \ - static const ::phi::KernelRegistrar PD_CONCATENATE( \ - __reg_phi_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \ - reg_type, \ - #kernel_name, \ - #backend, \ - DATALAYOUT(layout), \ - ::paddle::experimental::CppTypeToDataType::Type(), \ - ::phi::KernelArgsParseFunctor< \ - decltype(&meta_kernel_fn)>::Parse, \ - args_def_fn, \ - PHI_KERNEL(meta_kernel_fn), \ - PHI_VARIADIC_KERNEL(meta_kernel_fn)); \ - PD_EXPAND(_PD_KERNEL_REGISTRAR_INIT_11(reg_type, \ - kernel_name, \ - backend, \ - context, \ - layout, \ - PD_ID, \ - args_def_fn, \ - meta_kernel_fn, \ +#define _PD_KERNEL_REGISTRAR_INIT_12(reg_type, \ + kernel_name, \ + backend, \ + context, \ + layout, \ + registrar_id, \ + args_def_fn, \ + meta_kernel_fn, \ + arg_parse_functor_macro, \ + kernel_unfold_macro, \ + variadic_kernel_unfold_marco, \ + cpp_dtype, \ + ...) \ + _PD_CREATE_REGISTRAR_OBJECT(reg_type, \ + kernel_name, \ + backend, \ + context, \ + layout, \ + registrar_id, \ + args_def_fn, \ + meta_kernel_fn, \ + arg_parse_functor_macro, \ + kernel_unfold_macro, \ + variadic_kernel_unfold_marco, \ + cpp_dtype) \ + PD_EXPAND(_PD_KERNEL_REGISTRAR_INIT_11(reg_type, \ + kernel_name, \ + backend, \ + context, \ + layout, \ + PD_ID, \ + args_def_fn, \ + meta_kernel_fn, \ + arg_parse_functor_macro, \ + kernel_unfold_macro, \ + variadic_kernel_unfold_marco, \ __VA_ARGS__)) -#define _PD_KERNEL_REGISTRAR_INIT_13(reg_type, \ - kernel_name, \ - backend, \ - context, \ - layout, \ - registrar_id, \ - args_def_fn, \ - meta_kernel_fn, \ - cpp_dtype, \ - ...) \ - static const ::phi::KernelRegistrar PD_CONCATENATE( \ - __reg_phi_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \ - reg_type, \ - #kernel_name, \ - #backend, \ - DATALAYOUT(layout), \ - ::paddle::experimental::CppTypeToDataType::Type(), \ - ::phi::KernelArgsParseFunctor< \ - decltype(&meta_kernel_fn)>::Parse, \ - args_def_fn, \ - PHI_KERNEL(meta_kernel_fn), \ - PHI_VARIADIC_KERNEL(meta_kernel_fn)); \ - PD_EXPAND(_PD_KERNEL_REGISTRAR_INIT_12(reg_type, \ - kernel_name, \ - backend, \ - context, \ - layout, \ - PD_ID, \ - args_def_fn, \ - meta_kernel_fn, \ +#define _PD_KERNEL_REGISTRAR_INIT_13(reg_type, \ + kernel_name, \ + backend, \ + context, \ + layout, \ + registrar_id, \ + args_def_fn, \ + meta_kernel_fn, \ + arg_parse_functor_macro, \ + kernel_unfold_macro, \ + variadic_kernel_unfold_marco, \ + cpp_dtype, \ + ...) \ + _PD_CREATE_REGISTRAR_OBJECT(reg_type, \ + kernel_name, \ + backend, \ + context, \ + layout, \ + registrar_id, \ + args_def_fn, \ + meta_kernel_fn, \ + arg_parse_functor_macro, \ + kernel_unfold_macro, \ + variadic_kernel_unfold_marco, \ + cpp_dtype) \ + PD_EXPAND(_PD_KERNEL_REGISTRAR_INIT_12(reg_type, \ + kernel_name, \ + backend, \ + context, \ + layout, \ + PD_ID, \ + args_def_fn, \ + meta_kernel_fn, \ + arg_parse_functor_macro, \ + kernel_unfold_macro, \ + variadic_kernel_unfold_marco, \ __VA_ARGS__)) -#define _PD_KERNEL_REGISTRAR_INIT_14(reg_type, \ - kernel_name, \ - backend, \ - context, \ - layout, \ - registrar_id, \ - args_def_fn, \ - meta_kernel_fn, \ - cpp_dtype, \ - ...) \ - static const ::phi::KernelRegistrar PD_CONCATENATE( \ - __reg_phi_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \ - reg_type, \ - #kernel_name, \ - #backend, \ - DATALAYOUT(layout), \ - ::paddle::experimental::CppTypeToDataType::Type(), \ - ::phi::KernelArgsParseFunctor< \ - decltype(&meta_kernel_fn)>::Parse, \ - args_def_fn, \ - PHI_KERNEL(meta_kernel_fn), \ - PHI_VARIADIC_KERNEL(meta_kernel_fn)); \ - PD_EXPAND(_PD_KERNEL_REGISTRAR_INIT_13(reg_type, \ - kernel_name, \ - backend, \ - context, \ - layout, \ - PD_ID, \ - args_def_fn, \ - meta_kernel_fn, \ +#define _PD_KERNEL_REGISTRAR_INIT_14(reg_type, \ + kernel_name, \ + backend, \ + context, \ + layout, \ + registrar_id, \ + args_def_fn, \ + meta_kernel_fn, \ + arg_parse_functor_macro, \ + kernel_unfold_macro, \ + variadic_kernel_unfold_marco, \ + cpp_dtype, \ + ...) \ + _PD_CREATE_REGISTRAR_OBJECT(reg_type, \ + kernel_name, \ + backend, \ + context, \ + layout, \ + registrar_id, \ + args_def_fn, \ + meta_kernel_fn, \ + arg_parse_functor_macro, \ + kernel_unfold_macro, \ + variadic_kernel_unfold_marco, \ + cpp_dtype) \ + PD_EXPAND(_PD_KERNEL_REGISTRAR_INIT_13(reg_type, \ + kernel_name, \ + backend, \ + context, \ + layout, \ + PD_ID, \ + args_def_fn, \ + meta_kernel_fn, \ + arg_parse_functor_macro, \ + kernel_unfold_macro, \ + variadic_kernel_unfold_marco, \ __VA_ARGS__)) -#define _PD_KERNEL_REGISTRAR_INIT_15(reg_type, \ - kernel_name, \ - backend, \ - context, \ - layout, \ - registrar_id, \ - args_def_fn, \ - meta_kernel_fn, \ - cpp_dtype, \ - ...) \ - static const ::phi::KernelRegistrar PD_CONCATENATE( \ - __reg_phi_kernel_##kernel_name##_##backend##_##layout##_, registrar_id)( \ - reg_type, \ - #kernel_name, \ - #backend, \ - DATALAYOUT(layout), \ - ::paddle::experimental::CppTypeToDataType::Type(), \ - ::phi::KernelArgsParseFunctor< \ - decltype(&meta_kernel_fn)>::Parse, \ - args_def_fn, \ - PHI_KERNEL(meta_kernel_fn), \ - PHI_VARIADIC_KERNEL(meta_kernel_fn)); \ - PD_EXPAND(_PD_KERNEL_REGISTRAR_INIT_14(reg_type, \ - kernel_name, \ - backend, \ - context, \ - layout, \ - PD_ID, \ - args_def_fn, \ - meta_kernel_fn, \ +#define _PD_KERNEL_REGISTRAR_INIT_15(reg_type, \ + kernel_name, \ + backend, \ + context, \ + layout, \ + registrar_id, \ + args_def_fn, \ + meta_kernel_fn, \ + arg_parse_functor_macro, \ + kernel_unfold_macro, \ + variadic_kernel_unfold_marco, \ + cpp_dtype, \ + ...) \ + _PD_CREATE_REGISTRAR_OBJECT(reg_type, \ + kernel_name, \ + backend, \ + context, \ + layout, \ + registrar_id, \ + args_def_fn, \ + meta_kernel_fn, \ + arg_parse_functor_macro, \ + kernel_unfold_macro, \ + variadic_kernel_unfold_marco, \ + cpp_dtype) \ + PD_EXPAND(_PD_KERNEL_REGISTRAR_INIT_14(reg_type, \ + kernel_name, \ + backend, \ + context, \ + layout, \ + PD_ID, \ + args_def_fn, \ + meta_kernel_fn, \ + arg_parse_functor_macro, \ + kernel_unfold_macro, \ + variadic_kernel_unfold_marco, \ __VA_ARGS__)) /** PD_REGISTER_GENERAL_KERNEL * @@ -1085,7 +1354,7 @@ struct KernelRegistrar { reg_type, \ #kernel_name, \ #backend, \ - DATALAYOUT(layout), \ + DATA_LAYOUT(layout), \ ::phi::KernelArgsParseFunctor::Parse, \ &__PD_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout, \ PHI_KERNEL(kernel_fn), \ @@ -1105,7 +1374,7 @@ struct KernelRegistrar { reg_type, \ #kernel_name, \ #backend, \ - DATALAYOUT(layout), \ + DATA_LAYOUT(layout), \ ::phi::KernelArgsParseFunctor::Parse, \ &__PD_KERNEL_args_def_FN_##kernel_name##_##backend##_##layout, \ PHI_KERNEL(kernel_fn), \ @@ -1144,6 +1413,10 @@ struct KernelRegistrar { ::phi::backend##Context, \ layout, \ meta_kernel_fn, \ + FUNCTION_KERNEL_INSTANTIATION, \ + ARG_PARSE_FUNCTOR, \ + PHI_KERNEL, \ + PHI_VARIADIC_KERNEL, \ __VA_ARGS__) /** PD_REGISTER_PLUGIN_KERNEL @@ -1159,6 +1432,10 @@ struct KernelRegistrar { ::phi::CustomContext, \ layout, \ meta_kernel_fn, \ + FUNCTION_KERNEL_INSTANTIATION, \ + ARG_PARSE_FUNCTOR, \ + PHI_KERNEL, \ + PHI_VARIADIC_KERNEL, \ __VA_ARGS__) } // namespace phi diff --git a/paddle/phi/ops/compat/save_combine_sig.cc b/paddle/phi/ops/compat/save_combine_sig.cc index 8c9760410b..164236cb99 100644 --- a/paddle/phi/ops/compat/save_combine_sig.cc +++ b/paddle/phi/ops/compat/save_combine_sig.cc @@ -35,4 +35,6 @@ KernelSignature SaveCombineOpArgumentMapping( } // namespace phi +PD_REGISTER_BASE_KERNEL_NAME(save_combine, save_combine_tensor); + PD_REGISTER_ARG_MAPPING_FN(save_combine, phi::SaveCombineOpArgumentMapping); -- GitLab