未验证 提交 511a2c1c 编写于 作者: C Charles-hit 提交者: GitHub

Move api(lgamma) from legacy_api.yaml to api.yaml (#44355)

* Move api(lgamma) from legacy_api.yaml to api.yaml

* Move api(lgamma) from legacy_api.yaml to api.yaml

* Move api(lgamma) from legacy_api.yaml to api.yaml

* modify code style

* add x to X mapping

* add definition of lgamma

* delete redundant lgamma definitions

* Modify code comments

* Modify ops.py code format

* add lgamma  single test and lgamma api in fluid

* Optimized lgamma unittest
上级 9a3e1bce
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/phi/infermeta/unary.h"
namespace paddle {
namespace operators {
class LgammaOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "(Tensor), The input tensor of lgamma op.");
AddOutput("Out", "(Tensor), The output tensor of lgamma op.");
AddComment(R"DOC(
Lgamma Operator.
This operator performs elementwise lgamma for input $X$.
$$out = log\Gamma(x)$$
)DOC");
}
};
class LgammaOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
};
template <typename T>
class LgammaGradMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
void Apply(GradOpPtr<T> retv) const override {
retv->SetType("lgamma_grad");
retv->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
retv->SetInput("X", this->Input("X"));
retv->SetAttrMap(this->Attrs());
retv->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
}
};
class LgammaGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")),
"Input",
"Out@Grad",
"LgammaGrad");
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "LgammaGrad");
OP_INOUT_CHECK(ctx->HasOutput(framework::GradVarName("X")),
"Output",
"X@Grad",
"LgammaGrad");
auto dout_dims = ctx->GetInputDim(framework::GradVarName("Out"));
ctx->SetOutputDim(framework::GradVarName("X"), dout_dims);
ctx->ShareLoD(framework::GradVarName("Out"), framework::GradVarName("X"));
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
DECLARE_INFER_SHAPE_FUNCTOR(lgamma,
LgammaInferShapeFunctor,
PD_INFER_META(phi::UnchangedInferMeta));
REGISTER_OPERATOR(lgamma,
ops::LgammaOp,
ops::LgammaOpMaker,
ops::LgammaGradMaker<paddle::framework::OpDesc>,
ops::LgammaGradMaker<paddle::imperative::OpBase>,
LgammaInferShapeFunctor);
REGISTER_OPERATOR(lgamma_grad, ops::LgammaGradOp);
......@@ -98,6 +98,15 @@
func : erf
backward : erf_grad
- api : lgamma
args : (Tensor x)
output : Tensor(out)
infer_meta :
func : UnchangedInferMeta
kernel :
func : lgamma
backward : lgamma_grad
- api : mv
args : (Tensor x, Tensor vec)
output : Tensor
......
......@@ -77,6 +77,12 @@
outputs :
out : Out
- api : lgamma
inputs :
x : X
outputs :
out : Out
- api : mv
inputs :
{x : X, vec : Vec}
......
......@@ -105,6 +105,16 @@
func : erf_grad
data_type : out_grad
- backward_api : lgamma_grad
forward : lgamma(Tensor x) -> Tensor(out)
args : (Tensor x, Tensor out_grad)
output : Tensor(x_grad)
infer_meta :
func : UnchangedInferMeta
param: [x]
kernel :
func : lgamma_grad
- backward_api : mv_grad
forward : mv (Tensor x, Tensor vec) -> Tensor(out)
args : (Tensor x, Tensor vec, Tensor out_grad)
......
......@@ -1263,15 +1263,6 @@
kernel :
func : less_than
- api : lgamma
args : (Tensor x)
output : Tensor(out)
infer_meta :
func : UnchangedInferMeta
kernel :
func : lgamma
backward : lgamma_grad
- api : linspace
args : (Tensor start, Tensor stop, Tensor number, DataType dtype)
output : Tensor
......
......@@ -1118,16 +1118,6 @@
kernel :
func : lerp_grad
- backward_api : lgamma_grad
forward : lgamma(Tensor x) -> Tensor(out)
args : (Tensor x, Tensor out_grad)
output : Tensor(x_grad)
infer_meta :
func : UnchangedInferMeta
param: [x]
kernel :
func : lgamma_grad
- backward_api : log10_grad
forward : log10 (Tensor x) -> Tensor(out)
args : (Tensor x, Tensor out_grad)
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/compat/op_utils.h"
namespace phi {
KernelSignature LgammaGradOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature("lgamma_grad", {"X", "Out@GRAD"}, {}, {"X@GRAD"});
}
} // namespace phi
PD_REGISTER_ARG_MAPPING_FN(lgamma_grad, phi::LgammaGradOpArgumentMapping);
......@@ -20,6 +20,7 @@ from ..framework import convert_np_dtype_to_dtype_, Variable, in_dygraph_mode
from ..data_feeder import convert_dtype, check_variable_and_dtype, check_type, check_dtype
from paddle.utils import deprecated
from paddle import _C_ops
import paddle
__deprecated_func_name__ = {
'tanh_shrink': 'tanhshrink',
......@@ -37,28 +38,9 @@ __activations_noattr__ = [
]
__unary_func__ = [
'exp',
'expm1',
'atan',
'sqrt',
'rsqrt',
'abs',
'ceil',
'floor',
'cos',
'tan',
'acos',
'sin',
'sinh',
'asin',
'cosh',
'round',
'reciprocal',
'square',
'lgamma',
'acosh',
'asinh',
'atanh',
'exp', 'expm1', 'atan', 'sqrt', 'rsqrt', 'abs', 'ceil', 'floor', 'cos',
'tan', 'acos', 'sin', 'sinh', 'asin', 'cosh', 'round', 'reciprocal',
'square', 'acosh', 'asinh', 'atanh', 'lgamma'
]
__inplace_unary_func__ = [
......@@ -480,20 +462,6 @@ Examples:
""")
add_sample_code(
globals()["lgamma"], r"""
Examples:
.. code-block:: python
import paddle
x = paddle.to_tensor([-0.4, -0.2, 0.1, 0.3])
out = paddle.lgamma(x)
print(out)
# [1.31452441, 1.76149750, 2.25271273, 1.09579802]
""")
add_sample_code(
globals()["softplus"], r"""
Examples:
......@@ -860,3 +828,31 @@ Examples:
print(out)
# [-0.42839236 -0.22270259 0.11246292 0.32862676]
"""
def lgamma(x, name=None):
r"""
Calculates the lgamma of the given input tensor, element-wise.
This operator performs elementwise lgamma for input $X$.
:math:`out = log\Gamma(x)`
Args:
x (Tensor): Input Tensor. Must be one of the following types: float32, float64.
name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.
Returns:
Tensor, the lgamma of the input Tensor, the shape and data type is the same with input.
Examples:
.. code-block:: python
import paddle
x = paddle.to_tensor([-0.4, -0.2, 0.1, 0.3])
out = paddle.lgamma(x)
print(out)
# [1.31452441, 1.76149750, 2.25271273, 1.09579802]
"""
return paddle.Tensor.lgamma(x)
......@@ -17,6 +17,7 @@ import math
import numpy as np
import paddle
from op_test import OpTest
from scipy import special
paddle.enable_static()
......@@ -58,5 +59,19 @@ class TestLgammaOpFp32(TestLgammaOp):
check_eager=True)
class TestLgammaOpApi(unittest.TestCase):
def test_lgamma(self):
paddle.disable_static()
self.dtype = "float32"
shape = (1, 4)
data = np.random.random(shape).astype(self.dtype) + 1
data_ = paddle.to_tensor(data)
out = paddle.fluid.layers.lgamma(data_)
result = special.gammaln(data)
self.assertTrue(np.allclose(result, out.numpy()))
paddle.enable_static()
if __name__ == "__main__":
unittest.main()
......@@ -63,7 +63,6 @@ from .ops import erf # noqa: F401
from .ops import sqrt # noqa: F401
from .ops import sqrt_ # noqa: F401
from .ops import sin # noqa: F401
from .ops import lgamma # noqa: F401
from .ops import asinh # noqa: F401
from .ops import acosh # noqa: F401
from .ops import atanh # noqa: F401
......@@ -3713,6 +3712,43 @@ def digamma(x, name=None):
helper.append_op(type='digamma', inputs={'X': x}, outputs={'Out': out})
return out
def lgamma(x, name=None):
r"""
Calculates the lgamma of the given input tensor, element-wise.
This operator performs elementwise lgamma for input $X$.
:math:`out = log\Gamma(x)`
Args:
x (Tensor): Input Tensor. Must be one of the following types: float32, float64.
name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.
Returns:
Tensor, the lgamma of the input Tensor, the shape and data type is the same with input.
Examples:
.. code-block:: python
import paddle
x = paddle.to_tensor([-0.4, -0.2, 0.1, 0.3])
out = paddle.lgamma(x)
print(out)
# [1.31452441, 1.76149750, 2.25271273, 1.09579802]
"""
if in_dygraph_mode():
return _C_ops.final_state_lgamma(x)
elif _in_legacy_dygraph():
return _C_ops.lgamma(x)
check_variable_and_dtype(x, 'x', ['float32', 'float64'], 'lgamma')
helper = LayerHelper('lgamma', **locals())
out = helper.create_variable_for_type_inference(x.dtype)
helper.append_op(type='lgamma', inputs={'X': x}, outputs={'Out': out})
return out
def neg(x, name=None):
"""
This function computes the negative of the Tensor elementwisely.
......
......@@ -54,7 +54,6 @@ __unary_func__ = [
'round',
'reciprocal',
'square',
'lgamma',
'acosh',
'asinh',
'atanh',
......@@ -475,20 +474,6 @@ Examples:
""")
add_sample_code(
globals()["lgamma"], r"""
Examples:
.. code-block:: python
import paddle
x = paddle.to_tensor([-0.4, -0.2, 0.1, 0.3])
out = paddle.lgamma(x)
print(out)
# [1.31452441, 1.76149750, 2.25271273, 1.09579802]
""")
add_sample_code(
globals()["softplus"], r"""
Examples:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册