未验证 提交 c7c1db33 编写于 作者: C Chen Weihang 提交者: GitHub

[PTen] Add standard kernel suffix set (#39404)

* add standard_suffix_set_and_remove_reshape_with_xshape

* revert reshape change

* polish reduce name
上级 63d2333e
...@@ -26,6 +26,11 @@ limitations under the License. */ ...@@ -26,6 +26,11 @@ limitations under the License. */
namespace pten { namespace pten {
const std::unordered_set<std::string> standard_kernel_suffixs({
"sr", // SelectedRows kernel
"raw" // fallback kernel of origfinal fluid op
});
/** /**
* Some fluid ops are no longer used under the corresponding official API * Some fluid ops are no longer used under the corresponding official API
* system of 2.0. These names need to correspond to the official API names * system of 2.0. These names need to correspond to the official API names
......
...@@ -66,10 +66,10 @@ KernelSignature ElementwiseDivOpArgumentMapping( ...@@ -66,10 +66,10 @@ KernelSignature ElementwiseDivOpArgumentMapping(
} // namespace pten } // namespace pten
PT_REGISTER_BASE_KERNEL_NAME(elementwise_add, add_raw); PT_REGISTER_BASE_KERNEL_NAME(elementwise_add, add);
PT_REGISTER_BASE_KERNEL_NAME(elementwise_sub, subtract_raw); PT_REGISTER_BASE_KERNEL_NAME(elementwise_sub, subtract);
PT_REGISTER_BASE_KERNEL_NAME(elementwise_mul, multiply_raw); PT_REGISTER_BASE_KERNEL_NAME(elementwise_mul, multiply);
PT_REGISTER_BASE_KERNEL_NAME(elementwise_div, divide_raw); PT_REGISTER_BASE_KERNEL_NAME(elementwise_div, divide);
PT_REGISTER_BASE_KERNEL_NAME(elementwise_add_grad, add_grad); PT_REGISTER_BASE_KERNEL_NAME(elementwise_add_grad, add_grad);
PT_REGISTER_BASE_KERNEL_NAME(elementwise_sub_grad, subtract_grad); PT_REGISTER_BASE_KERNEL_NAME(elementwise_sub_grad, subtract_grad);
......
...@@ -45,8 +45,8 @@ KernelSignature ReduceMeanOpArgumentMapping(const ArgumentMappingContext& ctx) { ...@@ -45,8 +45,8 @@ KernelSignature ReduceMeanOpArgumentMapping(const ArgumentMappingContext& ctx) {
} // namespace pten } // namespace pten
PT_REGISTER_BASE_KERNEL_NAME(reduce_sum, sum_raw); PT_REGISTER_BASE_KERNEL_NAME(reduce_sum, sum);
PT_REGISTER_BASE_KERNEL_NAME(reduce_mean, mean_raw); PT_REGISTER_BASE_KERNEL_NAME(reduce_mean, mean);
PT_REGISTER_ARG_MAPPING_FN(reduce_sum, pten::ReduceSumOpArgumentMapping); PT_REGISTER_ARG_MAPPING_FN(reduce_sum, pten::ReduceSumOpArgumentMapping);
PT_REGISTER_ARG_MAPPING_FN(reduce_mean, pten::ReduceMeanOpArgumentMapping); PT_REGISTER_ARG_MAPPING_FN(reduce_mean, pten::ReduceMeanOpArgumentMapping);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册