From 3d35aa802fea2d91d17160584e767d635026f42e Mon Sep 17 00:00:00 2001 From: zyfncg Date: Thu, 1 Dec 2022 21:12:29 +0800 Subject: [PATCH] Rename kernel for top_k, slogdeterminant, generate_proposals_v2 (#48594) * rename kernel for top_k, slogdeterminant, generate_proposals_v2 * fix bug --- paddle/phi/api/yaml/legacy_backward.yaml | 4 +- paddle/phi/api/yaml/legacy_ops.yaml | 6 +-- paddle/phi/core/compat/op_utils.h | 3 +- ...kernel.cc => generate_proposals_kernel.cc} | 36 ++++++++--------- .../cpu/slogdeterminant_grad_kernel.cc | 2 +- .../phi/kernels/cpu/slogdeterminant_kernel.cc | 8 +--- paddle/phi/kernels/cpu/top_k_grad_kernel.cc | 2 +- paddle/phi/kernels/cpu/top_k_kernel.cc | 2 +- .../phi/kernels/generate_proposals_kernel.h | 38 ++++++++++++++++++ .../kernels/generate_proposals_v2_kernel.h | 38 ------------------ ...kernel.cu => generate_proposals_kernel.cu} | 39 +++++++++---------- .../gpu/slogdeterminant_grad_kernel.cu | 2 +- .../phi/kernels/gpu/slogdeterminant_kernel.cu | 8 +--- paddle/phi/kernels/gpu/top_k_grad_kernel.cu | 2 +- paddle/phi/kernels/gpu/top_k_kernel.cu | 2 +- ...kernel.cc => generate_proposals_kernel.cc} | 39 +++++++++---------- paddle/phi/kernels/xpu/top_k_kernel.cc | 2 +- .../phi/ops/compat/generate_proposals_sig.cc | 19 +++++++++ paddle/phi/ops/compat/slogdeterminant_sig.cc | 5 ++- paddle/phi/ops/compat/top_k_sig.cc | 10 ++--- 20 files changed, 138 insertions(+), 129 deletions(-) rename paddle/phi/kernels/cpu/{generate_proposals_v2_kernel.cc => generate_proposals_kernel.cc} (93%) create mode 100644 paddle/phi/kernels/generate_proposals_kernel.h delete mode 100644 paddle/phi/kernels/generate_proposals_v2_kernel.h rename paddle/phi/kernels/gpu/{generate_proposals_v2_kernel.cu => generate_proposals_kernel.cu} (95%) rename paddle/phi/kernels/xpu/{generate_proposals_v2_kernel.cc => generate_proposals_kernel.cc} (93%) create mode 100644 paddle/phi/ops/compat/generate_proposals_sig.cc diff --git a/paddle/phi/api/yaml/legacy_backward.yaml b/paddle/phi/api/yaml/legacy_backward.yaml index 47ba24b091d..064c6b00a88 100755 --- a/paddle/phi/api/yaml/legacy_backward.yaml +++ b/paddle/phi/api/yaml/legacy_backward.yaml @@ -1504,7 +1504,7 @@ func : UnchangedInferMeta param : [x] kernel : - func : slogdeterminant_grad + func : slogdet_grad - backward_op : softmax_grad forward : softmax (Tensor x, int axis) -> Tensor(out) @@ -1713,7 +1713,7 @@ func : UnchangedInferMeta param : [x] kernel : - func : top_k_grad + func : topk_grad - backward_op : transpose_double_grad forward : transpose_grad (Tensor grad_out, int[] perm) -> Tensor(grad_x) diff --git a/paddle/phi/api/yaml/legacy_ops.yaml b/paddle/phi/api/yaml/legacy_ops.yaml index 1bc0fc7f0aa..d32a853b8c0 100755 --- a/paddle/phi/api/yaml/legacy_ops.yaml +++ b/paddle/phi/api/yaml/legacy_ops.yaml @@ -878,7 +878,7 @@ infer_meta : func : GenerateProposalsV2InferMeta kernel : - func : generate_proposals_v2 + func : generate_proposals - op : greater_equal args : (Tensor x, Tensor y) @@ -1935,7 +1935,7 @@ infer_meta : func : UnchangedInferMeta kernel : - func : slogdeterminant + func : slogdet backward : slogdet_grad - op : softmax @@ -2100,7 +2100,7 @@ infer_meta : func : TopKInferMeta kernel : - func : top_k + func : topk backward : topk_grad - op : transpose diff --git a/paddle/phi/core/compat/op_utils.h b/paddle/phi/core/compat/op_utils.h index b836359ae81..2145d73cd9f 100644 --- a/paddle/phi/core/compat/op_utils.h +++ b/paddle/phi/core/compat/op_utils.h @@ -83,7 +83,8 @@ static const std::unordered_set deprecated_op_names( "bicubic_interp", "bicubic_interp_grad", "crop", - "crop_grad"}); + "crop_grad", + "generate_proposals"}); class DefaultKernelSignatureMap { public: diff --git a/paddle/phi/kernels/cpu/generate_proposals_v2_kernel.cc b/paddle/phi/kernels/cpu/generate_proposals_kernel.cc similarity index 93% rename from paddle/phi/kernels/cpu/generate_proposals_v2_kernel.cc rename to paddle/phi/kernels/cpu/generate_proposals_kernel.cc index 22f39555449..4a9569c045c 100644 --- a/paddle/phi/kernels/cpu/generate_proposals_v2_kernel.cc +++ b/paddle/phi/kernels/cpu/generate_proposals_kernel.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/phi/kernels/generate_proposals_v2_kernel.h" +#include "paddle/phi/kernels/generate_proposals_kernel.h" #include "paddle/phi/backends/cpu/cpu_context.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/funcs/detection/nms_util.h" @@ -284,21 +284,21 @@ std::pair ProposalForOneImage( } template -void GenerateProposalsV2Kernel(const Context& ctx, - const DenseTensor& scores, - const DenseTensor& bbox_deltas, - const DenseTensor& im_shape, - const DenseTensor& anchors, - const DenseTensor& variances, - int pre_nms_top_n, - int post_nms_top_n, - float nms_thresh, - float min_size, - float eta, - bool pixel_offset, - DenseTensor* rpn_rois, - DenseTensor* rpn_roi_probs, - DenseTensor* rpn_rois_num) { +void GenerateProposalsKernel(const Context& ctx, + const DenseTensor& scores, + const DenseTensor& bbox_deltas, + const DenseTensor& im_shape, + const DenseTensor& anchors, + const DenseTensor& variances, + int pre_nms_top_n, + int post_nms_top_n, + float nms_thresh, + float min_size, + float eta, + bool pixel_offset, + DenseTensor* rpn_rois, + DenseTensor* rpn_roi_probs, + DenseTensor* rpn_rois_num) { auto& scores_dim = scores.dims(); int64_t num = scores_dim[0]; int64_t c_score = scores_dim[1]; @@ -384,9 +384,9 @@ void GenerateProposalsV2Kernel(const Context& ctx, } // namespace phi -PD_REGISTER_KERNEL(generate_proposals_v2, +PD_REGISTER_KERNEL(generate_proposals, CPU, ALL_LAYOUT, - phi::GenerateProposalsV2Kernel, + phi::GenerateProposalsKernel, float, double) {} diff --git a/paddle/phi/kernels/cpu/slogdeterminant_grad_kernel.cc b/paddle/phi/kernels/cpu/slogdeterminant_grad_kernel.cc index 0854895f0c1..5f265ab9bc8 100644 --- a/paddle/phi/kernels/cpu/slogdeterminant_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/slogdeterminant_grad_kernel.cc @@ -17,7 +17,7 @@ #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/impl/slogdeterminant_grad_kernel_impl.h" -PD_REGISTER_KERNEL(slogdeterminant_grad, +PD_REGISTER_KERNEL(slogdet_grad, CPU, ALL_LAYOUT, phi::SlogDeterminantGradKernel, diff --git a/paddle/phi/kernels/cpu/slogdeterminant_kernel.cc b/paddle/phi/kernels/cpu/slogdeterminant_kernel.cc index 6bd9f0296c6..8e96c163164 100644 --- a/paddle/phi/kernels/cpu/slogdeterminant_kernel.cc +++ b/paddle/phi/kernels/cpu/slogdeterminant_kernel.cc @@ -17,9 +17,5 @@ #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/impl/slogdeterminant_kernel_impl.h" -PD_REGISTER_KERNEL(slogdeterminant, - CPU, - ALL_LAYOUT, - phi::SlogDeterminantKernel, - float, - double) {} +PD_REGISTER_KERNEL( + slogdet, CPU, ALL_LAYOUT, phi::SlogDeterminantKernel, float, double) {} diff --git a/paddle/phi/kernels/cpu/top_k_grad_kernel.cc b/paddle/phi/kernels/cpu/top_k_grad_kernel.cc index e44f85fb6c0..2d02b0ab523 100644 --- a/paddle/phi/kernels/cpu/top_k_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/top_k_grad_kernel.cc @@ -141,7 +141,7 @@ void TopkGradKernel(const Context& dev_ctx, } // namespace phi -PD_REGISTER_KERNEL(top_k_grad, +PD_REGISTER_KERNEL(topk_grad, CPU, ALL_LAYOUT, phi::TopkGradKernel, diff --git a/paddle/phi/kernels/cpu/top_k_kernel.cc b/paddle/phi/kernels/cpu/top_k_kernel.cc index 4ac16667ce2..3e946803660 100644 --- a/paddle/phi/kernels/cpu/top_k_kernel.cc +++ b/paddle/phi/kernels/cpu/top_k_kernel.cc @@ -227,4 +227,4 @@ void TopkKernel(const Context& dev_ctx, } // namespace phi PD_REGISTER_KERNEL( - top_k, CPU, ALL_LAYOUT, phi::TopkKernel, float, double, int32_t, int64_t) {} + topk, CPU, ALL_LAYOUT, phi::TopkKernel, float, double, int32_t, int64_t) {} diff --git a/paddle/phi/kernels/generate_proposals_kernel.h b/paddle/phi/kernels/generate_proposals_kernel.h new file mode 100644 index 00000000000..e14b250a7d0 --- /dev/null +++ b/paddle/phi/kernels/generate_proposals_kernel.h @@ -0,0 +1,38 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void GenerateProposalsKernel(const Context& ctx, + const DenseTensor& scores, + const DenseTensor& bbox_deltas, + const DenseTensor& im_shape, + const DenseTensor& anchors, + const DenseTensor& variances, + int pre_nms_top_n, + int post_nms_top_n, + float nms_thresh, + float min_size, + float eta, + bool pixel_offset, + DenseTensor* rpn_rois, + DenseTensor* rpn_roi_probs, + DenseTensor* rpn_rois_num); + +} // namespace phi diff --git a/paddle/phi/kernels/generate_proposals_v2_kernel.h b/paddle/phi/kernels/generate_proposals_v2_kernel.h deleted file mode 100644 index c2fc2677039..00000000000 --- a/paddle/phi/kernels/generate_proposals_v2_kernel.h +++ /dev/null @@ -1,38 +0,0 @@ -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#include "paddle/phi/core/dense_tensor.h" - -namespace phi { - -template -void GenerateProposalsV2Kernel(const Context& ctx, - const DenseTensor& scores, - const DenseTensor& bbox_deltas, - const DenseTensor& im_shape, - const DenseTensor& anchors, - const DenseTensor& variances, - int pre_nms_top_n, - int post_nms_top_n, - float nms_thresh, - float min_size, - float eta, - bool pixel_offset, - DenseTensor* rpn_rois, - DenseTensor* rpn_roi_probs, - DenseTensor* rpn_rois_num); - -} // namespace phi diff --git a/paddle/phi/kernels/gpu/generate_proposals_v2_kernel.cu b/paddle/phi/kernels/gpu/generate_proposals_kernel.cu similarity index 95% rename from paddle/phi/kernels/gpu/generate_proposals_v2_kernel.cu rename to paddle/phi/kernels/gpu/generate_proposals_kernel.cu index 91abb290dd8..f750bd5fe7e 100644 --- a/paddle/phi/kernels/gpu/generate_proposals_v2_kernel.cu +++ b/paddle/phi/kernels/gpu/generate_proposals_kernel.cu @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/phi/kernels/generate_proposals_v2_kernel.h" +#include "paddle/phi/kernels/generate_proposals_kernel.h" #include #include @@ -458,21 +458,21 @@ static std::pair ProposalForOneImage( } template -void GenerateProposalsV2Kernel(const Context &ctx, - const DenseTensor &scores, - const DenseTensor &bbox_deltas, - const DenseTensor &im_shape, - const DenseTensor &anchors, - const DenseTensor &variances, - int pre_nms_top_n, - int post_nms_top_n, - float nms_thresh, - float min_size, - float eta, - bool pixel_offset, - DenseTensor *rpn_rois, - DenseTensor *rpn_roi_probs, - DenseTensor *rpn_rois_num) { +void GenerateProposalsKernel(const Context &ctx, + const DenseTensor &scores, + const DenseTensor &bbox_deltas, + const DenseTensor &im_shape, + const DenseTensor &anchors, + const DenseTensor &variances, + int pre_nms_top_n, + int post_nms_top_n, + float nms_thresh, + float min_size, + float eta, + bool pixel_offset, + DenseTensor *rpn_rois, + DenseTensor *rpn_roi_probs, + DenseTensor *rpn_rois_num) { PADDLE_ENFORCE_GE( eta, 1., @@ -584,8 +584,5 @@ void GenerateProposalsV2Kernel(const Context &ctx, } // namespace phi -PD_REGISTER_KERNEL(generate_proposals_v2, - GPU, - ALL_LAYOUT, - phi::GenerateProposalsV2Kernel, - float) {} +PD_REGISTER_KERNEL( + generate_proposals, GPU, ALL_LAYOUT, phi::GenerateProposalsKernel, float) {} diff --git a/paddle/phi/kernels/gpu/slogdeterminant_grad_kernel.cu b/paddle/phi/kernels/gpu/slogdeterminant_grad_kernel.cu index 153a97fa7a5..f9f9055f57a 100644 --- a/paddle/phi/kernels/gpu/slogdeterminant_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/slogdeterminant_grad_kernel.cu @@ -17,7 +17,7 @@ #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/impl/slogdeterminant_grad_kernel_impl.h" -PD_REGISTER_KERNEL(slogdeterminant_grad, +PD_REGISTER_KERNEL(slogdet_grad, GPU, ALL_LAYOUT, phi::SlogDeterminantGradKernel, diff --git a/paddle/phi/kernels/gpu/slogdeterminant_kernel.cu b/paddle/phi/kernels/gpu/slogdeterminant_kernel.cu index e94dc117fb9..14a9b7e387b 100644 --- a/paddle/phi/kernels/gpu/slogdeterminant_kernel.cu +++ b/paddle/phi/kernels/gpu/slogdeterminant_kernel.cu @@ -17,9 +17,5 @@ #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/impl/slogdeterminant_kernel_impl.h" -PD_REGISTER_KERNEL(slogdeterminant, - GPU, - ALL_LAYOUT, - phi::SlogDeterminantKernel, - float, - double) {} +PD_REGISTER_KERNEL( + slogdet, GPU, ALL_LAYOUT, phi::SlogDeterminantKernel, float, double) {} diff --git a/paddle/phi/kernels/gpu/top_k_grad_kernel.cu b/paddle/phi/kernels/gpu/top_k_grad_kernel.cu index ae95923e7f6..e20fec80687 100644 --- a/paddle/phi/kernels/gpu/top_k_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/top_k_grad_kernel.cu @@ -76,7 +76,7 @@ void TopkGradKernel(const Context& dev_ctx, } // namespace phi -PD_REGISTER_KERNEL(top_k_grad, +PD_REGISTER_KERNEL(topk_grad, GPU, ALL_LAYOUT, phi::TopkGradKernel, diff --git a/paddle/phi/kernels/gpu/top_k_kernel.cu b/paddle/phi/kernels/gpu/top_k_kernel.cu index c9ea86472f3..a455d9305d9 100644 --- a/paddle/phi/kernels/gpu/top_k_kernel.cu +++ b/paddle/phi/kernels/gpu/top_k_kernel.cu @@ -332,7 +332,7 @@ void TopkKernel(const Context& dev_ctx, } // namespace phi -PD_REGISTER_KERNEL(top_k, +PD_REGISTER_KERNEL(topk, GPU, ALL_LAYOUT, phi::TopkKernel, diff --git a/paddle/phi/kernels/xpu/generate_proposals_v2_kernel.cc b/paddle/phi/kernels/xpu/generate_proposals_kernel.cc similarity index 93% rename from paddle/phi/kernels/xpu/generate_proposals_v2_kernel.cc rename to paddle/phi/kernels/xpu/generate_proposals_kernel.cc index 5a91f5ad9d5..bf7f3e90bfd 100644 --- a/paddle/phi/kernels/xpu/generate_proposals_v2_kernel.cc +++ b/paddle/phi/kernels/xpu/generate_proposals_kernel.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/phi/kernels/generate_proposals_v2_kernel.h" +#include "paddle/phi/kernels/generate_proposals_kernel.h" #include "paddle/phi/backends/xpu/enforce_xpu.h" #include "paddle/phi/backends/xpu/xpu_context.h" @@ -272,21 +272,21 @@ std::pair ProposalForOneImage( } template -void GenerateProposalsV2Kernel(const Context& dev_ctx, - const DenseTensor& scores, - const DenseTensor& bbox_deltas, - const DenseTensor& im_shape, - const DenseTensor& anchors, - const DenseTensor& variances, - int pre_nms_top_n, - int post_nms_top_n, - float nms_thresh, - float min_size, - float eta, - bool pixel_offset, - DenseTensor* rpn_rois, - DenseTensor* rpn_roi_probs, - DenseTensor* rpn_rois_num) { +void GenerateProposalsKernel(const Context& dev_ctx, + const DenseTensor& scores, + const DenseTensor& bbox_deltas, + const DenseTensor& im_shape, + const DenseTensor& anchors, + const DenseTensor& variances, + int pre_nms_top_n, + int post_nms_top_n, + float nms_thresh, + float min_size, + float eta, + bool pixel_offset, + DenseTensor* rpn_rois, + DenseTensor* rpn_roi_probs, + DenseTensor* rpn_rois_num) { PADDLE_ENFORCE_GE(eta, 1., phi::errors::InvalidArgument( @@ -408,8 +408,5 @@ void GenerateProposalsV2Kernel(const Context& dev_ctx, } } // namespace phi -PD_REGISTER_KERNEL(generate_proposals_v2, - XPU, - ALL_LAYOUT, - phi::GenerateProposalsV2Kernel, - float) {} +PD_REGISTER_KERNEL( + generate_proposals, XPU, ALL_LAYOUT, phi::GenerateProposalsKernel, float) {} diff --git a/paddle/phi/kernels/xpu/top_k_kernel.cc b/paddle/phi/kernels/xpu/top_k_kernel.cc index f2592f9501e..0fdb66c4129 100644 --- a/paddle/phi/kernels/xpu/top_k_kernel.cc +++ b/paddle/phi/kernels/xpu/top_k_kernel.cc @@ -173,4 +173,4 @@ void TopkKernel(const Context& dev_ctx, } // namespace phi -PD_REGISTER_KERNEL(top_k, XPU, ALL_LAYOUT, phi::TopkKernel, float) {} +PD_REGISTER_KERNEL(topk, XPU, ALL_LAYOUT, phi::TopkKernel, float) {} diff --git a/paddle/phi/ops/compat/generate_proposals_sig.cc b/paddle/phi/ops/compat/generate_proposals_sig.cc new file mode 100644 index 00000000000..fc6696d4a27 --- /dev/null +++ b/paddle/phi/ops/compat/generate_proposals_sig.cc @@ -0,0 +1,19 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/core/compat/op_utils.h" + +PD_REGISTER_BASE_KERNEL_NAME(generate_proposals_v2, generate_proposals); +PD_REGISTER_BASE_KERNEL_NAME(generate_proposals_v2_grad, + generate_proposals_grad); diff --git a/paddle/phi/ops/compat/slogdeterminant_sig.cc b/paddle/phi/ops/compat/slogdeterminant_sig.cc index e4eeca05152..2e63a90d929 100644 --- a/paddle/phi/ops/compat/slogdeterminant_sig.cc +++ b/paddle/phi/ops/compat/slogdeterminant_sig.cc @@ -19,10 +19,13 @@ namespace phi { KernelSignature SlogDeterminantGradOpArgumentMapping( const ArgumentMappingContext& ctx) { return KernelSignature( - "slogdeterminant_grad", {"Input", "Out", "Out@GRAD"}, {}, {"Input@GRAD"}); + "slogdet_grad", {"Input", "Out", "Out@GRAD"}, {}, {"Input@GRAD"}); } } // namespace phi +PD_REGISTER_BASE_KERNEL_NAME(slogdeterminant, slogdet); +PD_REGISTER_BASE_KERNEL_NAME(slogdeterminant_grad, slogdet_grad); + PD_REGISTER_ARG_MAPPING_FN(slogdeterminant_grad, phi::SlogDeterminantGradOpArgumentMapping); diff --git a/paddle/phi/ops/compat/top_k_sig.cc b/paddle/phi/ops/compat/top_k_sig.cc index c1073f9efdc..0f3a5c1c0b5 100644 --- a/paddle/phi/ops/compat/top_k_sig.cc +++ b/paddle/phi/ops/compat/top_k_sig.cc @@ -19,16 +19,16 @@ namespace phi { KernelSignature TopkOpArgumentMapping(const ArgumentMappingContext& ctx) { if (ctx.HasInput("K")) { return KernelSignature( - "top_k", {"X"}, {"K", "axis", "largest", "sorted"}, {"Out", "Indices"}); + "topk", {"X"}, {"K", "axis", "largest", "sorted"}, {"Out", "Indices"}); } else { return KernelSignature( - "top_k", {"X"}, {"k", "axis", "largest", "sorted"}, {"Out", "Indices"}); + "topk", {"X"}, {"k", "axis", "largest", "sorted"}, {"Out", "Indices"}); } } KernelSignature TopkGradOpArgumentMapping(const ArgumentMappingContext& ctx) { - return KernelSignature("top_k_grad", + return KernelSignature("topk_grad", {"X", "Indices", "Out@GRAD"}, {"k", "axis", "largest", "sorted"}, {"X@GRAD"}); @@ -36,7 +36,7 @@ KernelSignature TopkGradOpArgumentMapping(const ArgumentMappingContext& ctx) { } // namespace phi -PD_REGISTER_BASE_KERNEL_NAME(top_k_v2, top_k); -PD_REGISTER_BASE_KERNEL_NAME(top_k_v2_grad, top_k_grad); +PD_REGISTER_BASE_KERNEL_NAME(top_k_v2, topk); +PD_REGISTER_BASE_KERNEL_NAME(top_k_v2_grad, topk_grad); PD_REGISTER_ARG_MAPPING_FN(top_k_v2, phi::TopkOpArgumentMapping); PD_REGISTER_ARG_MAPPING_FN(top_k_v2_grad, phi::TopkGradOpArgumentMapping); -- GitLab