未验证 提交 b197bfe6 编写于 作者: W wangzhen38 提交者: GitHub

add logit API (#37844)

* add Logit API

* add unittest

* conflict

* pull conflit

* pull conflit logit

* fix unititest

* fix code style

* update docs style of

* update en doc

* fix docs en style

* fix docs en style1

* fix docs en style2

* fix docs en style3

* fix docs en style4

* fix docs en style5

* fix docs en style6

* fix docs en style7

* fix docs en style8

* update by review

* fix nan bug
上级 cba84f88
......@@ -600,6 +600,39 @@ class ELUGradOpMaker : public framework::SingleGradOpMaker<T> {
}
};
class LogitOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "Input of Logit operator");
AddOutput("Out", "Output of Logit operator");
AddAttr<float>("eps",
"(float, default 1e-6f) the epsilon for input clamp bound")
.SetDefault(1e-6f);
AddComment(R"DOC(
Logit Operator.
this function is defined as follow:
$ logit=ln\left ( {\frac {x} {1-x}} \right ) $
)DOC");
}
};
template <typename T>
class LogitGradOpMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
void Apply(GradOpPtr<T> grad_op) const override {
grad_op->SetType("logit_grad");
grad_op->SetInput("X", this->Input("X"));
grad_op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
grad_op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
grad_op->SetAttrMap(this->Attrs());
}
};
class CELUOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
......@@ -1192,6 +1225,67 @@ DECLARE_INPLACE_OP_INFERER(ActivationDoubleGradOpInplaceInferer,
DECLARE_INPLACE_OP_INFERER(ActivationTripleGradOpInplaceInferer,
{"DDX", "D_DOut"});
class LogitOp : public framework::OperatorWithKernel {
public:
LogitOp(const std::string& type, const framework::VariableNameMap& inputs,
const framework::VariableNameMap& outputs,
const framework::AttributeMap& attrs)
: OperatorWithKernel(type, inputs, outputs, attrs) {}
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true,
platform::errors::InvalidArgument(
"Input(%s) of LogitOp should not be null.", "X"));
PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true,
platform::errors::InvalidArgument(
"Output(%s) of LogitOp should not be null.", "Out"));
ctx->ShareDim("X", /*->*/ "Out");
ctx->ShareLoD("X", /*->*/ "Out");
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
framework::LibraryType library{framework::LibraryType::kPlain};
framework::DataLayout layout = framework::DataLayout::kAnyLayout;
auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X");
return framework::OpKernelType(data_type, ctx.GetPlace(), layout, library);
}
};
class LogitGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE_EQ(
ctx->HasInput(framework::GradVarName("Out")), true,
platform::errors::InvalidArgument(
"Input(%s) of LogitGradOp should not be null.", "DOut"));
PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true,
platform::errors::InvalidArgument(
"Input(%s) of LogitGradOp should not be null.", "X"));
PADDLE_ENFORCE_EQ(
ctx->HasOutput(framework::GradVarName("X")), true,
platform::errors::InvalidArgument(
"Output(%s) of LogitGradOp should not be null.", "DX"));
auto x_grad_name = framework::GradVarName("X");
ctx->SetOutputDim(x_grad_name, ctx->GetInputDim("X"));
ctx->ShareLoD("X", /*->*/ x_grad_name);
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
framework::LibraryType library{framework::LibraryType::kPlain};
framework::DataLayout layout = framework::DataLayout::kAnyLayout;
auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X");
return framework::OpKernelType(data_type, ctx.GetPlace(), layout, library);
}
};
template <typename T>
class PowGradOpMaker : public framework::SingleGradOpMaker<T> {
public:
......@@ -1491,6 +1585,20 @@ REGISTER_OP_CPU_KERNEL(
/* ========================================================================== */
/* ======================== logit register ============================
*/
REGISTER_OPERATOR(logit, ops::LogitOp, ops::LogitOpMaker,
ops::LogitGradOpMaker<paddle::framework::OpDesc>,
ops::LogitGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(logit_grad, ops::LogitGradOp);
REGISTER_OP_CPU_KERNEL(
logit, ops::LogitKernel<paddle::platform::CPUDeviceContext, float>,
ops::LogitKernel<paddle::platform::CPUDeviceContext, double>);
REGISTER_OP_CPU_KERNEL(
logit_grad, ops::LogitGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::LogitGradKernel<paddle::platform::CPUDeviceContext, double>);
/* ========================================================================== */
/* ======================== celu register ============================
*/
REGISTER_OPERATOR(
......
......@@ -1623,6 +1623,21 @@ REGISTER_OP_CUDA_KERNEL(
ops::PowGradFunctor<plat::float16>>);
/* ========================================================================== */
/* ========================== logit register ============================ */
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
logit, ops::LogitKernel<paddle::platform::CUDADeviceContext, float>,
ops::LogitKernel<paddle::platform::CUDADeviceContext, double>,
ops::LogitKernel<paddle::platform::CUDADeviceContext,
paddle::platform::float16>);
REGISTER_OP_CUDA_KERNEL(
logit_grad,
ops::LogitGradKernel<paddle::platform::CUDADeviceContext, float>,
ops::LogitGradKernel<paddle::platform::CUDADeviceContext, double>,
ops::LogitGradKernel<paddle::platform::CUDADeviceContext,
paddle::platform::float16>);
/* ========================================================================== */
/* ========================== exp register ============================ */
REGISTER_OP_CUDA_KERNEL(
exp, ops::ActivationCudaKernel<plat::CUDADeviceContext,
......
......@@ -1563,6 +1563,36 @@ struct PowGradFunctor : public BaseActivationFunctor<T> {
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
};
template <typename T>
struct LogitFunctor {
template <typename Device, typename X, typename Out, typename P>
void operator()(Device d, X x, Out out, P p, float eps) const {
// logit(x) = ln(x/(1-x))
auto tmp_x =
(x.cwiseMin(static_cast<T>(1.0 - eps))).cwiseMax(static_cast<T>(eps));
if (!eps) {
out.device(d) = (x < static_cast<T>(0.0) || x > static_cast<T>(1.0))
.select(p.constant(static_cast<T>(NAN)),
(tmp_x / (static_cast<T>(1) - tmp_x)).log());
} else {
out.device(d) = (tmp_x / (static_cast<T>(1) - tmp_x)).log();
}
}
};
template <typename T>
struct LogitGradFunctor {
template <typename Device, typename X, typename dOut, typename dX, typename P>
void operator()(Device d, X x, dOut dout, dX dx, P p, float eps) const {
// logit(x)' = 1/(x*(1-x))
dx.device(d) =
(x < static_cast<T>(eps) || x > static_cast<T>(1.0 - eps))
.select(p.constant(static_cast<T>(0)),
dout * (static_cast<T>(1) / ((static_cast<T>(1) - x) * x)));
}
};
template <typename T>
struct STanhFunctor : public BaseActivationFunctor<T> {
float scale_a;
......@@ -2599,6 +2629,49 @@ class PowGradKernel
}
};
template <typename DeviceContext, typename T>
class LogitKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* out = context.Output<framework::Tensor>("Out");
auto* in = context.Input<framework::Tensor>("X");
auto eps = context.Attr<float>("eps");
out->mutable_data<T>(in->place());
auto eigen_out = framework::EigenVector<T>::Flatten(*out);
auto eigen_in = framework::EigenVector<T>::Flatten(*in);
auto& place =
*context.template device_context<DeviceContext>().eigen_device();
auto eigen_p = framework::EigenVector<T>::Flatten(*out);
LogitFunctor<T> functor;
functor(place, eigen_in, eigen_out, eigen_p, eps);
}
};
template <typename DeviceContext, typename T>
class LogitGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* x = context.Input<framework::Tensor>("X");
auto* dout =
context.Input<framework::Tensor>(framework::GradVarName("Out"));
auto* dx = context.Output<framework::Tensor>(framework::GradVarName("X"));
auto eps = context.Attr<float>("eps");
dx->mutable_data<T>(dout->place());
auto eigen_x = framework::EigenVector<T>::Flatten(*x);
auto eigen_dout = framework::EigenVector<T>::Flatten(*dout);
auto eigen_dx = framework::EigenVector<T>::Flatten(*dx);
auto& place =
*context.template device_context<DeviceContext>().eigen_device();
auto eigen_p = framework::EigenVector<T>::Flatten(*x);
LogitGradFunctor<T> functor;
functor(place, eigen_x, eigen_dout, eigen_dx, eigen_p, eps);
}
};
template <typename T>
struct LogGradGradFunctor : public BaseActivationFunctor<T> {
template <typename Device>
......
......@@ -170,6 +170,7 @@ from .tensor.math import tan # noqa: F401
from .tensor.math import cosh # noqa: F401
from .tensor.math import cumsum # noqa: F401
from .tensor.math import cumprod # noqa: F401
from .tensor.math import logit # noqa: F401
from .tensor.math import exp # noqa: F401
from .tensor.math import expm1 # noqa: F401
from .tensor.math import floor # noqa: F401
......@@ -365,6 +366,7 @@ __all__ = [ # noqa
'eye',
'cumsum',
'cumprod',
'logit',
'sign',
'is_empty',
'equal',
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
from op_test import OpTest
import paddle
np.random.seed(10)
def logit(x, eps):
x_min = np.minimum(x, 1. - eps)
x_max = np.maximum(x_min, eps)
return np.log(x_max / (1. - x_max))
def logit_grad(x, eps=1e-8):
tmp_x = np.select([x < eps, x > (1. - eps)], [x * 0., x * 0.], default=-1.0)
x_1 = 1. - x
_x = np.select([tmp_x == -1.0], [np.reciprocal(x * x_1)], default=0.0)
dout = np.full_like(x, fill_value=1. / _x.size)
dx = dout * _x
return dx
class TestLogitOp(OpTest):
def setUp(self):
self.op_type = 'logit'
self.dtype = np.float64
self.shape = [120]
self.eps = 1e-8
self.set_attrs()
x = np.random.uniform(-1., 1., self.shape).astype(self.dtype)
out = logit(x, self.eps)
self.x_grad = logit_grad(x, self.eps)
self.inputs = {'X': x}
self.outputs = {'Out': out}
self.attrs = {'eps': self.eps}
def set_attrs(self):
pass
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['X'], ['Out'], user_defined_grads=[self.x_grad])
class TestLogitShape(TestLogitOp):
def set_attrs(self):
self.shape = [2, 60]
class TestLogitEps(TestLogitOp):
def set_attrs(self):
self.eps = 1e-8
class TestLogitAPI(unittest.TestCase):
def setUp(self):
self.x_shape = [120]
self.x = np.random.uniform(0., 1., self.x_shape).astype(np.float32)
self.place = paddle.CUDAPlace(0) \
if paddle.fluid.core.is_compiled_with_cuda() \
else paddle.CPUPlace()
def check_api(self, eps=1e-8):
ref_out = logit(self.x, eps)
# test static api
with paddle.static.program_guard(paddle.static.Program()):
x = paddle.fluid.data(name='x', shape=self.x_shape)
y = paddle.logit(x, eps)
exe = paddle.static.Executor(self.place)
out = exe.run(feed={'x': self.x}, fetch_list=[y])
self.assertTrue(np.allclose(out[0], ref_out))
# test dygrapg api
paddle.disable_static()
x = paddle.to_tensor(self.x)
y = paddle.logit(x, 1e-8)
self.assertTrue(np.allclose(y.numpy(), ref_out))
paddle.enable_static()
def test_check_api(self):
paddle.enable_static()
for eps in [1e-6, 0.0]:
self.check_api(eps)
def test_errors(self):
paddle.enable_static()
with paddle.static.program_guard(paddle.static.Program()):
x = paddle.fluid.data(name='X1', shape=[100], dtype='int32')
self.assertRaises(TypeError, paddle.logit, x)
x = paddle.fluid.data(name='X2', shape=[100], dtype='float32')
self.assertRaises(TypeError, paddle.logit, x, dtype='int32')
if __name__ == "__main__":
unittest.main()
......@@ -125,6 +125,7 @@ from .math import tan # noqa: F401
from .math import cosh # noqa: F401
from .math import cumsum # noqa: F401
from .math import cumprod # noqa: F401
from .math import logit # noqa: F401
from .math import exp # noqa: F401
from .math import exp_ # noqa: F401
from .math import expm1 # noqa: F401
......@@ -268,6 +269,7 @@ tensor_method_func = [ #noqa
'cosh',
'cumsum',
'cumprod',
'logit',
'exp',
'exp_',
'floor',
......
......@@ -2614,6 +2614,62 @@ def atan2(x, y, name=None):
type='atan2', inputs=inputs, outputs={'Out': out})
return out
def logit(x, eps=None, name=None):
r"""
This function generates a new tensor with the logit of the elements of input x. x is clamped to [eps, 1-eps] when eps is not zero. When eps is zero and x < 0 or x > 1, the function will yields NaN.
.. math::
logit(x) = ln(\frac{x}{1 - x})
where
.. math::
x_i=
\left\{\begin{array}{rcl}
x_i & &\text{if } eps == Default \\
eps & &\text{if } x_i < eps \\
x_i & &\text{if } eps <= x_i <= 1-eps \\
1-eps & &\text{if } x_i > 1-eps
\end{array}\right.
Args:
x (Tensor): The input Tensor with data type float32, float64.
eps (float, optional): the epsilon for input clamp bound. Default is None.
name (str, optional): Name for the operation (optional, default is None).
For more information, please refer to :ref:`api_guide_Name`.
Returns:
out(Tensor): A Tensor with the same data type and shape as ``x`` .
Examples:
.. code-block:: python
import paddle
x = paddle.to_tensor([0.2635, 0.0106, 0.2780, 0.2097, 0.8095])
out1 = paddle.logit(x)
print(out1)
# [-1.0277, -4.5365, -0.9544, -1.3269, 1.4468]
"""
if eps == None:
eps = 0.0
if in_dygraph_mode():
return _C_ops.logit(x, 'eps', eps)
check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'], 'logit')
helper = LayerHelper("logit", **locals())
out = helper.create_variable_for_type_inference(x.dtype)
helper.append_op(
type='logit',
inputs={'X': x},
outputs={'Out': out},
attrs={'eps': eps})
return out
def lerp(x, y, weight, name=None):
r"""
Does a linear interpolation between x and y based on weight.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册