未验证 提交 8df8cb10 编写于 作者: Z zyfncg 提交者: GitHub

Delete axis of fmin kernel (#50358)

* delete axis of fmin

* fix bug
上级 615d9f53
......@@ -515,7 +515,7 @@
- backward_op : fmin_grad
forward : fmin(Tensor x, Tensor y) -> Tensor(out)
args : (Tensor x, Tensor y, Tensor out_grad, int axis = -1)
args : (Tensor x, Tensor y, Tensor out_grad)
output : Tensor(x_grad), Tensor(y_grad)
infer_meta :
func : GeneralBinaryGradInferMeta
......
......@@ -131,14 +131,8 @@ PD_REGISTER_KERNEL(fmax_raw,
int,
int64_t) {}
PD_REGISTER_KERNEL(fmin_raw,
CPU,
ALL_LAYOUT,
phi::FMinRawKernel,
float,
double,
int,
int64_t) {}
PD_REGISTER_KERNEL(
fmin, CPU, ALL_LAYOUT, phi::FMinKernel, float, double, int, int64_t) {}
PD_REGISTER_KERNEL(maximum_raw,
CPU,
......
......@@ -33,7 +33,6 @@ void ElementwiseFMinGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
const DenseTensor& out_grad,
int axis,
DenseTensor* x_grad,
DenseTensor* y_grad);
......
......@@ -118,14 +118,6 @@ void FMaxKernel(const Context& dev_ctx,
FMaxRawKernel<T, Context>(dev_ctx, x, y, -1, out);
}
template <typename T, typename Context>
void FMinKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
DenseTensor* out) {
FMinRawKernel<T, Context>(dev_ctx, x, y, -1, out);
}
} // namespace phi
using complex64 = ::phi::dtype::complex<float>;
using complex128 = ::phi::dtype::complex<double>;
......@@ -133,9 +125,6 @@ using complex128 = ::phi::dtype::complex<double>;
PD_REGISTER_KERNEL(
fmax, CPU, ALL_LAYOUT, phi::FMaxKernel, float, double, int, int64_t) {}
PD_REGISTER_KERNEL(
fmin, CPU, ALL_LAYOUT, phi::FMinKernel, float, double, int, int64_t) {}
PD_REGISTER_KERNEL(maximum,
CPU,
ALL_LAYOUT,
......@@ -242,16 +231,6 @@ PD_REGISTER_KERNEL(fmax,
phi::dtype::float16,
int64_t) {}
PD_REGISTER_KERNEL(fmin,
KPS,
ALL_LAYOUT,
phi::FMinKernel,
float,
double,
int,
phi::dtype::float16,
int64_t) {}
PD_REGISTER_KERNEL(maximum,
KPS,
ALL_LAYOUT,
......
......@@ -32,13 +32,6 @@ void FMaxKernel(const Context& dev_ctx,
const DenseTensor& y,
DenseTensor* out);
template <typename T, typename Context>
void FMinRawKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
int axis,
DenseTensor* out);
template <typename T, typename Context>
void FMinKernel(const Context& dev_ctx,
const DenseTensor& x,
......
......@@ -314,13 +314,13 @@ void ElementwiseFMinGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
const DenseTensor& out_grad,
int axis,
DenseTensor* x_grad,
DenseTensor* y_grad) {
funcs::ElementwiseGradPreProcess(out_grad, x_grad);
auto out = out_grad; // Fake out, not used
auto x_dim = x.dims();
auto y_dim = y.dims();
int axis = -1;
if (x.dims() == y.dims()) {
funcs::ElemwiseGradComputeNoBroadcast<Context,
T,
......
......@@ -78,14 +78,13 @@ void FMaxRawKernel(const Context& dev_ctx,
}
template <typename T, typename Context>
void FMinRawKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
int axis,
DenseTensor* out) {
void FMinKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
DenseTensor* out) {
dev_ctx.template Alloc<T>(out);
funcs::ElementwiseCompute<funcs::FMinFunctor<T>, T, T>(
dev_ctx, x, y, axis, funcs::FMinFunctor<T>(), out);
dev_ctx, x, y, -1, funcs::FMinFunctor<T>(), out);
}
} // namespace phi
......@@ -103,10 +103,10 @@ PD_REGISTER_KERNEL(fmax_raw,
float16,
int64_t) {}
PD_REGISTER_KERNEL(fmin_raw,
PD_REGISTER_KERNEL(fmin,
KPS,
ALL_LAYOUT,
phi::FMinRawKernel,
phi::FMinKernel,
float,
double,
int,
......
......@@ -162,7 +162,7 @@ KernelSignature ElementwiseDivGradOpArgumentMapping(
KernelSignature ElementwiseFMinGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature(
"fmin_grad", {"X", "Y", "Out@GRAD"}, {"axis"}, {"X@GRAD", "Y@GRAD"});
"fmin_grad", {"X", "Y", "Out@GRAD"}, {}, {"X@GRAD", "Y@GRAD"});
}
KernelSignature ElementwiseDivDoubleGradOpArgumentMapping(
......@@ -186,7 +186,7 @@ KernelSignature ElementwiseFMaxOpArgumentMapping(
KernelSignature ElementwiseFMinOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature("fmin_raw", {"X", "Y"}, {"axis"}, {"Out"});
return KernelSignature("fmin", {"X", "Y"}, {}, {"Out"});
}
KernelSignature ElementwiseFMaxGradOpArgumentMapping(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册