diff --git a/paddle/fluid/operators/digamma_op.cc b/paddle/fluid/operators/digamma_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..b1a58817e060434d0e309da3476edb5e96b5dfa3 --- /dev/null +++ b/paddle/fluid/operators/digamma_op.cc @@ -0,0 +1,100 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/digamma_op.h" + +namespace paddle { +namespace operators { + +class DigammaOp : public framework::OperatorWithKernel { + public: + DigammaOp(const std::string &type, const framework::VariableNameMap &inputs, + const framework::VariableNameMap &outputs, + const framework::AttributeMap &attrs) + : OperatorWithKernel(type, inputs, outputs, attrs) {} + + void InferShape(framework::InferShapeContext *ctx) const override { + OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "Digamma"); + OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "Digamma"); + + auto in_dims = ctx->GetInputDim("X"); + + ctx->SetOutputDim("Out", in_dims); + ctx->ShareLoD("X", "Out"); + } +}; + +class DigammaOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", "(Tensor), The input tensor of digamma operator."); + AddOutput("Out", "(Tensor), The output tensor of digamma operator."); + AddComment(R"DOC( +Digamma Operator. + +This operator is used to perform elementwise digamma for input $X$. +$$out = \Psi(x) = \frac{ \Gamma^{'}(x) }{ \Gamma(x) }$$ + +)DOC"); + } +}; + +class DigammaGradOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + void InferShape(framework::InferShapeContext *ctx) const override { + OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), "Input", + "Out@Grad", "DigammaGrad"); + OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "DigammaGrad"); + OP_INOUT_CHECK(ctx->HasOutput(framework::GradVarName("X")), "Output", + "X@Grad", "DigammaGrad"); + + auto dout_dims = ctx->GetInputDim(framework::GradVarName("Out")); + ctx->SetOutputDim(framework::GradVarName("X"), dout_dims); + ctx->ShareLoD(framework::GradVarName("Out"), framework::GradVarName("X")); + } +}; + +template +class DigammaGradOpMaker : public framework::SingleGradOpMaker { + public: + using framework::SingleGradOpMaker::SingleGradOpMaker; + + void Apply(GradOpPtr retv) const override { + retv->SetType("digamma_grad"); + retv->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); + retv->SetInput("X", this->Input("X")); + retv->SetAttrMap(this->Attrs()); + retv->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; + +REGISTER_OPERATOR(digamma, ops::DigammaOp, ops::DigammaOpMaker, + ops::DigammaGradOpMaker, + ops::DigammaGradOpMaker); +REGISTER_OPERATOR(digamma_grad, ops::DigammaGradOp); + +REGISTER_OP_CPU_KERNEL( + digamma, ops::DigammaKernel, + ops::DigammaKernel); + +REGISTER_OP_CPU_KERNEL( + digamma_grad, + ops::DigammaGradKernel, + ops::DigammaGradKernel); diff --git a/paddle/fluid/operators/digamma_op.cu b/paddle/fluid/operators/digamma_op.cu new file mode 100644 index 0000000000000000000000000000000000000000..5f2f59ba520d0fb1e2c083c211bceba0e4a25715 --- /dev/null +++ b/paddle/fluid/operators/digamma_op.cu @@ -0,0 +1,26 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/digamma_op.h" + +namespace ops = paddle::operators; + +REGISTER_OP_CUDA_KERNEL( + digamma, ops::DigammaKernel, + ops::DigammaKernel); + +REGISTER_OP_CUDA_KERNEL( + digamma_grad, + ops::DigammaGradKernel, + ops::DigammaGradKernel); diff --git a/paddle/fluid/operators/digamma_op.h b/paddle/fluid/operators/digamma_op.h new file mode 100644 index 0000000000000000000000000000000000000000..f82628f020480f5eca22079b13e586e1ebf13643 --- /dev/null +++ b/paddle/fluid/operators/digamma_op.h @@ -0,0 +1,99 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/operator.h" +#include "paddle/fluid/platform/for_range.h" + +namespace paddle { +namespace operators { + +template +struct DigammaFunctor { + DigammaFunctor(const T* input, T* output, int64_t numel) + : input_(input), output_(output), numel_(numel) {} + + HOSTDEVICE void operator()(int64_t idx) const { + output_[idx] = Eigen::numext::digamma(input_[idx]); + } + + private: + const T* input_; + T* output_; + int64_t numel_; +}; + +template +struct DigammaGradFunctor { + DigammaGradFunctor(const T* dout, const T* x, T* output, int64_t numel) + : dout_(dout), x_(x), output_(output), numel_(numel) {} + + HOSTDEVICE void operator()(int64_t idx) const { + output_[idx] = dout_[idx] * Eigen::numext::polygamma(T(1), x_[idx]); + } + + private: + const T* dout_; + const T* x_; + T* output_; + int64_t numel_; +}; + +using Tensor = framework::Tensor; + +template +class DigammaKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + const Tensor* x = context.Input("X"); + Tensor* out = context.Output("Out"); + + auto numel = x->numel(); + auto* x_data = x->data(); + auto* out_data = out->mutable_data(context.GetPlace(), + size_t(x->numel() * sizeof(T))); + + auto& dev_ctx = context.template device_context(); + platform::ForRange for_range(dev_ctx, numel); + DigammaFunctor functor(x_data, out_data, numel); + for_range(functor); + } +}; + +template +class DigammaGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + const Tensor* d_out = context.Input(framework::GradVarName("Out")); + const Tensor* x = context.Input("X"); + auto* d_x = context.Output(framework::GradVarName("X")); + + auto numel = d_out->numel(); + auto* dout_data = d_out->data(); + auto* x_data = x->data(); + auto* dx_data = d_x->mutable_data( + context.GetPlace(), static_cast(numel * sizeof(T))); + + auto& dev_ctx = context.template device_context(); + platform::ForRange for_range(dev_ctx, numel); + DigammaGradFunctor functor(dout_data, x_data, dx_data, numel); + for_range(functor); + } +}; + +} // namespace operators +} // namespace paddle diff --git a/python/paddle/__init__.py b/python/paddle/__init__.py index 3c16f327df4c24a669e5ab85d1c7addfa270725f..738de4e393d7782631fd1d2ca256e6afcb23fa9a 100755 --- a/python/paddle/__init__.py +++ b/python/paddle/__init__.py @@ -205,6 +205,7 @@ from .tensor.math import isnan # noqa: F401 from .tensor.math import prod # noqa: F401 from .tensor.math import broadcast_shape # noqa: F401 from .tensor.math import conj # noqa: F401 +from .tensor.math import digamma # noqa: F401 from .tensor.math import neg # noqa: F401 from .tensor.math import lgamma # noqa: F401 @@ -489,5 +490,6 @@ __all__ = [ # noqa 'log10', 'concat', 'check_shape', + 'digamma', 'standard_normal' ] diff --git a/python/paddle/fluid/tests/unittests/test_digamma_op.py b/python/paddle/fluid/tests/unittests/test_digamma_op.py new file mode 100644 index 0000000000000000000000000000000000000000..86f59af19346cb23d76012a19c4a02690449a61b --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_digamma_op.py @@ -0,0 +1,119 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import math +import numpy as np +from scipy.special import psi +import paddle +import paddle.fluid as fluid +import paddle.static as static +from op_test import OpTest + + +class TestDigammaOp(OpTest): + def setUp(self): + # switch to static + paddle.enable_static() + + self.op_type = 'digamma' + self.init_dtype_type() + shape = (5, 32) + data = np.random.random(shape).astype(self.dtype) + 1 + self.inputs = {'X': data} + result = np.ones(shape).astype(self.dtype) + result = psi(data) + self.outputs = {'Out': result} + + def init_dtype_type(self): + self.dtype = np.float64 + + def test_check_output(self): + self.check_output() + + def test_check_grad_normal(self): + self.check_grad(['X'], 'Out') + + +class TestDigammaOpFp32(TestDigammaOp): + def init_dtype_type(self): + self.dtype = np.float32 + + def test_check_grad_normal(self): + self.check_grad(['X'], 'Out') + + +class TestDigammaAPI(unittest.TestCase): + def setUp(self): + # switch to static + paddle.enable_static() + # prepare test attrs + self.dtypes = ["float32", "float64"] + self.places = [paddle.CPUPlace()] + if paddle.is_compiled_with_cuda(): + self.places.append(paddle.CUDAPlace(0)) + self._shape = [8, 3, 32, 32] + + def test_in_static_mode(self): + def init_input_output(dtype): + input = np.random.random(self._shape).astype(dtype) + return {'x': input}, psi(input) + + for dtype in self.dtypes: + input_dict, sc_res = init_input_output(dtype) + for place in self.places: + with static.program_guard(static.Program()): + x = static.data(name="x", shape=self._shape, dtype=dtype) + out = paddle.digamma(x) + + exe = static.Executor(place) + out_value = exe.run(feed=input_dict, fetch_list=[out.name]) + self.assertEqual( + np.allclose( + out_value[0], sc_res, rtol=1e-5), True) + + def test_in_dynamic_mode(self): + for dtype in self.dtypes: + input = np.random.random(self._shape).astype(dtype) + sc_res = psi(input) + for place in self.places: + # it is more convenient to use `guard` than `enable/disable_**` here + with fluid.dygraph.guard(place): + input_t = paddle.to_tensor(input) + res = paddle.digamma(input_t).numpy() + self.assertEqual(np.allclose(res, sc_res, rtol=1e-05), True) + + def test_name_argument(self): + with static.program_guard(static.Program()): + x = static.data(name="x", shape=self._shape, dtype=self.dtypes[0]) + out = paddle.digamma(x, name="digamma_res") + self.assertTrue("digamma_res" in out.name) + + def test_dtype_error(self): + # in static mode + with self.assertRaises(TypeError): + with static.program_guard(static.Program()): + x = static.data(name="x", shape=self._shape, dtype="int32") + out = paddle.digamma(x, name="digamma_res") + + # in dynamic mode + with self.assertRaises(RuntimeError): + with fluid.dygraph.guard(): + input = np.random.random(self._shape).astype("int32") + input_t = paddle.to_tensor(input) + res = paddle.digamma(input_t) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/tensor/__init__.py b/python/paddle/tensor/__init__.py index 0b8d2be24f3cb4336ced0b58e55ad7a989884c74..8c83b1786b01e058b84211136750b271833dbb79 100755 --- a/python/paddle/tensor/__init__.py +++ b/python/paddle/tensor/__init__.py @@ -162,6 +162,7 @@ from .math import all # noqa: F401 from .math import any # noqa: F401 from .math import broadcast_shape # noqa: F401 from .math import conj # noqa: F401 +from .math import digamma # noqa: F401 from .math import neg # noqa: F401 from .math import lgamma # noqa: F401 @@ -347,5 +348,6 @@ tensor_method_func = [ #noqa 'rank', 'shape', 'real', - 'imag' + 'imag', + 'digamma' ] diff --git a/python/paddle/tensor/math.py b/python/paddle/tensor/math.py index 15d0cd0146ab09587ea7c482d08c9dc7cbf7f99f..a9e24949aae2b18e4dcba30bc98613278ef52977 100755 --- a/python/paddle/tensor/math.py +++ b/python/paddle/tensor/math.py @@ -2283,6 +2283,42 @@ def conj(x, name=None): helper.append_op(type='conj', inputs={'X': x}, outputs={'Out': [out]}) return out +def digamma(x, name=None): + r""" + Calculates the digamma of the given input tensor, element-wise. + + .. math:: + Out = \Psi(x) = \frac{ \Gamma^{'}(x) }{ \Gamma(x) } + + Args: + x (Tensor): Input Tensor. Must be one of the following types: float32, float64. + name(str, optional): The default value is None. Normally there is no need for + user to set this property. For more information, please refer to :ref:`api_guide_Name` + Returns: + Tensor, the digamma of the input Tensor, the shape and data type is the same with input. + + Examples: + .. code-block:: python + + import paddle + + data = paddle.to_tensor([[1, 1.5], [0, -2.2]], dtype='float32') + res = paddle.digamma(data) + print(res) + # Tensor(shape=[2, 2], dtype=float32, place=CUDAPlace(0), stop_gradient=True, + # [[-0.57721591, 0.03648996], + # [ nan , 5.32286835]]) + """ + + if in_dygraph_mode(): + return core.ops.digamma(x) + + check_variable_and_dtype(x, 'x', ['float32', 'float64'], 'digamma') + helper = LayerHelper('digamma', **locals()) + out = helper.create_variable_for_type_inference(x.dtype) + helper.append_op(type='digamma', inputs={'X': x}, outputs={'Out': out}) + return out + def neg(x, name=None): """ This function computes the negative of the Tensor elementwisely.