未验证 提交 c7c1db33 编写于 作者: C Chen Weihang 提交者: GitHub

[PTen] Add standard kernel suffix set (#39404)

* add standard_suffix_set_and_remove_reshape_with_xshape

* revert reshape change

* polish reduce name
上级 63d2333e
......@@ -26,6 +26,11 @@ limitations under the License. */
namespace pten {
const std::unordered_set<std::string> standard_kernel_suffixs({
"sr", // SelectedRows kernel
"raw" // fallback kernel of origfinal fluid op
});
/**
* Some fluid ops are no longer used under the corresponding official API
* system of 2.0. These names need to correspond to the official API names
......
......@@ -66,10 +66,10 @@ KernelSignature ElementwiseDivOpArgumentMapping(
} // namespace pten
PT_REGISTER_BASE_KERNEL_NAME(elementwise_add, add_raw);
PT_REGISTER_BASE_KERNEL_NAME(elementwise_sub, subtract_raw);
PT_REGISTER_BASE_KERNEL_NAME(elementwise_mul, multiply_raw);
PT_REGISTER_BASE_KERNEL_NAME(elementwise_div, divide_raw);
PT_REGISTER_BASE_KERNEL_NAME(elementwise_add, add);
PT_REGISTER_BASE_KERNEL_NAME(elementwise_sub, subtract);
PT_REGISTER_BASE_KERNEL_NAME(elementwise_mul, multiply);
PT_REGISTER_BASE_KERNEL_NAME(elementwise_div, divide);
PT_REGISTER_BASE_KERNEL_NAME(elementwise_add_grad, add_grad);
PT_REGISTER_BASE_KERNEL_NAME(elementwise_sub_grad, subtract_grad);
......
......@@ -45,8 +45,8 @@ KernelSignature ReduceMeanOpArgumentMapping(const ArgumentMappingContext& ctx) {
} // namespace pten
PT_REGISTER_BASE_KERNEL_NAME(reduce_sum, sum_raw);
PT_REGISTER_BASE_KERNEL_NAME(reduce_mean, mean_raw);
PT_REGISTER_BASE_KERNEL_NAME(reduce_sum, sum);
PT_REGISTER_BASE_KERNEL_NAME(reduce_mean, mean);
PT_REGISTER_ARG_MAPPING_FN(reduce_sum, pten::ReduceSumOpArgumentMapping);
PT_REGISTER_ARG_MAPPING_FN(reduce_mean, pten::ReduceMeanOpArgumentMapping);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册