From c7c1db33a4491dfb1833ea83379fef1d134e195f Mon Sep 17 00:00:00 2001 From: Chen Weihang Date: Thu, 10 Feb 2022 11:29:55 +0800 Subject: [PATCH] [PTen] Add standard kernel suffix set (#39404) * add standard_suffix_set_and_remove_reshape_with_xshape * revert reshape change * polish reduce name --- paddle/pten/core/compat/op_utils.h | 5 +++++ paddle/pten/ops/compat/elementwise_sig.cc | 8 ++++---- paddle/pten/ops/compat/reduce_sig.cc | 4 ++-- 3 files changed, 11 insertions(+), 6 deletions(-) diff --git a/paddle/pten/core/compat/op_utils.h b/paddle/pten/core/compat/op_utils.h index 93090616366..950b6cd039a 100644 --- a/paddle/pten/core/compat/op_utils.h +++ b/paddle/pten/core/compat/op_utils.h @@ -26,6 +26,11 @@ limitations under the License. */ namespace pten { +const std::unordered_set standard_kernel_suffixs({ + "sr", // SelectedRows kernel + "raw" // fallback kernel of origfinal fluid op +}); + /** * Some fluid ops are no longer used under the corresponding official API * system of 2.0. These names need to correspond to the official API names diff --git a/paddle/pten/ops/compat/elementwise_sig.cc b/paddle/pten/ops/compat/elementwise_sig.cc index 8f145a875ad..57bd03f8a21 100644 --- a/paddle/pten/ops/compat/elementwise_sig.cc +++ b/paddle/pten/ops/compat/elementwise_sig.cc @@ -66,10 +66,10 @@ KernelSignature ElementwiseDivOpArgumentMapping( } // namespace pten -PT_REGISTER_BASE_KERNEL_NAME(elementwise_add, add_raw); -PT_REGISTER_BASE_KERNEL_NAME(elementwise_sub, subtract_raw); -PT_REGISTER_BASE_KERNEL_NAME(elementwise_mul, multiply_raw); -PT_REGISTER_BASE_KERNEL_NAME(elementwise_div, divide_raw); +PT_REGISTER_BASE_KERNEL_NAME(elementwise_add, add); +PT_REGISTER_BASE_KERNEL_NAME(elementwise_sub, subtract); +PT_REGISTER_BASE_KERNEL_NAME(elementwise_mul, multiply); +PT_REGISTER_BASE_KERNEL_NAME(elementwise_div, divide); PT_REGISTER_BASE_KERNEL_NAME(elementwise_add_grad, add_grad); PT_REGISTER_BASE_KERNEL_NAME(elementwise_sub_grad, subtract_grad); diff --git a/paddle/pten/ops/compat/reduce_sig.cc b/paddle/pten/ops/compat/reduce_sig.cc index a8a2b517d3e..10f73d8122e 100644 --- a/paddle/pten/ops/compat/reduce_sig.cc +++ b/paddle/pten/ops/compat/reduce_sig.cc @@ -45,8 +45,8 @@ KernelSignature ReduceMeanOpArgumentMapping(const ArgumentMappingContext& ctx) { } // namespace pten -PT_REGISTER_BASE_KERNEL_NAME(reduce_sum, sum_raw); -PT_REGISTER_BASE_KERNEL_NAME(reduce_mean, mean_raw); +PT_REGISTER_BASE_KERNEL_NAME(reduce_sum, sum); +PT_REGISTER_BASE_KERNEL_NAME(reduce_mean, mean); PT_REGISTER_ARG_MAPPING_FN(reduce_sum, pten::ReduceSumOpArgumentMapping); PT_REGISTER_ARG_MAPPING_FN(reduce_mean, pten::ReduceMeanOpArgumentMapping); -- GitLab