未验证 提交 36492bc5 编写于 作者: Y YuanRisheng 提交者: GitHub

rename elementwise fmax (#40810)

上级 3980e222
...@@ -259,7 +259,7 @@ PD_REGISTER_KERNEL(multiply_triple_grad, ...@@ -259,7 +259,7 @@ PD_REGISTER_KERNEL(multiply_triple_grad,
phi::dtype::bfloat16, phi::dtype::bfloat16,
phi::dtype::complex<float>, phi::dtype::complex<float>,
phi::dtype::complex<double>) {} phi::dtype::complex<double>) {}
PD_REGISTER_KERNEL(elementwise_fmax_grad, PD_REGISTER_KERNEL(fmax_grad,
CPU, CPU,
ALL_LAYOUT, ALL_LAYOUT,
phi::ElementwiseFMaxGradKernel, phi::ElementwiseFMaxGradKernel,
...@@ -268,7 +268,7 @@ PD_REGISTER_KERNEL(elementwise_fmax_grad, ...@@ -268,7 +268,7 @@ PD_REGISTER_KERNEL(elementwise_fmax_grad,
int, int,
int64_t) {} int64_t) {}
PD_REGISTER_KERNEL(elementwise_fmin_grad, PD_REGISTER_KERNEL(fmin_grad,
CPU, CPU,
ALL_LAYOUT, ALL_LAYOUT,
phi::ElementwiseFMinGradKernel, phi::ElementwiseFMinGradKernel,
......
...@@ -87,23 +87,11 @@ using complex128 = ::phi::dtype::complex<double>; ...@@ -87,23 +87,11 @@ using complex128 = ::phi::dtype::complex<double>;
// NOTE(chenweihang): using bfloat16 will cause redefine with xpu bfloat16 // NOTE(chenweihang): using bfloat16 will cause redefine with xpu bfloat16
// using bfloat16 = ::phi::dtype::bfloat16; // using bfloat16 = ::phi::dtype::bfloat16;
PD_REGISTER_KERNEL(elementwise_fmax, PD_REGISTER_KERNEL(
CPU, fmax, CPU, ALL_LAYOUT, phi::FMaxKernel, float, double, int, int64_t) {}
ALL_LAYOUT,
phi::ElementwiseFMaxKernel,
float,
double,
int,
int64_t) {}
PD_REGISTER_KERNEL(elementwise_fmin, PD_REGISTER_KERNEL(
CPU, fmin, CPU, ALL_LAYOUT, phi::FMinKernel, float, double, int, int64_t) {}
ALL_LAYOUT,
phi::ElementwiseFMinKernel,
float,
double,
int,
int64_t) {}
PD_REGISTER_KERNEL(add_raw, PD_REGISTER_KERNEL(add_raw,
CPU, CPU,
......
...@@ -20,18 +20,18 @@ ...@@ -20,18 +20,18 @@
namespace phi { namespace phi {
template <typename T, typename Context> template <typename T, typename Context>
void ElementwiseFMaxKernel(const Context& dev_ctx, void FMaxKernel(const Context& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
const DenseTensor& y, const DenseTensor& y,
int axis, int axis,
DenseTensor* out); DenseTensor* out);
template <typename T, typename Context> template <typename T, typename Context>
void ElementwiseFMinKernel(const Context& dev_ctx, void FMinKernel(const Context& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
const DenseTensor& y, const DenseTensor& y,
int axis, int axis,
DenseTensor* out); DenseTensor* out);
template <typename T, typename Context> template <typename T, typename Context>
void AddRawKernel(const Context& dev_ctx, void AddRawKernel(const Context& dev_ctx,
......
...@@ -282,7 +282,7 @@ PD_REGISTER_KERNEL(multiply_triple_grad, ...@@ -282,7 +282,7 @@ PD_REGISTER_KERNEL(multiply_triple_grad,
phi::dtype::bfloat16, phi::dtype::bfloat16,
phi::dtype::complex<float>, phi::dtype::complex<float>,
phi::dtype::complex<double>) {} phi::dtype::complex<double>) {}
PD_REGISTER_KERNEL(elementwise_fmax_grad, PD_REGISTER_KERNEL(fmax_grad,
GPU, GPU,
ALL_LAYOUT, ALL_LAYOUT,
phi::ElementwiseFMaxGradKernel, phi::ElementwiseFMaxGradKernel,
...@@ -291,7 +291,7 @@ PD_REGISTER_KERNEL(elementwise_fmax_grad, ...@@ -291,7 +291,7 @@ PD_REGISTER_KERNEL(elementwise_fmax_grad,
int, int,
int64_t) {} int64_t) {}
PD_REGISTER_KERNEL(elementwise_fmin_grad, PD_REGISTER_KERNEL(fmin_grad,
GPU, GPU,
ALL_LAYOUT, ALL_LAYOUT,
phi::ElementwiseFMinGradKernel, phi::ElementwiseFMinGradKernel,
......
...@@ -57,23 +57,11 @@ using bfloat16 = phi::dtype::bfloat16; ...@@ -57,23 +57,11 @@ using bfloat16 = phi::dtype::bfloat16;
using complex64 = ::phi::dtype::complex<float>; using complex64 = ::phi::dtype::complex<float>;
using complex128 = ::phi::dtype::complex<double>; using complex128 = ::phi::dtype::complex<double>;
PD_REGISTER_KERNEL(elementwise_fmax, PD_REGISTER_KERNEL(
GPU, fmax, GPU, ALL_LAYOUT, phi::FMaxKernel, float, double, int, int64_t) {}
ALL_LAYOUT,
phi::ElementwiseFMaxKernel,
float,
double,
int,
int64_t) {}
PD_REGISTER_KERNEL(elementwise_fmin, PD_REGISTER_KERNEL(
GPU, fmin, GPU, ALL_LAYOUT, phi::FMinKernel, float, double, int, int64_t) {}
ALL_LAYOUT,
phi::ElementwiseFMinKernel,
float,
double,
int,
int64_t) {}
PD_REGISTER_KERNEL(add_raw, PD_REGISTER_KERNEL(add_raw,
GPU, GPU,
......
...@@ -23,22 +23,22 @@ ...@@ -23,22 +23,22 @@
namespace phi { namespace phi {
template <typename T, typename Context> template <typename T, typename Context>
void ElementwiseFMaxKernel(const Context& dev_ctx, void FMaxKernel(const Context& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
const DenseTensor& y, const DenseTensor& y,
int axis, int axis,
DenseTensor* out) { DenseTensor* out) {
dev_ctx.template Alloc<T>(out); dev_ctx.template Alloc<T>(out);
funcs::ElementwiseCompute<funcs::FMaxFunctor<T>, T, T>( funcs::ElementwiseCompute<funcs::FMaxFunctor<T>, T, T>(
dev_ctx, x, y, axis, funcs::FMaxFunctor<T>(), out); dev_ctx, x, y, axis, funcs::FMaxFunctor<T>(), out);
} }
template <typename T, typename Context> template <typename T, typename Context>
void ElementwiseFMinKernel(const Context& dev_ctx, void FMinKernel(const Context& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
const DenseTensor& y, const DenseTensor& y,
int axis, int axis,
DenseTensor* out) { DenseTensor* out) {
dev_ctx.template Alloc<T>(out); dev_ctx.template Alloc<T>(out);
funcs::ElementwiseCompute<funcs::FMinFunctor<T>, T, T>( funcs::ElementwiseCompute<funcs::FMinFunctor<T>, T, T>(
dev_ctx, x, y, axis, funcs::FMinFunctor<T>(), out); dev_ctx, x, y, axis, funcs::FMinFunctor<T>(), out);
......
...@@ -19,25 +19,19 @@ namespace phi { ...@@ -19,25 +19,19 @@ namespace phi {
KernelSignature ElementwiseAddOpArgumentMapping( KernelSignature ElementwiseAddOpArgumentMapping(
const ArgumentMappingContext& ctx) { const ArgumentMappingContext& ctx) {
int axis = paddle::any_cast<int>(ctx.Attr("axis")); int axis = paddle::any_cast<int>(ctx.Attr("axis"));
if (ctx.IsDenseTensorInput("X")) { if (axis == -1) {
if (axis == -1) { return KernelSignature("add", {"X", "Y"}, {}, {"Out"});
return KernelSignature("add", {"X", "Y"}, {}, {"Out"});
}
return KernelSignature("add_raw", {"X", "Y"}, {"axis"}, {"Out"});
} }
return KernelSignature("unregistered", {}, {}, {}); return KernelSignature("add_raw", {"X", "Y"}, {"axis"}, {"Out"});
} }
KernelSignature ElementwiseSubOpArgumentMapping( KernelSignature ElementwiseSubOpArgumentMapping(
const ArgumentMappingContext& ctx) { const ArgumentMappingContext& ctx) {
int axis = paddle::any_cast<int>(ctx.Attr("axis")); int axis = paddle::any_cast<int>(ctx.Attr("axis"));
if (ctx.IsDenseTensorInput("X")) { if (axis == -1) {
if (axis == -1) { return KernelSignature("subtract", {"X", "Y"}, {}, {"Out"});
return KernelSignature("subtract", {"X", "Y"}, {}, {"Out"});
}
return KernelSignature("subtract_raw", {"X", "Y"}, {"axis"}, {"Out"});
} }
return KernelSignature("unregistered", {}, {}, {}); return KernelSignature("subtract_raw", {"X", "Y"}, {"axis"}, {"Out"});
} }
KernelSignature ElementwiseMulOpArgumentMapping( KernelSignature ElementwiseMulOpArgumentMapping(
...@@ -55,24 +49,18 @@ KernelSignature ElementwiseMulOpArgumentMapping( ...@@ -55,24 +49,18 @@ KernelSignature ElementwiseMulOpArgumentMapping(
KernelSignature ElementwiseDivOpArgumentMapping( KernelSignature ElementwiseDivOpArgumentMapping(
const ArgumentMappingContext& ctx) { const ArgumentMappingContext& ctx) {
int axis = paddle::any_cast<int>(ctx.Attr("axis")); int axis = paddle::any_cast<int>(ctx.Attr("axis"));
if (ctx.IsDenseTensorInput("X")) { if (axis == -1) {
if (axis == -1) { return KernelSignature("divide", {"X", "Y"}, {}, {"Out"});
return KernelSignature("divide", {"X", "Y"}, {}, {"Out"});
}
return KernelSignature("divide_raw", {"X", "Y"}, {"axis"}, {"Out"});
} }
return KernelSignature("unregistered", {}, {}, {}); return KernelSignature("divide_raw", {"X", "Y"}, {"axis"}, {"Out"});
} }
KernelSignature ElementwiseAddGradOpArgumentMapping( KernelSignature ElementwiseAddGradOpArgumentMapping(
const ArgumentMappingContext& ctx) { const ArgumentMappingContext& ctx) {
if (ctx.IsDenseTensorInput("X")) { return KernelSignature("add_grad",
return KernelSignature("add_grad", {"X", "Y", GradVarName("Out")},
{"X", "Y", GradVarName("Out")}, {"axis"},
{"axis"}, {GradVarName("X"), GradVarName("Y")});
{GradVarName("X"), GradVarName("Y")});
}
return KernelSignature("unregistered", {}, {}, {});
} }
KernelSignature ElementwiseAddDoubleGradOpArgumentMapping( KernelSignature ElementwiseAddDoubleGradOpArgumentMapping(
...@@ -91,13 +79,10 @@ KernelSignature ElementwiseAddTripleGradOpArgumentMapping( ...@@ -91,13 +79,10 @@ KernelSignature ElementwiseAddTripleGradOpArgumentMapping(
KernelSignature ElementwiseSubGradOpArgumentMapping( KernelSignature ElementwiseSubGradOpArgumentMapping(
const ArgumentMappingContext& ctx) { const ArgumentMappingContext& ctx) {
if (ctx.IsDenseTensorInput("X")) { return KernelSignature("subtract_grad",
return KernelSignature("subtract_grad", {"X", "Y", GradVarName("Out")},
{"X", "Y", GradVarName("Out")}, {"axis"},
{"axis"}, {GradVarName("X"), GradVarName("Y")});
{GradVarName("X"), GradVarName("Y")});
}
return KernelSignature("unregistered", {}, {}, {});
} }
KernelSignature ElementwiseSubDoubleGradOpArgumentMapping( KernelSignature ElementwiseSubDoubleGradOpArgumentMapping(
...@@ -116,7 +101,7 @@ KernelSignature ElementwiseDivGradOpArgumentMapping( ...@@ -116,7 +101,7 @@ KernelSignature ElementwiseDivGradOpArgumentMapping(
KernelSignature ElementwiseFMinGradOpArgumentMapping( KernelSignature ElementwiseFMinGradOpArgumentMapping(
const ArgumentMappingContext& ctx) { const ArgumentMappingContext& ctx) {
return KernelSignature("elementwise_fmin_grad", return KernelSignature("fmin_grad",
{"X", "Y", GradVarName("Out")}, {"X", "Y", GradVarName("Out")},
{"axis"}, {"axis"},
{GradVarName("X"), GradVarName("Y")}); {GradVarName("X"), GradVarName("Y")});
...@@ -138,9 +123,19 @@ KernelSignature ElementwiseMulGradOpArgumentMapping( ...@@ -138,9 +123,19 @@ KernelSignature ElementwiseMulGradOpArgumentMapping(
{GradVarName("X"), GradVarName("Y")}); {GradVarName("X"), GradVarName("Y")});
} }
KernelSignature ElementwiseFMaxOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature("fmax", {"X", "Y"}, {"axis"}, {"Out"});
}
KernelSignature ElementwiseFMinOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature("fmin", {"X", "Y"}, {"axis"}, {"Out"});
}
KernelSignature ElementwiseFMaxGradOpArgumentMapping( KernelSignature ElementwiseFMaxGradOpArgumentMapping(
const ArgumentMappingContext& ctx) { const ArgumentMappingContext& ctx) {
return KernelSignature("elementwise_fmax_grad", return KernelSignature("fmax_grad",
{"X", "Y", GradVarName("Out")}, {"X", "Y", GradVarName("Out")},
{"axis"}, {"axis"},
{GradVarName("X"), GradVarName("Y")}); {GradVarName("X"), GradVarName("Y")});
...@@ -179,6 +174,10 @@ PD_REGISTER_BASE_KERNEL_NAME(elementwise_div_grad_grad, divide_double_grad); ...@@ -179,6 +174,10 @@ PD_REGISTER_BASE_KERNEL_NAME(elementwise_div_grad_grad, divide_double_grad);
PD_REGISTER_BASE_KERNEL_NAME(elementwise_mul_grad, multiply_grad); PD_REGISTER_BASE_KERNEL_NAME(elementwise_mul_grad, multiply_grad);
PD_REGISTER_BASE_KERNEL_NAME(elementwise_mul_grad_grad, multiply_double_grad); PD_REGISTER_BASE_KERNEL_NAME(elementwise_mul_grad_grad, multiply_double_grad);
PD_REGISTER_BASE_KERNEL_NAME(elementwise_mul_triple_grad, multiply_triple_grad); PD_REGISTER_BASE_KERNEL_NAME(elementwise_mul_triple_grad, multiply_triple_grad);
PD_REGISTER_BASE_KERNEL_NAME(elementwise_fmax, fmax);
PD_REGISTER_BASE_KERNEL_NAME(elementwise_fmin, fmin);
PD_REGISTER_BASE_KERNEL_NAME(elementwise_fmax_grad, fmax_grad);
PD_REGISTER_BASE_KERNEL_NAME(elementwise_fmin_grad, fmin_grad);
PD_REGISTER_ARG_MAPPING_FN(elementwise_add, PD_REGISTER_ARG_MAPPING_FN(elementwise_add,
phi::ElementwiseAddOpArgumentMapping); phi::ElementwiseAddOpArgumentMapping);
...@@ -208,9 +207,12 @@ PD_REGISTER_ARG_MAPPING_FN(elementwise_mul_grad_grad, ...@@ -208,9 +207,12 @@ PD_REGISTER_ARG_MAPPING_FN(elementwise_mul_grad_grad,
phi::ElementwiseMulDoubleGradOpArgumentMapping); phi::ElementwiseMulDoubleGradOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(elementwise_mul_triple_grad, PD_REGISTER_ARG_MAPPING_FN(elementwise_mul_triple_grad,
phi::ElementwiseMulTripleGradOpArgumentMapping); phi::ElementwiseMulTripleGradOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(elementwise_fmax,
phi::ElementwiseFMaxOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(elementwise_fmin,
phi::ElementwiseFMinOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(elementwise_fmax_grad, PD_REGISTER_ARG_MAPPING_FN(elementwise_fmax_grad,
phi::ElementwiseFMaxGradOpArgumentMapping); phi::ElementwiseFMaxGradOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(elementwise_fmin_grad, PD_REGISTER_ARG_MAPPING_FN(elementwise_fmin_grad,
phi::ElementwiseFMinGradOpArgumentMapping); phi::ElementwiseFMinGradOpArgumentMapping);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册