diff --git a/paddle/fluid/operators/lgamma_op.cc b/paddle/fluid/operators/lgamma_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..148fb05afcfd9a4ef1fcbc587a2bd33947a41000 --- /dev/null +++ b/paddle/fluid/operators/lgamma_op.cc @@ -0,0 +1,99 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/operators/lgamma_op.h" + +namespace paddle { +namespace operators { + +class LgammaOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", "(Tensor), The input tensor of lgamma op."); + AddOutput("Out", "(Tensor), The output tensor of lgamma op."); + AddComment(R"DOC( +Lgamma Operator. + +This operator performs elementwise lgamma for input $X$. +$$out = log\Gamma(x)$$ + +)DOC"); + } +}; + +class LgammaOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "Lgamma"); + OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "Lgamma"); + + auto in_dims = ctx->GetInputDim("X"); + + ctx->SetOutputDim("Out", in_dims); + ctx->ShareLoD("X", "Out"); + } +}; + +template +class LgammaGradMaker : public framework::SingleGradOpMaker { + public: + using framework::SingleGradOpMaker::SingleGradOpMaker; + + void Apply(GradOpPtr retv) const override { + retv->SetType("lgamma_grad"); + retv->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); + retv->SetInput("X", this->Input("X")); + retv->SetAttrMap(this->Attrs()); + retv->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); + } +}; + +class LgammaGradOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), "Input", + "Out@Grad", "LgammaGrad"); + OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "LgammaGrad"); + OP_INOUT_CHECK(ctx->HasOutput(framework::GradVarName("X")), "Output", + "X@Grad", "LgammaGrad"); + + auto dout_dims = ctx->GetInputDim(framework::GradVarName("Out")); + ctx->SetOutputDim(framework::GradVarName("X"), dout_dims); + ctx->ShareLoD(framework::GradVarName("Out"), framework::GradVarName("X")); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; + +REGISTER_OPERATOR(lgamma, ops::LgammaOp, ops::LgammaOpMaker, + ops::LgammaGradMaker, + ops::LgammaGradMaker); + +REGISTER_OPERATOR(lgamma_grad, ops::LgammaGradOp); + +REGISTER_OP_CPU_KERNEL( + lgamma, ops::LgammaKernel, + ops::LgammaKernel) + +REGISTER_OP_CPU_KERNEL( + lgamma_grad, + ops::LgammaGradKernel, + ops::LgammaGradKernel); diff --git a/paddle/fluid/operators/lgamma_op.cu b/paddle/fluid/operators/lgamma_op.cu new file mode 100644 index 0000000000000000000000000000000000000000..befd31e3bd8b1898ad6c59dca80dac3ae6de339d --- /dev/null +++ b/paddle/fluid/operators/lgamma_op.cu @@ -0,0 +1,64 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include "paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h" +#include "paddle/fluid/operators/lgamma_op.h" +#include "paddle/fluid/operators/math/complex_functors.h" + +namespace paddle { +namespace operators { + +template +struct CudaLgammaFunctor; + +template +struct CudaLgammaFunctor>> { + __device__ __forceinline__ T operator()(const T* args) const { + return Eigen::numext::lgamma(args[0]); + } +}; + +template +class LgammaKernel + : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + const Tensor* x = context.Input("X"); + Tensor* out = context.Output("Out"); + out->mutable_data>(context.GetPlace()); + + auto& dev_ctx = context.device_context(); + std::vector ins = {x}; + std::vector outs = {out}; + auto functor = CudaLgammaFunctor(); + LaunchSameDimsElementwiseCudaKernel>(dev_ctx, ins, &outs, + functor); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; + +REGISTER_OP_CUDA_KERNEL( + lgamma, ops::LgammaKernel, + ops::LgammaKernel); + +REGISTER_OP_CUDA_KERNEL( + lgamma_grad, + ops::LgammaGradKernel, + ops::LgammaGradKernel); diff --git a/paddle/fluid/operators/lgamma_op.h b/paddle/fluid/operators/lgamma_op.h new file mode 100644 index 0000000000000000000000000000000000000000..674054e74573208ea9bbd537419d202e1a30d8c0 --- /dev/null +++ b/paddle/fluid/operators/lgamma_op.h @@ -0,0 +1,100 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/operator.h" +#include "paddle/fluid/platform/for_range.h" + +namespace paddle { +namespace operators { + +template +struct LgammaFunctor { + LgammaFunctor(const T* input, T* output, int64_t numel) + : input_(input), output_(output), numel_(numel) {} + + HOSTDEVICE void operator()(int64_t idx) const { + output_[idx] = Eigen::numext::lgamma(input_[idx]); + } + + private: + const T* input_; + T* output_; + int64_t numel_; +}; + +template +struct LgammaGradFunctor { + LgammaGradFunctor(const T* dout, const T* x, T* output, int64_t numel) + : dout_(dout), x_(x), output_(output), numel_(numel) {} + + HOSTDEVICE void operator()(int64_t idx) const { + output_[idx] = dout_[idx] * Eigen::numext::digamma(x_[idx]); + } + + private: + const T* dout_; + const T* x_; + T* output_; + int64_t numel_; +}; + +using Tensor = framework::Tensor; + +template +class LgammaKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + const Tensor* x = context.Input("X"); + Tensor* out = context.Output("Out"); + + auto numel = x->numel(); + auto* x_data = x->data(); + auto* out_data = out->mutable_data(context.GetPlace(), + size_t(x->numel() * sizeof(T))); + + auto& dev_ctx = context.template device_context(); + platform::ForRange for_range(dev_ctx, numel); + LgammaFunctor functor(x_data, out_data, numel); + for_range(functor); + } +}; + +template +class LgammaGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const { + const framework::Tensor* d_out = + ctx.Input(framework::GradVarName("Out")); + const framework::Tensor* x = ctx.Input("X"); + framework::Tensor* d_x = + ctx.Output(framework::GradVarName("X")); + + auto numel = d_out->numel(); + auto* dout_data = d_out->data(); + auto* x_data = x->data(); + auto* dx_data = d_x->mutable_data( + ctx.GetPlace(), static_cast(numel * sizeof(T))); + + auto& dev_ctx = ctx.template device_context(); + platform::ForRange for_range(dev_ctx, numel); + LgammaGradFunctor functor(dout_data, x_data, dx_data, numel); + for_range(functor); + } +}; + +} // namespace operators +} // namespace paddle diff --git a/python/paddle/fluid/tests/unittests/test_lgamma_op.py b/python/paddle/fluid/tests/unittests/test_lgamma_op.py new file mode 100644 index 0000000000000000000000000000000000000000..686d5b1eb6dfefc024ffb435f802dea25fe1d2e0 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_lgamma_op.py @@ -0,0 +1,56 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import math +import numpy as np +import paddle +from op_test import OpTest + +paddle.enable_static() + + +class TestLgammaOp(OpTest): + def setUp(self): + self.op_type = 'lgamma' + self.init_dtype_type() + shape = (5, 20) + data = np.random.random(shape).astype(self.dtype) + 1 + self.inputs = {'X': data} + result = np.ones(shape).astype(self.dtype) + for i in range(shape[0]): + for j in range(shape[1]): + result[i][j] = math.lgamma(data[i][j]) + self.outputs = {'Out': result} + + def init_dtype_type(self): + self.dtype = np.float64 + + def test_check_output(self): + self.check_output() + + def test_check_grad_normal(self): + self.check_grad(['X'], 'Out', numeric_grad_delta=1e-7) + + +class TestLgammaOpFp32(TestLgammaOp): + def init_dtype_type(self): + self.dtype = np.float32 + + def test_check_grad_normal(self): + self.check_grad(['X'], 'Out', numeric_grad_delta=0.005) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/white_list/op_threshold_white_list.py b/python/paddle/fluid/tests/unittests/white_list/op_threshold_white_list.py index 6076e9dc9f60405c4b5e4dde002191e9f1fdcd5b..c771531b7b61be7933b5355204c532b847b13dc5 100644 --- a/python/paddle/fluid/tests/unittests/white_list/op_threshold_white_list.py +++ b/python/paddle/fluid/tests/unittests/white_list/op_threshold_white_list.py @@ -45,6 +45,7 @@ NEED_FIX_FP64_CHECK_GRAD_THRESHOLD_OP_LIST = [ 'bilateral_slice',\ 'cudnn_lstm', \ 'rnn', \ + 'lgamma', \ ] NEED_FIX_FP64_CHECK_OUTPUT_THRESHOLD_OP_LIST = ['bilinear_interp',\