diff --git a/paddle/pten/core/kernel_alias_name.h b/paddle/pten/core/kernel_alias_name.h index 0310b6e6fafb52db36d1308675478498b873238a..3b8347dec772e93cfe533f4263c3937979025878 100644 --- a/paddle/pten/core/kernel_alias_name.h +++ b/paddle/pten/core/kernel_alias_name.h @@ -27,13 +27,13 @@ const std::unordered_map kernel_alias_name_map = { {"fill_any_like", "full_like"}, {"fill_constant", "full"}, {"flatten_contiguous_range", "flatten"}, - // {"matmul_v2", "matmul"}, + {"matmul_v2", "matmul"}, {"reduce_mean", "mean"}, {"reduce_sum", "sum"}, {"reshape2", "reshape"}, // fluid kernel "mean/reshape/matmul/flatten/sum" should be deprecated {"flatten", "deprecated"}, - // {"matmul", "deprecated"}, + {"matmul", "deprecated"}, {"mean", "deprecated"}, {"reshape", "deprecated"}, {"sum", "deprecated"}}; diff --git a/paddle/pten/core/kernel_factory.h b/paddle/pten/core/kernel_factory.h index dbdf90b5bdbf4dafe92c42a3b5ad5201ab8ab8a8..4adfb703503a8d1c92b47e5fc60f7dc2fc1d0d35 100644 --- a/paddle/pten/core/kernel_factory.h +++ b/paddle/pten/core/kernel_factory.h @@ -265,12 +265,8 @@ class KernelFactory { KernelMap& kernels() { return kernels_; } - void InsertCompatibleOpType(const std::string& op_type) { - compatible_op_types_.insert(op_type); - } - bool HasCompatiblePtenKernel(const std::string& op_type) const { - return compatible_op_types_.count(TransToPtenKernelName(op_type)) > 0; + return kernels_.find(TransToPtenKernelName(op_type)) != kernels_.end(); } const Kernel& SelectKernelOrThrowError(const KernelName& kernel_name, @@ -288,9 +284,6 @@ class KernelFactory { KernelFactory() = default; KernelMap kernels_; - // Used to be compatible with the original execution system and - // quickly confirm whether the new kernel can be called - std::unordered_set compatible_op_types_; }; /** operator << overload **/ diff --git a/paddle/pten/core/kernel_registry.h b/paddle/pten/core/kernel_registry.h index be624177dfb14cf3240ee935c5446e0597e0da9f..645e77fc60f8c2533eda3ee355b86a45f3510c23 100644 --- a/paddle/pten/core/kernel_registry.h +++ b/paddle/pten/core/kernel_registry.h @@ -143,7 +143,6 @@ struct KernelRegistrar { Kernel kernel(kernel_fn); args_parse_fn(kernel_key, kernel.mutable_args_def()); args_def_fn(&kernel); - KernelFactory::Instance().InsertCompatibleOpType(kernel_name.name()); KernelFactory::Instance().kernels()[kernel_name][kernel_key] = kernel; } }; diff --git a/paddle/pten/kernels/cpu/linalg.cc b/paddle/pten/kernels/cpu/linalg.cc index 7ffac0537b60c075d9667836a31c24ccf268199b..9f4f1be18259a53a5f945302112b961c757e7036 100644 --- a/paddle/pten/kernels/cpu/linalg.cc +++ b/paddle/pten/kernels/cpu/linalg.cc @@ -85,4 +85,4 @@ PT_REGISTER_KERNEL(dot, complex128) {} PT_REGISTER_KERNEL( - matmul_v2, CPU, ANY, pten::Matmul, float, double, complex64, complex128) {} + matmul, CPU, ANY, pten::Matmul, float, double, complex64, complex128) {} diff --git a/paddle/pten/kernels/cuda/linalg.cu b/paddle/pten/kernels/cuda/linalg.cu index b08ed8f71ee6b274e7abd54fb682220328342661..2114bbcc71c75e7edd9d3b71a606d0d394aebd59 100644 --- a/paddle/pten/kernels/cuda/linalg.cu +++ b/paddle/pten/kernels/cuda/linalg.cu @@ -69,7 +69,7 @@ PT_REGISTER_KERNEL(dot, complex64, complex128) {} -PT_REGISTER_KERNEL(matmul_v2, +PT_REGISTER_KERNEL(matmul, CUDA, ANY, pten::Matmul, diff --git a/python/paddle/utils/code_gen/api.yaml b/python/paddle/utils/code_gen/api.yaml index 3d61caae002e99d3b78cfaf1338338b02a1f800d..2c47bbe4566d6be70c524150bc2f2d6cdbbf81ce 100644 --- a/python/paddle/utils/code_gen/api.yaml +++ b/python/paddle/utils/code_gen/api.yaml @@ -76,7 +76,7 @@ infer_meta : func : MatmulInferMeta kernel : - func : matmul_v2 + func : matmul - api : mean args : (const Tensor& x, const std::vector& axis, bool keep_dim)