未验证 提交 36492bc5 编写于 作者: Y YuanRisheng 提交者: GitHub

rename elementwise fmax (#40810)

上级 3980e222
......@@ -259,7 +259,7 @@ PD_REGISTER_KERNEL(multiply_triple_grad,
phi::dtype::bfloat16,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
PD_REGISTER_KERNEL(elementwise_fmax_grad,
PD_REGISTER_KERNEL(fmax_grad,
CPU,
ALL_LAYOUT,
phi::ElementwiseFMaxGradKernel,
......@@ -268,7 +268,7 @@ PD_REGISTER_KERNEL(elementwise_fmax_grad,
int,
int64_t) {}
PD_REGISTER_KERNEL(elementwise_fmin_grad,
PD_REGISTER_KERNEL(fmin_grad,
CPU,
ALL_LAYOUT,
phi::ElementwiseFMinGradKernel,
......
......@@ -87,23 +87,11 @@ using complex128 = ::phi::dtype::complex<double>;
// NOTE(chenweihang): using bfloat16 will cause redefine with xpu bfloat16
// using bfloat16 = ::phi::dtype::bfloat16;
PD_REGISTER_KERNEL(elementwise_fmax,
CPU,
ALL_LAYOUT,
phi::ElementwiseFMaxKernel,
float,
double,
int,
int64_t) {}
PD_REGISTER_KERNEL(
fmax, CPU, ALL_LAYOUT, phi::FMaxKernel, float, double, int, int64_t) {}
PD_REGISTER_KERNEL(elementwise_fmin,
CPU,
ALL_LAYOUT,
phi::ElementwiseFMinKernel,
float,
double,
int,
int64_t) {}
PD_REGISTER_KERNEL(
fmin, CPU, ALL_LAYOUT, phi::FMinKernel, float, double, int, int64_t) {}
PD_REGISTER_KERNEL(add_raw,
CPU,
......
......@@ -20,14 +20,14 @@
namespace phi {
template <typename T, typename Context>
void ElementwiseFMaxKernel(const Context& dev_ctx,
void FMaxKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
int axis,
DenseTensor* out);
template <typename T, typename Context>
void ElementwiseFMinKernel(const Context& dev_ctx,
void FMinKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
int axis,
......
......@@ -282,7 +282,7 @@ PD_REGISTER_KERNEL(multiply_triple_grad,
phi::dtype::bfloat16,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
PD_REGISTER_KERNEL(elementwise_fmax_grad,
PD_REGISTER_KERNEL(fmax_grad,
GPU,
ALL_LAYOUT,
phi::ElementwiseFMaxGradKernel,
......@@ -291,7 +291,7 @@ PD_REGISTER_KERNEL(elementwise_fmax_grad,
int,
int64_t) {}
PD_REGISTER_KERNEL(elementwise_fmin_grad,
PD_REGISTER_KERNEL(fmin_grad,
GPU,
ALL_LAYOUT,
phi::ElementwiseFMinGradKernel,
......
......@@ -57,23 +57,11 @@ using bfloat16 = phi::dtype::bfloat16;
using complex64 = ::phi::dtype::complex<float>;
using complex128 = ::phi::dtype::complex<double>;
PD_REGISTER_KERNEL(elementwise_fmax,
GPU,
ALL_LAYOUT,
phi::ElementwiseFMaxKernel,
float,
double,
int,
int64_t) {}
PD_REGISTER_KERNEL(
fmax, GPU, ALL_LAYOUT, phi::FMaxKernel, float, double, int, int64_t) {}
PD_REGISTER_KERNEL(elementwise_fmin,
GPU,
ALL_LAYOUT,
phi::ElementwiseFMinKernel,
float,
double,
int,
int64_t) {}
PD_REGISTER_KERNEL(
fmin, GPU, ALL_LAYOUT, phi::FMinKernel, float, double, int, int64_t) {}
PD_REGISTER_KERNEL(add_raw,
GPU,
......
......@@ -23,7 +23,7 @@
namespace phi {
template <typename T, typename Context>
void ElementwiseFMaxKernel(const Context& dev_ctx,
void FMaxKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
int axis,
......@@ -34,7 +34,7 @@ void ElementwiseFMaxKernel(const Context& dev_ctx,
}
template <typename T, typename Context>
void ElementwiseFMinKernel(const Context& dev_ctx,
void FMinKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
int axis,
......
......@@ -19,25 +19,19 @@ namespace phi {
KernelSignature ElementwiseAddOpArgumentMapping(
const ArgumentMappingContext& ctx) {
int axis = paddle::any_cast<int>(ctx.Attr("axis"));
if (ctx.IsDenseTensorInput("X")) {
if (axis == -1) {
return KernelSignature("add", {"X", "Y"}, {}, {"Out"});
}
return KernelSignature("add_raw", {"X", "Y"}, {"axis"}, {"Out"});
}
return KernelSignature("unregistered", {}, {}, {});
}
KernelSignature ElementwiseSubOpArgumentMapping(
const ArgumentMappingContext& ctx) {
int axis = paddle::any_cast<int>(ctx.Attr("axis"));
if (ctx.IsDenseTensorInput("X")) {
if (axis == -1) {
return KernelSignature("subtract", {"X", "Y"}, {}, {"Out"});
}
return KernelSignature("subtract_raw", {"X", "Y"}, {"axis"}, {"Out"});
}
return KernelSignature("unregistered", {}, {}, {});
}
KernelSignature ElementwiseMulOpArgumentMapping(
......@@ -55,24 +49,18 @@ KernelSignature ElementwiseMulOpArgumentMapping(
KernelSignature ElementwiseDivOpArgumentMapping(
const ArgumentMappingContext& ctx) {
int axis = paddle::any_cast<int>(ctx.Attr("axis"));
if (ctx.IsDenseTensorInput("X")) {
if (axis == -1) {
return KernelSignature("divide", {"X", "Y"}, {}, {"Out"});
}
return KernelSignature("divide_raw", {"X", "Y"}, {"axis"}, {"Out"});
}
return KernelSignature("unregistered", {}, {}, {});
}
KernelSignature ElementwiseAddGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
if (ctx.IsDenseTensorInput("X")) {
return KernelSignature("add_grad",
{"X", "Y", GradVarName("Out")},
{"axis"},
{GradVarName("X"), GradVarName("Y")});
}
return KernelSignature("unregistered", {}, {}, {});
}
KernelSignature ElementwiseAddDoubleGradOpArgumentMapping(
......@@ -91,13 +79,10 @@ KernelSignature ElementwiseAddTripleGradOpArgumentMapping(
KernelSignature ElementwiseSubGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
if (ctx.IsDenseTensorInput("X")) {
return KernelSignature("subtract_grad",
{"X", "Y", GradVarName("Out")},
{"axis"},
{GradVarName("X"), GradVarName("Y")});
}
return KernelSignature("unregistered", {}, {}, {});
}
KernelSignature ElementwiseSubDoubleGradOpArgumentMapping(
......@@ -116,7 +101,7 @@ KernelSignature ElementwiseDivGradOpArgumentMapping(
KernelSignature ElementwiseFMinGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature("elementwise_fmin_grad",
return KernelSignature("fmin_grad",
{"X", "Y", GradVarName("Out")},
{"axis"},
{GradVarName("X"), GradVarName("Y")});
......@@ -138,9 +123,19 @@ KernelSignature ElementwiseMulGradOpArgumentMapping(
{GradVarName("X"), GradVarName("Y")});
}
KernelSignature ElementwiseFMaxOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature("fmax", {"X", "Y"}, {"axis"}, {"Out"});
}
KernelSignature ElementwiseFMinOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature("fmin", {"X", "Y"}, {"axis"}, {"Out"});
}
KernelSignature ElementwiseFMaxGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature("elementwise_fmax_grad",
return KernelSignature("fmax_grad",
{"X", "Y", GradVarName("Out")},
{"axis"},
{GradVarName("X"), GradVarName("Y")});
......@@ -179,6 +174,10 @@ PD_REGISTER_BASE_KERNEL_NAME(elementwise_div_grad_grad, divide_double_grad);
PD_REGISTER_BASE_KERNEL_NAME(elementwise_mul_grad, multiply_grad);
PD_REGISTER_BASE_KERNEL_NAME(elementwise_mul_grad_grad, multiply_double_grad);
PD_REGISTER_BASE_KERNEL_NAME(elementwise_mul_triple_grad, multiply_triple_grad);
PD_REGISTER_BASE_KERNEL_NAME(elementwise_fmax, fmax);
PD_REGISTER_BASE_KERNEL_NAME(elementwise_fmin, fmin);
PD_REGISTER_BASE_KERNEL_NAME(elementwise_fmax_grad, fmax_grad);
PD_REGISTER_BASE_KERNEL_NAME(elementwise_fmin_grad, fmin_grad);
PD_REGISTER_ARG_MAPPING_FN(elementwise_add,
phi::ElementwiseAddOpArgumentMapping);
......@@ -208,9 +207,12 @@ PD_REGISTER_ARG_MAPPING_FN(elementwise_mul_grad_grad,
phi::ElementwiseMulDoubleGradOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(elementwise_mul_triple_grad,
phi::ElementwiseMulTripleGradOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(elementwise_fmax,
phi::ElementwiseFMaxOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(elementwise_fmin,
phi::ElementwiseFMinOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(elementwise_fmax_grad,
phi::ElementwiseFMaxGradOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(elementwise_fmin_grad,
phi::ElementwiseFMinGradOpArgumentMapping);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册