diff --git a/paddle/phi/kernels/cpu/elementwise_grad_kernel.cc b/paddle/phi/kernels/cpu/elementwise_grad_kernel.cc index bf6ec012b24443e877b235e17488725dc0d14151..d5b78909e9287ee0c6cf93164a19b49733a2d76d 100644 --- a/paddle/phi/kernels/cpu/elementwise_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/elementwise_grad_kernel.cc @@ -259,7 +259,7 @@ PD_REGISTER_KERNEL(multiply_triple_grad, phi::dtype::bfloat16, phi::dtype::complex, phi::dtype::complex) {} -PD_REGISTER_KERNEL(elementwise_fmax_grad, +PD_REGISTER_KERNEL(fmax_grad, CPU, ALL_LAYOUT, phi::ElementwiseFMaxGradKernel, @@ -268,7 +268,7 @@ PD_REGISTER_KERNEL(elementwise_fmax_grad, int, int64_t) {} -PD_REGISTER_KERNEL(elementwise_fmin_grad, +PD_REGISTER_KERNEL(fmin_grad, CPU, ALL_LAYOUT, phi::ElementwiseFMinGradKernel, diff --git a/paddle/phi/kernels/cpu/elementwise_kernel.cc b/paddle/phi/kernels/cpu/elementwise_kernel.cc index 095d11720ce26622c31e517286d6f656869e62ff..004f40ddedadf5e2609868478c7b0d4169b73a63 100644 --- a/paddle/phi/kernels/cpu/elementwise_kernel.cc +++ b/paddle/phi/kernels/cpu/elementwise_kernel.cc @@ -87,23 +87,11 @@ using complex128 = ::phi::dtype::complex; // NOTE(chenweihang): using bfloat16 will cause redefine with xpu bfloat16 // using bfloat16 = ::phi::dtype::bfloat16; -PD_REGISTER_KERNEL(elementwise_fmax, - CPU, - ALL_LAYOUT, - phi::ElementwiseFMaxKernel, - float, - double, - int, - int64_t) {} +PD_REGISTER_KERNEL( + fmax, CPU, ALL_LAYOUT, phi::FMaxKernel, float, double, int, int64_t) {} -PD_REGISTER_KERNEL(elementwise_fmin, - CPU, - ALL_LAYOUT, - phi::ElementwiseFMinKernel, - float, - double, - int, - int64_t) {} +PD_REGISTER_KERNEL( + fmin, CPU, ALL_LAYOUT, phi::FMinKernel, float, double, int, int64_t) {} PD_REGISTER_KERNEL(add_raw, CPU, diff --git a/paddle/phi/kernels/elementwise_kernel.h b/paddle/phi/kernels/elementwise_kernel.h index b064ecc454c592df49670205163e73d2d3b249b3..a6ba7bdac5829f88c153496c908a6e7ac14f91d2 100644 --- a/paddle/phi/kernels/elementwise_kernel.h +++ b/paddle/phi/kernels/elementwise_kernel.h @@ -20,18 +20,18 @@ namespace phi { template -void ElementwiseFMaxKernel(const Context& dev_ctx, - const DenseTensor& x, - const DenseTensor& y, - int axis, - DenseTensor* out); +void FMaxKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + int axis, + DenseTensor* out); template -void ElementwiseFMinKernel(const Context& dev_ctx, - const DenseTensor& x, - const DenseTensor& y, - int axis, - DenseTensor* out); +void FMinKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + int axis, + DenseTensor* out); template void AddRawKernel(const Context& dev_ctx, diff --git a/paddle/phi/kernels/gpu/elementwise_grad_kernel.cu b/paddle/phi/kernels/gpu/elementwise_grad_kernel.cu index c4481bf6ce3c33ea260d774d0ac240a166856388..3392a3cec4ecad08b0442a54c3c3dbc652ebd0b6 100644 --- a/paddle/phi/kernels/gpu/elementwise_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/elementwise_grad_kernel.cu @@ -282,7 +282,7 @@ PD_REGISTER_KERNEL(multiply_triple_grad, phi::dtype::bfloat16, phi::dtype::complex, phi::dtype::complex) {} -PD_REGISTER_KERNEL(elementwise_fmax_grad, +PD_REGISTER_KERNEL(fmax_grad, GPU, ALL_LAYOUT, phi::ElementwiseFMaxGradKernel, @@ -291,7 +291,7 @@ PD_REGISTER_KERNEL(elementwise_fmax_grad, int, int64_t) {} -PD_REGISTER_KERNEL(elementwise_fmin_grad, +PD_REGISTER_KERNEL(fmin_grad, GPU, ALL_LAYOUT, phi::ElementwiseFMinGradKernel, diff --git a/paddle/phi/kernels/gpu/elementwise_kernel.cu b/paddle/phi/kernels/gpu/elementwise_kernel.cu index a57d89013f921e3adb5587c70b7bbb12c383de61..8de55e8a412d36c615ed923984c1a3fadc073d0b 100644 --- a/paddle/phi/kernels/gpu/elementwise_kernel.cu +++ b/paddle/phi/kernels/gpu/elementwise_kernel.cu @@ -57,23 +57,11 @@ using bfloat16 = phi::dtype::bfloat16; using complex64 = ::phi::dtype::complex; using complex128 = ::phi::dtype::complex; -PD_REGISTER_KERNEL(elementwise_fmax, - GPU, - ALL_LAYOUT, - phi::ElementwiseFMaxKernel, - float, - double, - int, - int64_t) {} +PD_REGISTER_KERNEL( + fmax, GPU, ALL_LAYOUT, phi::FMaxKernel, float, double, int, int64_t) {} -PD_REGISTER_KERNEL(elementwise_fmin, - GPU, - ALL_LAYOUT, - phi::ElementwiseFMinKernel, - float, - double, - int, - int64_t) {} +PD_REGISTER_KERNEL( + fmin, GPU, ALL_LAYOUT, phi::FMinKernel, float, double, int, int64_t) {} PD_REGISTER_KERNEL(add_raw, GPU, diff --git a/paddle/phi/kernels/impl/elementwise_kernel_impl.h b/paddle/phi/kernels/impl/elementwise_kernel_impl.h index 775a91bf026d298a61315a7e2d7ebfbe92efb0b5..0e69d00110eadf1a3845a2bbb56be917153f654e 100644 --- a/paddle/phi/kernels/impl/elementwise_kernel_impl.h +++ b/paddle/phi/kernels/impl/elementwise_kernel_impl.h @@ -23,22 +23,22 @@ namespace phi { template -void ElementwiseFMaxKernel(const Context& dev_ctx, - const DenseTensor& x, - const DenseTensor& y, - int axis, - DenseTensor* out) { +void FMaxKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + int axis, + DenseTensor* out) { dev_ctx.template Alloc(out); funcs::ElementwiseCompute, T, T>( dev_ctx, x, y, axis, funcs::FMaxFunctor(), out); } template -void ElementwiseFMinKernel(const Context& dev_ctx, - const DenseTensor& x, - const DenseTensor& y, - int axis, - DenseTensor* out) { +void FMinKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + int axis, + DenseTensor* out) { dev_ctx.template Alloc(out); funcs::ElementwiseCompute, T, T>( dev_ctx, x, y, axis, funcs::FMinFunctor(), out); diff --git a/paddle/phi/ops/compat/elementwise_sig.cc b/paddle/phi/ops/compat/elementwise_sig.cc index 1d2aaa04f05d205483dbda5c738c7499ad068881..bb05689dee1d31e2a81bfa15793ee6de52f63120 100644 --- a/paddle/phi/ops/compat/elementwise_sig.cc +++ b/paddle/phi/ops/compat/elementwise_sig.cc @@ -19,25 +19,19 @@ namespace phi { KernelSignature ElementwiseAddOpArgumentMapping( const ArgumentMappingContext& ctx) { int axis = paddle::any_cast(ctx.Attr("axis")); - if (ctx.IsDenseTensorInput("X")) { - if (axis == -1) { - return KernelSignature("add", {"X", "Y"}, {}, {"Out"}); - } - return KernelSignature("add_raw", {"X", "Y"}, {"axis"}, {"Out"}); + if (axis == -1) { + return KernelSignature("add", {"X", "Y"}, {}, {"Out"}); } - return KernelSignature("unregistered", {}, {}, {}); + return KernelSignature("add_raw", {"X", "Y"}, {"axis"}, {"Out"}); } KernelSignature ElementwiseSubOpArgumentMapping( const ArgumentMappingContext& ctx) { int axis = paddle::any_cast(ctx.Attr("axis")); - if (ctx.IsDenseTensorInput("X")) { - if (axis == -1) { - return KernelSignature("subtract", {"X", "Y"}, {}, {"Out"}); - } - return KernelSignature("subtract_raw", {"X", "Y"}, {"axis"}, {"Out"}); + if (axis == -1) { + return KernelSignature("subtract", {"X", "Y"}, {}, {"Out"}); } - return KernelSignature("unregistered", {}, {}, {}); + return KernelSignature("subtract_raw", {"X", "Y"}, {"axis"}, {"Out"}); } KernelSignature ElementwiseMulOpArgumentMapping( @@ -55,24 +49,18 @@ KernelSignature ElementwiseMulOpArgumentMapping( KernelSignature ElementwiseDivOpArgumentMapping( const ArgumentMappingContext& ctx) { int axis = paddle::any_cast(ctx.Attr("axis")); - if (ctx.IsDenseTensorInput("X")) { - if (axis == -1) { - return KernelSignature("divide", {"X", "Y"}, {}, {"Out"}); - } - return KernelSignature("divide_raw", {"X", "Y"}, {"axis"}, {"Out"}); + if (axis == -1) { + return KernelSignature("divide", {"X", "Y"}, {}, {"Out"}); } - return KernelSignature("unregistered", {}, {}, {}); + return KernelSignature("divide_raw", {"X", "Y"}, {"axis"}, {"Out"}); } KernelSignature ElementwiseAddGradOpArgumentMapping( const ArgumentMappingContext& ctx) { - if (ctx.IsDenseTensorInput("X")) { - return KernelSignature("add_grad", - {"X", "Y", GradVarName("Out")}, - {"axis"}, - {GradVarName("X"), GradVarName("Y")}); - } - return KernelSignature("unregistered", {}, {}, {}); + return KernelSignature("add_grad", + {"X", "Y", GradVarName("Out")}, + {"axis"}, + {GradVarName("X"), GradVarName("Y")}); } KernelSignature ElementwiseAddDoubleGradOpArgumentMapping( @@ -91,13 +79,10 @@ KernelSignature ElementwiseAddTripleGradOpArgumentMapping( KernelSignature ElementwiseSubGradOpArgumentMapping( const ArgumentMappingContext& ctx) { - if (ctx.IsDenseTensorInput("X")) { - return KernelSignature("subtract_grad", - {"X", "Y", GradVarName("Out")}, - {"axis"}, - {GradVarName("X"), GradVarName("Y")}); - } - return KernelSignature("unregistered", {}, {}, {}); + return KernelSignature("subtract_grad", + {"X", "Y", GradVarName("Out")}, + {"axis"}, + {GradVarName("X"), GradVarName("Y")}); } KernelSignature ElementwiseSubDoubleGradOpArgumentMapping( @@ -116,7 +101,7 @@ KernelSignature ElementwiseDivGradOpArgumentMapping( KernelSignature ElementwiseFMinGradOpArgumentMapping( const ArgumentMappingContext& ctx) { - return KernelSignature("elementwise_fmin_grad", + return KernelSignature("fmin_grad", {"X", "Y", GradVarName("Out")}, {"axis"}, {GradVarName("X"), GradVarName("Y")}); @@ -138,9 +123,19 @@ KernelSignature ElementwiseMulGradOpArgumentMapping( {GradVarName("X"), GradVarName("Y")}); } +KernelSignature ElementwiseFMaxOpArgumentMapping( + const ArgumentMappingContext& ctx) { + return KernelSignature("fmax", {"X", "Y"}, {"axis"}, {"Out"}); +} + +KernelSignature ElementwiseFMinOpArgumentMapping( + const ArgumentMappingContext& ctx) { + return KernelSignature("fmin", {"X", "Y"}, {"axis"}, {"Out"}); +} + KernelSignature ElementwiseFMaxGradOpArgumentMapping( const ArgumentMappingContext& ctx) { - return KernelSignature("elementwise_fmax_grad", + return KernelSignature("fmax_grad", {"X", "Y", GradVarName("Out")}, {"axis"}, {GradVarName("X"), GradVarName("Y")}); @@ -179,6 +174,10 @@ PD_REGISTER_BASE_KERNEL_NAME(elementwise_div_grad_grad, divide_double_grad); PD_REGISTER_BASE_KERNEL_NAME(elementwise_mul_grad, multiply_grad); PD_REGISTER_BASE_KERNEL_NAME(elementwise_mul_grad_grad, multiply_double_grad); PD_REGISTER_BASE_KERNEL_NAME(elementwise_mul_triple_grad, multiply_triple_grad); +PD_REGISTER_BASE_KERNEL_NAME(elementwise_fmax, fmax); +PD_REGISTER_BASE_KERNEL_NAME(elementwise_fmin, fmin); +PD_REGISTER_BASE_KERNEL_NAME(elementwise_fmax_grad, fmax_grad); +PD_REGISTER_BASE_KERNEL_NAME(elementwise_fmin_grad, fmin_grad); PD_REGISTER_ARG_MAPPING_FN(elementwise_add, phi::ElementwiseAddOpArgumentMapping); @@ -208,9 +207,12 @@ PD_REGISTER_ARG_MAPPING_FN(elementwise_mul_grad_grad, phi::ElementwiseMulDoubleGradOpArgumentMapping); PD_REGISTER_ARG_MAPPING_FN(elementwise_mul_triple_grad, phi::ElementwiseMulTripleGradOpArgumentMapping); +PD_REGISTER_ARG_MAPPING_FN(elementwise_fmax, + phi::ElementwiseFMaxOpArgumentMapping); +PD_REGISTER_ARG_MAPPING_FN(elementwise_fmin, + phi::ElementwiseFMinOpArgumentMapping); PD_REGISTER_ARG_MAPPING_FN(elementwise_fmax_grad, phi::ElementwiseFMaxGradOpArgumentMapping); - PD_REGISTER_ARG_MAPPING_FN(elementwise_fmin_grad, phi::ElementwiseFMinGradOpArgumentMapping);