From 30f5e39b6c6031da2116488b03d7a5e23f04a4f7 Mon Sep 17 00:00:00 2001 From: YuanRisheng Date: Thu, 12 Jan 2023 10:50:49 +0800 Subject: [PATCH] [PHI]Rename some PHI Kernel (#49470) * rename kernel * delete sig * modify code according comment * fix ci bugs --- paddle/fluid/operators/reshape_op.cc | 24 +++++----- paddle/phi/api/yaml/legacy_ops.yaml | 8 ++-- paddle/phi/api/yaml/ops.yaml | 4 +- paddle/phi/backends/xpu/xpu2_op_list.cc | 8 ++-- paddle/phi/kernels/cpu/einsum_kernel.cc | 8 ++-- paddle/phi/kernels/cpu/prod_kernel.cc | 22 ++++----- paddle/phi/kernels/einsum_kernel.h | 18 +++---- paddle/phi/kernels/flatten_kernel.cc | 48 +++++++++---------- paddle/phi/kernels/flatten_kernel.h | 20 ++++---- paddle/phi/kernels/gpu/einsum_kernel.cu | 8 ++-- paddle/phi/kernels/impl/einsum_grad_impl.h | 5 +- paddle/phi/kernels/impl/einsum_impl.h | 20 ++++---- paddle/phi/kernels/impl/solve_kernel_impl.h | 2 +- paddle/phi/kernels/kps/prod_kernel.cu | 24 ++++------ paddle/phi/kernels/onednn/reshape_kernel.cc | 28 +++++------ paddle/phi/kernels/onednn/squeeze_kernel.cc | 30 ++++++------ paddle/phi/kernels/prod_kernel.cc | 36 +++++++++----- paddle/phi/kernels/prod_kernel.h | 16 +++---- paddle/phi/kernels/reshape_kernel.cc | 52 ++++++++++----------- paddle/phi/kernels/reshape_kernel.h | 18 +++---- paddle/phi/kernels/squeeze_kernel.cc | 44 ++++++++--------- paddle/phi/kernels/squeeze_kernel.h | 16 +++---- paddle/phi/kernels/unsqueeze_kernel.cc | 44 ++++++++--------- paddle/phi/kernels/unsqueeze_kernel.h | 18 +++---- paddle/phi/kernels/xpu/prod_kernel.cc | 14 +++--- paddle/phi/ops/compat/einsum_sig.cc | 8 +--- paddle/phi/ops/compat/flatten_sig.cc | 8 ++-- paddle/phi/ops/compat/reduce_sig.cc | 4 +- paddle/phi/ops/compat/reshape_sig.cc | 14 +++--- paddle/phi/tests/ops/test_op_signature.cc | 6 +-- 30 files changed, 284 insertions(+), 291 deletions(-) mode change 100755 => 100644 paddle/phi/kernels/onednn/squeeze_kernel.cc diff --git a/paddle/fluid/operators/reshape_op.cc b/paddle/fluid/operators/reshape_op.cc index dc14b3ab0df..ae4db70fd52 100644 --- a/paddle/fluid/operators/reshape_op.cc +++ b/paddle/fluid/operators/reshape_op.cc @@ -424,27 +424,27 @@ class ReshapeKernel { } if (platform::is_cpu_place(ctx.GetPlace())) { auto &dev_ctx = ctx.device_context(); - phi::ReshapeKernel(static_cast(dev_ctx), - *in, - pt_scalar_shape, - out); + phi::ReshapeInferKernel(static_cast(dev_ctx), + *in, + pt_scalar_shape, + out); } #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) if (platform::is_gpu_place(ctx.GetPlace())) { auto &dev_ctx = ctx.device_context(); - phi::ReshapeKernel(static_cast(dev_ctx), - *in, - pt_scalar_shape, - out); + phi::ReshapeInferKernel(static_cast(dev_ctx), + *in, + pt_scalar_shape, + out); } #endif #ifdef PADDLE_WITH_XPU if (platform::is_xpu_place(ctx.GetPlace())) { auto &dev_ctx = ctx.device_context(); - phi::ReshapeKernel(static_cast(dev_ctx), - *in, - pt_scalar_shape, - out); + phi::ReshapeInferKernel(static_cast(dev_ctx), + *in, + pt_scalar_shape, + out); } #endif } diff --git a/paddle/phi/api/yaml/legacy_ops.yaml b/paddle/phi/api/yaml/legacy_ops.yaml index 3b320dace1a..07ab8ca455e 100755 --- a/paddle/phi/api/yaml/legacy_ops.yaml +++ b/paddle/phi/api/yaml/legacy_ops.yaml @@ -561,7 +561,7 @@ func : EinsumRawInferMeta param : [x, equation] kernel : - func : einsum_raw + func : einsum backward : einsum_grad - op : elementwise_pow @@ -677,7 +677,7 @@ infer_meta : func : FlattenWithXShapeInferMeta kernel : - func : flatten_with_xshape + func : flatten backend : x inplace : (x -> out) view : (x -> out) @@ -1391,7 +1391,7 @@ infer_meta : func : ReduceIntArrayAxisInferMetaBase kernel : - func : prod_raw + func : prod backward : prod_grad - op : psroi_pool @@ -1473,7 +1473,7 @@ infer_meta : func : ReshapeWithXShapeInferMeta kernel : - func : reshape_with_xshape + func : reshape inplace : (x -> out) view: (x -> out) intermediate : xshape diff --git a/paddle/phi/api/yaml/ops.yaml b/paddle/phi/api/yaml/ops.yaml index c1a0b6da6c2..a6786ea729b 100644 --- a/paddle/phi/api/yaml/ops.yaml +++ b/paddle/phi/api/yaml/ops.yaml @@ -1150,7 +1150,7 @@ infer_meta : func : SqueezeWithXShapeInferMeta kernel : - func : squeeze_with_xshape + func : squeeze data_type : x inplace : (x -> out) view: (x -> out) @@ -1258,7 +1258,7 @@ infer_meta : func : UnsqueezeWithXShapeInferMeta kernel : - func : unsqueeze_with_xshape + func : unsqueeze data_type : x inplace : (x -> out) view: (x -> out) diff --git a/paddle/phi/backends/xpu/xpu2_op_list.cc b/paddle/phi/backends/xpu/xpu2_op_list.cc index 8b0d0851ef3..9c42a8b5508 100644 --- a/paddle/phi/backends/xpu/xpu2_op_list.cc +++ b/paddle/phi/backends/xpu/xpu2_op_list.cc @@ -270,7 +270,7 @@ XPUOpMap& get_kl2_ops() { phi::DataType::INT8, phi::DataType::FLOAT16, phi::DataType::FLOAT32})}, - {"flatten_with_xshape", + {"flatten", XPUKernelSet({phi::DataType::INT64, phi::DataType::INT32, phi::DataType::INT8, @@ -450,7 +450,7 @@ XPUOpMap& get_kl2_ops() { phi::DataType::INT32, phi::DataType::BOOL, phi::DataType::FLOAT32})}, - {"reshape_with_xshape", + {"reshape", XPUKernelSet({phi::DataType::FLOAT64, phi::DataType::FLOAT16, phi::DataType::INT64, @@ -541,7 +541,7 @@ XPUOpMap& get_kl2_ops() { phi::DataType::INT8, phi::DataType::UINT8, phi::DataType::FLOAT32})}, - {"squeeze_with_xshape", + {"squeeze", XPUKernelSet({phi::DataType::FLOAT64, phi::DataType::INT64, phi::DataType::INT32, @@ -655,7 +655,7 @@ XPUOpMap& get_kl2_ops() { phi::DataType::UINT8, phi::DataType::FLOAT16, phi::DataType::FLOAT32})}, - {"unsqueeze_with_xshape", + {"unsqueeze", XPUKernelSet({phi::DataType::FLOAT64, phi::DataType::INT64, phi::DataType::INT32, diff --git a/paddle/phi/kernels/cpu/einsum_kernel.cc b/paddle/phi/kernels/cpu/einsum_kernel.cc index 7ef85a942e4..b02e6d497f5 100644 --- a/paddle/phi/kernels/cpu/einsum_kernel.cc +++ b/paddle/phi/kernels/cpu/einsum_kernel.cc @@ -18,19 +18,19 @@ #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/impl/einsum_impl.h" -PD_REGISTER_KERNEL(einsum_raw, +PD_REGISTER_KERNEL(einsum, CPU, ALL_LAYOUT, - phi::EinsumKernelRaw, + phi::EinsumKernel, float, double, phi::dtype::complex, phi::dtype::complex) {} -PD_REGISTER_KERNEL(einsum, +PD_REGISTER_KERNEL(einsum_infer, CPU, ALL_LAYOUT, - phi::EinsumKernel, + phi::EinsumInferKernel, float, double, phi::dtype::complex, diff --git a/paddle/phi/kernels/cpu/prod_kernel.cc b/paddle/phi/kernels/cpu/prod_kernel.cc index d5a07c0057d..09082e3a319 100644 --- a/paddle/phi/kernels/cpu/prod_kernel.cc +++ b/paddle/phi/kernels/cpu/prod_kernel.cc @@ -22,12 +22,12 @@ namespace phi { template -void ProdRawKernel(const Context& dev_ctx, - const DenseTensor& x, - const IntArray& dims, - bool keep_dim, - bool reduce_all, - DenseTensor* out) { +void ProdKernel(const Context& dev_ctx, + const DenseTensor& x, + const IntArray& dims, + bool keep_dim, + bool reduce_all, + DenseTensor* out) { reduce_all = recompute_reduce_all(x, dims, reduce_all); auto out_dtype = x.dtype(); phi::Reduce( @@ -36,11 +36,5 @@ void ProdRawKernel(const Context& dev_ctx, } // namespace phi -PD_REGISTER_KERNEL(prod_raw, - CPU, - ALL_LAYOUT, - phi::ProdRawKernel, - float, - double, - int, - int64_t) {} +PD_REGISTER_KERNEL( + prod, CPU, ALL_LAYOUT, phi::ProdKernel, float, double, int, int64_t) {} diff --git a/paddle/phi/kernels/einsum_kernel.h b/paddle/phi/kernels/einsum_kernel.h index 569cf7a55af..7a3dd455b46 100644 --- a/paddle/phi/kernels/einsum_kernel.h +++ b/paddle/phi/kernels/einsum_kernel.h @@ -18,18 +18,18 @@ namespace phi { +template +void EinsumInferKernel(const Context& dev_ctx, + const std::vector& inputs, + const std::string& equation, + DenseTensor* out); + template void EinsumKernel(const Context& dev_ctx, const std::vector& inputs, const std::string& equation, - DenseTensor* out); - -template -void EinsumKernelRaw(const Context& dev_ctx, - const std::vector& inputs, - const std::string& equation, - DenseTensor* out, - std::vector inner_cache, - std::vector xshape); + DenseTensor* out, + std::vector inner_cache, + std::vector xshape); } // namespace phi diff --git a/paddle/phi/kernels/flatten_kernel.cc b/paddle/phi/kernels/flatten_kernel.cc index 58ba3d70a34..1706778237f 100644 --- a/paddle/phi/kernels/flatten_kernel.cc +++ b/paddle/phi/kernels/flatten_kernel.cc @@ -23,11 +23,11 @@ namespace phi { template -void FlattenKernel(const Context& dev_ctx, - const DenseTensor& x, - int start_axis, - int stop_axis, - DenseTensor* out) { +void FlattenInferKernel(const Context& dev_ctx, + const DenseTensor& x, + int start_axis, + int stop_axis, + DenseTensor* out) { dev_ctx.Alloc(out, x.dtype()); auto out_dims = out->dims(); phi::Copy(dev_ctx, x, dev_ctx.GetPlace(), false, out); @@ -38,21 +38,21 @@ void FlattenKernel(const Context& dev_ctx, // Output Tensor, // is there a more flexible way to deal with this case? template -void FlattenWithXShape(const Context& dev_ctx, - const DenseTensor& x, - int start_axis, - int stop_axis, - DenseTensor* out, - DenseTensor* xshape) { - FlattenKernel(dev_ctx, x, start_axis, stop_axis, out); +void FlattenKernel(const Context& dev_ctx, + const DenseTensor& x, + int start_axis, + int stop_axis, + DenseTensor* out, + DenseTensor* xshape) { + FlattenInferKernel(dev_ctx, x, start_axis, stop_axis, out); } } // namespace phi -PD_REGISTER_KERNEL(flatten, +PD_REGISTER_KERNEL(flatten_infer, CPU, ALL_LAYOUT, - phi::FlattenKernel, + phi::FlattenInferKernel, float, phi::dtype::bfloat16, double, @@ -62,10 +62,10 @@ PD_REGISTER_KERNEL(flatten, int, int64_t) {} -PD_REGISTER_KERNEL(flatten_with_xshape, +PD_REGISTER_KERNEL(flatten, CPU, ALL_LAYOUT, - phi::FlattenWithXShape, + phi::FlattenKernel, float, phi::dtype::bfloat16, double, @@ -76,10 +76,10 @@ PD_REGISTER_KERNEL(flatten_with_xshape, int64_t) {} #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) -PD_REGISTER_KERNEL(flatten, +PD_REGISTER_KERNEL(flatten_infer, GPU, ALL_LAYOUT, - phi::FlattenKernel, + phi::FlattenInferKernel, float, phi::dtype::float16, phi::dtype::bfloat16, @@ -90,10 +90,10 @@ PD_REGISTER_KERNEL(flatten, int, int64_t) {} -PD_REGISTER_KERNEL(flatten_with_xshape, +PD_REGISTER_KERNEL(flatten, GPU, ALL_LAYOUT, - phi::FlattenWithXShape, + phi::FlattenKernel, float, phi::dtype::float16, phi::dtype::bfloat16, @@ -106,10 +106,10 @@ PD_REGISTER_KERNEL(flatten_with_xshape, #endif #ifdef PADDLE_WITH_XPU -PD_REGISTER_KERNEL(flatten, +PD_REGISTER_KERNEL(flatten_infer, XPU, ALL_LAYOUT, - phi::FlattenKernel, + phi::FlattenInferKernel, float, phi::dtype::float16, int8_t, @@ -117,10 +117,10 @@ PD_REGISTER_KERNEL(flatten, int, int64_t) {} -PD_REGISTER_KERNEL(flatten_with_xshape, +PD_REGISTER_KERNEL(flatten, XPU, ALL_LAYOUT, - phi::FlattenWithXShape, + phi::FlattenKernel, float, phi::dtype::float16, int8_t, diff --git a/paddle/phi/kernels/flatten_kernel.h b/paddle/phi/kernels/flatten_kernel.h index 808af7d9b7b..f6b40770f20 100644 --- a/paddle/phi/kernels/flatten_kernel.h +++ b/paddle/phi/kernels/flatten_kernel.h @@ -20,20 +20,20 @@ limitations under the License. */ namespace phi { +template +void FlattenInferKernel(const Context& dev_ctx, + const DenseTensor& x, + int start_axis, + int stop_axis, + DenseTensor* out); + template void FlattenKernel(const Context& dev_ctx, const DenseTensor& x, int start_axis, int stop_axis, - DenseTensor* out); - -template -void FlattenWithXShape(const Context& dev_ctx, - const DenseTensor& x, - int start_axis, - int stop_axis, - DenseTensor* out, - DenseTensor* xshape); + DenseTensor* out, + DenseTensor* xshape); template DenseTensor Flatten(const Context& dev_ctx, @@ -43,7 +43,7 @@ DenseTensor Flatten(const Context& dev_ctx, DenseTensor dense_out; MetaTensor meta_out(&dense_out); FlattenInferMeta(x, start_axis, stop_axis, &meta_out); - FlattenKernel(dev_ctx, x, start_axis, stop_axis, &dense_out); + FlattenInferKernel(dev_ctx, x, start_axis, stop_axis, &dense_out); return dense_out; } diff --git a/paddle/phi/kernels/gpu/einsum_kernel.cu b/paddle/phi/kernels/gpu/einsum_kernel.cu index 99a9c58995c..b05d26b7fc6 100644 --- a/paddle/phi/kernels/gpu/einsum_kernel.cu +++ b/paddle/phi/kernels/gpu/einsum_kernel.cu @@ -18,10 +18,10 @@ #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/impl/einsum_impl.h" -PD_REGISTER_KERNEL(einsum_raw, +PD_REGISTER_KERNEL(einsum, GPU, ALL_LAYOUT, - phi::EinsumKernelRaw, + phi::EinsumKernel, float, double, phi::dtype::float16, @@ -29,10 +29,10 @@ PD_REGISTER_KERNEL(einsum_raw, phi::dtype::complex, phi::dtype::complex) {} -PD_REGISTER_KERNEL(einsum, +PD_REGISTER_KERNEL(einsum_infer, GPU, ALL_LAYOUT, - phi::EinsumKernel, + phi::EinsumInferKernel, float, double, phi::dtype::float16, diff --git a/paddle/phi/kernels/impl/einsum_grad_impl.h b/paddle/phi/kernels/impl/einsum_grad_impl.h index 816badcd79e..3481a442f26 100644 --- a/paddle/phi/kernels/impl/einsum_grad_impl.h +++ b/paddle/phi/kernels/impl/einsum_grad_impl.h @@ -103,7 +103,7 @@ DenseTensor PerformTileAndReduction(const Context& dev_ctx, // undiagonalize by einsum equation. only contain undiagonal operations. DenseTensor out; VLOG(5) << "Undiagonal by einsum with args: " << op_label + "->" + equ; - EinsumKernel(dev_ctx, {&ret}, op_label + "->" + equ, &out); + EinsumInferKernel(dev_ctx, {&ret}, op_label + "->" + equ, &out); return out; } @@ -157,7 +157,8 @@ void EinsumGradKernel(const Context& dev_ctx, new_operands.push_back(&out_grad); DenseTensor before_tile; VLOG(5) << "new_equation is " << new_equation; - EinsumKernel(dev_ctx, new_operands, new_equation, &before_tile); + EinsumInferKernel( + dev_ctx, new_operands, new_equation, &before_tile); *(x_grad[0]) = PerformTileAndReduction(dev_ctx, labeltype, labelshape, diff --git a/paddle/phi/kernels/impl/einsum_impl.h b/paddle/phi/kernels/impl/einsum_impl.h index 392949e065b..238c89b1701 100644 --- a/paddle/phi/kernels/impl/einsum_impl.h +++ b/paddle/phi/kernels/impl/einsum_impl.h @@ -746,12 +746,12 @@ void EinsumKernelImpl(const Context& dev_ctx, } template -void EinsumKernelRaw(const Context& dev_ctx, - const std::vector& inputs, - const std::string& equation, - DenseTensor* out, - std::vector cache, - std::vector xshape) { +void EinsumKernel(const Context& dev_ctx, + const std::vector& inputs, + const std::string& equation, + DenseTensor* out, + std::vector cache, + std::vector xshape) { std::vector tmp; // for the sake of compatibility, we may load and run v2.3 EinsumOp. Output // may have nullptr and the cache.size() is not equal to inputs.size(). refer @@ -765,10 +765,10 @@ void EinsumKernelRaw(const Context& dev_ctx, } template -void EinsumKernel(const Context& dev_ctx, - const std::vector& inputs, - const std::string& equation, - DenseTensor* out) { +void EinsumInferKernel(const Context& dev_ctx, + const std::vector& inputs, + const std::string& equation, + DenseTensor* out) { std::vector place_holder; std::vector cache_tensor( inputs.size()); // set empty; TA, TB, TdC diff --git a/paddle/phi/kernels/impl/solve_kernel_impl.h b/paddle/phi/kernels/impl/solve_kernel_impl.h index 4120823a9d2..b0e6b2b6cc0 100644 --- a/paddle/phi/kernels/impl/solve_kernel_impl.h +++ b/paddle/phi/kernels/impl/solve_kernel_impl.h @@ -169,7 +169,7 @@ static void linalg_solve(const Context& dev_ctx, out_tmp.Resize(out->dims()); out_tmp = *out; - phi::SqueezeKernel(dev_ctx, out_tmp, {-1}, out); + phi::SqueezeInferKernel(dev_ctx, out_tmp, {-1}, out); } else { PADDLE_ENFORCE_EQ( x_dim[x_dim_size - 1], diff --git a/paddle/phi/kernels/kps/prod_kernel.cu b/paddle/phi/kernels/kps/prod_kernel.cu index 79dc76f81c0..a584af357cf 100644 --- a/paddle/phi/kernels/kps/prod_kernel.cu +++ b/paddle/phi/kernels/kps/prod_kernel.cu @@ -19,12 +19,12 @@ namespace phi { template -void ProdRawKernel(const Context& dev_ctx, - const DenseTensor& x, - const IntArray& dims, - bool keep_dim, - bool reduce_all, - DenseTensor* out) { +void ProdKernel(const Context& dev_ctx, + const DenseTensor& x, + const IntArray& dims, + bool keep_dim, + bool reduce_all, + DenseTensor* out) { reduce_all = recompute_reduce_all(x, dims, reduce_all); auto out_dtype = x.dtype(); phi::Reduce( @@ -33,14 +33,8 @@ void ProdRawKernel(const Context& dev_ctx, } // namespace phi #ifdef PADDLE_WITH_XPU_KP -PD_REGISTER_KERNEL(prod_raw, KPS, ALL_LAYOUT, phi::ProdRawKernel, float) {} +PD_REGISTER_KERNEL(prod, KPS, ALL_LAYOUT, phi::ProdKernel, float) {} #else -PD_REGISTER_KERNEL(prod_raw, - KPS, - ALL_LAYOUT, - phi::ProdRawKernel, - float, - double, - int, - int64_t) {} +PD_REGISTER_KERNEL( + prod, KPS, ALL_LAYOUT, phi::ProdKernel, float, double, int, int64_t) {} #endif diff --git a/paddle/phi/kernels/onednn/reshape_kernel.cc b/paddle/phi/kernels/onednn/reshape_kernel.cc index 4d8adc4b9a6..95d27b94a84 100644 --- a/paddle/phi/kernels/onednn/reshape_kernel.cc +++ b/paddle/phi/kernels/onednn/reshape_kernel.cc @@ -148,32 +148,32 @@ void ExecuteReshape(const Context& dev_ctx, } template -void ReshapeKernel(const Context& dev_ctx, - const DenseTensor& x, - const IntArray& shape, - DenseTensor* out) { +void ReshapeInferKernel(const Context& dev_ctx, + const DenseTensor& x, + const IntArray& shape, + DenseTensor* out) { auto x_dims = x.dims(); ExecuteReshape(dev_ctx, x, shape, x_dims, out); } template -void ReshapeWithXShape(const Context& dev_ctx, - const DenseTensor& x, - const IntArray& shape, - DenseTensor* out, - DenseTensor* xshape) { +void ReshapeKernel(const Context& dev_ctx, + const DenseTensor& x, + const IntArray& shape, + DenseTensor* out, + DenseTensor* xshape) { auto x_dims = slice_ddim(xshape->dims(), 1, xshape->dims().size()); ExecuteReshape(dev_ctx, x, shape, x_dims, out); } } // namespace phi -PD_REGISTER_KERNEL( - reshape, OneDNN, ONEDNN, phi::ReshapeKernel, float, phi::dtype::bfloat16) {} - -PD_REGISTER_KERNEL(reshape_with_xshape, +PD_REGISTER_KERNEL(reshape_infer, OneDNN, ONEDNN, - phi::ReshapeWithXShape, + phi::ReshapeInferKernel, float, phi::dtype::bfloat16) {} + +PD_REGISTER_KERNEL( + reshape, OneDNN, ONEDNN, phi::ReshapeKernel, float, phi::dtype::bfloat16) {} diff --git a/paddle/phi/kernels/onednn/squeeze_kernel.cc b/paddle/phi/kernels/onednn/squeeze_kernel.cc old mode 100755 new mode 100644 index 9f2b9a8a442..0ad82bfedda --- a/paddle/phi/kernels/onednn/squeeze_kernel.cc +++ b/paddle/phi/kernels/onednn/squeeze_kernel.cc @@ -52,10 +52,10 @@ void ExecuteSqueeze(const Context& dev_ctx, } template -void SqueezeKernel(const Context& dev_ctx, - const DenseTensor& x, - const IntArray& axes, - DenseTensor* out) { +void SqueezeInferKernel(const Context& dev_ctx, + const DenseTensor& x, + const IntArray& axes, + DenseTensor* out) { auto x_dims = x.dims(); std::vector tmp(axes.GetData().begin(), axes.GetData().end()); auto out_dims = funcs::GetOutputSqueezeShape(tmp, x_dims, true); @@ -63,13 +63,13 @@ void SqueezeKernel(const Context& dev_ctx, } template -void SqueezeWithXShapeKernel(const Context& dev_ctx, - const DenseTensor& x, - const IntArray& axes, - DenseTensor* out, - DenseTensor* xshape) { +void SqueezeKernel(const Context& dev_ctx, + const DenseTensor& x, + const IntArray& axes, + DenseTensor* out, + DenseTensor* xshape) { if (xshape == nullptr) { - SqueezeKernel(dev_ctx, x, axes, out); + SqueezeInferKernel(dev_ctx, x, axes, out); } else { auto x_dims = slice_ddim(xshape->dims(), 1, xshape->dims().size()); auto out_dims = out->dims(); @@ -78,12 +78,12 @@ void SqueezeWithXShapeKernel(const Context& dev_ctx, } } // namespace phi -PD_REGISTER_KERNEL( - squeeze, OneDNN, ONEDNN, phi::SqueezeKernel, float, phi::dtype::bfloat16) {} - -PD_REGISTER_KERNEL(squeeze_with_xshape, +PD_REGISTER_KERNEL(squeeze_infer, OneDNN, ONEDNN, - phi::SqueezeWithXShapeKernel, + phi::SqueezeInferKernel, float, phi::dtype::bfloat16) {} + +PD_REGISTER_KERNEL( + squeeze, OneDNN, ONEDNN, phi::SqueezeKernel, float, phi::dtype::bfloat16) {} diff --git a/paddle/phi/kernels/prod_kernel.cc b/paddle/phi/kernels/prod_kernel.cc index 61ed575a198..12b55e12030 100644 --- a/paddle/phi/kernels/prod_kernel.cc +++ b/paddle/phi/kernels/prod_kernel.cc @@ -20,29 +20,41 @@ namespace phi { template -void ProdKernel(const Context& dev_ctx, - const DenseTensor& x, - const IntArray& dims, - bool keep_dim, - DenseTensor* out) { +void ProdInferKernel(const Context& dev_ctx, + const DenseTensor& x, + const IntArray& dims, + bool keep_dim, + DenseTensor* out) { bool reduce_all = recompute_reduce_all(x, dims); - ProdRawKernel(dev_ctx, x, dims, keep_dim, reduce_all, out); + ProdKernel(dev_ctx, x, dims, keep_dim, reduce_all, out); } } // namespace phi -PD_REGISTER_KERNEL( - prod, CPU, ALL_LAYOUT, phi::ProdKernel, float, double, int, int64_t) {} +PD_REGISTER_KERNEL(prod_infer, + CPU, + ALL_LAYOUT, + phi::ProdInferKernel, + float, + double, + int, + int64_t) {} #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) -PD_REGISTER_KERNEL( - prod, GPU, ALL_LAYOUT, phi::ProdKernel, float, double, int, int64_t) {} +PD_REGISTER_KERNEL(prod_infer, + GPU, + ALL_LAYOUT, + phi::ProdInferKernel, + float, + double, + int, + int64_t) {} #endif #if defined(PADDLE_WITH_XPU_KP) && !defined(PADDLE_WITH_XPU) -PD_REGISTER_KERNEL(prod, KPS, ALL_LAYOUT, phi::ProdKernel, float) {} +PD_REGISTER_KERNEL(prod_infer, KPS, ALL_LAYOUT, phi::ProdInferKernel, float) {} #endif #if defined(PADDLE_WITH_XPU) -PD_REGISTER_KERNEL(prod, XPU, ALL_LAYOUT, phi::ProdKernel, float) {} +PD_REGISTER_KERNEL(prod_infer, XPU, ALL_LAYOUT, phi::ProdInferKernel, float) {} #endif diff --git a/paddle/phi/kernels/prod_kernel.h b/paddle/phi/kernels/prod_kernel.h index 91de087ccbc..834ef765694 100644 --- a/paddle/phi/kernels/prod_kernel.h +++ b/paddle/phi/kernels/prod_kernel.h @@ -18,19 +18,19 @@ #include "paddle/phi/core/dense_tensor.h" namespace phi { -template -void ProdRawKernel(const Context& dev_ctx, - const DenseTensor& x, - const IntArray& dims, - bool keep_dim, - bool reduce_all, - DenseTensor* out); - template void ProdKernel(const Context& dev_ctx, const DenseTensor& x, const IntArray& dims, bool keep_dim, + bool reduce_all, DenseTensor* out); +template +void ProdInferKernel(const Context& dev_ctx, + const DenseTensor& x, + const IntArray& dims, + bool keep_dim, + DenseTensor* out); + } // namespace phi diff --git a/paddle/phi/kernels/reshape_kernel.cc b/paddle/phi/kernels/reshape_kernel.cc index a792322a440..d13007dc555 100644 --- a/paddle/phi/kernels/reshape_kernel.cc +++ b/paddle/phi/kernels/reshape_kernel.cc @@ -26,10 +26,10 @@ namespace phi { template -void ReshapeKernel(const Context& dev_ctx, - const DenseTensor& x, - const IntArray& shape, - DenseTensor* out) { +void ReshapeInferKernel(const Context& dev_ctx, + const DenseTensor& x, + const IntArray& shape, + DenseTensor* out) { MetaTensor meta_out(out); InferMetaFromVecValue(x, shape.GetData(), &meta_out); if (x.initialized() && x.Holder() == out->Holder()) { @@ -47,10 +47,10 @@ void ReshapeKernel(const Context& dev_ctx, #ifdef PADDLE_WITH_XPU template <> -void ReshapeKernel(const XPUContext& dev_ctx, - const DenseTensor& x, - const IntArray& shape, - DenseTensor* out) { +void ReshapeInferKernel(const XPUContext& dev_ctx, + const DenseTensor& x, + const IntArray& shape, + DenseTensor* out) { MetaTensor meta_out(out); InferMetaFromVecValue(x, shape.GetData(), &meta_out); if (x.initialized() && x.Holder() == out->Holder()) { @@ -73,40 +73,40 @@ void ReshapeKernel(const XPUContext& dev_ctx, #endif template -void ReshapeWithXShape(const Context& dev_ctx, - const DenseTensor& x, - const IntArray& shape, - DenseTensor* out, - DenseTensor* xshape) { - ReshapeKernel(dev_ctx, x, shape, out); +void ReshapeKernel(const Context& dev_ctx, + const DenseTensor& x, + const IntArray& shape, + DenseTensor* out, + DenseTensor* xshape) { + ReshapeInferKernel(dev_ctx, x, shape, out); } } // namespace phi -PD_REGISTER_GENERAL_KERNEL( - reshape, CPU, ALL_LAYOUT, phi::ReshapeKernel, ALL_DTYPE) {} -PD_REGISTER_GENERAL_KERNEL(reshape_with_xshape, +PD_REGISTER_GENERAL_KERNEL(reshape_infer, CPU, ALL_LAYOUT, - phi::ReshapeWithXShape, + phi::ReshapeInferKernel, ALL_DTYPE) {} +PD_REGISTER_GENERAL_KERNEL( + reshape, CPU, ALL_LAYOUT, phi::ReshapeKernel, ALL_DTYPE) {} #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) -PD_REGISTER_GENERAL_KERNEL( - reshape, GPU, ALL_LAYOUT, phi::ReshapeKernel, ALL_DTYPE) {} -PD_REGISTER_GENERAL_KERNEL(reshape_with_xshape, +PD_REGISTER_GENERAL_KERNEL(reshape_infer, GPU, ALL_LAYOUT, - phi::ReshapeWithXShape, + phi::ReshapeInferKernel, ALL_DTYPE) {} +PD_REGISTER_GENERAL_KERNEL( + reshape, GPU, ALL_LAYOUT, phi::ReshapeKernel, ALL_DTYPE) {} #endif #ifdef PADDLE_WITH_XPU -PD_REGISTER_GENERAL_KERNEL( - reshape, XPU, ALL_LAYOUT, phi::ReshapeKernel, ALL_DTYPE) {} -PD_REGISTER_GENERAL_KERNEL(reshape_with_xshape, +PD_REGISTER_GENERAL_KERNEL(reshape_infer, XPU, ALL_LAYOUT, - phi::ReshapeWithXShape, + phi::ReshapeInferKernel, ALL_DTYPE) {} +PD_REGISTER_GENERAL_KERNEL( + reshape, XPU, ALL_LAYOUT, phi::ReshapeKernel, ALL_DTYPE) {} #endif diff --git a/paddle/phi/kernels/reshape_kernel.h b/paddle/phi/kernels/reshape_kernel.h index 88b1bd95871..20037c1d67b 100644 --- a/paddle/phi/kernels/reshape_kernel.h +++ b/paddle/phi/kernels/reshape_kernel.h @@ -21,18 +21,18 @@ limitations under the License. */ namespace phi { +template +void ReshapeInferKernel(const Context& dev_ctx, + const DenseTensor& x, + const IntArray& shape, + DenseTensor* out); + template void ReshapeKernel(const Context& dev_ctx, const DenseTensor& x, const IntArray& shape, - DenseTensor* out); - -template -void ReshapeWithXShape(const Context& dev_ctx, - const DenseTensor& x, - const IntArray& shape, - DenseTensor* out, - DenseTensor* xshape); + DenseTensor* out, + DenseTensor* xshape); template DenseTensor Reshape(const Context& dev_ctx, @@ -41,7 +41,7 @@ DenseTensor Reshape(const Context& dev_ctx, DenseTensor dense_out; MetaTensor meta_out(&dense_out); InferMetaFromVecValue(x, shape, &meta_out); - ReshapeKernel(dev_ctx, x, IntArray(shape), &dense_out); + ReshapeInferKernel(dev_ctx, x, IntArray(shape), &dense_out); return dense_out; } diff --git a/paddle/phi/kernels/squeeze_kernel.cc b/paddle/phi/kernels/squeeze_kernel.cc index 46cbbb174b8..a95a8cc9a2f 100644 --- a/paddle/phi/kernels/squeeze_kernel.cc +++ b/paddle/phi/kernels/squeeze_kernel.cc @@ -21,10 +21,10 @@ namespace phi { template -void SqueezeKernel(const Context& dev_ctx, - const DenseTensor& x, - const IntArray& axes, - DenseTensor* out) { +void SqueezeInferKernel(const Context& dev_ctx, + const DenseTensor& x, + const IntArray& axes, + DenseTensor* out) { auto x_dims = x.dims(); std::vector tmp(axes.GetData().begin(), axes.GetData().end()); auto out_dims = funcs::GetOutputSqueezeShape(tmp, x_dims, true); @@ -36,20 +36,20 @@ void SqueezeKernel(const Context& dev_ctx, } template -void SqueezeWithXShapeKernel(const Context& dev_ctx, - const DenseTensor& x, - const IntArray& axes, - DenseTensor* out, - DenseTensor* xshape) { - SqueezeKernel(dev_ctx, x, axes, out); +void SqueezeKernel(const Context& dev_ctx, + const DenseTensor& x, + const IntArray& axes, + DenseTensor* out, + DenseTensor* xshape) { + SqueezeInferKernel(dev_ctx, x, axes, out); } } // namespace phi -PD_REGISTER_KERNEL(squeeze, +PD_REGISTER_KERNEL(squeeze_infer, CPU, ALL_LAYOUT, - phi::SqueezeKernel, + phi::SqueezeInferKernel, float, double, phi::dtype::bfloat16, @@ -61,10 +61,10 @@ PD_REGISTER_KERNEL(squeeze, phi::dtype::complex, phi::dtype::complex) {} -PD_REGISTER_KERNEL(squeeze_with_xshape, +PD_REGISTER_KERNEL(squeeze, CPU, ALL_LAYOUT, - phi::SqueezeWithXShapeKernel, + phi::SqueezeKernel, float, double, phi::dtype::bfloat16, @@ -76,10 +76,10 @@ PD_REGISTER_KERNEL(squeeze_with_xshape, phi::dtype::complex, phi::dtype::complex) {} #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) -PD_REGISTER_KERNEL(squeeze, +PD_REGISTER_KERNEL(squeeze_infer, GPU, ALL_LAYOUT, - phi::SqueezeKernel, + phi::SqueezeInferKernel, float, double, phi::dtype::float16, @@ -92,10 +92,10 @@ PD_REGISTER_KERNEL(squeeze, phi::dtype::complex, phi::dtype::complex) {} -PD_REGISTER_KERNEL(squeeze_with_xshape, +PD_REGISTER_KERNEL(squeeze, GPU, ALL_LAYOUT, - phi::SqueezeWithXShapeKernel, + phi::SqueezeKernel, float, double, phi::dtype::float16, @@ -110,10 +110,10 @@ PD_REGISTER_KERNEL(squeeze_with_xshape, #endif #ifdef PADDLE_WITH_XPU -PD_REGISTER_KERNEL(squeeze, +PD_REGISTER_KERNEL(squeeze_infer, XPU, ALL_LAYOUT, - phi::SqueezeKernel, + phi::SqueezeInferKernel, float, double, phi::dtype::float16, @@ -123,10 +123,10 @@ PD_REGISTER_KERNEL(squeeze, int8_t, int64_t) {} -PD_REGISTER_KERNEL(squeeze_with_xshape, +PD_REGISTER_KERNEL(squeeze, XPU, ALL_LAYOUT, - phi::SqueezeWithXShapeKernel, + phi::SqueezeKernel, float, double, phi::dtype::float16, diff --git a/paddle/phi/kernels/squeeze_kernel.h b/paddle/phi/kernels/squeeze_kernel.h index 7e5a1b0775a..8114969ea7d 100644 --- a/paddle/phi/kernels/squeeze_kernel.h +++ b/paddle/phi/kernels/squeeze_kernel.h @@ -20,17 +20,17 @@ namespace phi { +template +void SqueezeInferKernel(const Context& dev_ctx, + const DenseTensor& x, + const IntArray& axes, + DenseTensor* out); + template void SqueezeKernel(const Context& dev_ctx, const DenseTensor& x, const IntArray& axes, - DenseTensor* out); - -template -void SqueezeWithXShapeKernel(const Context& dev_ctx, - const DenseTensor& x, - const IntArray& axes, - DenseTensor* out, - DenseTensor* xshape); + DenseTensor* out, + DenseTensor* xshape); } // namespace phi diff --git a/paddle/phi/kernels/unsqueeze_kernel.cc b/paddle/phi/kernels/unsqueeze_kernel.cc index 1887c5abf7a..159e7a4ce17 100644 --- a/paddle/phi/kernels/unsqueeze_kernel.cc +++ b/paddle/phi/kernels/unsqueeze_kernel.cc @@ -21,10 +21,10 @@ namespace phi { template -void UnsqueezeKernel(const Context& dev_ctx, - const DenseTensor& x, - const IntArray& axes, - DenseTensor* out) { +void UnsqueezeInferKernel(const Context& dev_ctx, + const DenseTensor& x, + const IntArray& axes, + DenseTensor* out) { auto x_dims = x.dims(); auto out_dims = out->dims(); if (axes.FromTensor()) { @@ -42,19 +42,19 @@ void UnsqueezeKernel(const Context& dev_ctx, } template -void UnsqueezeWithXShapeKernel(const Context& dev_ctx, - const DenseTensor& x, - const IntArray& axes, - DenseTensor* out, - DenseTensor* xshape) { - UnsqueezeKernel(dev_ctx, x, axes, out); +void UnsqueezeKernel(const Context& dev_ctx, + const DenseTensor& x, + const IntArray& axes, + DenseTensor* out, + DenseTensor* xshape) { + UnsqueezeInferKernel(dev_ctx, x, axes, out); } } // namespace phi -PD_REGISTER_KERNEL(unsqueeze, +PD_REGISTER_KERNEL(unsqueeze_infer, CPU, ALL_LAYOUT, - phi::UnsqueezeKernel, + phi::UnsqueezeInferKernel, float, double, phi::dtype::bfloat16, @@ -67,10 +67,10 @@ PD_REGISTER_KERNEL(unsqueeze, phi::dtype::complex, phi::dtype::complex) {} -PD_REGISTER_KERNEL(unsqueeze_with_xshape, +PD_REGISTER_KERNEL(unsqueeze, CPU, ALL_LAYOUT, - phi::UnsqueezeWithXShapeKernel, + phi::UnsqueezeKernel, float, double, phi::dtype::bfloat16, @@ -83,10 +83,10 @@ PD_REGISTER_KERNEL(unsqueeze_with_xshape, phi::dtype::complex, phi::dtype::complex) {} #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) -PD_REGISTER_KERNEL(unsqueeze, +PD_REGISTER_KERNEL(unsqueeze_infer, GPU, ALL_LAYOUT, - phi::UnsqueezeKernel, + phi::UnsqueezeInferKernel, float, double, phi::dtype::float16, @@ -100,10 +100,10 @@ PD_REGISTER_KERNEL(unsqueeze, phi::dtype::complex, phi::dtype::complex) {} -PD_REGISTER_KERNEL(unsqueeze_with_xshape, +PD_REGISTER_KERNEL(unsqueeze, GPU, ALL_LAYOUT, - phi::UnsqueezeWithXShapeKernel, + phi::UnsqueezeKernel, float, double, phi::dtype::float16, @@ -119,10 +119,10 @@ PD_REGISTER_KERNEL(unsqueeze_with_xshape, #endif #ifdef PADDLE_WITH_XPU -PD_REGISTER_KERNEL(unsqueeze, +PD_REGISTER_KERNEL(unsqueeze_infer, XPU, ALL_LAYOUT, - phi::UnsqueezeKernel, + phi::UnsqueezeInferKernel, float, double, phi::dtype::float16, @@ -132,10 +132,10 @@ PD_REGISTER_KERNEL(unsqueeze, int8_t, int64_t) {} -PD_REGISTER_KERNEL(unsqueeze_with_xshape, +PD_REGISTER_KERNEL(unsqueeze, XPU, ALL_LAYOUT, - phi::UnsqueezeWithXShapeKernel, + phi::UnsqueezeKernel, float, double, phi::dtype::float16, diff --git a/paddle/phi/kernels/unsqueeze_kernel.h b/paddle/phi/kernels/unsqueeze_kernel.h index 35a0515c92d..bb190e14179 100644 --- a/paddle/phi/kernels/unsqueeze_kernel.h +++ b/paddle/phi/kernels/unsqueeze_kernel.h @@ -21,18 +21,18 @@ namespace phi { +template +void UnsqueezeInferKernel(const Context& dev_ctx, + const DenseTensor& x, + const IntArray& axes, + DenseTensor* out); + template void UnsqueezeKernel(const Context& dev_ctx, const DenseTensor& x, const IntArray& axes, - DenseTensor* out); - -template -void UnsqueezeWithXShapeKernel(const Context& dev_ctx, - const DenseTensor& x, - const IntArray& axes, - DenseTensor* out, - DenseTensor* xshape); + DenseTensor* out, + DenseTensor* xshape); template void Unsqueeze(const Context& dev_ctx, @@ -42,7 +42,7 @@ void Unsqueeze(const Context& dev_ctx, DenseTensor* xshape) { MetaTensor meta_out(out); UnsqueezeInferMeta(x, axes, &meta_out); - UnsqueezeKernel(dev_ctx, x, axes, out); + UnsqueezeInferKernel(dev_ctx, x, axes, out); } } // namespace phi diff --git a/paddle/phi/kernels/xpu/prod_kernel.cc b/paddle/phi/kernels/xpu/prod_kernel.cc index ebc9abc049c..ce006f095b0 100644 --- a/paddle/phi/kernels/xpu/prod_kernel.cc +++ b/paddle/phi/kernels/xpu/prod_kernel.cc @@ -22,12 +22,12 @@ namespace phi { template -void ProdRawKernel(const Context& dev_ctx, - const DenseTensor& x, - const IntArray& dims, - bool keep_dim, - bool reduce_all, - DenseTensor* out) { +void ProdKernel(const Context& dev_ctx, + const DenseTensor& x, + const IntArray& dims, + bool keep_dim, + bool reduce_all, + DenseTensor* out) { reduce_all = recompute_reduce_all(x, dims, reduce_all); using XPUType = typename XPUTypeTrait::Type; @@ -46,4 +46,4 @@ void ProdRawKernel(const Context& dev_ctx, } // namespace phi -PD_REGISTER_KERNEL(prod_raw, XPU, ALL_LAYOUT, phi::ProdRawKernel, float) {} +PD_REGISTER_KERNEL(prod, XPU, ALL_LAYOUT, phi::ProdKernel, float) {} diff --git a/paddle/phi/ops/compat/einsum_sig.cc b/paddle/phi/ops/compat/einsum_sig.cc index c145b8f4fa5..4fd31c1a2d8 100644 --- a/paddle/phi/ops/compat/einsum_sig.cc +++ b/paddle/phi/ops/compat/einsum_sig.cc @@ -17,10 +17,8 @@ limitations under the License. */ namespace phi { KernelSignature EinsumOpArgumentMapping(const ArgumentMappingContext& ctx) { - return KernelSignature("einsum_raw", - {"Operands"}, - {"equation"}, - {"Out", "InnerCache", "XShape"}); + return KernelSignature( + "einsum", {"Operands"}, {"equation"}, {"Out", "InnerCache", "XShape"}); } KernelSignature EinsumGradOpArgumentMapping(const ArgumentMappingContext& ctx) { @@ -31,7 +29,5 @@ KernelSignature EinsumGradOpArgumentMapping(const ArgumentMappingContext& ctx) { } } // namespace phi -PD_REGISTER_BASE_KERNEL_NAME(einsum, einsum_raw); - PD_REGISTER_ARG_MAPPING_FN(einsum, phi::EinsumOpArgumentMapping); PD_REGISTER_ARG_MAPPING_FN(einsum_grad, phi::EinsumGradOpArgumentMapping); diff --git a/paddle/phi/ops/compat/flatten_sig.cc b/paddle/phi/ops/compat/flatten_sig.cc index 122e0efa22b..b225dc62524 100644 --- a/paddle/phi/ops/compat/flatten_sig.cc +++ b/paddle/phi/ops/compat/flatten_sig.cc @@ -18,13 +18,11 @@ namespace phi { KernelSignature FlattenOpArgumentMapping(const ArgumentMappingContext& ctx) { if (ctx.HasOutput("XShape")) { - return KernelSignature("flatten_with_xshape", - {"X"}, - {"start_axis", "stop_axis"}, - {"Out", "XShape"}); + return KernelSignature( + "flatten", {"X"}, {"start_axis", "stop_axis"}, {"Out", "XShape"}); } else { return KernelSignature( - "flatten", {"X"}, {"start_axis", "stop_axis"}, {"Out"}); + "flatten_infer", {"X"}, {"start_axis", "stop_axis"}, {"Out"}); } } diff --git a/paddle/phi/ops/compat/reduce_sig.cc b/paddle/phi/ops/compat/reduce_sig.cc index e796307c0c9..dc00897a559 100644 --- a/paddle/phi/ops/compat/reduce_sig.cc +++ b/paddle/phi/ops/compat/reduce_sig.cc @@ -60,9 +60,9 @@ KernelSignature ReduceProdOpArgumentMapping(const ArgumentMappingContext& ctx) { // the "max_raw" KernelSignature if (ctx.IsForInferShape() || reduce_all) { return KernelSignature( - "prod_raw", {"X"}, {"dim", "keep_dim", "reduce_all"}, {"Out"}); + "prod", {"X"}, {"dim", "keep_dim", "reduce_all"}, {"Out"}); } - return KernelSignature("prod", {"X"}, {"dim", "keep_dim"}, {"Out"}); + return KernelSignature("prod_infer", {"X"}, {"dim", "keep_dim"}, {"Out"}); } return KernelSignature("unregistered", {}, {}, {}); } diff --git a/paddle/phi/ops/compat/reshape_sig.cc b/paddle/phi/ops/compat/reshape_sig.cc index a01f2a98c9b..0c6937199c9 100644 --- a/paddle/phi/ops/compat/reshape_sig.cc +++ b/paddle/phi/ops/compat/reshape_sig.cc @@ -20,21 +20,19 @@ KernelSignature ReshapeOpArgumentMapping(const ArgumentMappingContext& ctx) { if (ctx.HasOutput("XShape")) { if (ctx.InputSize("ShapeTensor") > 0) { return KernelSignature( - "reshape_with_xshape", {"X"}, {"ShapeTensor"}, {"Out", "XShape"}); + "reshape", {"X"}, {"ShapeTensor"}, {"Out", "XShape"}); } else if (ctx.HasInput("Shape")) { - return KernelSignature( - "reshape_with_xshape", {"X"}, {"Shape"}, {"Out", "XShape"}); + return KernelSignature("reshape", {"X"}, {"Shape"}, {"Out", "XShape"}); } else { - return KernelSignature( - "reshape_with_xshape", {"X"}, {"shape"}, {"Out", "XShape"}); + return KernelSignature("reshape", {"X"}, {"shape"}, {"Out", "XShape"}); } } else { if (ctx.InputSize("ShapeTensor") > 0) { - return KernelSignature("reshape", {"X"}, {"ShapeTensor"}, {"Out"}); + return KernelSignature("reshape_infer", {"X"}, {"ShapeTensor"}, {"Out"}); } else if (ctx.HasInput("Shape")) { - return KernelSignature("reshape", {"X"}, {"Shape"}, {"Out"}); + return KernelSignature("reshape_infer", {"X"}, {"Shape"}, {"Out"}); } else { - return KernelSignature("reshape", {"X"}, {"shape"}, {"Out"}); + return KernelSignature("reshape_infer", {"X"}, {"shape"}, {"Out"}); } } } diff --git a/paddle/phi/tests/ops/test_op_signature.cc b/paddle/phi/tests/ops/test_op_signature.cc index de97634ed61..99c529f5a90 100644 --- a/paddle/phi/tests/ops/test_op_signature.cc +++ b/paddle/phi/tests/ops/test_op_signature.cc @@ -618,18 +618,18 @@ TEST(ARG_MAP, reshape) { TestArgumentMappingContext arg_case1({"X", "ShapeTensor"}, {}, {}, {"Out"}); auto signature1 = (*OpUtilsMap::Instance().GetArgumentMappingFn("reshape2"))(arg_case1); - EXPECT_STREQ(signature1.name, "reshape"); + EXPECT_STREQ(signature1.name, "reshape_infer"); TestArgumentMappingContext arg_case2({"X", "Shape"}, {}, {}, {"Out"}); auto signature2 = (*OpUtilsMap::Instance().GetArgumentMappingFn("reshape2"))(arg_case2); - EXPECT_STREQ(signature2.name, "reshape"); + EXPECT_STREQ(signature2.name, "reshape_infer"); TestArgumentMappingContext arg_case3( {"X"}, {}, {{"shape", paddle::any(std::vector({1, 2}))}}, {"Out"}); auto signature3 = (*OpUtilsMap::Instance().GetArgumentMappingFn("reshape2"))(arg_case3); - EXPECT_STREQ(signature3.name, "reshape"); + EXPECT_STREQ(signature3.name, "reshape_infer"); } } // namespace tests -- GitLab