From c5a7da4bc62d8e4a4d0d4ae0a155f361a83a9f6c Mon Sep 17 00:00:00 2001 From: YuanRisheng Date: Fri, 10 Dec 2021 15:15:27 +0800 Subject: [PATCH] [PTen]Add alias name for matmul and remove redundant member in kernel factory (#38011) * add alias kernel name * modify code as suggestions * add alias name for matmul and remove redundant member in kernel factory --- paddle/pten/core/kernel_alias_name.h | 4 ++-- paddle/pten/core/kernel_factory.h | 9 +-------- paddle/pten/core/kernel_registry.h | 1 - paddle/pten/kernels/cpu/linalg.cc | 2 +- paddle/pten/kernels/cuda/linalg.cu | 2 +- python/paddle/utils/code_gen/api.yaml | 2 +- 6 files changed, 6 insertions(+), 14 deletions(-) diff --git a/paddle/pten/core/kernel_alias_name.h b/paddle/pten/core/kernel_alias_name.h index 0310b6e6faf..3b8347dec77 100644 --- a/paddle/pten/core/kernel_alias_name.h +++ b/paddle/pten/core/kernel_alias_name.h @@ -27,13 +27,13 @@ const std::unordered_map kernel_alias_name_map = { {"fill_any_like", "full_like"}, {"fill_constant", "full"}, {"flatten_contiguous_range", "flatten"}, - // {"matmul_v2", "matmul"}, + {"matmul_v2", "matmul"}, {"reduce_mean", "mean"}, {"reduce_sum", "sum"}, {"reshape2", "reshape"}, // fluid kernel "mean/reshape/matmul/flatten/sum" should be deprecated {"flatten", "deprecated"}, - // {"matmul", "deprecated"}, + {"matmul", "deprecated"}, {"mean", "deprecated"}, {"reshape", "deprecated"}, {"sum", "deprecated"}}; diff --git a/paddle/pten/core/kernel_factory.h b/paddle/pten/core/kernel_factory.h index dbdf90b5bdb..4adfb703503 100644 --- a/paddle/pten/core/kernel_factory.h +++ b/paddle/pten/core/kernel_factory.h @@ -265,12 +265,8 @@ class KernelFactory { KernelMap& kernels() { return kernels_; } - void InsertCompatibleOpType(const std::string& op_type) { - compatible_op_types_.insert(op_type); - } - bool HasCompatiblePtenKernel(const std::string& op_type) const { - return compatible_op_types_.count(TransToPtenKernelName(op_type)) > 0; + return kernels_.find(TransToPtenKernelName(op_type)) != kernels_.end(); } const Kernel& SelectKernelOrThrowError(const KernelName& kernel_name, @@ -288,9 +284,6 @@ class KernelFactory { KernelFactory() = default; KernelMap kernels_; - // Used to be compatible with the original execution system and - // quickly confirm whether the new kernel can be called - std::unordered_set compatible_op_types_; }; /** operator << overload **/ diff --git a/paddle/pten/core/kernel_registry.h b/paddle/pten/core/kernel_registry.h index be624177dfb..645e77fc60f 100644 --- a/paddle/pten/core/kernel_registry.h +++ b/paddle/pten/core/kernel_registry.h @@ -143,7 +143,6 @@ struct KernelRegistrar { Kernel kernel(kernel_fn); args_parse_fn(kernel_key, kernel.mutable_args_def()); args_def_fn(&kernel); - KernelFactory::Instance().InsertCompatibleOpType(kernel_name.name()); KernelFactory::Instance().kernels()[kernel_name][kernel_key] = kernel; } }; diff --git a/paddle/pten/kernels/cpu/linalg.cc b/paddle/pten/kernels/cpu/linalg.cc index 7ffac0537b6..9f4f1be1825 100644 --- a/paddle/pten/kernels/cpu/linalg.cc +++ b/paddle/pten/kernels/cpu/linalg.cc @@ -85,4 +85,4 @@ PT_REGISTER_KERNEL(dot, complex128) {} PT_REGISTER_KERNEL( - matmul_v2, CPU, ANY, pten::Matmul, float, double, complex64, complex128) {} + matmul, CPU, ANY, pten::Matmul, float, double, complex64, complex128) {} diff --git a/paddle/pten/kernels/cuda/linalg.cu b/paddle/pten/kernels/cuda/linalg.cu index b08ed8f71ee..2114bbcc71c 100644 --- a/paddle/pten/kernels/cuda/linalg.cu +++ b/paddle/pten/kernels/cuda/linalg.cu @@ -69,7 +69,7 @@ PT_REGISTER_KERNEL(dot, complex64, complex128) {} -PT_REGISTER_KERNEL(matmul_v2, +PT_REGISTER_KERNEL(matmul, CUDA, ANY, pten::Matmul, diff --git a/python/paddle/utils/code_gen/api.yaml b/python/paddle/utils/code_gen/api.yaml index 3d61caae002..2c47bbe4566 100644 --- a/python/paddle/utils/code_gen/api.yaml +++ b/python/paddle/utils/code_gen/api.yaml @@ -76,7 +76,7 @@ infer_meta : func : MatmulInferMeta kernel : - func : matmul_v2 + func : matmul - api : mean args : (const Tensor& x, const std::vector& axis, bool keep_dim) -- GitLab