diff --git a/paddle/fluid/operators/erf_op.cc b/paddle/fluid/operators/erf_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..c6ff306b9579caab36726a2d78109c6b4f798c5d --- /dev/null +++ b/paddle/fluid/operators/erf_op.cc @@ -0,0 +1,133 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include +#include + +#include "paddle/fluid/operators/erf_op.h" +#include "paddle/fluid/platform/float16.h" + +namespace paddle { +namespace operators { + +class ErfOp : public framework::OperatorWithKernel { + public: + ErfOp(const std::string &type, const framework::VariableNameMap &inputs, + const framework::VariableNameMap &outputs, + const framework::AttributeMap &attrs) + : OperatorWithKernel(type, inputs, outputs, attrs) {} + + void InferShape(framework::InferShapeContext *ctx) const override { + PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true, + platform::errors::InvalidArgument( + "Input(%s) of ErfOp should not be null.", "X")); + PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true, + platform::errors::InvalidArgument( + "Output(%s) of ErfOp should not be null.", "Out")); + + ctx->ShareDim("X", /*->*/ "Out"); + ctx->ShareLoD("X", /*->*/ "Out"); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext &ctx) const override { + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); + } +}; + +class ErfGradOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext *ctx) const override { + PADDLE_ENFORCE_EQ( + ctx->HasInput(framework::GradVarName("Out")), true, + platform::errors::InvalidArgument( + "Input(%s) of ErfGradOp should not be null.", "DOut")); + PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true, + platform::errors::InvalidArgument( + "Input(%s) of ErfGradOp should not be null.", "X")); + PADDLE_ENFORCE_EQ(ctx->HasOutput(framework::GradVarName("X")), true, + platform::errors::InvalidArgument( + "Output(%s) of ErfGradOp should not be null.", "DX")); + auto x_grad_name = framework::GradVarName("X"); + ctx->SetOutputDim(x_grad_name, ctx->GetInputDim("X")); + ctx->ShareLoD("X", /*->*/ x_grad_name); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext &ctx) const override { + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); + } +}; + +class ErfOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", "The input tensor of erf operator."); + AddOutput("Out", "The output tensor of erf operator."); + AddComment(R"DOC( +Erf Operator. + +The equation is: +$$ +f(x) = \frac{2}{\sqrt{\pi}} \int_{0}^{x}e^{- \eta^{2}}d\eta +$$ + +The input `X` can carry the LoD (Level of Details) information, +or not. And the output shares the LoD information with input `X`. +)DOC"); + } +}; + +template +class ErfGradOpMaker : public framework::SingleGradOpMaker { + public: + using framework::SingleGradOpMaker::SingleGradOpMaker; + + std::unique_ptr Apply() const override { + auto *grad_op = new T(); + grad_op->SetType("erf_grad"); + grad_op->SetInput("X", this->Input("X")); + grad_op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); + grad_op->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); + grad_op->SetAttrMap(this->Attrs()); + return std::unique_ptr(grad_op); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; + +REGISTER_OPERATOR(erf, ops::ErfOp, ops::ErfOpMaker, + ops::ErfGradOpMaker, + ops::ErfGradOpMaker); +REGISTER_OPERATOR(erf_grad, ops::ErfGradOp); +REGISTER_OP_CPU_KERNEL( + erf, ops::ErfKernel, + ops::ErfKernel, + ops::ErfKernel); +REGISTER_OP_CPU_KERNEL( + erf_grad, ops::ErfGradKernel, + ops::ErfGradKernel, + ops::ErfGradKernel); diff --git a/paddle/fluid/operators/erf_op.cu b/paddle/fluid/operators/erf_op.cu new file mode 100644 index 0000000000000000000000000000000000000000..357b9e79c4e72854549f11ab49735fac65a400be --- /dev/null +++ b/paddle/fluid/operators/erf_op.cu @@ -0,0 +1,28 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/erf_op.h" +#include "paddle/fluid/platform/float16.h" + +namespace ops = paddle::operators; +REGISTER_OP_CUDA_KERNEL( + erf, ops::ErfKernel, + ops::ErfKernel, + ops::ErfKernel); +REGISTER_OP_CUDA_KERNEL( + erf_grad, ops::ErfGradKernel, + ops::ErfGradKernel, + ops::ErfGradKernel); diff --git a/paddle/fluid/operators/erf_op.h b/paddle/fluid/operators/erf_op.h new file mode 100644 index 0000000000000000000000000000000000000000..08c827df95d9bfa4f01f3c7af9e657b7b3a360a8 --- /dev/null +++ b/paddle/fluid/operators/erf_op.h @@ -0,0 +1,64 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#ifndef _USE_MATH_DEFINES +#define _USE_MATH_DEFINES +#endif +#include +#include "paddle/fluid/framework/eigen.h" +#include "paddle/fluid/framework/op_registry.h" + +namespace paddle { +namespace operators { + +template +class ErfKernel : public framework::OpKernel { + public: + virtual void Compute(const framework::ExecutionContext& context) const { + auto* out = context.Output("Out"); + auto* in = context.Input("X"); + out->mutable_data(in->place()); + + auto eigen_out = framework::EigenVector::Flatten(*out); + auto eigen_in = framework::EigenVector::Flatten(*in); + auto& place = + *context.template device_context().eigen_device(); + eigen_out.device(place) = eigen_in.erf(); + } +}; + +template +class ErfGradKernel : public framework::OpKernel { + public: + virtual void Compute(const framework::ExecutionContext& context) const { + auto* x = context.Input("X"); + auto* dout = + context.Input(framework::GradVarName("Out")); + auto* dx = context.Output(framework::GradVarName("X")); + + dx->mutable_data(dout->place()); + + auto eigen_x = framework::EigenVector::Flatten(*x); + auto eigen_dout = framework::EigenVector::Flatten(*dout); + auto eigen_dx = framework::EigenVector::Flatten(*dx); + auto& place = + *context.template device_context().eigen_device(); + eigen_dx.device(place) = + eigen_dout * static_cast(M_2_SQRTPI) * (-(eigen_x.square())).exp(); + } +}; + +} // namespace operators +} // namespace paddle diff --git a/python/paddle/fluid/layers/ops.py b/python/paddle/fluid/layers/ops.py index 072b95969c7bea272e38e00d0ff6be7cd9ef3a04..97d8a5bb3cd12fea2149031f201b50d727e04948 100644 --- a/python/paddle/fluid/layers/ops.py +++ b/python/paddle/fluid/layers/ops.py @@ -318,3 +318,82 @@ Examples: # array([[ 0.70456535, -0.15380788, -0.13207214], # [ 0.08796856, 0.20387867, 0.2080159 ]], dtype=float32) """ + +__all__ += ['erf'] + +_erf_ = generate_layer_fn('erf') + + +def erf(x): + locals_var = locals().copy() + kwargs = dict() + for name, val in locals_var.items(): + if val is not None: + kwargs[name] = val + return _erf_(**kwargs) + + +erf.__doc__ = """ +:strong:`Erf Operator` +For more details, see [Error function](https://en.wikipedia.org/wiki/Error_function). + +Equation: + .. math:: + out = \\frac{2}{\\sqrt{\\pi}} \\int_{0}^{x}e^{- \\eta^{2}}d\\eta + +Args: + + x(Variable): The input of Erf op, Tensor or LoDTensor, dtype: float32 or float64. + +Returns: + + Variable: The output of Erf op, Tensor or LoDTensor, dtype: float32 or float64, the same as the input, shape: the same as the input. + +Examples: + + .. code-block:: python + + # declarative mode + import numpy as np + from paddle import fluid + + x = fluid.data(name="x", shape=(-1, 3), dtype="float32") + y = fluid.layers.erf(x) + + place = fluid.CPUPlace() + exe = fluid.Executor(place) + start = fluid.default_startup_program() + main = fluid.default_main_program() + + data = np.random.randn(2, 3).astype("float32") + exe.run(start) + + y_np, = exe.run(main, feed={"x": data}, fetch_list=[y]) + + data + # array([[ 0.4643714 , -1.1509596 , 1.2538221 ], + # [ 0.34369683, 0.27478245, 1.1805398 ]], dtype=float32) + y_np + # array([[ 0.48863927, -0.8964121 , 0.9237998 ], + # [ 0.37307587, 0.30242872, 0.9049887 ]], dtype=float32) + + .. code-block:: python + + # imperative mode + import numpy as np + from paddle import fluid + import paddle.fluid.dygraph as dg + + data = np.random.randn(2, 3).astype("float32") + place = fluid.CPUPlace() + with dg.guard(place) as g: + x = dg.to_variable(data) + y = fluid.layers.erf(x) + y_np = y.numpy() + data + # array([[ 0.4643714 , -1.1509596 , 1.2538221 ], + # [ 0.34369683, 0.27478245, 1.1805398 ]], dtype=float32) + y_np + # array([[ 0.48863927, -0.8964121 , 0.9237998 ], + # [ 0.37307587, 0.30242872, 0.9049887 ]], dtype=float32) +""" diff --git a/python/paddle/fluid/tests/unittests/test_erf_op.py b/python/paddle/fluid/tests/unittests/test_erf_op.py new file mode 100644 index 0000000000000000000000000000000000000000..93ab0212f136adfedacb52a2fde47e15edf279d3 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_erf_op.py @@ -0,0 +1,63 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +from scipy.special import erf +from op_test import OpTest + +import paddle.fluid as fluid +import paddle.fluid.dygraph as dg + + +class TestErfOp(OpTest): + def setUp(self): + self.op_type = "erf" + self.dtype = self._init_dtype() + self.x_shape = [11, 17] + x = np.random.uniform(-1, 1, size=self.x_shape).astype(self.dtype) + y_ref = erf(x).astype(self.dtype) + self.inputs = {'X': x} + self.outputs = {'Out': y_ref} + + def _init_dtype(self): + return "float64" + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(['X'], 'Out') + + +class TestErfLayer(unittest.TestCase): + def _test_case(self, place): + x = np.random.uniform(-1, 1, size=(11, 17)).astype(np.float64) + y_ref = erf(x) + with dg.guard(place) as g: + x_var = dg.to_variable(x) + y_var = fluid.layers.erf(x_var) + y_test = y_var.numpy() + self.assertTrue(np.allclose(y_ref, y_test)) + + def test_case(self): + self._test_case(fluid.CPUPlace()) + if fluid.is_compiled_with_cuda(): + self._test_case(fluid.CUDAPlace(0)) + + +if __name__ == '__main__': + unittest.main()