未验证 提交 c5a7da4b 编写于 作者: Y YuanRisheng 提交者: GitHub

[PTen]Add alias name for matmul and remove redundant member in kernel factory (#38011)

* add alias kernel name

* modify code as suggestions

* add alias name for matmul and remove redundant member in kernel factory
上级 ae40370d
......@@ -27,13 +27,13 @@ const std::unordered_map<std::string, std::string> kernel_alias_name_map = {
{"fill_any_like", "full_like"},
{"fill_constant", "full"},
{"flatten_contiguous_range", "flatten"},
// {"matmul_v2", "matmul"},
{"matmul_v2", "matmul"},
{"reduce_mean", "mean"},
{"reduce_sum", "sum"},
{"reshape2", "reshape"},
// fluid kernel "mean/reshape/matmul/flatten/sum" should be deprecated
{"flatten", "deprecated"},
// {"matmul", "deprecated"},
{"matmul", "deprecated"},
{"mean", "deprecated"},
{"reshape", "deprecated"},
{"sum", "deprecated"}};
......
......@@ -265,12 +265,8 @@ class KernelFactory {
KernelMap& kernels() { return kernels_; }
void InsertCompatibleOpType(const std::string& op_type) {
compatible_op_types_.insert(op_type);
}
bool HasCompatiblePtenKernel(const std::string& op_type) const {
return compatible_op_types_.count(TransToPtenKernelName(op_type)) > 0;
return kernels_.find(TransToPtenKernelName(op_type)) != kernels_.end();
}
const Kernel& SelectKernelOrThrowError(const KernelName& kernel_name,
......@@ -288,9 +284,6 @@ class KernelFactory {
KernelFactory() = default;
KernelMap kernels_;
// Used to be compatible with the original execution system and
// quickly confirm whether the new kernel can be called
std::unordered_set<std::string> compatible_op_types_;
};
/** operator << overload **/
......
......@@ -143,7 +143,6 @@ struct KernelRegistrar {
Kernel kernel(kernel_fn);
args_parse_fn(kernel_key, kernel.mutable_args_def());
args_def_fn(&kernel);
KernelFactory::Instance().InsertCompatibleOpType(kernel_name.name());
KernelFactory::Instance().kernels()[kernel_name][kernel_key] = kernel;
}
};
......
......@@ -85,4 +85,4 @@ PT_REGISTER_KERNEL(dot,
complex128) {}
PT_REGISTER_KERNEL(
matmul_v2, CPU, ANY, pten::Matmul, float, double, complex64, complex128) {}
matmul, CPU, ANY, pten::Matmul, float, double, complex64, complex128) {}
......@@ -69,7 +69,7 @@ PT_REGISTER_KERNEL(dot,
complex64,
complex128) {}
PT_REGISTER_KERNEL(matmul_v2,
PT_REGISTER_KERNEL(matmul,
CUDA,
ANY,
pten::Matmul,
......
......@@ -76,7 +76,7 @@
infer_meta :
func : MatmulInferMeta
kernel :
func : matmul_v2
func : matmul
- api : mean
args : (const Tensor& x, const std::vector<int64_t>& axis, bool keep_dim)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册