diff --git a/paddle/pten/core/compat/op_utils.h b/paddle/pten/core/compat/op_utils.h index 93090616366f007427b6b1d5d20608545a13f13f..950b6cd039a285611d43139613e9b039f85e3467 100644 --- a/paddle/pten/core/compat/op_utils.h +++ b/paddle/pten/core/compat/op_utils.h @@ -26,6 +26,11 @@ limitations under the License. */ namespace pten { +const std::unordered_set standard_kernel_suffixs({ + "sr", // SelectedRows kernel + "raw" // fallback kernel of origfinal fluid op +}); + /** * Some fluid ops are no longer used under the corresponding official API * system of 2.0. These names need to correspond to the official API names diff --git a/paddle/pten/ops/compat/elementwise_sig.cc b/paddle/pten/ops/compat/elementwise_sig.cc index 8f145a875ad2aaccd99b05e7c4f553081cef8c58..57bd03f8a21d58f4f8f23776b60d5a297d5e9848 100644 --- a/paddle/pten/ops/compat/elementwise_sig.cc +++ b/paddle/pten/ops/compat/elementwise_sig.cc @@ -66,10 +66,10 @@ KernelSignature ElementwiseDivOpArgumentMapping( } // namespace pten -PT_REGISTER_BASE_KERNEL_NAME(elementwise_add, add_raw); -PT_REGISTER_BASE_KERNEL_NAME(elementwise_sub, subtract_raw); -PT_REGISTER_BASE_KERNEL_NAME(elementwise_mul, multiply_raw); -PT_REGISTER_BASE_KERNEL_NAME(elementwise_div, divide_raw); +PT_REGISTER_BASE_KERNEL_NAME(elementwise_add, add); +PT_REGISTER_BASE_KERNEL_NAME(elementwise_sub, subtract); +PT_REGISTER_BASE_KERNEL_NAME(elementwise_mul, multiply); +PT_REGISTER_BASE_KERNEL_NAME(elementwise_div, divide); PT_REGISTER_BASE_KERNEL_NAME(elementwise_add_grad, add_grad); PT_REGISTER_BASE_KERNEL_NAME(elementwise_sub_grad, subtract_grad); diff --git a/paddle/pten/ops/compat/reduce_sig.cc b/paddle/pten/ops/compat/reduce_sig.cc index a8a2b517d3e9d37e7078302a736e7438d1d0d4c3..10f73d8122e4eea2abb143ffcc79cfc6d99807be 100644 --- a/paddle/pten/ops/compat/reduce_sig.cc +++ b/paddle/pten/ops/compat/reduce_sig.cc @@ -45,8 +45,8 @@ KernelSignature ReduceMeanOpArgumentMapping(const ArgumentMappingContext& ctx) { } // namespace pten -PT_REGISTER_BASE_KERNEL_NAME(reduce_sum, sum_raw); -PT_REGISTER_BASE_KERNEL_NAME(reduce_mean, mean_raw); +PT_REGISTER_BASE_KERNEL_NAME(reduce_sum, sum); +PT_REGISTER_BASE_KERNEL_NAME(reduce_mean, mean); PT_REGISTER_ARG_MAPPING_FN(reduce_sum, pten::ReduceSumOpArgumentMapping); PT_REGISTER_ARG_MAPPING_FN(reduce_mean, pten::ReduceMeanOpArgumentMapping);