未验证 提交 c5a7da4b 编写于 作者: Y YuanRisheng 提交者: GitHub

[PTen]Add alias name for matmul and remove redundant member in kernel factory (#38011)

* add alias kernel name

* modify code as suggestions

* add alias name for matmul and remove redundant member in kernel factory
上级 ae40370d
...@@ -27,13 +27,13 @@ const std::unordered_map<std::string, std::string> kernel_alias_name_map = { ...@@ -27,13 +27,13 @@ const std::unordered_map<std::string, std::string> kernel_alias_name_map = {
{"fill_any_like", "full_like"}, {"fill_any_like", "full_like"},
{"fill_constant", "full"}, {"fill_constant", "full"},
{"flatten_contiguous_range", "flatten"}, {"flatten_contiguous_range", "flatten"},
// {"matmul_v2", "matmul"}, {"matmul_v2", "matmul"},
{"reduce_mean", "mean"}, {"reduce_mean", "mean"},
{"reduce_sum", "sum"}, {"reduce_sum", "sum"},
{"reshape2", "reshape"}, {"reshape2", "reshape"},
// fluid kernel "mean/reshape/matmul/flatten/sum" should be deprecated // fluid kernel "mean/reshape/matmul/flatten/sum" should be deprecated
{"flatten", "deprecated"}, {"flatten", "deprecated"},
// {"matmul", "deprecated"}, {"matmul", "deprecated"},
{"mean", "deprecated"}, {"mean", "deprecated"},
{"reshape", "deprecated"}, {"reshape", "deprecated"},
{"sum", "deprecated"}}; {"sum", "deprecated"}};
......
...@@ -265,12 +265,8 @@ class KernelFactory { ...@@ -265,12 +265,8 @@ class KernelFactory {
KernelMap& kernels() { return kernels_; } KernelMap& kernels() { return kernels_; }
void InsertCompatibleOpType(const std::string& op_type) {
compatible_op_types_.insert(op_type);
}
bool HasCompatiblePtenKernel(const std::string& op_type) const { bool HasCompatiblePtenKernel(const std::string& op_type) const {
return compatible_op_types_.count(TransToPtenKernelName(op_type)) > 0; return kernels_.find(TransToPtenKernelName(op_type)) != kernels_.end();
} }
const Kernel& SelectKernelOrThrowError(const KernelName& kernel_name, const Kernel& SelectKernelOrThrowError(const KernelName& kernel_name,
...@@ -288,9 +284,6 @@ class KernelFactory { ...@@ -288,9 +284,6 @@ class KernelFactory {
KernelFactory() = default; KernelFactory() = default;
KernelMap kernels_; KernelMap kernels_;
// Used to be compatible with the original execution system and
// quickly confirm whether the new kernel can be called
std::unordered_set<std::string> compatible_op_types_;
}; };
/** operator << overload **/ /** operator << overload **/
......
...@@ -143,7 +143,6 @@ struct KernelRegistrar { ...@@ -143,7 +143,6 @@ struct KernelRegistrar {
Kernel kernel(kernel_fn); Kernel kernel(kernel_fn);
args_parse_fn(kernel_key, kernel.mutable_args_def()); args_parse_fn(kernel_key, kernel.mutable_args_def());
args_def_fn(&kernel); args_def_fn(&kernel);
KernelFactory::Instance().InsertCompatibleOpType(kernel_name.name());
KernelFactory::Instance().kernels()[kernel_name][kernel_key] = kernel; KernelFactory::Instance().kernels()[kernel_name][kernel_key] = kernel;
} }
}; };
......
...@@ -85,4 +85,4 @@ PT_REGISTER_KERNEL(dot, ...@@ -85,4 +85,4 @@ PT_REGISTER_KERNEL(dot,
complex128) {} complex128) {}
PT_REGISTER_KERNEL( PT_REGISTER_KERNEL(
matmul_v2, CPU, ANY, pten::Matmul, float, double, complex64, complex128) {} matmul, CPU, ANY, pten::Matmul, float, double, complex64, complex128) {}
...@@ -69,7 +69,7 @@ PT_REGISTER_KERNEL(dot, ...@@ -69,7 +69,7 @@ PT_REGISTER_KERNEL(dot,
complex64, complex64,
complex128) {} complex128) {}
PT_REGISTER_KERNEL(matmul_v2, PT_REGISTER_KERNEL(matmul,
CUDA, CUDA,
ANY, ANY,
pten::Matmul, pten::Matmul,
......
...@@ -76,7 +76,7 @@ ...@@ -76,7 +76,7 @@
infer_meta : infer_meta :
func : MatmulInferMeta func : MatmulInferMeta
kernel : kernel :
func : matmul_v2 func : matmul
- api : mean - api : mean
args : (const Tensor& x, const std::vector<int64_t>& axis, bool keep_dim) args : (const Tensor& x, const std::vector<int64_t>& axis, bool keep_dim)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册