未验证 提交 2f351ed5 编写于 作者: M minghaoBD 提交者: GitHub

add silu op, test=develop (#32384)

上级 7ef1de67
......@@ -35,7 +35,7 @@ void IsTestPass::ApplyImpl(ir::Graph* graph) const {
"hard_shrink", "hard_sigmoid", "relu6",
"soft_relu", "swish", "thresholded_relu",
"log", "square", "softplus",
"softsign"};
"softsign", "silu"};
for (const Node* n : graph->Nodes()) {
if (n->IsOp()) {
auto* op = n->Op();
......
......@@ -162,6 +162,12 @@ $$out = \\frac{1}{1 + e^{-x}}$$
)DOC";
UNUSED constexpr char SiluDoc[] = R"DOC(
Silu Activation Operator
$$out = x * \\frac{1}{1 + e^{-x}}$$
)DOC";
UNUSED constexpr char LogSigmoidDoc[] = R"DOC(
Logsigmoid Activation Operator
......@@ -697,6 +703,7 @@ It is recommended to use the defaults for this activation.
};
REGISTER_ACTIVATION_OP_MAKER(Sigmoid, SigmoidDoc);
REGISTER_ACTIVATION_OP_MAKER(Silu, SiluDoc);
REGISTER_ACTIVATION_OP_MAKER(LogSigmoid, LogSigmoidDoc);
REGISTER_ACTIVATION_OP_MAKER(Exp, ExpDoc);
REGISTER_ACTIVATION_OP_MAKER(Relu, ReluDoc);
......
......@@ -258,6 +258,31 @@ struct SigmoidGradFunctor : public BaseActivationFunctor<T> {
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; }
};
// silu(x) = x / (1 + exp(-x))
template <typename T>
struct SiluFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Out out) const {
auto temp = static_cast<T>(1) / (static_cast<T>(1) + (-x).exp());
out.device(d) = x * temp;
}
};
// silu'(x) = (1 / (1 + e^{-x})) * (1 + out * e^{-x}))
template <typename T>
struct SiluGradFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Out, typename dOut,
typename dX>
void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
auto temp1 = static_cast<T>(1) + (-x).exp(); // 1+e^(-x)
auto temp2 = x * (-x).exp(); // x*e^(-x)
dx.device(d) = dout * ((static_cast<T>(1) / temp1) *
(static_cast<T>(1) + (temp2 / temp1)));
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
};
// Originally: logsigmoid(x) = -log (1 + exp(-x))
// For numerical stability, we can use the log-sum-exp trick:
// https://hips.seas.harvard.edu/blog/2013/01/09/computing-log-sum-exp/
......@@ -2129,6 +2154,7 @@ struct LogGradGradFunctor : public BaseActivationFunctor<T> {
#define FOR_EACH_ACTIVATION_OP(__macro) \
__macro(sigmoid, Sigmoid, SigmoidFunctor, SigmoidGradFunctor); \
__macro(silu, Silu, SiluFunctor, SiluGradFunctor); \
__macro(logsigmoid, LogSigmoid, LogSigmoidFunctor, LogSigmoidGradFunctor); \
__macro(atan, Atan, AtanFunctor, AtanGradFunctor); \
__macro(softshrink, SoftShrink, SoftShrinkFunctor, SoftShrinkGradFunctor); \
......
......@@ -27,6 +27,7 @@ __deprecated_func_name__ = {
__activations_noattr__ = [
'sigmoid',
'silu',
'logsigmoid',
'tanh_shrink',
'softplus',
......@@ -100,6 +101,20 @@ Examples:
""")
add_sample_code(globals()["silu"], r"""
Examples:
.. code-block:: python
import paddle
import paddle.nn.functional as F
x = paddle.to_tensor([1.0, 2.0, 3.0, 4.0])
out = F.silu(x)
print(out)
# [ 0.7310586 1.7615942 2.8577224, 3.9280552 ]
""")
add_sample_code(globals()["logsigmoid"], r"""
Examples:
.. code-block:: python
......
......@@ -119,6 +119,72 @@ class TestSigmoid(TestActivation):
self.check_grad(['X'], 'Out', max_relative_error=0.01)
class TestSilu(TestActivation):
def setUp(self):
self.op_type = "silu"
self.init_dtype()
np.random.seed(1024)
x = np.random.uniform(-1, 1, [11, 17]).astype(self.dtype)
out = x / (np.exp(-x) + 1)
self.inputs = {'X': x}
self.outputs = {'Out': out}
def init_dtype(self):
self.dtype = np.float32
def test_check_grad(self):
if self.dtype == np.float16:
return
self.check_grad(['X'], 'Out')
class TestSiluAPI(unittest.TestCase):
# test paddle.nn.Silu, paddle.nn.functional.silu
def setUp(self):
self.x_np = np.random.uniform(-1, 1, [11, 17]).astype('float32')
self.place = paddle.CUDAPlace(0) if core.is_compiled_with_cuda() \
else paddle.CPUPlace()
def test_static_api(self):
paddle.enable_static()
with paddle.static.program_guard(paddle.static.Program()):
x = paddle.fluid.data('X', [11, 17])
out1 = F.silu(x)
m = paddle.nn.Silu()
out2 = m(x)
exe = paddle.static.Executor(self.place)
res = exe.run(feed={'X': self.x_np}, fetch_list=[out1, out2])
out_ref = self.x_np / (1 + np.exp(-self.x_np))
for r in res:
self.assertEqual(np.allclose(out_ref, r), True)
def test_dygraph_api(self):
paddle.disable_static(self.place)
x = paddle.to_tensor(self.x_np)
out1 = F.silu(x)
m = paddle.nn.Silu()
out2 = m(x)
out_ref = self.x_np / (1 + np.exp(-self.x_np))
for r in [out1, out2]:
self.assertEqual(np.allclose(out_ref, r.numpy()), True)
paddle.enable_static()
def test_errors(self):
with paddle.static.program_guard(paddle.static.Program()):
# The input type must be Variable.
self.assertRaises(TypeError, F.silu, 1)
# The input dtype must be float16, float32, float64.
x_int32 = paddle.fluid.data(
name='x_int32', shape=[11, 17], dtype='int32')
self.assertRaises(TypeError, F.silu, x_int32)
# support the input dtype is float16
x_fp16 = paddle.fluid.data(
name='x_fp16', shape=[11, 17], dtype='float16')
F.silu(x_fp16)
class TestLogSigmoid(TestActivation):
def setUp(self):
self.op_type = "logsigmoid"
......@@ -2629,6 +2695,7 @@ def create_test_act_fp16_class(parent,
create_test_act_fp16_class(TestActivation)
create_test_act_fp16_class(TestSigmoid)
create_test_act_fp16_class(TestSilu)
create_test_act_fp16_class(TestLogSigmoid)
create_test_act_fp16_class(TestTanh)
create_test_act_fp16_class(TestTanhshrink)
......
......@@ -55,6 +55,7 @@ from .layer.activation import PReLU #DEFINE_ALIAS
from .layer.activation import ReLU #DEFINE_ALIAS
from .layer.activation import ReLU6 #DEFINE_ALIAS
from .layer.activation import SELU #DEFINE_ALIAS
from .layer.activation import Silu #DEFINE_ALIAS
from .layer.activation import LeakyReLU #DEFINE_ALIAS
from .layer.activation import Sigmoid #DEFINE_ALIAS
from .layer.activation import Hardsigmoid #DEFINE_ALIAS
......
......@@ -46,6 +46,7 @@ from .activation import relu_ #DEFINE_ALIAS
from .activation import relu6 #DEFINE_ALIAS
from .activation import selu #DEFINE_ALIAS
from .activation import sigmoid #DEFINE_ALIAS
from .activation import silu #DEFINE_ALIAS
# from .activation import soft_relu #DEFINE_ALIAS
from .activation import softmax #DEFINE_ALIAS
from .activation import softmax_ #DEFINE_ALIAS
......
......@@ -49,6 +49,7 @@ __all__ = [
'softshrink',
'softsign',
'sigmoid',
'silu'
'swish',
'tanh',
'tanh_',
......@@ -761,6 +762,39 @@ def selu(x,
return out
def silu(x, name=None):
"""
silu activation.
.. math:
silu(x) = \frac{x}{1 + e^{-x}}
Parameters:
x (Tensor): The input Tensor with data type float32, float64.
name (str, optional): Name for the operation (optional, default is None).
For more information, please refer to :ref:`api_guide_Name`.
Returns:
A Tensor with the same data type and shape as ``x`` .
Examples:
.. code-block:: python
import paddle
import paddle.nn.functional as F
x = paddle.to_tensor([1.0, 2.0, 3.0, 4.0])
out = F.silu(x) # [ 0.731059, 1.761594, 2.857722, 3.928055 ]
"""
if in_dygraph_mode():
return core.ops.silu(x)
check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'], 'silu')
helper = LayerHelper("silu", **locals())
out = helper.create_variable_for_type_inference(x.dtype)
helper.append_op(type='silu', inputs={'X': x}, outputs={'Out': out})
return out
def softmax(x, axis=-1, dtype=None, name=None):
r"""
This operator implements the softmax layer. The calculation process is as follows:
......
......@@ -27,6 +27,7 @@ __all__ = [
'SELU',
'LeakyReLU',
'Sigmoid',
'Silu',
'Hardsigmoid',
'Softmax',
'Softplus',
......@@ -919,6 +920,44 @@ class ThresholdedReLU(layers.Layer):
return 'threshold={}{}'.format(self._threshold, name_str)
class Silu(layers.Layer):
"""
Silu Activation.
.. math::
Silu(x) = \frac{x}{1 + e^{-x}}
Parameters:
x (Tensor): The input Tensor with data type float32, or float64.
name (str, optional): Name for the operation (optional, default is None).
For more information, please refer to :ref:`api_guide_Name`.
Shape:
- input: Tensor with any shape.
- output: Tensor with the same shape as input.
Examples:
.. code-block:: python
import paddle
x = paddle.to_tensor([1.0, 2.0, 3.0, 4.0])
m = paddle.nn.Silu()
out = m(x) # [ 0.731059, 1.761594, 2.857722, 3.928055 ]
"""
def __init__(self, name=None):
super(Silu, self).__init__()
self._name = name
def forward(self, x):
return F.silu(x, self._name)
def extra_repr(self):
name_str = 'name={}'.format(self._name) if self._name else ''
return name_str
class LogSigmoid(layers.Layer):
r"""
LogSigmoid Activation.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册