未验证 提交 8df8cb10 编写于 作者: Z zyfncg 提交者: GitHub

Delete axis of fmin kernel (#50358)

* delete axis of fmin

* fix bug
上级 615d9f53
...@@ -515,7 +515,7 @@ ...@@ -515,7 +515,7 @@
- backward_op : fmin_grad - backward_op : fmin_grad
forward : fmin(Tensor x, Tensor y) -> Tensor(out) forward : fmin(Tensor x, Tensor y) -> Tensor(out)
args : (Tensor x, Tensor y, Tensor out_grad, int axis = -1) args : (Tensor x, Tensor y, Tensor out_grad)
output : Tensor(x_grad), Tensor(y_grad) output : Tensor(x_grad), Tensor(y_grad)
infer_meta : infer_meta :
func : GeneralBinaryGradInferMeta func : GeneralBinaryGradInferMeta
......
...@@ -131,14 +131,8 @@ PD_REGISTER_KERNEL(fmax_raw, ...@@ -131,14 +131,8 @@ PD_REGISTER_KERNEL(fmax_raw,
int, int,
int64_t) {} int64_t) {}
PD_REGISTER_KERNEL(fmin_raw, PD_REGISTER_KERNEL(
CPU, fmin, CPU, ALL_LAYOUT, phi::FMinKernel, float, double, int, int64_t) {}
ALL_LAYOUT,
phi::FMinRawKernel,
float,
double,
int,
int64_t) {}
PD_REGISTER_KERNEL(maximum_raw, PD_REGISTER_KERNEL(maximum_raw,
CPU, CPU,
......
...@@ -33,7 +33,6 @@ void ElementwiseFMinGradKernel(const Context& dev_ctx, ...@@ -33,7 +33,6 @@ void ElementwiseFMinGradKernel(const Context& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
const DenseTensor& y, const DenseTensor& y,
const DenseTensor& out_grad, const DenseTensor& out_grad,
int axis,
DenseTensor* x_grad, DenseTensor* x_grad,
DenseTensor* y_grad); DenseTensor* y_grad);
......
...@@ -118,14 +118,6 @@ void FMaxKernel(const Context& dev_ctx, ...@@ -118,14 +118,6 @@ void FMaxKernel(const Context& dev_ctx,
FMaxRawKernel<T, Context>(dev_ctx, x, y, -1, out); FMaxRawKernel<T, Context>(dev_ctx, x, y, -1, out);
} }
template <typename T, typename Context>
void FMinKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
DenseTensor* out) {
FMinRawKernel<T, Context>(dev_ctx, x, y, -1, out);
}
} // namespace phi } // namespace phi
using complex64 = ::phi::dtype::complex<float>; using complex64 = ::phi::dtype::complex<float>;
using complex128 = ::phi::dtype::complex<double>; using complex128 = ::phi::dtype::complex<double>;
...@@ -133,9 +125,6 @@ using complex128 = ::phi::dtype::complex<double>; ...@@ -133,9 +125,6 @@ using complex128 = ::phi::dtype::complex<double>;
PD_REGISTER_KERNEL( PD_REGISTER_KERNEL(
fmax, CPU, ALL_LAYOUT, phi::FMaxKernel, float, double, int, int64_t) {} fmax, CPU, ALL_LAYOUT, phi::FMaxKernel, float, double, int, int64_t) {}
PD_REGISTER_KERNEL(
fmin, CPU, ALL_LAYOUT, phi::FMinKernel, float, double, int, int64_t) {}
PD_REGISTER_KERNEL(maximum, PD_REGISTER_KERNEL(maximum,
CPU, CPU,
ALL_LAYOUT, ALL_LAYOUT,
...@@ -242,16 +231,6 @@ PD_REGISTER_KERNEL(fmax, ...@@ -242,16 +231,6 @@ PD_REGISTER_KERNEL(fmax,
phi::dtype::float16, phi::dtype::float16,
int64_t) {} int64_t) {}
PD_REGISTER_KERNEL(fmin,
KPS,
ALL_LAYOUT,
phi::FMinKernel,
float,
double,
int,
phi::dtype::float16,
int64_t) {}
PD_REGISTER_KERNEL(maximum, PD_REGISTER_KERNEL(maximum,
KPS, KPS,
ALL_LAYOUT, ALL_LAYOUT,
......
...@@ -32,13 +32,6 @@ void FMaxKernel(const Context& dev_ctx, ...@@ -32,13 +32,6 @@ void FMaxKernel(const Context& dev_ctx,
const DenseTensor& y, const DenseTensor& y,
DenseTensor* out); DenseTensor* out);
template <typename T, typename Context>
void FMinRawKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
int axis,
DenseTensor* out);
template <typename T, typename Context> template <typename T, typename Context>
void FMinKernel(const Context& dev_ctx, void FMinKernel(const Context& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
......
...@@ -314,13 +314,13 @@ void ElementwiseFMinGradKernel(const Context& dev_ctx, ...@@ -314,13 +314,13 @@ void ElementwiseFMinGradKernel(const Context& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
const DenseTensor& y, const DenseTensor& y,
const DenseTensor& out_grad, const DenseTensor& out_grad,
int axis,
DenseTensor* x_grad, DenseTensor* x_grad,
DenseTensor* y_grad) { DenseTensor* y_grad) {
funcs::ElementwiseGradPreProcess(out_grad, x_grad); funcs::ElementwiseGradPreProcess(out_grad, x_grad);
auto out = out_grad; // Fake out, not used auto out = out_grad; // Fake out, not used
auto x_dim = x.dims(); auto x_dim = x.dims();
auto y_dim = y.dims(); auto y_dim = y.dims();
int axis = -1;
if (x.dims() == y.dims()) { if (x.dims() == y.dims()) {
funcs::ElemwiseGradComputeNoBroadcast<Context, funcs::ElemwiseGradComputeNoBroadcast<Context,
T, T,
......
...@@ -78,14 +78,13 @@ void FMaxRawKernel(const Context& dev_ctx, ...@@ -78,14 +78,13 @@ void FMaxRawKernel(const Context& dev_ctx,
} }
template <typename T, typename Context> template <typename T, typename Context>
void FMinRawKernel(const Context& dev_ctx, void FMinKernel(const Context& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
const DenseTensor& y, const DenseTensor& y,
int axis, DenseTensor* out) {
DenseTensor* out) {
dev_ctx.template Alloc<T>(out); dev_ctx.template Alloc<T>(out);
funcs::ElementwiseCompute<funcs::FMinFunctor<T>, T, T>( funcs::ElementwiseCompute<funcs::FMinFunctor<T>, T, T>(
dev_ctx, x, y, axis, funcs::FMinFunctor<T>(), out); dev_ctx, x, y, -1, funcs::FMinFunctor<T>(), out);
} }
} // namespace phi } // namespace phi
...@@ -103,10 +103,10 @@ PD_REGISTER_KERNEL(fmax_raw, ...@@ -103,10 +103,10 @@ PD_REGISTER_KERNEL(fmax_raw,
float16, float16,
int64_t) {} int64_t) {}
PD_REGISTER_KERNEL(fmin_raw, PD_REGISTER_KERNEL(fmin,
KPS, KPS,
ALL_LAYOUT, ALL_LAYOUT,
phi::FMinRawKernel, phi::FMinKernel,
float, float,
double, double,
int, int,
......
...@@ -162,7 +162,7 @@ KernelSignature ElementwiseDivGradOpArgumentMapping( ...@@ -162,7 +162,7 @@ KernelSignature ElementwiseDivGradOpArgumentMapping(
KernelSignature ElementwiseFMinGradOpArgumentMapping( KernelSignature ElementwiseFMinGradOpArgumentMapping(
const ArgumentMappingContext& ctx) { const ArgumentMappingContext& ctx) {
return KernelSignature( return KernelSignature(
"fmin_grad", {"X", "Y", "Out@GRAD"}, {"axis"}, {"X@GRAD", "Y@GRAD"}); "fmin_grad", {"X", "Y", "Out@GRAD"}, {}, {"X@GRAD", "Y@GRAD"});
} }
KernelSignature ElementwiseDivDoubleGradOpArgumentMapping( KernelSignature ElementwiseDivDoubleGradOpArgumentMapping(
...@@ -186,7 +186,7 @@ KernelSignature ElementwiseFMaxOpArgumentMapping( ...@@ -186,7 +186,7 @@ KernelSignature ElementwiseFMaxOpArgumentMapping(
KernelSignature ElementwiseFMinOpArgumentMapping( KernelSignature ElementwiseFMinOpArgumentMapping(
const ArgumentMappingContext& ctx) { const ArgumentMappingContext& ctx) {
return KernelSignature("fmin_raw", {"X", "Y"}, {"axis"}, {"Out"}); return KernelSignature("fmin", {"X", "Y"}, {}, {"Out"});
} }
KernelSignature ElementwiseFMaxGradOpArgumentMapping( KernelSignature ElementwiseFMaxGradOpArgumentMapping(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册