未验证 提交 dd3afc9d 编写于 作者: L LJQ❤️ 提交者: GitHub

Add fmax and fmin operators (#37826)

Add elementwise_fmax and elementwise_fmin operators
上级 fa463b90
......@@ -113,5 +113,45 @@ struct MinFunctor {
}
};
// Fmax
template <typename T>
struct FMaxFunctor {
inline HOSTDEVICE T operator()(const T& a, const T& b) const {
return std::fmax(a, b);
}
};
template <>
struct FMaxFunctor<paddle::platform::float16> {
inline HOSTDEVICE paddle::platform::float16 operator()(
const paddle::platform::float16& a,
const paddle::platform::float16& b) const {
float float_a = static_cast<float>(a);
float float_b = static_cast<float>(b);
auto result = std::fmax(float_a, float_b);
return static_cast<paddle::platform::float16>(result);
}
};
// Fmin
template <typename T>
struct FMinFunctor {
inline HOSTDEVICE T operator()(const T& a, const T& b) const {
return std::fmin(a, b);
}
};
template <>
struct FMinFunctor<paddle::platform::float16> {
inline HOSTDEVICE paddle::platform::float16 operator()(
const paddle::platform::float16& a,
const paddle::platform::float16& b) const {
float float_a = static_cast<float>(a);
float float_b = static_cast<float>(b);
auto result = std::fmin(float_a, float_b);
return static_cast<paddle::platform::float16>(result);
}
};
} // namespace operators
} // namespace paddle
......@@ -53,6 +53,27 @@ class ElementwiseMaxOpMaker : public ElementwiseOpMaker {
}
};
class ElementwiseFMaxOpMaker : public ElementwiseOpMaker {
protected:
std::string GetName() const override { return "FMax"; }
std::string GetEquation() const override { return "Out = fmax(X, Y)"; }
void AddInputX() override {
AddInput("X", "The first tensor holding the elements to be compared.");
}
void AddInputY() override {
AddInput("Y", "The second tensor holding the elements to be compared.");
}
std::string GetOpFuntionality() const override {
return "Compare two tensors and returns a new tensor containing the "
"element-wise maxima. If the element of one tensor is nan, "
"return the element value of the other tensor, if both are nan, "
"return the first nan";
}
};
template <typename T>
class ElementwiseMaxGradOpMaker : public framework::SingleGradOpMaker<T> {
public:
......@@ -70,6 +91,23 @@ class ElementwiseMaxGradOpMaker : public framework::SingleGradOpMaker<T> {
}
};
template <typename T>
class ElementwiseFMaxGradOpMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
void Apply(GradOpPtr<T> op) const override {
op->SetType("elementwise_fmax_grad");
op->SetInput("X", this->Input("X"));
op->SetInput("Y", this->Input("Y"));
op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
op->SetOutput(framework::GradVarName("Y"), this->InputGrad("Y"));
op->SetAttrMap(this->Attrs());
}
};
} // namespace operators
} // namespace paddle
......@@ -103,3 +141,28 @@ REGISTER_OP_VERSION(elementwise_max)
"In order to support the function of scaling the input Y when "
"using the operator of elementwise_max.",
1.0f));
REGISTER_OPERATOR(elementwise_fmax, ops::ElementwiseOp,
ops::ElementwiseFMaxOpMaker, ops::ElementwiseOpInferVarType,
ops::ElementwiseFMaxGradOpMaker<paddle::framework::OpDesc>,
ops::ElementwiseFMaxGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(elementwise_fmax_grad, ops::ElementwiseOpGrad);
REGISTER_OP_CPU_KERNEL(
elementwise_fmax,
ops::ElementwiseFMaxKernel<paddle::platform::CPUDeviceContext, float>,
ops::ElementwiseFMaxKernel<paddle::platform::CPUDeviceContext,
paddle::platform::float16>,
ops::ElementwiseFMaxKernel<paddle::platform::CPUDeviceContext, double>,
ops::ElementwiseFMaxKernel<paddle::platform::CPUDeviceContext, int>,
ops::ElementwiseFMaxKernel<paddle::platform::CPUDeviceContext, int64_t>);
REGISTER_OP_CPU_KERNEL(
elementwise_fmax_grad,
ops::ElementwiseFMaxGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::ElementwiseFMaxGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::float16>,
ops::ElementwiseFMaxGradKernel<paddle::platform::CPUDeviceContext, double>,
ops::ElementwiseFMaxGradKernel<paddle::platform::CPUDeviceContext, int>,
ops::ElementwiseFMaxGradKernel<paddle::platform::CPUDeviceContext,
int64_t>);
......@@ -56,3 +56,21 @@ REGISTER_OP_CUDA_KERNEL(
ops::ElementwiseMaxGradKernel<paddle::platform::CUDADeviceContext, int>,
ops::ElementwiseMaxGradKernel<paddle::platform::CUDADeviceContext,
int64_t>);
REGISTER_OP_CUDA_KERNEL(
elementwise_fmax,
ops::ElementwiseFMaxKernel<paddle::platform::CUDADeviceContext, float>,
ops::ElementwiseFMaxKernel<paddle::platform::CUDADeviceContext,
paddle::platform::float16>,
ops::ElementwiseFMaxKernel<paddle::platform::CUDADeviceContext, double>,
ops::ElementwiseFMaxKernel<paddle::platform::CUDADeviceContext, int>,
ops::ElementwiseFMaxKernel<paddle::platform::CUDADeviceContext, int64_t>);
REGISTER_OP_CUDA_KERNEL(
elementwise_fmax_grad,
ops::ElementwiseFMaxGradKernel<paddle::platform::CUDADeviceContext, float>,
ops::ElementwiseFMaxGradKernel<paddle::platform::CUDADeviceContext,
paddle::platform::float16>,
ops::ElementwiseFMaxGradKernel<paddle::platform::CUDADeviceContext, double>,
ops::ElementwiseFMaxGradKernel<paddle::platform::CUDADeviceContext, int>,
ops::ElementwiseFMaxGradKernel<paddle::platform::CUDADeviceContext,
int64_t>);
......@@ -14,9 +14,11 @@ limitations under the License. */
#pragma once
#include <cmath>
#include "paddle/fluid/operators/elementwise/elementwise_functor.h"
#include "paddle/fluid/operators/elementwise/elementwise_op.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_function.h"
#include "paddle/fluid/platform/eigen_ext.h"
namespace paddle {
namespace operators {
......@@ -36,6 +38,21 @@ class ElementwiseMaxKernel : public framework::OpKernel<T> {
}
};
template <typename DeviceContext, typename T>
class ElementwiseFMaxKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* x = ctx.Input<framework::LoDTensor>("X");
auto* y = ctx.Input<framework::LoDTensor>("Y");
auto* z = ctx.Output<framework::LoDTensor>("Out");
z->mutable_data<T>(ctx.GetPlace());
int axis = ctx.Attr<int>("axis");
ElementwiseComputeEx<FMaxFunctor<T>, DeviceContext, T>(ctx, x, y, axis,
FMaxFunctor<T>(), z);
}
};
template <typename T>
struct MaxGradDx {
HOSTDEVICE T operator()(T x, T y, T out, T dout) const {
......@@ -68,5 +85,89 @@ class ElementwiseMaxGradKernel : public ElemwiseGradKernel<T> {
ctx, *x, *y, *out, *dout, axis, dx, dy, MaxGradDx<T>(), MaxGradDy<T>());
}
};
template <typename T>
struct FMaxGradDx {
HOSTDEVICE T operator()(T x, T y, T out, T dout) const {
return dout * static_cast<T>((x >= y) || isnan(y));
}
};
template <>
struct FMaxGradDx<paddle::platform::float16> {
HOSTDEVICE paddle::platform::float16 operator()(
paddle::platform::float16 x, paddle::platform::float16 y,
paddle::platform::float16 out, paddle::platform::float16 dout) const {
return dout * static_cast<paddle::platform::float16>(
(x >= y) || paddle::platform::isnan(y));
}
};
template <>
struct FMaxGradDx<int> {
HOSTDEVICE int operator()(int x, int y, int out, int dout) const {
return dout * static_cast<int>((x >= y));
}
};
template <>
struct FMaxGradDx<int64_t> {
HOSTDEVICE int64_t operator()(int64_t x, int64_t y, int64_t out,
int64_t dout) const {
return dout * static_cast<int64_t>((x >= y));
}
};
template <typename T>
struct FMaxGradDy {
HOSTDEVICE T operator()(T x, T y, T out, T dout) const {
return dout * static_cast<T>(!((x >= y) || isnan(y)));
}
};
template <>
struct FMaxGradDy<paddle::platform::float16> {
HOSTDEVICE paddle::platform::float16 operator()(
paddle::platform::float16 x, paddle::platform::float16 y,
paddle::platform::float16 out, paddle::platform::float16 dout) const {
return dout * static_cast<paddle::platform::float16>(
!((x >= y) || paddle::platform::isnan(y)));
}
};
template <>
struct FMaxGradDy<int64_t> {
HOSTDEVICE int64_t operator()(int64_t x, int64_t y, int64_t out,
int64_t dout) const {
return dout * static_cast<int64_t>(!((x >= y)));
}
};
template <>
struct FMaxGradDy<int> {
HOSTDEVICE int operator()(int x, int y, int out, int dout) const {
return dout * static_cast<int>(!((x >= y)));
}
};
template <typename DeviceContext, typename T>
class ElementwiseFMaxGradKernel : public ElemwiseGradKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
ElemwiseGradKernel<T>::Compute(ctx);
using Tensor = framework::Tensor;
auto* x = ctx.Input<Tensor>("X");
auto* y = ctx.Input<Tensor>("Y");
auto* dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
auto* dy = ctx.Output<Tensor>(framework::GradVarName("Y"));
auto* out = dout; // Fake out, not used
int axis = ctx.Attr<int>("axis");
ElemwiseGradCompute<DeviceContext, T, FMaxGradDx<T>, FMaxGradDy<T>>(
ctx, *x, *y, *out, *dout, axis, dx, dy, FMaxGradDx<T>(),
FMaxGradDy<T>());
}
};
} // namespace operators
} // namespace paddle
......@@ -53,6 +53,27 @@ class ElementwiseMinOpMaker : public ElementwiseOpMaker {
}
};
class ElementwiseFMinOpMaker : public ElementwiseOpMaker {
protected:
std::string GetName() const override { return "FMin"; }
std::string GetEquation() const override { return "Out = fmin(X, Y)"; }
void AddInputX() override {
AddInput("X", "The first tensor holding the elements to be compared.");
}
void AddInputY() override {
AddInput("Y", "The second tensor holding the elements to be compared.");
}
std::string GetOpFuntionality() const override {
return "Compare two tensors and returns a new tensor containing the "
"element-wise minima. If the element of one tensor is nan, "
"return the element value of the other tensor, if both are nan, "
"return the first nan";
}
};
template <typename T>
class ElementwiseMinGradOpMaker : public framework::SingleGradOpMaker<T> {
public:
......@@ -70,6 +91,23 @@ class ElementwiseMinGradOpMaker : public framework::SingleGradOpMaker<T> {
}
};
template <typename T>
class ElementwiseFMinGradOpMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
void Apply(GradOpPtr<T> op) const override {
op->SetType("elementwise_fmin_grad");
op->SetInput("X", this->Input("X"));
op->SetInput("Y", this->Input("Y"));
op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
op->SetOutput(framework::GradVarName("Y"), this->InputGrad("Y"));
op->SetAttrMap(this->Attrs());
}
};
} // namespace operators
} // namespace paddle
......@@ -103,3 +141,28 @@ REGISTER_OP_VERSION(elementwise_min)
"In order to support the function of scaling the input Y when "
"using the operator of elementwise_min.",
1.0f));
REGISTER_OPERATOR(elementwise_fmin, ops::ElementwiseOp,
ops::ElementwiseFMinOpMaker, ops::ElementwiseOpInferVarType,
ops::ElementwiseFMinGradOpMaker<paddle::framework::OpDesc>,
ops::ElementwiseFMinGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(elementwise_fmin_grad, ops::ElementwiseOpGrad);
REGISTER_OP_CPU_KERNEL(
elementwise_fmin,
ops::ElementwiseFMinKernel<paddle::platform::CPUDeviceContext, float>,
ops::ElementwiseFMinKernel<paddle::platform::CPUDeviceContext,
paddle::platform::float16>,
ops::ElementwiseFMinKernel<paddle::platform::CPUDeviceContext, double>,
ops::ElementwiseFMinKernel<paddle::platform::CPUDeviceContext, int>,
ops::ElementwiseFMinKernel<paddle::platform::CPUDeviceContext, int64_t>);
REGISTER_OP_CPU_KERNEL(
elementwise_fmin_grad,
ops::ElementwiseFMinGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::ElementwiseFMinGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::float16>,
ops::ElementwiseFMinGradKernel<paddle::platform::CPUDeviceContext, double>,
ops::ElementwiseFMinGradKernel<paddle::platform::CPUDeviceContext, int>,
ops::ElementwiseFMinGradKernel<paddle::platform::CPUDeviceContext,
int64_t>);
......@@ -52,3 +52,21 @@ REGISTER_OP_CUDA_KERNEL(
ops::ElementwiseMinGradKernel<paddle::platform::CUDADeviceContext, int>,
ops::ElementwiseMinGradKernel<paddle::platform::CUDADeviceContext,
int64_t>);
REGISTER_OP_CUDA_KERNEL(
elementwise_fmin,
ops::ElementwiseFMinKernel<paddle::platform::CUDADeviceContext, float>,
ops::ElementwiseFMinKernel<paddle::platform::CUDADeviceContext,
paddle::platform::float16>,
ops::ElementwiseFMinKernel<paddle::platform::CUDADeviceContext, double>,
ops::ElementwiseFMinKernel<paddle::platform::CUDADeviceContext, int>,
ops::ElementwiseFMinKernel<paddle::platform::CUDADeviceContext, int64_t>);
REGISTER_OP_CUDA_KERNEL(
elementwise_fmin_grad,
ops::ElementwiseFMinGradKernel<paddle::platform::CUDADeviceContext, float>,
ops::ElementwiseFMinGradKernel<paddle::platform::CUDADeviceContext,
paddle::platform::float16>,
ops::ElementwiseFMinGradKernel<paddle::platform::CUDADeviceContext, double>,
ops::ElementwiseFMinGradKernel<paddle::platform::CUDADeviceContext, int>,
ops::ElementwiseFMinGradKernel<paddle::platform::CUDADeviceContext,
int64_t>);
......@@ -14,9 +14,11 @@ limitations under the License. */
#pragma once
#include <cmath>
#include "paddle/fluid/operators/elementwise/elementwise_functor.h"
#include "paddle/fluid/operators/elementwise/elementwise_op.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_function.h"
#include "paddle/fluid/platform/eigen_ext.h"
namespace paddle {
namespace operators {
......@@ -36,6 +38,21 @@ class ElementwiseMinKernel : public framework::OpKernel<T> {
}
};
template <typename DeviceContext, typename T>
class ElementwiseFMinKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* x = ctx.Input<framework::LoDTensor>("X");
auto* y = ctx.Input<framework::LoDTensor>("Y");
auto* z = ctx.Output<framework::LoDTensor>("Out");
z->mutable_data<T>(ctx.GetPlace());
int axis = ctx.Attr<int>("axis");
ElementwiseComputeEx<FMinFunctor<T>, DeviceContext, T>(ctx, x, y, axis,
FMinFunctor<T>(), z);
}
};
template <typename T>
struct MinGradDx {
HOSTDEVICE T operator()(T x, T y, T out, T dout) const {
......@@ -68,5 +85,89 @@ class ElementwiseMinGradKernel : public ElemwiseGradKernel<T> {
ctx, *x, *y, *out, *dout, axis, dx, dy, MinGradDx<T>(), MinGradDy<T>());
}
};
template <typename T>
struct FMinGradDx {
HOSTDEVICE T operator()(T x, T y, T out, T dout) const {
return dout * static_cast<T>((x <= y) || isnan(y));
}
};
template <>
struct FMinGradDx<paddle::platform::float16> {
HOSTDEVICE paddle::platform::float16 operator()(
paddle::platform::float16 x, paddle::platform::float16 y,
paddle::platform::float16 out, paddle::platform::float16 dout) const {
return dout * static_cast<paddle::platform::float16>(
(x <= y) || paddle::platform::isnan(y));
}
};
template <>
struct FMinGradDx<int> {
HOSTDEVICE int operator()(int x, int y, int out, int dout) const {
return dout * static_cast<int>((x <= y));
}
};
template <>
struct FMinGradDx<int64_t> {
HOSTDEVICE int64_t operator()(int64_t x, int64_t y, int64_t out,
int64_t dout) const {
return dout * static_cast<int64_t>((x <= y));
}
};
template <typename T>
struct FMinGradDy {
HOSTDEVICE T operator()(T x, T y, T out, T dout) const {
return dout * static_cast<T>(!((x <= y) || isnan(y)));
}
};
template <>
struct FMinGradDy<paddle::platform::float16> {
HOSTDEVICE paddle::platform::float16 operator()(
paddle::platform::float16 x, paddle::platform::float16 y,
paddle::platform::float16 out, paddle::platform::float16 dout) const {
return dout * static_cast<paddle::platform::float16>(
!((x <= y) || paddle::platform::isnan(y)));
}
};
template <>
struct FMinGradDy<int> {
HOSTDEVICE int operator()(int x, int y, int out, int dout) const {
return dout * static_cast<int>(!((x <= y)));
}
};
template <>
struct FMinGradDy<int64_t> {
HOSTDEVICE int64_t operator()(int64_t x, int64_t y, int64_t out,
int64_t dout) const {
return dout * static_cast<int64_t>(!((x <= y)));
}
};
template <typename DeviceContext, typename T>
class ElementwiseFMinGradKernel : public ElemwiseGradKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
ElemwiseGradKernel<T>::Compute(ctx);
using Tensor = framework::Tensor;
auto* x = ctx.Input<Tensor>("X");
auto* y = ctx.Input<Tensor>("Y");
auto* dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
auto* dy = ctx.Output<Tensor>(framework::GradVarName("Y"));
auto* out = dout; // Fake out, not used
int axis = ctx.Attr<int>("axis");
ElemwiseGradCompute<DeviceContext, T, FMinGradDx<T>, FMinGradDy<T>>(
ctx, *x, *y, *out, *dout, axis, dx, dy, FMinGradDx<T>(),
FMinGradDy<T>());
}
};
} // namespace operators
} // namespace paddle
......@@ -235,6 +235,8 @@ from .tensor.math import gcd # noqa: F401
from .tensor.math import lcm # noqa: F401
from .tensor.math import diff # noqa: F401
from .tensor.math import angle # noqa: F401
from .tensor.math import fmax # noqa: F401
from .tensor.math import fmin # noqa: F401
from .tensor.random import multinomial # noqa: F401
from .tensor.random import standard_normal # noqa: F401
......@@ -568,5 +570,7 @@ __all__ = [ # noqa
'as_real',
'diff',
'angle',
'fmax',
'fmin',
'moveaxis',
]
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
import paddle
import paddle.fluid.core as core
from op_test import OpTest
class ApiFMaxTest(unittest.TestCase):
"""ApiFMaxTest"""
def setUp(self):
"""setUp"""
if core.is_compiled_with_cuda():
self.place = core.CUDAPlace(0)
else:
self.place = core.CPUPlace()
self.input_x = np.random.rand(10, 15).astype("float32")
self.input_y = np.random.rand(10, 15).astype("float32")
self.input_z = np.random.rand(15).astype("float32")
self.input_a = np.array([0, np.nan, np.nan]).astype('int64')
self.input_b = np.array([2, np.inf, -np.inf]).astype('int64')
self.input_c = np.array([4, 1, 3]).astype('int64')
self.np_expected1 = np.fmax(self.input_x, self.input_y)
self.np_expected2 = np.fmax(self.input_x, self.input_z)
self.np_expected3 = np.fmax(self.input_a, self.input_c)
self.np_expected4 = np.fmax(self.input_b, self.input_c)
def test_static_api(self):
"""test_static_api"""
paddle.enable_static()
with paddle.static.program_guard(paddle.static.Program(),
paddle.static.Program()):
data_x = paddle.static.data("x", shape=[10, 15], dtype="float32")
data_y = paddle.static.data("y", shape=[10, 15], dtype="float32")
result_fmax = paddle.fmax(data_x, data_y)
exe = paddle.static.Executor(self.place)
res, = exe.run(feed={"x": self.input_x,
"y": self.input_y},
fetch_list=[result_fmax])
self.assertTrue(np.allclose(res, self.np_expected1))
with paddle.static.program_guard(paddle.static.Program(),
paddle.static.Program()):
data_x = paddle.static.data("x", shape=[10, 15], dtype="float32")
data_z = paddle.static.data("z", shape=[15], dtype="float32")
result_fmax = paddle.fmax(data_x, data_z)
exe = paddle.static.Executor(self.place)
res, = exe.run(feed={"x": self.input_x,
"z": self.input_z},
fetch_list=[result_fmax])
self.assertTrue(np.allclose(res, self.np_expected2))
with paddle.static.program_guard(paddle.static.Program(),
paddle.static.Program()):
data_a = paddle.static.data("a", shape=[3], dtype="int64")
data_c = paddle.static.data("c", shape=[3], dtype="int64")
result_fmax = paddle.fmax(data_a, data_c)
exe = paddle.static.Executor(self.place)
res, = exe.run(feed={"a": self.input_a,
"c": self.input_c},
fetch_list=[result_fmax])
self.assertTrue(np.allclose(res, self.np_expected3))
with paddle.static.program_guard(paddle.static.Program(),
paddle.static.Program()):
data_b = paddle.static.data("b", shape=[3], dtype="int64")
data_c = paddle.static.data("c", shape=[3], dtype="int64")
result_fmax = paddle.fmax(data_b, data_c)
exe = paddle.static.Executor(self.place)
res, = exe.run(feed={"b": self.input_b,
"c": self.input_c},
fetch_list=[result_fmax])
self.assertTrue(np.allclose(res, self.np_expected4))
def test_dynamic_api(self):
"""test_dynamic_api"""
paddle.disable_static()
x = paddle.to_tensor(self.input_x)
y = paddle.to_tensor(self.input_y)
z = paddle.to_tensor(self.input_z)
a = paddle.to_tensor(self.input_a)
b = paddle.to_tensor(self.input_b)
c = paddle.to_tensor(self.input_c)
res = paddle.fmax(x, y)
res = res.numpy()
self.assertTrue(np.allclose(res, self.np_expected1))
# test broadcast
res = paddle.fmax(x, z)
res = res.numpy()
self.assertTrue(np.allclose(res, self.np_expected2))
res = paddle.fmax(a, c)
res = res.numpy()
self.assertTrue(np.allclose(res, self.np_expected3))
res = paddle.fmax(b, c)
res = res.numpy()
self.assertTrue(np.allclose(res, self.np_expected4))
class TestElementwiseFmaxOp(OpTest):
"""TestElementwiseFmaxOp"""
def setUp(self):
"""setUp"""
self.op_type = "elementwise_fmax"
# If x and y have the same value, the max() is not differentiable.
# So we generate test data by the following method
# to avoid them being too close to each other.
x = np.random.uniform(0.1, 1, [13, 17]).astype("float64")
sgn = np.random.choice([-1, 1], [13, 17]).astype("float64")
y = x + sgn * np.random.uniform(0.1, 1, [13, 17]).astype("float64")
self.inputs = {'X': x, 'Y': y}
self.outputs = {'Out': np.fmax(self.inputs['X'], self.inputs['Y'])}
def test_check_output(self):
"""test_check_output"""
self.check_output()
def test_check_grad_normal(self):
"""test_check_grad_normal"""
self.check_grad(['X', 'Y'], 'Out')
def test_check_grad_ingore_x(self):
"""test_check_grad_ingore_x"""
self.check_grad(
['Y'], 'Out', max_relative_error=0.005, no_grad_set=set("X"))
def test_check_grad_ingore_y(self):
"""test_check_grad_ingore_y"""
self.check_grad(
['X'], 'Out', max_relative_error=0.005, no_grad_set=set('Y'))
class TestElementwiseFmax2Op(OpTest):
"""TestElementwiseFmax2Op"""
def setUp(self):
"""setUp"""
self.op_type = "elementwise_fmax"
# If x and y have the same value, the max() is not differentiable.
# So we generate test data by the following method
# to avoid them being too close to each other.
x = np.random.uniform(0.1, 1, [13, 17]).astype("float64")
sgn = np.random.choice([-1, 1], [13, 17]).astype("float64")
y = x + sgn * np.random.uniform(0.1, 1, [13, 17]).astype("float64")
y[2, 10:] = np.nan
self.inputs = {'X': x, 'Y': y}
self.outputs = {'Out': np.fmax(self.inputs['X'], self.inputs['Y'])}
def test_check_output(self):
"""test_check_output"""
self.check_output()
def test_check_grad_normal(self):
"""test_check_grad_normal"""
self.check_grad(['X', 'Y'], 'Out')
def test_check_grad_ingore_x(self):
"""test_check_grad_ingore_x"""
self.check_grad(
['Y'], 'Out', max_relative_error=0.005, no_grad_set=set("X"))
def test_check_grad_ingore_y(self):
"""test_check_grad_ingore_y"""
self.check_grad(
['X'], 'Out', max_relative_error=0.005, no_grad_set=set('Y'))
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
import paddle
import paddle.fluid.core as core
from op_test import OpTest
paddle.enable_static()
class ApiFMinTest(unittest.TestCase):
"""ApiFMinTest"""
def setUp(self):
"""setUp"""
if core.is_compiled_with_cuda():
self.place = core.CUDAPlace(0)
else:
self.place = core.CPUPlace()
self.input_x = np.random.rand(10, 15).astype("float32")
self.input_y = np.random.rand(10, 15).astype("float32")
self.input_z = np.random.rand(15).astype("float32")
self.input_a = np.array([0, np.nan, np.nan]).astype('int64')
self.input_b = np.array([2, np.inf, -np.inf]).astype('int64')
self.input_c = np.array([4, 1, 3]).astype('int64')
self.np_expected1 = np.fmin(self.input_x, self.input_y)
self.np_expected2 = np.fmin(self.input_x, self.input_z)
self.np_expected3 = np.fmin(self.input_a, self.input_c)
self.np_expected4 = np.fmin(self.input_b, self.input_c)
def test_static_api(self):
"""test_static_api"""
paddle.enable_static()
with paddle.static.program_guard(paddle.static.Program(),
paddle.static.Program()):
data_x = paddle.static.data("x", shape=[10, 15], dtype="float32")
data_y = paddle.static.data("y", shape=[10, 15], dtype="float32")
result_fmin = paddle.fmin(data_x, data_y)
exe = paddle.static.Executor(self.place)
res, = exe.run(feed={"x": self.input_x,
"y": self.input_y},
fetch_list=[result_fmin])
self.assertTrue(np.allclose(res, self.np_expected1))
with paddle.static.program_guard(paddle.static.Program(),
paddle.static.Program()):
data_x = paddle.static.data("x", shape=[10, 15], dtype="float32")
data_z = paddle.static.data("z", shape=[15], dtype="float32")
result_fmin = paddle.fmin(data_x, data_z)
exe = paddle.static.Executor(self.place)
res, = exe.run(feed={"x": self.input_x,
"z": self.input_z},
fetch_list=[result_fmin])
self.assertTrue(np.allclose(res, self.np_expected2))
with paddle.static.program_guard(paddle.static.Program(),
paddle.static.Program()):
data_a = paddle.static.data("a", shape=[3], dtype="int64")
data_c = paddle.static.data("c", shape=[3], dtype="int64")
result_fmin = paddle.fmin(data_a, data_c)
exe = paddle.static.Executor(self.place)
res, = exe.run(feed={"a": self.input_a,
"c": self.input_c},
fetch_list=[result_fmin])
self.assertTrue(np.allclose(res, self.np_expected3))
with paddle.static.program_guard(paddle.static.Program(),
paddle.static.Program()):
data_b = paddle.static.data("b", shape=[3], dtype="int64")
data_c = paddle.static.data("c", shape=[3], dtype="int64")
result_fmin = paddle.fmin(data_b, data_c)
exe = paddle.static.Executor(self.place)
res, = exe.run(feed={"b": self.input_b,
"c": self.input_c},
fetch_list=[result_fmin])
self.assertTrue(np.allclose(res, self.np_expected4))
def test_dynamic_api(self):
"""test_dynamic_api"""
paddle.disable_static()
x = paddle.to_tensor(self.input_x)
y = paddle.to_tensor(self.input_y)
z = paddle.to_tensor(self.input_z)
a = paddle.to_tensor(self.input_a)
b = paddle.to_tensor(self.input_b)
c = paddle.to_tensor(self.input_c)
res = paddle.fmin(x, y)
res = res.numpy()
self.assertTrue(np.allclose(res, self.np_expected1))
# test broadcast
res = paddle.fmin(x, z)
res = res.numpy()
self.assertTrue(np.allclose(res, self.np_expected2))
res = paddle.fmin(a, c)
res = res.numpy()
self.assertTrue(np.allclose(res, self.np_expected3))
res = paddle.fmin(b, c)
res = res.numpy()
self.assertTrue(np.allclose(res, self.np_expected4))
class TestElementwiseFminOp(OpTest):
"""TestElementwiseFminOp"""
def setUp(self):
"""setUp"""
self.op_type = "elementwise_fmin"
# If x and y have the same value, the min() is not differentiable.
# So we generate test data by the following method
# to avoid them being too close to each other.
x = np.random.uniform(0.1, 1, [13, 17]).astype("float64")
sgn = np.random.choice([-1, 1], [13, 17]).astype("float64")
y = x + sgn * np.random.uniform(0.1, 1, [13, 17]).astype("float64")
self.inputs = {'X': x, 'Y': y}
self.outputs = {'Out': np.fmin(self.inputs['X'], self.inputs['Y'])}
def test_check_output(self):
"""test_check_output"""
self.check_output()
def test_check_grad_normal(self):
"""test_check_grad_normal"""
self.check_grad(['X', 'Y'], 'Out')
def test_check_grad_ingore_x(self):
"""test_check_grad_ingore_x"""
self.check_grad(
['Y'], 'Out', max_relative_error=0.005, no_grad_set=set("X"))
def test_check_grad_ingore_y(self):
"""test_check_grad_ingore_y"""
self.check_grad(
['X'], 'Out', max_relative_error=0.005, no_grad_set=set('Y'))
class TestElementwiseFmin2Op(OpTest):
"""TestElementwiseFmin2Op"""
def setUp(self):
"""setUp"""
self.op_type = "elementwise_fmin"
# If x and y have the same value, the min() is not differentiable.
# So we generate test data by the following method
# to avoid them being too close to each other.
x = np.random.uniform(0.1, 1, [13, 17]).astype("float64")
sgn = np.random.choice([-1, 1], [13, 17]).astype("float64")
y = x + sgn * np.random.uniform(0.1, 1, [13, 17]).astype("float64")
y[2, 10:] = np.nan
self.inputs = {'X': x, 'Y': y}
self.outputs = {'Out': np.fmin(self.inputs['X'], self.inputs['Y'])}
def test_check_output(self):
"""test_check_output"""
self.check_output()
def test_check_grad_normal(self):
"""test_check_grad_normal"""
self.check_grad(['X', 'Y'], 'Out')
def test_check_grad_ingore_x(self):
"""test_check_grad_ingore_x"""
self.check_grad(
['Y'], 'Out', max_relative_error=0.005, no_grad_set=set("X"))
def test_check_grad_ingore_y(self):
"""test_check_grad_ingore_y"""
self.check_grad(
['X'], 'Out', max_relative_error=0.005, no_grad_set=set('Y'))
......@@ -42,6 +42,8 @@ NEED_TO_FIX_OP_LIST = [
'elementwise_mul',
'elementwise_sub',
'elementwise_pow',
'elementwise_fmin',
'elementwise_fmax',
'filter_by_instag',
'fused_elemwise_activation',
'fused_emb_seq_pool',
......
......@@ -202,6 +202,8 @@ from .math import gcd # noqa: F401
from .math import lcm # noqa: F401
from .math import diff # noqa: F401
from .math import angle # noqa: F401
from .math import fmax # noqa: F401
from .math import fmin # noqa: F401
from .random import multinomial # noqa: F401
from .random import standard_normal # noqa: F401
......@@ -305,6 +307,8 @@ tensor_method_func = [ #noqa
'maximum',
'min',
'minimum',
'fmax',
'fmin',
'mm',
'divide',
'floor_divide',
......
......@@ -631,6 +631,128 @@ def minimum(x, y, name=None):
x, y, axis=axis, act=act, op_name=op_type)
return _elementwise_op(LayerHelper(op_type, **locals()))
def fmax(x, y, name=None):
"""
Compares the elements at the corresponding positions of the two tensors and returns a new tensor containing the maximum value of the element.
If one of them is a nan value, the other value is directly returned, if both are nan values, then the first nan value is returned.
The equation is:
.. math::
out = fmax(x, y)
**Note**:
``paddle.fmax`` supports broadcasting. If you want know more about broadcasting, please refer to :ref:`user_guide_broadcasting` .
Args:
x (Tensor): the input tensor, it's data type should be float32, float64, int32, int64.
y (Tensor): the input tensor, it's data type should be float32, float64, int32, int64.
name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.
Returns:
N-D Tensor. A location into which the result is stored. If x, y have different shapes and are "broadcastable", the resulting tensor shape is the shape of x and y after broadcasting. If x, y have the same shape, its shape is the same as x and y.
Examples:
.. code-block:: python
import numpy as np
import paddle
x = paddle.to_tensor([[1, 2], [7, 8]])
y = paddle.to_tensor([[3, 4], [5, 6]])
res = paddle.fmax(x, y)
print(res)
# [[3, 4],
# [7, 8]]
x = paddle.to_tensor([[1, 2, 3], [1, 2, 3]])
y = paddle.to_tensor([3, 0, 4])
res = paddle.fmax(x, y)
print(res)
# [[3, 2, 4],
# [3, 2, 4]]
x = paddle.to_tensor([2, 3, 5], dtype='float32')
y = paddle.to_tensor([1, np.nan, np.nan], dtype='float32')
res = paddle.fmax(x, y)
print(res)
# [ 2., 3., 5.]
x = paddle.to_tensor([5, 3, np.inf], dtype='float32')
y = paddle.to_tensor([1, -np.inf, 5], dtype='float32')
res = paddle.fmax(x, y)
print(res)
# [ 5., 3., inf.]
"""
op_type = 'elementwise_fmax'
axis = -1
act = None
if in_dygraph_mode():
return _elementwise_op_in_dygraph(
x, y, axis=axis, act=act, op_name=op_type)
return _elementwise_op(LayerHelper(op_type, **locals()))
def fmin(x, y, name=None):
"""
Compares the elements at the corresponding positions of the two tensors and returns a new tensor containing the minimum value of the element.
If one of them is a nan value, the other value is directly returned, if both are nan values, then the first nan value is returned.
The equation is:
.. math::
out = fmin(x, y)
**Note**:
``paddle.fmin`` supports broadcasting. If you want know more about broadcasting, please refer to :ref:`user_guide_broadcasting` .
Args:
x (Tensor): the input tensor, it's data type should be float32, float64, int32, int64.
y (Tensor): the input tensor, it's data type should be float32, float64, int32, int64.
name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.
Returns:
N-D Tensor. A location into which the result is stored. If x, y have different shapes and are "broadcastable", the resulting tensor shape is the shape of x and y after broadcasting. If x, y have the same shape, its shape is the same as x and y.
Examples:
.. code-block:: python
import numpy as np
import paddle
x = paddle.to_tensor([[1, 2], [7, 8]])
y = paddle.to_tensor([[3, 4], [5, 6]])
res = paddle.fmin(x, y)
print(res)
# [[1, 2],
# [5, 6]]
x = paddle.to_tensor([[[1, 2, 3], [1, 2, 3]]])
y = paddle.to_tensor([3, 0, 4])
res = paddle.fmin(x, y)
print(res)
# [[[1, 0, 3],
# [1, 0, 3]]]
x = paddle.to_tensor([2, 3, 5], dtype='float32')
y = paddle.to_tensor([1, np.nan, np.nan], dtype='float32')
res = paddle.fmin(x, y)
print(res)
# [ 1., 3., 5.]
x = paddle.to_tensor([5, 3, np.inf], dtype='float64')
y = paddle.to_tensor([1, -np.inf, 5], dtype='float64')
res = paddle.fmin(x, y)
print(res)
# [ 1., -inf., 5.]
"""
op_type = 'elementwise_fmin'
axis = -1
act = None
if in_dygraph_mode():
return _elementwise_op_in_dygraph(
x, y, axis=axis, act=act, op_name=op_type)
return _elementwise_op(LayerHelper(op_type, **locals()))
for func in [
add,
multiply
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册