未验证 提交 7d138402 编写于 作者: Z zhangyuqin1998 提交者: GitHub

delete axis of fmax (#51264)

上级 f759cf0f
......@@ -499,7 +499,7 @@
- backward_op : fmax_grad
forward : fmax(Tensor x, Tensor y) -> Tensor(out)
args : (Tensor x, Tensor y, Tensor out_grad, int axis = -1)
args : (Tensor x, Tensor y, Tensor out_grad)
output : Tensor(x_grad), Tensor(y_grad)
infer_meta :
func : GeneralBinaryGradInferMeta
......
......@@ -121,14 +121,8 @@ using complex128 = ::phi::dtype::complex<double>;
// NOTE(chenweihang): using bfloat16 will cause redefine with xpu bfloat16
// using bfloat16 = ::phi::dtype::bfloat16;
PD_REGISTER_KERNEL(fmax_raw,
CPU,
ALL_LAYOUT,
phi::FMaxRawKernel,
float,
double,
int,
int64_t) {}
PD_REGISTER_KERNEL(
fmax, CPU, ALL_LAYOUT, phi::FMaxKernel, float, double, int, int64_t) {}
PD_REGISTER_KERNEL(
fmin, CPU, ALL_LAYOUT, phi::FMinKernel, float, double, int, int64_t) {}
......
......@@ -24,7 +24,6 @@ void ElementwiseFMaxGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
const DenseTensor& out_grad,
int axis,
DenseTensor* x_grad,
DenseTensor* y_grad);
......
......@@ -101,21 +101,10 @@ void SubtractKernel(const Context& dev_ctx,
SubtractRawKernel<T>(dev_ctx, x, y, axis, out);
}
template <typename T, typename Context>
void FMaxKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
DenseTensor* out) {
FMaxRawKernel<T, Context>(dev_ctx, x, y, -1, out);
}
} // namespace phi
using complex64 = ::phi::dtype::complex<float>;
using complex128 = ::phi::dtype::complex<double>;
PD_REGISTER_KERNEL(
fmax, CPU, ALL_LAYOUT, phi::FMaxKernel, float, double, int, int64_t) {}
PD_REGISTER_KERNEL(maximum,
CPU,
ALL_LAYOUT,
......@@ -204,16 +193,6 @@ PD_REGISTER_KERNEL(divide,
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
PD_REGISTER_KERNEL(fmax,
KPS,
ALL_LAYOUT,
phi::FMaxKernel,
float,
double,
int,
phi::dtype::float16,
int64_t) {}
PD_REGISTER_KERNEL(maximum,
KPS,
ALL_LAYOUT,
......
......@@ -19,13 +19,6 @@
namespace phi {
template <typename T, typename Context>
void FMaxRawKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
int axis,
DenseTensor* out);
template <typename T, typename Context>
void FMaxKernel(const Context& dev_ctx,
const DenseTensor& x,
......
......@@ -265,7 +265,6 @@ void ElementwiseFMaxGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
const DenseTensor& out_grad,
int axis,
DenseTensor* x_grad,
DenseTensor* y_grad) {
funcs::ElementwiseGradPreProcess(out_grad, x_grad);
......@@ -273,6 +272,7 @@ void ElementwiseFMaxGradKernel(const Context& dev_ctx,
auto out = out_grad; // Fake out, not used
auto x_dim = x.dims();
auto y_dim = y.dims();
int axis = -1;
if (x.dims() == y.dims()) {
funcs::ElemwiseGradComputeNoBroadcast<Context,
T,
......
......@@ -67,14 +67,13 @@ namespace phi {
}
template <typename T, typename Context>
void FMaxRawKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
int axis,
DenseTensor* out) {
void FMaxKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
DenseTensor* out) {
dev_ctx.template Alloc<T>(out);
funcs::ElementwiseCompute<funcs::FMaxFunctor<T>, T, T>(
dev_ctx, x, y, axis, funcs::FMaxFunctor<T>(), out);
dev_ctx, x, y, -1, funcs::FMaxFunctor<T>(), out);
}
template <typename T, typename Context>
......
......@@ -109,10 +109,10 @@ using bfloat16 = phi::dtype::bfloat16;
using complex64 = ::phi::dtype::complex<float>;
using complex128 = ::phi::dtype::complex<double>;
PD_REGISTER_KERNEL(fmax_raw,
PD_REGISTER_KERNEL(fmax,
KPS,
ALL_LAYOUT,
phi::FMaxRawKernel,
phi::FMaxKernel,
float,
double,
int,
......
......@@ -176,7 +176,7 @@ KernelSignature ElementwiseMulGradOpArgumentMapping(
KernelSignature ElementwiseFMaxOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature("fmax_raw", {"X", "Y"}, {"axis"}, {"Out"});
return KernelSignature("fmax", {"X", "Y"}, {}, {"Out"});
}
KernelSignature ElementwiseFMinOpArgumentMapping(
......@@ -187,7 +187,7 @@ KernelSignature ElementwiseFMinOpArgumentMapping(
KernelSignature ElementwiseFMaxGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature(
"fmax_grad", {"X", "Y", "Out@GRAD"}, {"axis"}, {"X@GRAD", "Y@GRAD"});
"fmax_grad", {"X", "Y", "Out@GRAD"}, {}, {"X@GRAD", "Y@GRAD"});
}
KernelSignature ElementwiseMulDoubleGradOpArgumentMapping(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册