From 6b59b58cd19d6d7a1e1a4a6fdbda4d5f740eb5b6 Mon Sep 17 00:00:00 2001 From: wuhuanzhou Date: Thu, 23 Dec 2021 16:23:51 +0800 Subject: [PATCH] Add erfinv API (#38295) * add erfinv API, test=develop * fix gradient accuracy error, test=develop * fix cuda compilation error on Windows, test=develop * fix M_2_SQRTPI undeclared identifier on Windows, test=develop --- paddle/fluid/operators/erfinv_op.cc | 100 ++++++++++++++++ paddle/fluid/operators/erfinv_op.cu | 28 +++++ paddle/fluid/operators/erfinv_op.h | 65 ++++++++++ python/paddle/__init__.py | 2 + .../fluid/tests/unittests/test_erfinv_op.py | 111 ++++++++++++++++++ python/paddle/tensor/__init__.py | 4 + python/paddle/tensor/math.py | 45 +++++++ 7 files changed, 355 insertions(+) create mode 100644 paddle/fluid/operators/erfinv_op.cc create mode 100644 paddle/fluid/operators/erfinv_op.cu create mode 100644 paddle/fluid/operators/erfinv_op.h create mode 100644 python/paddle/fluid/tests/unittests/test_erfinv_op.py diff --git a/paddle/fluid/operators/erfinv_op.cc b/paddle/fluid/operators/erfinv_op.cc new file mode 100644 index 00000000000..56a6a80b45d --- /dev/null +++ b/paddle/fluid/operators/erfinv_op.cc @@ -0,0 +1,100 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/operators/erfinv_op.h" + +namespace paddle { +namespace operators { + +class ErfinvOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "erfinv"); + OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "erfinv"); + + ctx->SetOutputDim("Out", ctx->GetInputDim("X")); + ctx->ShareLoD("X", /*->*/ "Out"); + } +}; + +class ErfinvOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", "(Tensor), The input tensor of erfinv op."); + AddOutput("Out", "(Tensor), The output tensor of erfinv op."); + AddComment(R"DOC( +Erfinv Operator. + +This operator is used to compute inverse error function of input $X$. + +The equation is: + +$$erfinv(x) = {ndtri({x \over 2} + 0.5)} \over {\sqrt{2}}$$ + +The input `X` can carry the LoD (Level of Details) information, +or not. And the output shares the LoD information with input `X`. +)DOC"); + } +}; + +class ErfinvGradOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("Out")); + } +}; + +template +class ErfinvGradMaker : public framework::SingleGradOpMaker { + public: + using framework::SingleGradOpMaker::SingleGradOpMaker; + + void Apply(GradOpPtr op) const override { + op->SetType("erfinv_grad"); + op->SetInput("Out", this->Output("Out")); + op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); + op->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); + op->SetAttrMap(this->Attrs()); + } +}; + +DECLARE_INPLACE_OP_INFERER(ErfinvInplaceInferer, {"X", "Out"}); + +} // namespace operators +} // namespace paddle + +REGISTER_OPERATOR( + erfinv, paddle::operators::ErfinvOp, paddle::operators::ErfinvOpMaker, + paddle::operators::ErfinvGradMaker, + paddle::operators::ErfinvGradMaker, + paddle::operators::ErfinvInplaceInferer); + +REGISTER_OPERATOR(erfinv_grad, paddle::operators::ErfinvGradOp); + +REGISTER_OP_CPU_KERNEL( + erfinv, + paddle::operators::ErfinvKernel, + paddle::operators::ErfinvKernel); + +REGISTER_OP_CPU_KERNEL( + erfinv_grad, + paddle::operators::ErfinvGradKernel, + paddle::operators::ErfinvGradKernel); diff --git a/paddle/fluid/operators/erfinv_op.cu b/paddle/fluid/operators/erfinv_op.cu new file mode 100644 index 00000000000..1fb2dbb97a2 --- /dev/null +++ b/paddle/fluid/operators/erfinv_op.cu @@ -0,0 +1,28 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/operators/erfinv_op.h" + +REGISTER_OP_CUDA_KERNEL( + erfinv, + paddle::operators::ErfinvKernel, + paddle::operators::ErfinvKernel); + +REGISTER_OP_CUDA_KERNEL( + erfinv_grad, + paddle::operators::ErfinvGradKernel, + paddle::operators::ErfinvGradKernel); diff --git a/paddle/fluid/operators/erfinv_op.h b/paddle/fluid/operators/erfinv_op.h new file mode 100644 index 00000000000..934d0f4a5a7 --- /dev/null +++ b/paddle/fluid/operators/erfinv_op.h @@ -0,0 +1,65 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#ifndef _USE_MATH_DEFINES +#define _USE_MATH_DEFINES // use M_2_SQRTPI on Windows +#endif +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/operator.h" + +namespace paddle { +namespace operators { + +// ndtri(x * 0.5 + 0.5) / sqrt(2) +template +class ErfinvKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto in = ctx.Input("X"); + auto out = ctx.Output("Out"); + out->mutable_data(ctx.GetPlace()); + + auto eigen_in = framework::EigenVector::Flatten(*in); + auto eigen_out = framework::EigenVector::Flatten(*out); + auto& place = *ctx.template device_context().eigen_device(); + constexpr T half = static_cast(0.5); + constexpr T half_sqrt = static_cast(M_SQRT1_2); + eigen_out.device(place) = (eigen_in * half + half).ndtri() * half_sqrt; + } +}; + +// sqrt(pi) / 2 * exp(square(out)) * grad +template +class ErfinvGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto out = ctx.Input("Out"); + auto dout = ctx.Input(framework::GradVarName("Out")); + auto dx = ctx.Output(framework::GradVarName("X")); + dx->mutable_data(ctx.GetPlace()); + + auto eigen_out = framework::EigenVector::Flatten(*out); + auto eigen_dout = framework::EigenVector::Flatten(*dout); + auto eigen_dx = framework::EigenVector::Flatten(*dx); + auto& place = *ctx.template device_context().eigen_device(); + + constexpr T half_sqrt_pi = static_cast(1 / M_2_SQRTPI); + eigen_dx.device(place) = + half_sqrt_pi * eigen_dout * eigen_out.square().exp(); + } +}; + +} // namespace operators +} // namespace paddle diff --git a/python/paddle/__init__.py b/python/paddle/__init__.py index e0e33d3805e..4e6fe5a686f 100755 --- a/python/paddle/__init__.py +++ b/python/paddle/__init__.py @@ -237,6 +237,7 @@ from .tensor.math import acosh # noqa: F401 from .tensor.math import asinh # noqa: F401 from .tensor.math import atanh # noqa: F401 from .tensor.math import lerp # noqa: F401 +from .tensor.math import erfinv # noqa: F401 from .tensor.math import rad2deg # noqa: F401 from .tensor.math import deg2rad # noqa: F401 from .tensor.math import gcd # noqa: F401 @@ -493,6 +494,7 @@ __all__ = [ # noqa 'neg', 'lgamma', 'lerp', + 'erfinv', 'square', 'divide', 'ceil', diff --git a/python/paddle/fluid/tests/unittests/test_erfinv_op.py b/python/paddle/fluid/tests/unittests/test_erfinv_op.py new file mode 100644 index 00000000000..847a868dd6c --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_erfinv_op.py @@ -0,0 +1,111 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +from scipy.special import erfinv +from op_test import OpTest +import paddle +import paddle.fluid.core as core + +paddle.enable_static() +np.random.seed(0) + + +class TestErfinv(OpTest): + def setUp(self): + self.op_type = "erfinv" + self.init_dtype() + self.shape = [11, 17] + self.x = np.random.uniform(-1, 1, size=self.shape).astype(self.dtype) + self.res_ref = erfinv(self.x).astype(self.dtype) + self.grad_out = np.ones(self.shape, self.dtype) + self.gradient = np.sqrt(np.pi) / 2 * np.exp(np.square( + self.res_ref)) * self.grad_out + self.inputs = {'X': self.x} + self.outputs = {'Out': self.res_ref} + + def init_dtype(self): + self.dtype = np.float64 + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad( + ['X'], + 'Out', + user_defined_grads=[self.gradient], + user_defined_grad_outputs=self.grad_out) + + +class TestErfinvFP32(TestErfinv): + def init_dtype(self): + self.dtype = np.float32 + + +class TestErfinvAPI(unittest.TestCase): + def init_dtype(self): + self.dtype = 'float32' + + def setUp(self): + self.init_dtype() + self.x = np.random.rand(5).astype(self.dtype) + self.res_ref = erfinv(self.x) + self.place = [paddle.CPUPlace()] + if core.is_compiled_with_cuda(): + self.place.append(paddle.CUDAPlace(0)) + + def test_static_api(self): + paddle.enable_static() + + def run(place): + with paddle.static.program_guard(paddle.static.Program()): + x = paddle.fluid.data('x', [1, 5], dtype=self.dtype) + out = paddle.erfinv(x) + exe = paddle.static.Executor(place) + res = exe.run(feed={'x': self.x.reshape([1, 5])}) + for r in res: + self.assertEqual(np.allclose(self.res_ref, r), True) + + for place in self.place: + run(place) + + def test_dygraph_api(self): + def run(place): + paddle.disable_static(place) + x = paddle.to_tensor(self.x) + out = paddle.erfinv(x) + self.assertEqual(np.allclose(self.res_ref, out.numpy()), True) + paddle.enable_static() + + for place in self.place: + run(place) + + def test_inplace_api(self): + def run(place): + paddle.disable_static(place) + x = paddle.to_tensor(self.x) + x.erfinv_() + self.assertEqual(np.allclose(self.res_ref, x.numpy()), True) + paddle.enable_static() + + for place in self.place: + run(place) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/tensor/__init__.py b/python/paddle/tensor/__init__.py index 314938ad732..ef9425f6db8 100755 --- a/python/paddle/tensor/__init__.py +++ b/python/paddle/tensor/__init__.py @@ -203,6 +203,8 @@ from .math import asinh # noqa: F401 from .math import atanh # noqa: F401 from .math import lerp # noqa: F401 from .math import lerp_ # noqa: F401 +from .math import erfinv # noqa: F401 +from .math import erfinv_ # noqa: F401 from .math import rad2deg # noqa: F401 from .math import deg2rad # noqa: F401 from .math import gcd # noqa: F401 @@ -441,6 +443,8 @@ tensor_method_func = [ #noqa 'diff', 'lerp', 'lerp_', + 'erfinv', + 'erfinv_', 'angle', 'moveaxis', 'repeat_interleave', diff --git a/python/paddle/tensor/math.py b/python/paddle/tensor/math.py index 9e59fbc56ad..ffe5f580bc6 100755 --- a/python/paddle/tensor/math.py +++ b/python/paddle/tensor/math.py @@ -2860,6 +2860,51 @@ def lerp_(x, y, weight, name=None): raise ValueError("The shape of broadcast output {} is different from that of inplace tensor {} in the Inplace operation.".format(out_shape, x.shape)) return _C_ops.lerp_(x, y, weight) +def erfinv(x, name=None): + r""" + The inverse error function of x, . + + Equation: + .. math:: + + erfinv(erf(x)) = x. + + Args: + x (Tensor): An N-D Tensor, the data type is float32, float64. + name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. + + Returns: + out (Tensor): An N-D Tensor, the shape and data type is the same with input. + + Example: + .. code-block:: python + + import paddle + + x = paddle.to_tensor([0, 0.5, -1.], dtype="float32") + out = paddle.erfinv(x) + # out: [0, 0.4769, -inf] + + """ + check_variable_and_dtype(x, 'x', ['float32', 'float64'], 'erfinv') + + if in_dygraph_mode(): + return _C_ops.erfinv(x) + + helper = LayerHelper('erfinv', **locals()) + out = helper.create_variable_for_type_inference(dtype=x.dtype) + helper.append_op(type='erfinv', inputs={'X': x}, outputs={'Out': out}) + return out + +@inplace_apis_in_dygraph_only +def erfinv_(x, name=None): + r""" + Inplace version of ``erfinv`` API, the output Tensor will be inplaced with input ``x``. + Please refer to :ref:`api_tensor_erfinv`. + """ + check_type(x, 'x', (paddle.Tensor, Variable), 'erfinv') + return _C_ops.erfinv_(x) + def rad2deg(x, name=None): """ Convert each of the elements of input x from angles in radians to degrees. -- GitLab