diff --git a/paddle/fluid/operators/elementwise/elementwise_functor.h b/paddle/fluid/operators/elementwise/elementwise_functor.h index abac43a2616f09b34b56aadbac7c6614f7a59fad..6e53af41b657c8bebacc78d1fbbf0fe7137f0997 100644 --- a/paddle/fluid/operators/elementwise/elementwise_functor.h +++ b/paddle/fluid/operators/elementwise/elementwise_functor.h @@ -113,5 +113,45 @@ struct MinFunctor { } }; +// Fmax +template +struct FMaxFunctor { + inline HOSTDEVICE T operator()(const T& a, const T& b) const { + return std::fmax(a, b); + } +}; + +template <> +struct FMaxFunctor { + inline HOSTDEVICE paddle::platform::float16 operator()( + const paddle::platform::float16& a, + const paddle::platform::float16& b) const { + float float_a = static_cast(a); + float float_b = static_cast(b); + auto result = std::fmax(float_a, float_b); + return static_cast(result); + } +}; + +// Fmin +template +struct FMinFunctor { + inline HOSTDEVICE T operator()(const T& a, const T& b) const { + return std::fmin(a, b); + } +}; + +template <> +struct FMinFunctor { + inline HOSTDEVICE paddle::platform::float16 operator()( + const paddle::platform::float16& a, + const paddle::platform::float16& b) const { + float float_a = static_cast(a); + float float_b = static_cast(b); + auto result = std::fmin(float_a, float_b); + return static_cast(result); + } +}; + } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/elementwise/elementwise_max_op.cc b/paddle/fluid/operators/elementwise/elementwise_max_op.cc index dde65c8199626bc7f9d72b44a0e46a7fa628ebef..e0686e815459a391cb14455bc8924d3d0df474a0 100644 --- a/paddle/fluid/operators/elementwise/elementwise_max_op.cc +++ b/paddle/fluid/operators/elementwise/elementwise_max_op.cc @@ -53,6 +53,27 @@ class ElementwiseMaxOpMaker : public ElementwiseOpMaker { } }; +class ElementwiseFMaxOpMaker : public ElementwiseOpMaker { + protected: + std::string GetName() const override { return "FMax"; } + std::string GetEquation() const override { return "Out = fmax(X, Y)"; } + + void AddInputX() override { + AddInput("X", "The first tensor holding the elements to be compared."); + } + + void AddInputY() override { + AddInput("Y", "The second tensor holding the elements to be compared."); + } + + std::string GetOpFuntionality() const override { + return "Compare two tensors and returns a new tensor containing the " + "element-wise maxima. If the element of one tensor is nan, " + "return the element value of the other tensor, if both are nan, " + "return the first nan"; + } +}; + template class ElementwiseMaxGradOpMaker : public framework::SingleGradOpMaker { public: @@ -70,6 +91,23 @@ class ElementwiseMaxGradOpMaker : public framework::SingleGradOpMaker { } }; +template +class ElementwiseFMaxGradOpMaker : public framework::SingleGradOpMaker { + public: + using framework::SingleGradOpMaker::SingleGradOpMaker; + + protected: + void Apply(GradOpPtr op) const override { + op->SetType("elementwise_fmax_grad"); + op->SetInput("X", this->Input("X")); + op->SetInput("Y", this->Input("Y")); + op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); + op->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); + op->SetOutput(framework::GradVarName("Y"), this->InputGrad("Y")); + op->SetAttrMap(this->Attrs()); + } +}; + } // namespace operators } // namespace paddle @@ -103,3 +141,28 @@ REGISTER_OP_VERSION(elementwise_max) "In order to support the function of scaling the input Y when " "using the operator of elementwise_max.", 1.0f)); + +REGISTER_OPERATOR(elementwise_fmax, ops::ElementwiseOp, + ops::ElementwiseFMaxOpMaker, ops::ElementwiseOpInferVarType, + ops::ElementwiseFMaxGradOpMaker, + ops::ElementwiseFMaxGradOpMaker); + +REGISTER_OPERATOR(elementwise_fmax_grad, ops::ElementwiseOpGrad); + +REGISTER_OP_CPU_KERNEL( + elementwise_fmax, + ops::ElementwiseFMaxKernel, + ops::ElementwiseFMaxKernel, + ops::ElementwiseFMaxKernel, + ops::ElementwiseFMaxKernel, + ops::ElementwiseFMaxKernel); +REGISTER_OP_CPU_KERNEL( + elementwise_fmax_grad, + ops::ElementwiseFMaxGradKernel, + ops::ElementwiseFMaxGradKernel, + ops::ElementwiseFMaxGradKernel, + ops::ElementwiseFMaxGradKernel, + ops::ElementwiseFMaxGradKernel); diff --git a/paddle/fluid/operators/elementwise/elementwise_max_op.cu b/paddle/fluid/operators/elementwise/elementwise_max_op.cu index 65505381db17410d72a3c702e18df480513fcf2a..eb6f78bf270ad6d9c5fd379658a83d71a8de6322 100644 --- a/paddle/fluid/operators/elementwise/elementwise_max_op.cu +++ b/paddle/fluid/operators/elementwise/elementwise_max_op.cu @@ -56,3 +56,21 @@ REGISTER_OP_CUDA_KERNEL( ops::ElementwiseMaxGradKernel, ops::ElementwiseMaxGradKernel); + +REGISTER_OP_CUDA_KERNEL( + elementwise_fmax, + ops::ElementwiseFMaxKernel, + ops::ElementwiseFMaxKernel, + ops::ElementwiseFMaxKernel, + ops::ElementwiseFMaxKernel, + ops::ElementwiseFMaxKernel); +REGISTER_OP_CUDA_KERNEL( + elementwise_fmax_grad, + ops::ElementwiseFMaxGradKernel, + ops::ElementwiseFMaxGradKernel, + ops::ElementwiseFMaxGradKernel, + ops::ElementwiseFMaxGradKernel, + ops::ElementwiseFMaxGradKernel); diff --git a/paddle/fluid/operators/elementwise/elementwise_max_op.h b/paddle/fluid/operators/elementwise/elementwise_max_op.h index 06269b12e8e209bc4120f99d8f5bd5a136cb2cf0..acb212e992a1d7a25cbf117f66055f07522ec21d 100644 --- a/paddle/fluid/operators/elementwise/elementwise_max_op.h +++ b/paddle/fluid/operators/elementwise/elementwise_max_op.h @@ -14,9 +14,11 @@ limitations under the License. */ #pragma once +#include #include "paddle/fluid/operators/elementwise/elementwise_functor.h" #include "paddle/fluid/operators/elementwise/elementwise_op.h" #include "paddle/fluid/operators/elementwise/elementwise_op_function.h" +#include "paddle/fluid/platform/eigen_ext.h" namespace paddle { namespace operators { @@ -36,6 +38,21 @@ class ElementwiseMaxKernel : public framework::OpKernel { } }; +template +class ElementwiseFMaxKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* x = ctx.Input("X"); + auto* y = ctx.Input("Y"); + auto* z = ctx.Output("Out"); + + z->mutable_data(ctx.GetPlace()); + int axis = ctx.Attr("axis"); + ElementwiseComputeEx, DeviceContext, T>(ctx, x, y, axis, + FMaxFunctor(), z); + } +}; + template struct MaxGradDx { HOSTDEVICE T operator()(T x, T y, T out, T dout) const { @@ -68,5 +85,89 @@ class ElementwiseMaxGradKernel : public ElemwiseGradKernel { ctx, *x, *y, *out, *dout, axis, dx, dy, MaxGradDx(), MaxGradDy()); } }; + +template +struct FMaxGradDx { + HOSTDEVICE T operator()(T x, T y, T out, T dout) const { + return dout * static_cast((x >= y) || isnan(y)); + } +}; + +template <> +struct FMaxGradDx { + HOSTDEVICE paddle::platform::float16 operator()( + paddle::platform::float16 x, paddle::platform::float16 y, + paddle::platform::float16 out, paddle::platform::float16 dout) const { + return dout * static_cast( + (x >= y) || paddle::platform::isnan(y)); + } +}; + +template <> +struct FMaxGradDx { + HOSTDEVICE int operator()(int x, int y, int out, int dout) const { + return dout * static_cast((x >= y)); + } +}; + +template <> +struct FMaxGradDx { + HOSTDEVICE int64_t operator()(int64_t x, int64_t y, int64_t out, + int64_t dout) const { + return dout * static_cast((x >= y)); + } +}; + +template +struct FMaxGradDy { + HOSTDEVICE T operator()(T x, T y, T out, T dout) const { + return dout * static_cast(!((x >= y) || isnan(y))); + } +}; + +template <> +struct FMaxGradDy { + HOSTDEVICE paddle::platform::float16 operator()( + paddle::platform::float16 x, paddle::platform::float16 y, + paddle::platform::float16 out, paddle::platform::float16 dout) const { + return dout * static_cast( + !((x >= y) || paddle::platform::isnan(y))); + } +}; + +template <> +struct FMaxGradDy { + HOSTDEVICE int64_t operator()(int64_t x, int64_t y, int64_t out, + int64_t dout) const { + return dout * static_cast(!((x >= y))); + } +}; + +template <> +struct FMaxGradDy { + HOSTDEVICE int operator()(int x, int y, int out, int dout) const { + return dout * static_cast(!((x >= y))); + } +}; + +template +class ElementwiseFMaxGradKernel : public ElemwiseGradKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + ElemwiseGradKernel::Compute(ctx); + using Tensor = framework::Tensor; + + auto* x = ctx.Input("X"); + auto* y = ctx.Input("Y"); + auto* dout = ctx.Input(framework::GradVarName("Out")); + auto* dx = ctx.Output(framework::GradVarName("X")); + auto* dy = ctx.Output(framework::GradVarName("Y")); + auto* out = dout; // Fake out, not used + int axis = ctx.Attr("axis"); + ElemwiseGradCompute, FMaxGradDy>( + ctx, *x, *y, *out, *dout, axis, dx, dy, FMaxGradDx(), + FMaxGradDy()); + } +}; } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/elementwise/elementwise_min_op.cc b/paddle/fluid/operators/elementwise/elementwise_min_op.cc index 174684e3c8476f7d887688cfa38750e03b08a1c4..1448520eca18f4b1c7bf0f5a90982b404c351270 100644 --- a/paddle/fluid/operators/elementwise/elementwise_min_op.cc +++ b/paddle/fluid/operators/elementwise/elementwise_min_op.cc @@ -53,6 +53,27 @@ class ElementwiseMinOpMaker : public ElementwiseOpMaker { } }; +class ElementwiseFMinOpMaker : public ElementwiseOpMaker { + protected: + std::string GetName() const override { return "FMin"; } + std::string GetEquation() const override { return "Out = fmin(X, Y)"; } + + void AddInputX() override { + AddInput("X", "The first tensor holding the elements to be compared."); + } + + void AddInputY() override { + AddInput("Y", "The second tensor holding the elements to be compared."); + } + + std::string GetOpFuntionality() const override { + return "Compare two tensors and returns a new tensor containing the " + "element-wise minima. If the element of one tensor is nan, " + "return the element value of the other tensor, if both are nan, " + "return the first nan"; + } +}; + template class ElementwiseMinGradOpMaker : public framework::SingleGradOpMaker { public: @@ -70,6 +91,23 @@ class ElementwiseMinGradOpMaker : public framework::SingleGradOpMaker { } }; +template +class ElementwiseFMinGradOpMaker : public framework::SingleGradOpMaker { + public: + using framework::SingleGradOpMaker::SingleGradOpMaker; + + protected: + void Apply(GradOpPtr op) const override { + op->SetType("elementwise_fmin_grad"); + op->SetInput("X", this->Input("X")); + op->SetInput("Y", this->Input("Y")); + op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); + op->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); + op->SetOutput(framework::GradVarName("Y"), this->InputGrad("Y")); + op->SetAttrMap(this->Attrs()); + } +}; + } // namespace operators } // namespace paddle @@ -103,3 +141,28 @@ REGISTER_OP_VERSION(elementwise_min) "In order to support the function of scaling the input Y when " "using the operator of elementwise_min.", 1.0f)); + +REGISTER_OPERATOR(elementwise_fmin, ops::ElementwiseOp, + ops::ElementwiseFMinOpMaker, ops::ElementwiseOpInferVarType, + ops::ElementwiseFMinGradOpMaker, + ops::ElementwiseFMinGradOpMaker); + +REGISTER_OPERATOR(elementwise_fmin_grad, ops::ElementwiseOpGrad); + +REGISTER_OP_CPU_KERNEL( + elementwise_fmin, + ops::ElementwiseFMinKernel, + ops::ElementwiseFMinKernel, + ops::ElementwiseFMinKernel, + ops::ElementwiseFMinKernel, + ops::ElementwiseFMinKernel); +REGISTER_OP_CPU_KERNEL( + elementwise_fmin_grad, + ops::ElementwiseFMinGradKernel, + ops::ElementwiseFMinGradKernel, + ops::ElementwiseFMinGradKernel, + ops::ElementwiseFMinGradKernel, + ops::ElementwiseFMinGradKernel); diff --git a/paddle/fluid/operators/elementwise/elementwise_min_op.cu b/paddle/fluid/operators/elementwise/elementwise_min_op.cu index eed6f72b04fb9bc58e8a4a59a50102f276d6816a..a51398640579b966859b328ab7680d62f8f55f57 100644 --- a/paddle/fluid/operators/elementwise/elementwise_min_op.cu +++ b/paddle/fluid/operators/elementwise/elementwise_min_op.cu @@ -52,3 +52,21 @@ REGISTER_OP_CUDA_KERNEL( ops::ElementwiseMinGradKernel, ops::ElementwiseMinGradKernel); + +REGISTER_OP_CUDA_KERNEL( + elementwise_fmin, + ops::ElementwiseFMinKernel, + ops::ElementwiseFMinKernel, + ops::ElementwiseFMinKernel, + ops::ElementwiseFMinKernel, + ops::ElementwiseFMinKernel); +REGISTER_OP_CUDA_KERNEL( + elementwise_fmin_grad, + ops::ElementwiseFMinGradKernel, + ops::ElementwiseFMinGradKernel, + ops::ElementwiseFMinGradKernel, + ops::ElementwiseFMinGradKernel, + ops::ElementwiseFMinGradKernel); diff --git a/paddle/fluid/operators/elementwise/elementwise_min_op.h b/paddle/fluid/operators/elementwise/elementwise_min_op.h index 648691063c59b237852a364fa002246de7ba6137..2f96ef747708bf5852e2de40fc593f8e913620b7 100644 --- a/paddle/fluid/operators/elementwise/elementwise_min_op.h +++ b/paddle/fluid/operators/elementwise/elementwise_min_op.h @@ -14,9 +14,11 @@ limitations under the License. */ #pragma once +#include #include "paddle/fluid/operators/elementwise/elementwise_functor.h" #include "paddle/fluid/operators/elementwise/elementwise_op.h" #include "paddle/fluid/operators/elementwise/elementwise_op_function.h" +#include "paddle/fluid/platform/eigen_ext.h" namespace paddle { namespace operators { @@ -36,6 +38,21 @@ class ElementwiseMinKernel : public framework::OpKernel { } }; +template +class ElementwiseFMinKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* x = ctx.Input("X"); + auto* y = ctx.Input("Y"); + auto* z = ctx.Output("Out"); + + z->mutable_data(ctx.GetPlace()); + int axis = ctx.Attr("axis"); + ElementwiseComputeEx, DeviceContext, T>(ctx, x, y, axis, + FMinFunctor(), z); + } +}; + template struct MinGradDx { HOSTDEVICE T operator()(T x, T y, T out, T dout) const { @@ -68,5 +85,89 @@ class ElementwiseMinGradKernel : public ElemwiseGradKernel { ctx, *x, *y, *out, *dout, axis, dx, dy, MinGradDx(), MinGradDy()); } }; + +template +struct FMinGradDx { + HOSTDEVICE T operator()(T x, T y, T out, T dout) const { + return dout * static_cast((x <= y) || isnan(y)); + } +}; + +template <> +struct FMinGradDx { + HOSTDEVICE paddle::platform::float16 operator()( + paddle::platform::float16 x, paddle::platform::float16 y, + paddle::platform::float16 out, paddle::platform::float16 dout) const { + return dout * static_cast( + (x <= y) || paddle::platform::isnan(y)); + } +}; + +template <> +struct FMinGradDx { + HOSTDEVICE int operator()(int x, int y, int out, int dout) const { + return dout * static_cast((x <= y)); + } +}; + +template <> +struct FMinGradDx { + HOSTDEVICE int64_t operator()(int64_t x, int64_t y, int64_t out, + int64_t dout) const { + return dout * static_cast((x <= y)); + } +}; + +template +struct FMinGradDy { + HOSTDEVICE T operator()(T x, T y, T out, T dout) const { + return dout * static_cast(!((x <= y) || isnan(y))); + } +}; + +template <> +struct FMinGradDy { + HOSTDEVICE paddle::platform::float16 operator()( + paddle::platform::float16 x, paddle::platform::float16 y, + paddle::platform::float16 out, paddle::platform::float16 dout) const { + return dout * static_cast( + !((x <= y) || paddle::platform::isnan(y))); + } +}; + +template <> +struct FMinGradDy { + HOSTDEVICE int operator()(int x, int y, int out, int dout) const { + return dout * static_cast(!((x <= y))); + } +}; + +template <> +struct FMinGradDy { + HOSTDEVICE int64_t operator()(int64_t x, int64_t y, int64_t out, + int64_t dout) const { + return dout * static_cast(!((x <= y))); + } +}; + +template +class ElementwiseFMinGradKernel : public ElemwiseGradKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + ElemwiseGradKernel::Compute(ctx); + using Tensor = framework::Tensor; + + auto* x = ctx.Input("X"); + auto* y = ctx.Input("Y"); + auto* dout = ctx.Input(framework::GradVarName("Out")); + auto* dx = ctx.Output(framework::GradVarName("X")); + auto* dy = ctx.Output(framework::GradVarName("Y")); + auto* out = dout; // Fake out, not used + int axis = ctx.Attr("axis"); + ElemwiseGradCompute, FMinGradDy>( + ctx, *x, *y, *out, *dout, axis, dx, dy, FMinGradDx(), + FMinGradDy()); + } +}; } // namespace operators } // namespace paddle diff --git a/python/paddle/__init__.py b/python/paddle/__init__.py index 4b5aee696ee3ec176f262af5a739824cde155e41..221f1cc2902a8a7a488d573c2ff845cbd6ac5691 100755 --- a/python/paddle/__init__.py +++ b/python/paddle/__init__.py @@ -235,6 +235,8 @@ from .tensor.math import gcd # noqa: F401 from .tensor.math import lcm # noqa: F401 from .tensor.math import diff # noqa: F401 from .tensor.math import angle # noqa: F401 +from .tensor.math import fmax # noqa: F401 +from .tensor.math import fmin # noqa: F401 from .tensor.random import multinomial # noqa: F401 from .tensor.random import standard_normal # noqa: F401 @@ -568,5 +570,7 @@ __all__ = [ # noqa 'as_real', 'diff', 'angle', + 'fmax', + 'fmin', 'moveaxis', ] diff --git a/python/paddle/fluid/tests/unittests/test_fmax_op.py b/python/paddle/fluid/tests/unittests/test_fmax_op.py new file mode 100644 index 0000000000000000000000000000000000000000..3981d63c00582ebd900279478fd935dd2602232d --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_fmax_op.py @@ -0,0 +1,189 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +import paddle +import paddle.fluid.core as core +from op_test import OpTest + + +class ApiFMaxTest(unittest.TestCase): + """ApiFMaxTest""" + + def setUp(self): + """setUp""" + if core.is_compiled_with_cuda(): + self.place = core.CUDAPlace(0) + else: + self.place = core.CPUPlace() + + self.input_x = np.random.rand(10, 15).astype("float32") + self.input_y = np.random.rand(10, 15).astype("float32") + self.input_z = np.random.rand(15).astype("float32") + self.input_a = np.array([0, np.nan, np.nan]).astype('int64') + self.input_b = np.array([2, np.inf, -np.inf]).astype('int64') + self.input_c = np.array([4, 1, 3]).astype('int64') + + self.np_expected1 = np.fmax(self.input_x, self.input_y) + self.np_expected2 = np.fmax(self.input_x, self.input_z) + self.np_expected3 = np.fmax(self.input_a, self.input_c) + self.np_expected4 = np.fmax(self.input_b, self.input_c) + + def test_static_api(self): + """test_static_api""" + paddle.enable_static() + with paddle.static.program_guard(paddle.static.Program(), + paddle.static.Program()): + data_x = paddle.static.data("x", shape=[10, 15], dtype="float32") + data_y = paddle.static.data("y", shape=[10, 15], dtype="float32") + result_fmax = paddle.fmax(data_x, data_y) + exe = paddle.static.Executor(self.place) + res, = exe.run(feed={"x": self.input_x, + "y": self.input_y}, + fetch_list=[result_fmax]) + self.assertTrue(np.allclose(res, self.np_expected1)) + + with paddle.static.program_guard(paddle.static.Program(), + paddle.static.Program()): + data_x = paddle.static.data("x", shape=[10, 15], dtype="float32") + data_z = paddle.static.data("z", shape=[15], dtype="float32") + result_fmax = paddle.fmax(data_x, data_z) + exe = paddle.static.Executor(self.place) + res, = exe.run(feed={"x": self.input_x, + "z": self.input_z}, + fetch_list=[result_fmax]) + self.assertTrue(np.allclose(res, self.np_expected2)) + + with paddle.static.program_guard(paddle.static.Program(), + paddle.static.Program()): + data_a = paddle.static.data("a", shape=[3], dtype="int64") + data_c = paddle.static.data("c", shape=[3], dtype="int64") + result_fmax = paddle.fmax(data_a, data_c) + exe = paddle.static.Executor(self.place) + res, = exe.run(feed={"a": self.input_a, + "c": self.input_c}, + fetch_list=[result_fmax]) + self.assertTrue(np.allclose(res, self.np_expected3)) + + with paddle.static.program_guard(paddle.static.Program(), + paddle.static.Program()): + data_b = paddle.static.data("b", shape=[3], dtype="int64") + data_c = paddle.static.data("c", shape=[3], dtype="int64") + result_fmax = paddle.fmax(data_b, data_c) + exe = paddle.static.Executor(self.place) + res, = exe.run(feed={"b": self.input_b, + "c": self.input_c}, + fetch_list=[result_fmax]) + self.assertTrue(np.allclose(res, self.np_expected4)) + + def test_dynamic_api(self): + """test_dynamic_api""" + paddle.disable_static() + x = paddle.to_tensor(self.input_x) + y = paddle.to_tensor(self.input_y) + z = paddle.to_tensor(self.input_z) + + a = paddle.to_tensor(self.input_a) + b = paddle.to_tensor(self.input_b) + c = paddle.to_tensor(self.input_c) + + res = paddle.fmax(x, y) + res = res.numpy() + self.assertTrue(np.allclose(res, self.np_expected1)) + + # test broadcast + res = paddle.fmax(x, z) + res = res.numpy() + self.assertTrue(np.allclose(res, self.np_expected2)) + + res = paddle.fmax(a, c) + res = res.numpy() + self.assertTrue(np.allclose(res, self.np_expected3)) + + res = paddle.fmax(b, c) + res = res.numpy() + self.assertTrue(np.allclose(res, self.np_expected4)) + + +class TestElementwiseFmaxOp(OpTest): + """TestElementwiseFmaxOp""" + + def setUp(self): + """setUp""" + self.op_type = "elementwise_fmax" + # If x and y have the same value, the max() is not differentiable. + # So we generate test data by the following method + # to avoid them being too close to each other. + x = np.random.uniform(0.1, 1, [13, 17]).astype("float64") + sgn = np.random.choice([-1, 1], [13, 17]).astype("float64") + y = x + sgn * np.random.uniform(0.1, 1, [13, 17]).astype("float64") + self.inputs = {'X': x, 'Y': y} + self.outputs = {'Out': np.fmax(self.inputs['X'], self.inputs['Y'])} + + def test_check_output(self): + """test_check_output""" + self.check_output() + + def test_check_grad_normal(self): + """test_check_grad_normal""" + self.check_grad(['X', 'Y'], 'Out') + + def test_check_grad_ingore_x(self): + """test_check_grad_ingore_x""" + self.check_grad( + ['Y'], 'Out', max_relative_error=0.005, no_grad_set=set("X")) + + def test_check_grad_ingore_y(self): + """test_check_grad_ingore_y""" + self.check_grad( + ['X'], 'Out', max_relative_error=0.005, no_grad_set=set('Y')) + + +class TestElementwiseFmax2Op(OpTest): + """TestElementwiseFmax2Op""" + + def setUp(self): + """setUp""" + self.op_type = "elementwise_fmax" + # If x and y have the same value, the max() is not differentiable. + # So we generate test data by the following method + # to avoid them being too close to each other. + x = np.random.uniform(0.1, 1, [13, 17]).astype("float64") + sgn = np.random.choice([-1, 1], [13, 17]).astype("float64") + y = x + sgn * np.random.uniform(0.1, 1, [13, 17]).astype("float64") + + y[2, 10:] = np.nan + self.inputs = {'X': x, 'Y': y} + self.outputs = {'Out': np.fmax(self.inputs['X'], self.inputs['Y'])} + + def test_check_output(self): + """test_check_output""" + self.check_output() + + def test_check_grad_normal(self): + """test_check_grad_normal""" + self.check_grad(['X', 'Y'], 'Out') + + def test_check_grad_ingore_x(self): + """test_check_grad_ingore_x""" + self.check_grad( + ['Y'], 'Out', max_relative_error=0.005, no_grad_set=set("X")) + + def test_check_grad_ingore_y(self): + """test_check_grad_ingore_y""" + self.check_grad( + ['X'], 'Out', max_relative_error=0.005, no_grad_set=set('Y')) diff --git a/python/paddle/fluid/tests/unittests/test_fmin_op.py b/python/paddle/fluid/tests/unittests/test_fmin_op.py new file mode 100644 index 0000000000000000000000000000000000000000..5cdf096be6708c47dd1f56dc97243be70c6d63d5 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_fmin_op.py @@ -0,0 +1,191 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +import paddle +import paddle.fluid.core as core +from op_test import OpTest + +paddle.enable_static() + + +class ApiFMinTest(unittest.TestCase): + """ApiFMinTest""" + + def setUp(self): + """setUp""" + if core.is_compiled_with_cuda(): + self.place = core.CUDAPlace(0) + else: + self.place = core.CPUPlace() + + self.input_x = np.random.rand(10, 15).astype("float32") + self.input_y = np.random.rand(10, 15).astype("float32") + self.input_z = np.random.rand(15).astype("float32") + self.input_a = np.array([0, np.nan, np.nan]).astype('int64') + self.input_b = np.array([2, np.inf, -np.inf]).astype('int64') + self.input_c = np.array([4, 1, 3]).astype('int64') + + self.np_expected1 = np.fmin(self.input_x, self.input_y) + self.np_expected2 = np.fmin(self.input_x, self.input_z) + self.np_expected3 = np.fmin(self.input_a, self.input_c) + self.np_expected4 = np.fmin(self.input_b, self.input_c) + + def test_static_api(self): + """test_static_api""" + paddle.enable_static() + with paddle.static.program_guard(paddle.static.Program(), + paddle.static.Program()): + data_x = paddle.static.data("x", shape=[10, 15], dtype="float32") + data_y = paddle.static.data("y", shape=[10, 15], dtype="float32") + result_fmin = paddle.fmin(data_x, data_y) + exe = paddle.static.Executor(self.place) + res, = exe.run(feed={"x": self.input_x, + "y": self.input_y}, + fetch_list=[result_fmin]) + self.assertTrue(np.allclose(res, self.np_expected1)) + + with paddle.static.program_guard(paddle.static.Program(), + paddle.static.Program()): + data_x = paddle.static.data("x", shape=[10, 15], dtype="float32") + data_z = paddle.static.data("z", shape=[15], dtype="float32") + result_fmin = paddle.fmin(data_x, data_z) + exe = paddle.static.Executor(self.place) + res, = exe.run(feed={"x": self.input_x, + "z": self.input_z}, + fetch_list=[result_fmin]) + self.assertTrue(np.allclose(res, self.np_expected2)) + + with paddle.static.program_guard(paddle.static.Program(), + paddle.static.Program()): + data_a = paddle.static.data("a", shape=[3], dtype="int64") + data_c = paddle.static.data("c", shape=[3], dtype="int64") + result_fmin = paddle.fmin(data_a, data_c) + exe = paddle.static.Executor(self.place) + res, = exe.run(feed={"a": self.input_a, + "c": self.input_c}, + fetch_list=[result_fmin]) + self.assertTrue(np.allclose(res, self.np_expected3)) + + with paddle.static.program_guard(paddle.static.Program(), + paddle.static.Program()): + data_b = paddle.static.data("b", shape=[3], dtype="int64") + data_c = paddle.static.data("c", shape=[3], dtype="int64") + result_fmin = paddle.fmin(data_b, data_c) + exe = paddle.static.Executor(self.place) + res, = exe.run(feed={"b": self.input_b, + "c": self.input_c}, + fetch_list=[result_fmin]) + self.assertTrue(np.allclose(res, self.np_expected4)) + + def test_dynamic_api(self): + """test_dynamic_api""" + paddle.disable_static() + x = paddle.to_tensor(self.input_x) + y = paddle.to_tensor(self.input_y) + z = paddle.to_tensor(self.input_z) + + a = paddle.to_tensor(self.input_a) + b = paddle.to_tensor(self.input_b) + c = paddle.to_tensor(self.input_c) + + res = paddle.fmin(x, y) + res = res.numpy() + self.assertTrue(np.allclose(res, self.np_expected1)) + + # test broadcast + res = paddle.fmin(x, z) + res = res.numpy() + self.assertTrue(np.allclose(res, self.np_expected2)) + + res = paddle.fmin(a, c) + res = res.numpy() + self.assertTrue(np.allclose(res, self.np_expected3)) + + res = paddle.fmin(b, c) + res = res.numpy() + self.assertTrue(np.allclose(res, self.np_expected4)) + + +class TestElementwiseFminOp(OpTest): + """TestElementwiseFminOp""" + + def setUp(self): + """setUp""" + self.op_type = "elementwise_fmin" + # If x and y have the same value, the min() is not differentiable. + # So we generate test data by the following method + # to avoid them being too close to each other. + x = np.random.uniform(0.1, 1, [13, 17]).astype("float64") + sgn = np.random.choice([-1, 1], [13, 17]).astype("float64") + y = x + sgn * np.random.uniform(0.1, 1, [13, 17]).astype("float64") + self.inputs = {'X': x, 'Y': y} + self.outputs = {'Out': np.fmin(self.inputs['X'], self.inputs['Y'])} + + def test_check_output(self): + """test_check_output""" + self.check_output() + + def test_check_grad_normal(self): + """test_check_grad_normal""" + self.check_grad(['X', 'Y'], 'Out') + + def test_check_grad_ingore_x(self): + """test_check_grad_ingore_x""" + self.check_grad( + ['Y'], 'Out', max_relative_error=0.005, no_grad_set=set("X")) + + def test_check_grad_ingore_y(self): + """test_check_grad_ingore_y""" + self.check_grad( + ['X'], 'Out', max_relative_error=0.005, no_grad_set=set('Y')) + + +class TestElementwiseFmin2Op(OpTest): + """TestElementwiseFmin2Op""" + + def setUp(self): + """setUp""" + self.op_type = "elementwise_fmin" + # If x and y have the same value, the min() is not differentiable. + # So we generate test data by the following method + # to avoid them being too close to each other. + x = np.random.uniform(0.1, 1, [13, 17]).astype("float64") + sgn = np.random.choice([-1, 1], [13, 17]).astype("float64") + y = x + sgn * np.random.uniform(0.1, 1, [13, 17]).astype("float64") + + y[2, 10:] = np.nan + self.inputs = {'X': x, 'Y': y} + self.outputs = {'Out': np.fmin(self.inputs['X'], self.inputs['Y'])} + + def test_check_output(self): + """test_check_output""" + self.check_output() + + def test_check_grad_normal(self): + """test_check_grad_normal""" + self.check_grad(['X', 'Y'], 'Out') + + def test_check_grad_ingore_x(self): + """test_check_grad_ingore_x""" + self.check_grad( + ['Y'], 'Out', max_relative_error=0.005, no_grad_set=set("X")) + + def test_check_grad_ingore_y(self): + """test_check_grad_ingore_y""" + self.check_grad( + ['X'], 'Out', max_relative_error=0.005, no_grad_set=set('Y')) diff --git a/python/paddle/fluid/tests/unittests/white_list/no_grad_set_white_list.py b/python/paddle/fluid/tests/unittests/white_list/no_grad_set_white_list.py index 29374a979650404341e39a341415aee64f657288..725ad4e93824f178bd0593e80c8c8f3475a73049 100644 --- a/python/paddle/fluid/tests/unittests/white_list/no_grad_set_white_list.py +++ b/python/paddle/fluid/tests/unittests/white_list/no_grad_set_white_list.py @@ -42,6 +42,8 @@ NEED_TO_FIX_OP_LIST = [ 'elementwise_mul', 'elementwise_sub', 'elementwise_pow', + 'elementwise_fmin', + 'elementwise_fmax', 'filter_by_instag', 'fused_elemwise_activation', 'fused_emb_seq_pool', diff --git a/python/paddle/tensor/__init__.py b/python/paddle/tensor/__init__.py index ea3e7f00a24dbacaee1f4d008ec98d320717a598..2d35f9cd6893ae2aa83879b1d862cd7992bb31c9 100755 --- a/python/paddle/tensor/__init__.py +++ b/python/paddle/tensor/__init__.py @@ -202,6 +202,8 @@ from .math import gcd # noqa: F401 from .math import lcm # noqa: F401 from .math import diff # noqa: F401 from .math import angle # noqa: F401 +from .math import fmax # noqa: F401 +from .math import fmin # noqa: F401 from .random import multinomial # noqa: F401 from .random import standard_normal # noqa: F401 @@ -305,6 +307,8 @@ tensor_method_func = [ #noqa 'maximum', 'min', 'minimum', + 'fmax', + 'fmin', 'mm', 'divide', 'floor_divide', diff --git a/python/paddle/tensor/math.py b/python/paddle/tensor/math.py index 9d61dee263a29b88eefd56e43080243dedef120f..931e7a6787fff1f3905d326ff09fefe1b24cd665 100755 --- a/python/paddle/tensor/math.py +++ b/python/paddle/tensor/math.py @@ -631,6 +631,128 @@ def minimum(x, y, name=None): x, y, axis=axis, act=act, op_name=op_type) return _elementwise_op(LayerHelper(op_type, **locals())) +def fmax(x, y, name=None): + """ + Compares the elements at the corresponding positions of the two tensors and returns a new tensor containing the maximum value of the element. + If one of them is a nan value, the other value is directly returned, if both are nan values, then the first nan value is returned. + The equation is: + + .. math:: + out = fmax(x, y) + + **Note**: + ``paddle.fmax`` supports broadcasting. If you want know more about broadcasting, please refer to :ref:`user_guide_broadcasting` . + + Args: + x (Tensor): the input tensor, it's data type should be float32, float64, int32, int64. + y (Tensor): the input tensor, it's data type should be float32, float64, int32, int64. + name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. + + Returns: + N-D Tensor. A location into which the result is stored. If x, y have different shapes and are "broadcastable", the resulting tensor shape is the shape of x and y after broadcasting. If x, y have the same shape, its shape is the same as x and y. + + Examples: + + .. code-block:: python + + import numpy as np + import paddle + + x = paddle.to_tensor([[1, 2], [7, 8]]) + y = paddle.to_tensor([[3, 4], [5, 6]]) + res = paddle.fmax(x, y) + print(res) + # [[3, 4], + # [7, 8]] + + x = paddle.to_tensor([[1, 2, 3], [1, 2, 3]]) + y = paddle.to_tensor([3, 0, 4]) + res = paddle.fmax(x, y) + print(res) + # [[3, 2, 4], + # [3, 2, 4]] + + x = paddle.to_tensor([2, 3, 5], dtype='float32') + y = paddle.to_tensor([1, np.nan, np.nan], dtype='float32') + res = paddle.fmax(x, y) + print(res) + # [ 2., 3., 5.] + + x = paddle.to_tensor([5, 3, np.inf], dtype='float32') + y = paddle.to_tensor([1, -np.inf, 5], dtype='float32') + res = paddle.fmax(x, y) + print(res) + # [ 5., 3., inf.] + """ + op_type = 'elementwise_fmax' + axis = -1 + act = None + if in_dygraph_mode(): + return _elementwise_op_in_dygraph( + x, y, axis=axis, act=act, op_name=op_type) + return _elementwise_op(LayerHelper(op_type, **locals())) + +def fmin(x, y, name=None): + """ + Compares the elements at the corresponding positions of the two tensors and returns a new tensor containing the minimum value of the element. + If one of them is a nan value, the other value is directly returned, if both are nan values, then the first nan value is returned. + The equation is: + + .. math:: + out = fmin(x, y) + + **Note**: + ``paddle.fmin`` supports broadcasting. If you want know more about broadcasting, please refer to :ref:`user_guide_broadcasting` . + + Args: + x (Tensor): the input tensor, it's data type should be float32, float64, int32, int64. + y (Tensor): the input tensor, it's data type should be float32, float64, int32, int64. + name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. + + Returns: + N-D Tensor. A location into which the result is stored. If x, y have different shapes and are "broadcastable", the resulting tensor shape is the shape of x and y after broadcasting. If x, y have the same shape, its shape is the same as x and y. + + Examples: + + .. code-block:: python + + import numpy as np + import paddle + + x = paddle.to_tensor([[1, 2], [7, 8]]) + y = paddle.to_tensor([[3, 4], [5, 6]]) + res = paddle.fmin(x, y) + print(res) + # [[1, 2], + # [5, 6]] + + x = paddle.to_tensor([[[1, 2, 3], [1, 2, 3]]]) + y = paddle.to_tensor([3, 0, 4]) + res = paddle.fmin(x, y) + print(res) + # [[[1, 0, 3], + # [1, 0, 3]]] + + x = paddle.to_tensor([2, 3, 5], dtype='float32') + y = paddle.to_tensor([1, np.nan, np.nan], dtype='float32') + res = paddle.fmin(x, y) + print(res) + # [ 1., 3., 5.] + + x = paddle.to_tensor([5, 3, np.inf], dtype='float64') + y = paddle.to_tensor([1, -np.inf, 5], dtype='float64') + res = paddle.fmin(x, y) + print(res) + # [ 1., -inf., 5.] + """ + op_type = 'elementwise_fmin' + axis = -1 + act = None + if in_dygraph_mode(): + return _elementwise_op_in_dygraph( + x, y, axis=axis, act=act, op_name=op_type) + return _elementwise_op(LayerHelper(op_type, **locals())) + for func in [ add, multiply