activation.py 50.5 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
#   Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

Z
zhiboniu 已提交
15 16 17
from ...fluid.layers import sigmoid  # noqa: F401
from ...tensor.math import tanh  # noqa: F401
from ...tensor.math import tanh_  # noqa: F401
18

19
from ...fluid.dygraph.inplace_utils import inplace_apis_in_dygraph_only
F
Feiyu Chan 已提交
20 21
from ...tensor.manipulation import chunk
from ...tensor.math import multiply
22

23 24
import warnings
from ...fluid.layer_helper import LayerHelper
J
Jiabin Yang 已提交
25
from ...fluid.framework import convert_np_dtype_to_dtype_
26
from ...fluid.framework import _in_legacy_dygraph, in_dygraph_mode
27
from ...fluid.data_feeder import check_variable_and_dtype, check_dtype
28
import paddle
Z
zhiboniu 已提交
29 30
from paddle import _C_ops, in_dynamic_mode
from paddle.framework import core
31

32 33
__all__ = []

34

35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64
def celu(x, alpha=1.0, name=None):
    r"""
    celu activation.

    .. math::

        celu(x) = max(0, x) + min(0, \alpha * (e^{x/\alpha}-1))

    Parameters:
        x (Tensor): The input Tensor with data type float32, float64.
        alpha (float, optional): The 'alpha' value of the CELU formulation. Default is 1.0.
        name (str, optional): Name for the operation (optional, default is None).
            For more information, please refer to :ref:`api_guide_Name`.

    Returns:
        A Tensor with the same data type and shape as ``x`` .

    Examples:
        .. code-block:: python

            import paddle
            import paddle.nn.functional as F
            x = paddle.to_tensor([[-1., 6.], [1., 15.6]])
            out = F.celu(x, alpha=0.2)
            # [[-0.19865242,  6.        ],
            #  [ 1.        , 15.60000038]]
    """
    if alpha == 0:
        raise ZeroDivisionError("alpha cannot be 0 for celu")

Z
zhiboniu 已提交
65
    if in_dynamic_mode():
66 67 68 69 70 71 72 73 74 75 76 77 78
        return _C_ops.celu(x, 'alpha', alpha)

    check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'], 'celu')
    helper = LayerHelper("celu", **locals())
    out = helper.create_variable_for_type_inference(x.dtype)
    helper.append_op(
        type='celu',
        inputs={'X': x},
        outputs={'Out': out},
        attrs={'alpha': alpha})
    return out


79
def elu(x, alpha=1.0, name=None):
80
    r"""
81 82
    elu activation.

83
    .. math::
84

Z
zhupengyang 已提交
85 86 87 88 89 90 91
        elu(x)=
            \left\{
                \begin{array}{lcl}
                x,& &\text{if } \ x > 0 \\
                alpha * (e^{x} - 1),& &\text{if } \ x <= 0
                \end{array}
            \right.
92 93 94 95 96 97

    Parameters:
        x (Tensor): The input Tensor with data type float32, float64.
        alpha (float, optional): The 'alpha' value of the ELU formulation. Default is 1.0.
        name (str, optional): Name for the operation (optional, default is None).
            For more information, please refer to :ref:`api_guide_Name`.
98

99 100
    Returns:
        A Tensor with the same data type and shape as ``x`` .
101

102 103 104
    Examples:
        .. code-block:: python

105 106
            import paddle
            import paddle.nn.functional as F
107

Z
zhupengyang 已提交
108
            x = paddle.to_tensor([[-1., 6.], [1., 15.6]])
109
            out = F.elu(x, alpha=0.2)
110 111
            # [[-0.12642411  6.        ]
            #  [ 1.          15.6      ]]
112 113
    """

Z
zhiboniu 已提交
114
    if in_dynamic_mode():
W
wanghuancoder 已提交
115
        return _C_ops.elu(x, 'alpha', alpha)
116 117 118 119 120 121 122 123 124 125 126 127

    check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'], 'elu')
    helper = LayerHelper("elu", **locals())
    out = helper.create_variable_for_type_inference(x.dtype)
    helper.append_op(
        type='elu',
        inputs={'X': x},
        outputs={'Out': out},
        attrs={'alpha': alpha})
    return out


128
@inplace_apis_in_dygraph_only
129 130 131 132 133
def elu_(x, alpha=1.0, name=None):
    r"""
    Inplace version of ``elu`` API, the output Tensor will be inplaced with input ``x``.
    Please refer to :ref:`api_nn_cn_elu`.
    """
Z
zhupengyang 已提交
134
    assert alpha >= 0., "elu_ only support alpha >= 0, please use elu instead."
W
wanghuancoder 已提交
135
    return _C_ops.elu_(x, 'alpha', alpha)
136 137


138
def gelu(x, approximate=False, name=None):
139
    r"""
140 141 142
    gelu activation.

    if approximate is True
143 144 145

    .. math::

146
        gelu(x) = 0.5 * x * (1 + tanh(\sqrt{\frac{2}{\pi}} * (x + 0.044715x^{3})))
147

148
    else
149 150 151

    .. math::

152
        gelu(x) = 0.5 * x * (1 + erf(\frac{x}{\sqrt{2}}))
153

154 155 156 157 158
    Parameters:
        x (Tensor): The input Tensor with data type float32, float64.
        approximate (bool, optional): Wether to enable approximation. Default is False.
        name (str, optional): Name for the operation (optional, default is None).
            For more information, please refer to :ref:`api_guide_Name`.
159

160 161
    Returns:
        A Tensor with the same data type and shape as ``x`` .
162

163 164 165
    Examples:
        .. code-block:: python

166 167
            import paddle
            import paddle.nn.functional as F
168

Z
zhupengyang 已提交
169 170 171 172 173 174 175
            x = paddle.to_tensor([[-1, 0.5], [1, 1.5]])
            out1 = F.gelu(x)
            # [[-0.15865529,  0.34573123],
            #  [ 0.84134471,  1.39978933]]
            out2 = F.gelu(x, True)
            # [[-0.15880799,  0.34571400],
            #  [ 0.84119201,  1.39957154]]
176 177
    """

Z
zhiboniu 已提交
178
    if in_dynamic_mode():
W
wanghuancoder 已提交
179
        return _C_ops.gelu(x, 'approximate', approximate)
180 181 182 183 184 185 186 187 188 189 190 191

    check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'], 'gelu')
    helper = LayerHelper("gelu", **locals())
    out = helper.create_variable_for_type_inference(x.dtype)
    helper.append_op(
        type='gelu',
        inputs={'X': x},
        outputs={'Out': out},
        attrs={'approximate': approximate})
    return out


192
def hardshrink(x, threshold=0.5, name=None):
193
    r"""
194 195 196 197 198
    hard shrinkage activation

    .. math::

        hardshrink(x)=
199 200 201 202 203 204 205
            \left\{
                \begin{array}{rcl}
                x,&  &if \ {x > threshold}  \\
                x,&  &if \ {x < -threshold}   \\
                0,&  &if \ {others} &
                \end{array}
            \right.
206 207 208 209 210 211 212 213 214 215 216 217 218

    Args:
        x (Tensor): The input Tensor with data type float32, float64.
        threshold (float, optional): The value of threshold for hardthrink. Default is 0.5
        name (str, optional): Name for the operation (optional, default is None).
            For more information, please refer to :ref:`api_guide_Name`.

    Returns:
        A Tensor with the same data type and shape as ``x`` .

    Examples:
        .. code-block:: python

219 220
            import paddle
            import paddle.nn.functional as F
221

Z
zhupengyang 已提交
222
            x = paddle.to_tensor([-1, 0.3, 2.5])
223
            out = F.hardshrink(x) # [-1., 0., 2.5]
224 225

    """
Z
zhiboniu 已提交
226
    if in_dynamic_mode():
W
wanghuancoder 已提交
227
        return _C_ops.hard_shrink(x, 'threshold', threshold)
228 229 230 231 232 233 234 235 236 237 238 239 240

    check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'],
                             'hardshrink')
    helper = LayerHelper('hardshrink', **locals())
    out = helper.create_variable_for_type_inference(x.dtype)
    helper.append_op(
        type='hard_shrink',
        inputs={'X': x},
        outputs={'Out': out},
        attrs={'threshold': threshold})
    return out


241
def hardtanh(x, min=-1.0, max=1.0, name=None):
242
    r"""
243 244 245 246
    hardtanh activation

    .. math::

247 248 249 250 251 252 253 254
        hardtanh(x)=
            \left\{
                \begin{array}{cll}
                    max,& & \text{if } x > max \\
                    min,& & \text{if } x < min \\
                    x,& & \text{otherwise}
                \end{array}
            \right.
255

256
    Parameters:
257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276
        x (Tensor): The input Tensor with data type float32, float64.
        min (float, optional): The minimum value of the linear region range. Default is -1.
        max (float, optional): The maximum value of the linear region range. Default is 1.
        name (str, optional): Name for the operation (optional, default is None).
            For more information, please refer to :ref:`api_guide_Name`.

    Returns:
        A Tensor with the same data type and shape as ``x`` .

    Examples:
        .. code-block:: python

            import paddle
            import paddle.nn.functional as F
            import numpy as np

            x = paddle.to_tensor(np.array([-1.5, 0.3, 2.5]))
            out = F.hardtanh(x) # [-1., 0.3, 1.]
    """

Z
zhiboniu 已提交
277
    if in_dynamic_mode():
W
wanghuancoder 已提交
278
        return _C_ops.brelu(x, 't_min', min, 't_max', max)
279 280 281 282 283 284 285 286 287 288 289 290 291 292 293

    check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'],
                             'hardtanh')

    helper = LayerHelper('hardtanh', **locals())
    out = helper.create_variable_for_type_inference(dtype=x.dtype)
    helper.append_op(
        type='brelu',
        inputs={'X': x},
        outputs={'Out': out},
        attrs={'t_min': min,
               't_max': max})
    return out


294
def hardsigmoid(x, slope=0.1666667, offset=0.5, name=None):
295
    r"""
296 297 298 299 300 301 302 303
    hardsigmoid activation.

    A 3-part piecewise linear approximation of sigmoid(https://arxiv.org/abs/1603.00391),
    which is much faster than sigmoid.

    .. math::

        hardsigmoid(x)=
304 305 306 307 308 309 310
            \left\{
                \begin{array}{lcl}
                0, & &\text{if } \ x \leq -3 \\
                1, & &\text{if } \ x \geq 3 \\
                slope * x + offset, & &\text{otherwise}
                \end{array}
            \right.
311 312 313

    Parameters:
        x (Tensor): The input Tensor with data type float32, float64.
314 315
        slope (float, optional): The slope of hardsigmoid function. Default is 0.1666667.
        offset (float, optional): The offset of hardsigmoid function. Default is 0.5.
316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331
        name (str, optional): Name for the operation (optional, default is None).
            For more information, please refer to :ref:`api_guide_Name`.

    Returns:
        A Tensor with the same data type and shape as ``x`` .

    Examples:
        .. code-block:: python

            import paddle
            import paddle.nn.functional as F

            x = paddle.to_tensor([-4., 5., 1.])
            out = F.hardsigmoid(x) # [0., 1., 0.666667]
    """

Z
zhiboniu 已提交
332
    if in_dynamic_mode():
W
wanghuancoder 已提交
333
        return _C_ops.hard_sigmoid(x, 'slope', slope, 'offset', offset)
334 335 336 337 338 339 340 341 342 343

    check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'],
                             'hardsigmoid')

    helper = LayerHelper('hardsigmoid', **locals())
    out = helper.create_variable_for_type_inference(x.dtype)
    helper.append_op(
        type='hard_sigmoid',
        inputs={'X': x},
        outputs={'Out': out},
344 345
        attrs={'slope': slope,
               'offset': offset})
346 347 348 349
    return out


def hardswish(x, name=None):
350
    r"""
351 352 353 354 355 356 357 358 359
    hardswish activation

    hardswish is proposed in MobileNetV3, and performs better in computational stability
    and efficiency compared to swish function. For more details please refer
    to: https://arxiv.org/pdf/1905.02244.pdf

    .. math::

        hardswish(x)=
360 361 362 363 364 365 366
            \left\{
                \begin{array}{cll}
                0 &, & \text{if } x \leq -3 \\
                x &, & \text{if } x \geq 3 \\
                \frac{x(x+3)}{6} &, & \text{otherwise}
                \end{array}
            \right.
367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385

    Parameters:
        x (Tensor): The input Tensor with data type float32, float64.
        name (str, optional): Name for the operation (optional, default is None).
            For more information, please refer to :ref:`api_guide_Name`.

    Returns:
        A Tensor with the same data type and shape as ``x`` .

    Examples:
        .. code-block:: python

            import paddle
            import paddle.nn.functional as F

            x = paddle.to_tensor([-4., 5., 1.])
            out = F.hardswish(x) # [0., 5., 0.666667]
    """

Z
zhiboniu 已提交
386
    if in_dynamic_mode():
W
wanghuancoder 已提交
387
        return _C_ops.hard_swish(x)
388 389 390 391 392 393 394 395 396 397

    check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'],
                             'hardswish')

    helper = LayerHelper('hardswish', **locals())
    out = helper.create_variable_for_type_inference(x.dtype)
    helper.append_op(type='hard_swish', inputs={'X': x}, outputs={'Out': out})
    return out


398
def leaky_relu(x, negative_slope=0.01, name=None):
399
    r"""
400 401
    leaky_relu activation

402
    .. math::
403 404 405 406 407 408 409
        leaky\_relu(x)=
        \left\{
            \begin{array}{rcl}
                x, & & if \ x >= 0 \\
                negative\_slope * x, & & otherwise \\
            \end{array}
        \right.
410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426

    Args:
        x (Tensor): The input Tensor with data type float32, float64.
        negative_slope (float, optional): Slope of the activation function at
            :math:`x < 0` . Default is 0.01.
        name (str, optional): Name for the operation (optional, default is None).
            For more information, please refer to :ref:`api_guide_Name`.

    Returns:
        A Tensor with the same data type and shape as ``x`` .

    Examples:
        .. code-block:: python

            import paddle
            import paddle.nn.functional as F

Z
zhupengyang 已提交
427
            x = paddle.to_tensor([-2., 0., 1.])
428 429 430
            out = F.leaky_relu(x) # [-0.02, 0., 1.]

    """
Z
zhiboniu 已提交
431
    if in_dynamic_mode():
W
wanghuancoder 已提交
432
        return _C_ops.leaky_relu(x, 'alpha', negative_slope)
433 434 435 436 437 438 439 440 441 442 443 444 445

    check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'],
                             'leaky_relu')
    helper = LayerHelper('leaky_relu', **locals())
    out = helper.create_variable_for_type_inference(dtype=x.dtype)
    helper.append_op(
        type='leaky_relu',
        inputs={'X': x},
        outputs={'Out': out},
        attrs={'alpha': negative_slope})
    return out


446
def prelu(x, weight, data_format="NCHW", name=None):
447 448 449 450 451 452 453 454 455 456 457 458 459
    """
    prelu activation.

    .. math::

        prelu(x) = max(0, x) + weight * min(0, x)

    Parameters:
        x (Tensor): The input Tensor with data type float32, float64.
        weight (Tensor): The learnable parameter with data type same as ``x``.
            The weight shape is [1] or [in], where `in` is the input channel of ``x``.
        name (str, optional): Name for the operation (optional, default is None).
            For more information, please refer to :ref:`api_guide_Name`.
460 461
        data_format(str, optional): Data format that specifies the layout of input.
            It may be "NC", "NCL", "NCHW", "NCDHW", "NLC", "NHWC" or "NDHWC". Default: "NCHW".
462 463 464 465 466 467 468 469 470 471 472 473

    Returns:
        A Tensor with the same data type and shape as ``x`` .

    Examples:
        .. code-block:: python

            import paddle
            import paddle.nn.functional as F
            import numpy as np

            data = np.array([[[[-2.0,  3.0, -4.0,  5.0],
Z
zhupengyang 已提交
474 475 476 477 478
                               [ 3.0, -4.0,  5.0, -6.0],
                               [-7.0, -8.0,  8.0,  9.0]],
                              [[ 1.0, -2.0, -3.0,  4.0],
                               [-5.0,  6.0,  7.0, -8.0],
                               [ 6.0,  7.0,  8.0,  9.0]]]], 'float32')
479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497
            x = paddle.to_tensor(data)
            w = paddle.to_tensor(np.array([0.25]).astype('float32'))
            out = F.prelu(x, w)
            # [[[[-0.5 ,  3.  , -1.  ,  5.  ],
            #    [ 3.  , -1.  ,  5.  , -1.5 ],
            #    [-1.75, -2.  ,  8.  ,  9.  ]],
            #   [[ 1.  , -0.5 , -0.75,  4.  ],
            #    [-1.25,  6.  ,  7.  , -2.  ],
            #    [ 6.  ,  7.  ,  8.  ,  9.  ]]]]
    """
    check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'], 'prelu')
    check_variable_and_dtype(weight, 'weight',
                             ['float16', 'float32', 'float64'], 'prelu')

    assert len(weight.shape
               ) == 1, "The dim count of weight shape should be 1 in prelu()."

    mode = 'all'
    if weight.shape[0] > 1:
498 499 500 501 502 503 504 505 506 507 508

        true_data_format = [
            'NC', 'NCL', 'NCHW', 'NCDHW', 'NLC', 'NHWC', 'NDHWC'
        ]
        if data_format not in true_data_format:
            raise ValueError(
                "data_format must be one of 'NC', 'NCL', 'NCHW', 'NCDHW', "
                "'NLC', 'NHWC', 'NDHWC' but receive {}".format(data_format))

        data_format = 'NCHW' if data_format[1] == 'C' else 'NHWC'

509 510 511
        assert len(
            x.shape
        ) > 1, "The dim count of x should be equal or larger than 2 in prelu() when weight shape is not [1]."
512 513 514 515 516 517 518 519

        #NOTE(GuoxiaWang): support NHWC data format
        if data_format == 'NHWC':
            assert weight.shape[0] == x.shape[
                -1], "The weight size should be equal to x input channel in prelu() when weight shape is not [1]."
        else:
            assert weight.shape[0] == x.shape[
                1], "The weight size should be equal to x input channel in prelu() when weight shape is not [1]."
520 521
        mode = 'channel'

Z
zhiboniu 已提交
522
    if in_dynamic_mode():
523
        return _C_ops.prelu(x, weight, 'mode', mode, 'data_format', data_format)
524

525
    helper = LayerHelper('prelu', **locals())
526 527 528 529 530 531
    out = helper.create_variable_for_type_inference(x.dtype)
    helper.append_op(
        type="prelu",
        inputs={"X": x,
                "Alpha": weight},
        outputs={"Out": out},
532 533
        attrs={"mode": mode,
               "data_format": data_format})
534 535 536
    return out


537
def relu(x, name=None):
538
    """
539
    relu activation.
540

541
    .. math::
542 543 544 545

        out = max(x, 0)

    Parameters:
546 547 548
        x (Tensor): The input Tensor with data type float32, float64.
        name (str, optional): Name for the operation (optional, default is None).
            For more information, please refer to :ref:`api_guide_Name`.
549 550

    Returns:
551
        A Tensor with the same data type and shape as ``x`` .
552 553 554 555

    Examples:
        .. code-block:: python

556 557 558
            import paddle
            import paddle.nn.functional as F
            import numpy as np
559

560 561
            x = paddle.to_tensor(np.array([-2, 0, 1]).astype('float32'))
            out = F.relu(x) # [0., 0., 1.]
562 563
    """

564 565 566
    if in_dygraph_mode():
        return _C_ops.final_state_relu(x)
    if _in_legacy_dygraph():
W
wanghuancoder 已提交
567
        return _C_ops.relu(x)
568
    check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'], 'relu')
569
    helper = LayerHelper('relu', **locals())
570 571 572 573 574
    out = helper.create_variable_for_type_inference(x.dtype)
    helper.append_op(type='relu', inputs={'X': x}, outputs={'Out': out})
    return out


575
@inplace_apis_in_dygraph_only
576 577 578 579 580
def relu_(x, name=None):
    """
    Inplace version of ``relu`` API, the output Tensor will be inplaced with input ``x``.
    Please refer to :ref:`api_nn_cn_relu`.
    """
J
Jiabin Yang 已提交
581
    if paddle.fluid.framework._in_eager_mode_:
582
        return _C_ops.final_state_relu_(x)
W
wanghuancoder 已提交
583
    return _C_ops.relu_(x)
584 585


586
def log_sigmoid(x, name=None):
587
    r"""
588
    log_sigmoid activation.
589

590
    .. math::
591

592
        log\_sigmoid(x) = log \frac{1}{1 + e^{-x}}
593

594 595 596 597
    Parameters:
        x (Tensor): The input Tensor with data type float32, float64.
        name (str, optional): Name for the operation (optional, default is None).
            For more information, please refer to :ref:`api_guide_Name`.
598

599 600
    Returns:
        A Tensor with the same data type and shape as ``x`` .
601

602 603 604
    Examples:
        .. code-block:: python

605 606
            import paddle
            import paddle.nn.functional as F
607

608 609
            x = paddle.to_tensor([1.0, 2.0, 3.0, 4.0])
            out = F.log_sigmoid(x) # [-0.313262 -0.126928 -0.0485874 -0.0181499]
610 611
    """

Z
zhiboniu 已提交
612
    if in_dynamic_mode():
W
wanghuancoder 已提交
613
        return _C_ops.logsigmoid(x)
614 615

    check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'],
616 617
                             'log_sigmoid')
    helper = LayerHelper("log_sigmoid", **locals())
618 619 620
    out = helper.create_variable_for_type_inference(x.dtype)
    helper.append_op(type='logsigmoid', inputs={'X': x}, outputs={'Out': out})
    return out
621 622


623
def maxout(x, groups, axis=1, name=None):
624
    r"""
625 626 627 628 629 630 631 632
    maxout activation.

    Assumed the input shape is (N, Ci, H, W).
    The output shape is (N, Co, H, W).
    Then Co = Ci/groups and the operator formula is as follows:

    .. math::

633 634 635 636 637 638 639 640 641
        \begin{array}{l}
        &out_{si+j} = \max_{k} x_{gsi + sk + j} \\
        &g = groups \\
        &s = \frac{input.size}{num\_channels} \\
        &0 \le i < \frac{num\_channels}{groups} \\
        &0 \le j < s \\
        &0 \le k < groups
        \end{array}

642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678

    Parameters:
        x (Tensor): The input is 4-D Tensor with shape [N, C, H, W] or [N, H, W, C], the data type
            of input is float32 or float64.
        groups (int, optional): The groups number of maxout. `groups` specifies the
            index of channel dimension where maxout will be performed. This must be
            a factor of number of features. Default is 1.
        axis (int, optional): The axis along which to perform maxout calculations.
            It should be 1 when data format is NCHW, be -1 or 3 when data format
            is NHWC. If ``axis`` < 0, it works the same way as :math:`axis + D` ,
            where D is the dimensions of ``x`` . ``axis`` only supports 1, 3 or -1.
            Default is 1.
        name (str, optional): Name for the operation (optional, default is None).
            For more information, please refer to :ref:`api_guide_Name`.

    Returns:
        A Tensor with the same data type as ``x`` .

    Examples:
        .. code-block:: python

            import paddle
            import paddle.nn.functional as F

            x = paddle.rand([1, 2, 3, 4])
            # [[[[0.5002636  0.22272532 0.17402348 0.2874594 ]
            #    [0.95313174 0.6228939  0.7129065  0.7087491 ]
            #    [0.02879342 0.88725346 0.61093384 0.38833922]]
            #   [[0.5231306  0.03807496 0.91661984 0.15602879]
            #    [0.666127   0.616567   0.30741522 0.24044901]
            #    [0.7142536  0.7351477  0.31588817 0.23782359]]]]
            out = F.maxout(x, groups=2)
            # [[[[0.5231306  0.22272532 0.91661984 0.2874594 ]
            #    [0.95313174 0.6228939  0.7129065  0.7087491 ]
            #    [0.7142536  0.88725346 0.61093384 0.38833922]]]]
    """

Z
zhiboniu 已提交
679
    if in_dynamic_mode():
W
wanghuancoder 已提交
680
        return _C_ops.maxout(x, 'groups', groups, 'axis', axis)
681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700

    check_variable_and_dtype(x, 'x', ['float32', 'float64'], 'maxout')
    if axis not in [1, -1, 3]:
        raise ValueError(
            "Attr(axis) should be 1 when data format is NCHW, -1 or 3 when data format is NHWC. Received "
            "Attr(axis): %s." % str(axis))
    if axis == -1:
        axis = 3

    helper = LayerHelper('maxout', **locals())
    out = helper.create_variable_for_type_inference(x.dtype)
    helper.append_op(
        type='maxout',
        inputs={'X': x},
        outputs={'Out': out},
        attrs={'groups': groups,
               'axis': axis})
    return out


701 702 703 704 705 706
def relu6(x, name=None):
    """
    relu6 activation

    .. math::

707
        relu6(x) = min(max(0,x), 6)
708

709
    Parameters:
710 711 712 713 714 715 716 717 718 719
        x (Tensor): The input Tensor with data type float32, float64.
        name (str, optional): Name for the operation (optional, default is None).
            For more information, please refer to :ref:`api_guide_Name`.

    Returns:
        A Tensor with the same data type and shape as ``x`` .

    Examples:
        .. code-block:: python

720 721 722
            import paddle
            import paddle.nn.functional as F
            import numpy as np
723

724 725
            x = paddle.to_tensor(np.array([-1, 0.3, 6.5]))
            out = F.relu6(x) # [0, 0.3, 6]
726 727
    """
    threshold = 6.0
Z
zhiboniu 已提交
728
    if in_dynamic_mode():
W
wanghuancoder 已提交
729
        return _C_ops.relu6(x, 'threshold', threshold)
730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745

    check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'], 'relu6')
    helper = LayerHelper('relu6', **locals())
    out = helper.create_variable_for_type_inference(x.dtype)
    helper.append_op(
        type='relu6',
        inputs={'X': x},
        outputs={'Out': out},
        attrs={'threshold': threshold})
    return out


def selu(x,
         scale=1.0507009873554804934193349852946,
         alpha=1.6732632423543772848170429916717,
         name=None):
746
    r"""
747 748 749 750
    selu activation

    .. math::

751
        selu(x)= scale *
752 753 754 755 756 757
            \left\{
                \begin{array}{lcl}
                x,& &\text{if } \ x > 0 \\
                alpha * e^{x} - alpha,& &\text{if } \ x <= 0
                \end{array}
            \right.
758

759
    Parameters:
760
        x (Tensor): The input Tensor with data type float32, float64.
761 762
        scale (float, optional): The value of scale(must be greater than 1.0) for selu. Default is 1.0507009873554804934193349852946
        alpha (float, optional): The value of alpha(must be no less than zero) for selu. Default is 1.6732632423543772848170429916717
763 764 765 766 767 768 769 770 771
        name (str, optional): Name for the operation (optional, default is None).
            For more information, please refer to :ref:`api_guide_Name`.

    Returns:
        A Tensor with the same data type and shape as ``x`` .

    Examples:
        .. code-block:: python

772 773 774
            import paddle
            import paddle.nn.functional as F
            import numpy as np
775

776
            x = paddle.to_tensor(np.array([[0.0, 1.0],[2.0, 3.0]]))
777
            out = F.selu(x) # [[0, 1.050701],[2.101402, 3.152103]]
778
    """
779 780 781 782 783 784 785 786
    if scale <= 1.0:
        raise ValueError(
            "The scale must be greater than 1.0. Received: {}.".format(scale))

    if alpha < 0:
        raise ValueError(
            "The alpha must be no less than zero. Received: {}.".format(alpha))

H
hong 已提交
787 788 789
    if in_dygraph_mode():
        return _C_ops.final_state_selu(x, scale, alpha)
    if _in_legacy_dygraph():
W
wanghuancoder 已提交
790
        return _C_ops.selu(x, 'scale', scale, 'alpha', alpha)
791 792 793 794 795 796 797 798 799 800 801 802 803

    check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'], 'selu')
    helper = LayerHelper('selu', **locals())
    out = helper.create_variable_for_type_inference(x.dtype)
    helper.append_op(
        type='selu',
        inputs={'X': x},
        outputs={'Out': out},
        attrs={'scale': scale,
               'alpha': alpha})
    return out


M
minghaoBD 已提交
804
def silu(x, name=None):
805 806 807 808 809
    r"""
    silu activation

    .. math::

M
minghaoBD 已提交
810 811 812 813 814 815 816 817 818 819 820 821
        silu(x) = \frac{x}{1 + e^{-x}}
    
    Parameters:
        x (Tensor): The input Tensor with data type float32, float64.
        name (str, optional): Name for the operation (optional, default is None).
            For more information, please refer to :ref:`api_guide_Name`.
    
    Returns:
        A Tensor with the same data type and shape as ``x`` .
    
    Examples:
        .. code-block:: python
822 823 824 825 826 827

            import paddle
            import paddle.nn.functional as F
            
            x = paddle.to_tensor([1.0, 2.0, 3.0, 4.0])
            out = F.silu(x) # [ 0.731059, 1.761594, 2.857722, 3.928055 ]
M
minghaoBD 已提交
828 829
    """

Z
zhiboniu 已提交
830
    if in_dynamic_mode():
W
wanghuancoder 已提交
831
        return _C_ops.silu(x)
M
minghaoBD 已提交
832 833 834 835 836 837 838 839

    check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'], 'silu')
    helper = LayerHelper("silu", **locals())
    out = helper.create_variable_for_type_inference(x.dtype)
    helper.append_op(type='silu', inputs={'X': x}, outputs={'Out': out})
    return out


840
def softmax(x, axis=-1, dtype=None, name=None):
841
    r"""
842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866
    This operator implements the softmax layer. The calculation process is as follows:

    1. The dimension :attr:`axis` of ``x`` will be permuted to the last.

    2. Then ``x`` will be logically flattened to a 2-D matrix. The matrix's second
    dimension(row length) is the same as the dimension :attr:`axis` of ``x``,
    and the first dimension(column length) is the product of all other dimensions
    of ``x``. For each row of the matrix, the softmax operator squashes the
    K-dimensional(K is the width of the matrix, which is also the size of ``x``'s
    dimension :attr:`axis`) vector of arbitrary real values to a K-dimensional
    vector of real values in the range [0, 1] that add up to 1.

    3. After the softmax operation is completed, the inverse operations of steps 1 and 2
    are performed to restore the two-dimensional matrix to the same dimension as the ``x`` .

    It computes the exponential of the given dimension and the sum of exponential
    values of all the other dimensions in the K-dimensional vector input.
    Then the ratio of the exponential of the given dimension and the sum of
    exponential values of all the other dimensions is the output of the softmax
    operator.

    For each row :math:`i` and each column :math:`j` in the matrix, we have:

    .. math::

867
        softmax[i, j] = \frac{\exp(x[i, j])}{\sum_j(exp(x[i, j])}
868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915

    Example:

    .. code-block:: text

        Case 1:
          Input:
            x.shape = [2, 3, 4]
            x.data = [[[2.0, 3.0, 4.0, 5.0],
                       [3.0, 4.0, 5.0, 6.0],
                       [7.0, 8.0, 8.0, 9.0]],
                      [[1.0, 2.0, 3.0, 4.0],
                       [5.0, 6.0, 7.0, 8.0],
                       [6.0, 7.0, 8.0, 9.0]]]

          Attrs:
            axis = -1

          Output:
            out.shape = [2, 3, 4]
            out.data = [[[0.0320586 , 0.08714432, 0.23688282, 0.64391426],
                         [0.0320586 , 0.08714432, 0.23688282, 0.64391426],
                         [0.07232949, 0.19661193, 0.19661193, 0.53444665]],
                        [[0.0320586 , 0.08714432, 0.23688282, 0.64391426],
                         [0.0320586 , 0.08714432, 0.23688282, 0.64391426],
                         [0.0320586 , 0.08714432, 0.23688282, 0.64391426]]]

        Case 2:
          Input:
            x.shape = [2, 3, 4]
            x.data = [[[2.0, 3.0, 4.0, 5.0],
                       [3.0, 4.0, 5.0, 6.0],
                       [7.0, 8.0, 8.0, 9.0]],
                      [[1.0, 2.0, 3.0, 4.0],
                       [5.0, 6.0, 7.0, 8.0],
                       [6.0, 7.0, 8.0, 9.0]]]
          Attrs:
            axis = 1

          Output:
            out.shape = [2, 3, 4]
            out.data = [[[0.00657326, 0.00657326, 0.01714783, 0.01714783],
                         [0.01786798, 0.01786798, 0.04661262, 0.04661262],
                         [0.97555875, 0.97555875, 0.93623955, 0.93623955]],
                        [[0.00490169, 0.00490169, 0.00490169, 0.00490169],
                         [0.26762315, 0.26762315, 0.26762315, 0.26762315],
                         [0.72747516, 0.72747516, 0.72747516, 0.72747516]]]

916 917 918 919 920 921
    Parameters:
        x (Tensor): The input Tensor with data type float32, float64.
        axis (int, optional): The axis along which to perform log_softmax
            calculations. It should be in range [-D, D), where D is the
            dimensions of ``x`` . If ``axis`` < 0, it works the same way as
            :math:`axis + D` . Default is -1.
922
        dtype (str, optional): The data type of the output tensor, can be float32, float64.
923 924 925 926
        name (str, optional): Name for the operation (optional, default is None).
            For more information, please refer to :ref:`api_guide_Name`.

    Returns:
927 928
        A Tensor with the same shape and data type (use ``dtype`` if it is
        specified) as x.
929 930 931 932

    Examples:
        .. code-block:: python

933 934 935
            import paddle
            import paddle.nn.functional as F
            import numpy as np
936

937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953
            x = np.array([[[2.0, 3.0, 4.0, 5.0],
                        [3.0, 4.0, 5.0, 6.0],
                        [7.0, 8.0, 8.0, 9.0]],
                        [[1.0, 2.0, 3.0, 4.0],
                        [5.0, 6.0, 7.0, 8.0],
                        [6.0, 7.0, 8.0, 9.0]]], 'float32')
            x = paddle.to_tensor(x)
            out1 = F.softmax(x)
            out2 = F.softmax(x, dtype='float64')
            # out1's data type is float32; out2's data type is float64
            # out1 and out2's value is as follows:
            # [[[0.0320586 , 0.08714432, 0.23688282, 0.64391426],
            #   [0.0320586 , 0.08714432, 0.23688282, 0.64391426],
            #   [0.07232949, 0.19661193, 0.19661193, 0.53444665]],
            # [[0.0320586 , 0.08714432, 0.23688282, 0.64391426],
            #   [0.0320586 , 0.08714432, 0.23688282, 0.64391426],
            #   [0.0320586 , 0.08714432, 0.23688282, 0.64391426]]]
954
    """
955 956 957

    if (dtype is not None) and (not isinstance(dtype, core.VarDesc.VarType)):
        dtype = convert_np_dtype_to_dtype_(dtype)
958
    use_cudnn = True
959

H
hong 已提交
960 961 962 963 964 965
    if in_dygraph_mode():
        outs_cast = x if dtype is None \
            else _C_ops.cast(x, 'in_dtype', x.dtype, 'out_dtype', dtype)
        return _C_ops.final_state_softmax(outs_cast, axis)

    if _in_legacy_dygraph():
966
        outs_cast = x if dtype is None \
W
wanghuancoder 已提交
967 968
            else _C_ops.cast(x, 'in_dtype', x.dtype, 'out_dtype', dtype)
        return _C_ops.softmax(outs_cast, 'axis', axis, 'use_cudnn', use_cudnn)
969 970 971 972 973 974 975 976 977 978 979 980 981 982 983 984 985 986 987 988 989 990 991 992 993 994 995 996

    if dtype is None:
        check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'],
                                 'softmax')
    else:
        check_dtype(dtype, 'dtype', ['float32', 'float64'], 'softmax',
                    'If dtype is not None, it only support float32 or float64.')

    helper = LayerHelper("softmax", **locals())
    outs_cast = x
    if dtype is not None:
        outs_cast = helper.create_variable_for_type_inference(dtype)
        helper.append_op(
            type='cast',
            inputs={'X': x},
            outputs={'Out': outs_cast},
            attrs={'in_dtype': x.dtype,
                   'out_dtype': dtype})

    outs_softmax = helper.create_variable_for_type_inference(outs_cast.dtype)
    helper.append_op(
        type='softmax',
        inputs={'X': outs_cast},
        outputs={'Out': outs_softmax},
        attrs={'axis': axis,
               'use_cudnn': use_cudnn})

    return outs_softmax
997 998


999
@inplace_apis_in_dygraph_only
1000 1001 1002 1003 1004 1005 1006 1007
def softmax_(x, axis=-1, dtype=None, name=None):
    r"""
    Inplace version of ``softmax`` API, the output Tensor will be inplaced with input ``x``.
    Please refer to :ref:`api_nn_cn_softmax`.
    """
    if (dtype is not None) and (not isinstance(dtype, core.VarDesc.VarType)):
        dtype = convert_np_dtype_to_dtype_(dtype)
    use_cudnn = True
W
wanghuancoder 已提交
1008
    return _C_ops.softmax_(x, 'axis', axis, 'use_cudnn', use_cudnn)
1009 1010


1011
def softplus(x, beta=1, threshold=20, name=None):
1012
    r"""
1013 1014 1015 1016
    softplus activation

    .. math::

1017 1018
        softplus(x) = \frac{1}{beta} * \log(1 + e^{beta * x}) \\
        \text{For numerical stability, the implementation reverts to the linear function when: beta * x > threshold.}
1019

1020
    Parameters:
1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1032
        x (Tensor): The input Tensor with data type float32, float64.
        beta (float, optional): The value of beta for softplus. Default is 1
        threshold (float, optional): The value of threshold for softplus. Default is 20
        name (str, optional): Name for the operation (optional, default is None).
            For more information, please refer to :ref:`api_guide_Name`.

    Returns:
        A Tensor with the same data type and shape as ``x`` .

    Examples:
        .. code-block:: python

1033 1034 1035
            import paddle
            import paddle.nn.functional as F
            import numpy as np
1036

1037 1038
            x = paddle.to_tensor(np.array([-0.4, -0.2, 0.1, 0.3]))
            out = F.softplus(x) # [0.513015, 0.598139, 0.744397, 0.854355]
1039
    """
Z
zhiboniu 已提交
1040
    if in_dynamic_mode():
W
wanghuancoder 已提交
1041
        return _C_ops.softplus(x, 'beta', beta, 'threshold', threshold)
1042 1043 1044 1045 1046 1047 1048 1049 1050 1051 1052 1053 1054 1055 1056

    check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'],
                             'softplus')
    helper = LayerHelper('softplus', **locals())
    out = helper.create_variable_for_type_inference(x.dtype)
    helper.append_op(
        type='softplus',
        inputs={'X': x},
        outputs={'Out': out},
        attrs={'beta': beta,
               'threshold': threshold})
    return out


def softshrink(x, threshold=0.5, name=None):
1057
    r"""
1058 1059 1060 1061
    softshrink activation

    .. math::

1062 1063 1064 1065 1066 1067 1068 1069
        softshrink(x)= 
            \left\{
                \begin{array}{rcl}
                x - threshold,& & \text{if } x > threshold \\
                x + threshold,& & \text{if } x < -threshold \\
                0,& &  \text{otherwise}
            \end{array}
            \right.
1070

1071
    Parameters:
1072 1073 1074 1075 1076 1077 1078 1079 1080 1081 1082
        x (Tensor): The input Tensor with data type float32, float64.
        threshold (float, optional): The value of threshold(must be no less than zero) for softplus. Default is 0.5
        name (str, optional): Name for the operation (optional, default is None).
            For more information, please refer to :ref:`api_guide_Name`.

    Returns:
        A Tensor with the same data type and shape as ``x`` .

    Examples:
        .. code-block:: python

1083 1084 1085
            import paddle
            import paddle.nn.functional as F
            import numpy as np
1086

1087 1088
            x = paddle.to_tensor(np.array([-0.9, -0.2, 0.1, 0.8]))
            out = F.softshrink(x) # [-0.4, 0, 0, 0.3]
1089
    """
1090 1091 1092 1093 1094
    if threshold < 0:
        raise ValueError(
            "The threshold must be no less than zero. Received: {}.".format(
                threshold))

Z
zhiboniu 已提交
1095
    if in_dynamic_mode():
W
wanghuancoder 已提交
1096
        return _C_ops.softshrink(x, 'lambda', threshold)
1097 1098 1099 1100 1101 1102 1103 1104 1105 1106 1107 1108 1109 1110

    check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'],
                             'softshrink')
    helper = LayerHelper('softshrink', **locals())
    out = helper.create_variable_for_type_inference(x.dtype)
    helper.append_op(
        type='softshrink',
        inputs={'X': x},
        outputs={'Out': out},
        attrs={'lambda': threshold})
    return out


def softsign(x, name=None):
1111
    r"""
1112 1113 1114 1115
    softsign activation

    .. math::

1116
        softsign(x) = \frac{x}{1 + |x|}
1117

1118
    Parameters:
1119 1120 1121 1122 1123 1124 1125 1126 1127 1128
        x (Tensor): The input Tensor with data type float32, float64.
        name (str, optional): Name for the operation (optional, default is None).
            For more information, please refer to :ref:`api_guide_Name`.

    Returns:
        A Tensor with the same data type and shape as ``x`` .

    Examples:
        .. code-block:: python

1129 1130 1131
            import paddle
            import paddle.nn.functional as F
            import numpy as np
1132

1133 1134
            x = paddle.to_tensor(np.array([-0.4, -0.2, 0.1, 0.3]))
            out = F.softsign(x) # [-0.285714, -0.166667, 0.0909091, 0.230769]
1135
    """
Z
zhiboniu 已提交
1136
    if in_dynamic_mode():
W
wanghuancoder 已提交
1137
        return _C_ops.softsign(x)
1138 1139 1140 1141 1142 1143 1144 1145 1146

    check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'],
                             'softsign')
    helper = LayerHelper('softsign', **locals())
    out = helper.create_variable_for_type_inference(x.dtype)
    helper.append_op(type='softsign', inputs={'X': x}, outputs={'Out': out})
    return out


1147
def swish(x, name=None):
1148
    r"""
1149 1150 1151 1152
    swish activation.

    .. math::

1153
        swish(x) = \frac{x}{1 + e^{-x}}
1154 1155 1156 1157 1158 1159 1160 1161 1162 1163 1164 1165 1166 1167 1168 1169 1170 1171 1172 1173

    Parameters:
        x (Tensor): The input Tensor with data type float32, float64.
        name (str, optional): Name for the operation (optional, default is None).
            For more information, please refer to :ref:`api_guide_Name`.

    Returns:
        A Tensor with the same data type and shape as ``x`` .

    Examples:
        .. code-block:: python

            import paddle
            import paddle.nn.functional as F
            import numpy as np

            x = paddle.to_tensor(np.array([-2., 0., 1.]))
            out = F.swish(x) # [-0.238406, 0., 0.731059]
    """

Z
zhiboniu 已提交
1174
    if in_dynamic_mode():
W
wanghuancoder 已提交
1175
        return _C_ops.swish(x, 'beta', 1.0)
1176 1177 1178 1179 1180 1181 1182 1183

    check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'], 'swish')
    helper = LayerHelper('swish', **locals())
    out = helper.create_variable_for_type_inference(x.dtype)
    helper.append_op(
        type='swish',
        inputs={'X': x},
        outputs={'Out': out},
H
hong19860320 已提交
1184
        attrs={'beta': 1.0})
1185 1186 1187
    return out


1188 1189 1190 1191 1192 1193 1194 1195 1196 1197 1198 1199 1200 1201 1202 1203 1204 1205 1206 1207 1208 1209 1210 1211 1212 1213 1214
def mish(x, name=None):
    r"""
    mish activation.

    ..  math::

        softplus(x) = \begin{cases}
                x, \text{if } x > \text{threshold} \\
                \ln(1 + e^{x}),  \text{otherwise}
            \end{cases}

        mish(x) = x * \tanh(softplus(x))
    
    Parameters:
        x (Tensor): The input Tensor with data type float32, float64.
        name (str, optional): Name for the operation (optional, default is None).
            For more information, please refer to :ref:`api_guide_Name`.

    Returns:
        A Tensor with the same data type and shape as ``x`` .

    Examples:
        .. code-block:: python

            import paddle
            import paddle.nn.functional as F

W
wangxinxin08 已提交
1215
            x = paddle.to_tensor([-5., 0., 5.])
1216 1217
            out = F.mish(x) # [-0.03357624, 0., 4.99955208]
    """
Z
zhiboniu 已提交
1218
    if in_dynamic_mode():
1219 1220 1221 1222 1223 1224 1225 1226 1227
        return _C_ops.mish(x)

    check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'], 'mish')
    helper = LayerHelper('mish', **locals())
    out = helper.create_variable_for_type_inference(x.dtype)
    helper.append_op(type='mish', inputs={'X': x}, outputs={'Out': out})
    return out


1228 1229 1230 1231 1232 1233
def tanhshrink(x, name=None):
    """
    tanhshrink activation

    .. math::

1234
        tanhshrink(x) = x - tanh(x)
1235 1236 1237 1238 1239 1240 1241 1242 1243 1244 1245 1246

    Args:
        x (Tensor): The input Tensor with data type float32, float64.
        name (str, optional): Name for the operation (optional, default is None).
            For more information, please refer to :ref:`api_guide_Name`.

    Returns:
        A Tensor with the same data type and shape as ``x`` .

    Examples:
        .. code-block:: python

1247 1248 1249
            import paddle
            import paddle.nn.functional as F
            import numpy as np
1250

1251 1252
            x = paddle.to_tensor(np.array([-0.4, -0.2, 0.1, 0.3]))
            out = F.tanhshrink(x) # [-0.020051, -0.00262468, 0.000332005, 0.00868739]
1253
    """
Z
zhiboniu 已提交
1254
    if in_dynamic_mode():
W
wanghuancoder 已提交
1255
        return _C_ops.tanh_shrink(x)
1256 1257 1258 1259 1260 1261 1262 1263 1264

    check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'],
                             'tanhshrink')
    helper = LayerHelper('tanh_shrink', **locals())
    out = helper.create_variable_for_type_inference(x.dtype)
    helper.append_op(type='tanh_shrink', inputs={'X': x}, outputs={'Out': out})
    return out


1265
def thresholded_relu(x, threshold=1.0, name=None):
1266
    r"""
1267 1268 1269 1270
    thresholded relu activation.

    .. math::

1271 1272 1273 1274 1275 1276 1277 1278
        thresholded\_relu(x) = 
            \left\{
                \begin{array}{rl}
                x,& \text{if } \ x > threshold \\
                0,& \text{otherwise}
                \end{array}
            \right.

1279 1280 1281 1282 1283 1284 1285 1286 1287 1288 1289 1290 1291 1292 1293 1294 1295 1296 1297 1298 1299

    Parameters:
        x (Tensor): The input Tensor with data type float32, float64.
        threshold (float, optional): The value of threshold for thresholded_relu. Default is 1.0
        name (str, optional): Name for the operation (optional, default is None).
            For more information, please refer to :ref:`api_guide_Name`.

    Returns:
        A Tensor with the same data type and shape as ``x`` .

    Examples:
        .. code-block:: python

            import paddle
            import paddle.nn.functional as F
            import numpy as np

            x = paddle.to_tensor(np.array([2., 0., 1.]))
            out = F.thresholded_relu(x) # [2., 0., 0.]
    """

Z
zhiboniu 已提交
1300
    if in_dynamic_mode():
W
wanghuancoder 已提交
1301
        return _C_ops.thresholded_relu(x, 'threshold', threshold)
1302 1303 1304 1305 1306 1307 1308 1309 1310 1311 1312 1313 1314

    check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'],
                             'thresholded_relu')
    helper = LayerHelper('thresholded_relu', **locals())
    out = helper.create_variable_for_type_inference(x.dtype)
    helper.append_op(
        type='thresholded_relu',
        inputs={'X': x},
        outputs={'Out': out},
        attrs={'threshold': threshold})
    return out


1315
def log_softmax(x, axis=-1, dtype=None, name=None):
1316
    r"""
1317 1318
    This operator implements the log_softmax layer. The calculation process is
    as follows:
1319 1320 1321

    .. math::

1322 1323 1324 1325
        \begin{aligned} 
        log\_softmax[i, j] &= log(softmax(x)) \\
        &= log(\frac{\exp(X[i, j])}{\sum_j(\exp(X[i, j])})
        \end{aligned}
1326 1327

    Parameters:
1328 1329 1330 1331 1332 1333 1334
        x (Tensor): The input Tensor with data type float32, float64.
        axis (int, optional): The axis along which to perform log_softmax
            calculations. It should be in range [-D, D), where D is the
            dimensions of ``x`` . If ``axis`` < 0, it works the same way as
            :math:`axis + D` . Default is -1.
        dtype (str|np.dtype|core.VarDesc.VarType, optional): The desired data
            type of the output tensor. If dtype is specified, ``x`` is casted
1335
            to ``dtype`` before the operation is performed. This is useful for
1336 1337 1338 1339 1340
            preventing data type overflows. Supported dtype: float32, float64.
            If ``dtype`` is None, the output Tensor has the same dtype as x.
            Default is None.
        name (str, optional): Name for the operation (optional, default is None).
            For more information, please refer to :ref:`api_guide_Name`.
1341

1342
    Returns:
1343 1344
        A Tensor with the same shape and data type (use ``dtype`` if it is
        specified) as x.
1345 1346 1347 1348

    Examples:
        .. code-block:: python

1349 1350 1351
            import paddle
            import paddle.nn.functional as F

Z
zhupengyang 已提交
1352 1353 1354 1355 1356 1357
            x = [[[-2.0, 3.0, -4.0, 5.0],
                  [3.0, -4.0, 5.0, -6.0],
                  [-7.0, -8.0, 8.0, 9.0]],
                 [[1.0, -2.0, -3.0, 4.0],
                  [-5.0, 6.0, 7.0, -8.0],
                  [6.0, 7.0, 8.0, 9.0]]]
1358 1359 1360 1361 1362 1363 1364 1365 1366 1367 1368 1369
            x = paddle.to_tensor(x)
            out1 = F.log_softmax(x)
            out2 = F.log_softmax(x, dtype='float64')
            # out1's data type is float32; out2's data type is float64
            # out1 and out2's value is as follows:
            # [[[ -7.1278396   -2.1278396   -9.127839    -0.12783948]
            #   [ -2.1270514   -9.127051    -0.12705144 -11.127051  ]
            #   [-16.313261   -17.313261    -1.3132617   -0.31326184]]
            #  [[ -3.0518122   -6.051812    -7.051812    -0.051812  ]
            #   [-12.313267    -1.3132664   -0.3132665  -15.313267  ]
            #   [ -3.4401896   -2.4401896   -1.4401896   -0.44018966]]]
    """
1370 1371 1372

    if (dtype is not None) and (not isinstance(dtype, core.VarDesc.VarType)):
        dtype = convert_np_dtype_to_dtype_(dtype)
1373

Z
zhiboniu 已提交
1374
    if in_dynamic_mode():
1375
        if dtype is not None:
W
wanghuancoder 已提交
1376 1377
            x = _C_ops.cast(x, 'in_dtype', x.dtype, 'out_dtype', dtype)
        return _C_ops.log_softmax(x, 'axis', axis)
1378

1379
    if dtype is None:
1380 1381 1382 1383 1384
        check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'],
                                 'log_softmax')
    else:
        check_dtype(dtype, 'dtype', ['float32', 'float64'], 'log_softmax',
                    'If dtype is not None, it only support float32 or float64.')
1385

1386
    helper = LayerHelper("log_softmax", **locals())
1387
    out_cast = x
1388
    if dtype is not None:
1389
        out_cast = helper.create_variable_for_type_inference(dtype)
1390 1391
        helper.append_op(
            type='cast',
1392 1393 1394
            inputs={'X': x},
            outputs={'Out': out_cast},
            attrs={'in_dtype': x.dtype,
1395 1396
                   'out_dtype': dtype})

1397
    out = helper.create_variable_for_type_inference(out_cast.dtype)
1398
    helper.append_op(
1399 1400 1401 1402
        type='log_softmax',
        inputs={'X': out_cast},
        outputs={'Out': out},
        attrs={'axis': axis})
1403

1404
    return out
F
Feiyu Chan 已提交
1405 1406 1407 1408 1409 1410 1411 1412 1413 1414 1415 1416 1417 1418 1419 1420 1421 1422 1423 1424 1425 1426 1427 1428 1429 1430 1431 1432 1433 1434 1435 1436 1437 1438 1439 1440 1441 1442 1443 1444 1445 1446 1447 1448 1449 1450 1451


def glu(x, axis=-1, name=None):
    r"""
    The gated linear unit. The input is evenly splited into 2 parts along a 
    given axis. The first part is used as the content, and the second part is
    passed through a sigmoid function then used as the gate. The output is a
    elementwise multiplication of the content and the gate.

    .. math::

        \mathrm{GLU}(a, b) = a \otimes \sigma(b)

    Parameters:
        x (Tensor): The input Tensor with data type float32, float64.
        axis (int, optional): The axis along which split the input tensor. It 
            should be in range [-D, D), where D is the dimensions of ``x`` . 
            If ``axis`` < 0, it works the same way as :math:`axis + D` . 
            Default is -1.
        name (str, optional): Name for the operation (optional, default is None).
            For more information, please refer to :ref:`api_guide_Name`.
    
    Returns:
        A Tensor with the same data type as x. The size of the given aixs is 
        halved.
    
    Examples:
        .. code-block:: python
        
            import paddle
            from paddle.nn import functional as F
            
            x = paddle.to_tensor(
                [[-0.22014759, -1.76358426,  0.80566144,  0.04241343],
                 [-1.94900405, -1.89956081,  0.17134808, -1.11280477]]
            )
            print(F.glu(x).numpy())
            # array([[-0.15216254, -0.9004892 ],
            #        [-1.0577879 , -0.46985325]], dtype=float32)
        
    """
    check_variable_and_dtype(x, 'input', ['float16', 'float32', 'float64'],
                             "glu")
    a, b = chunk(x, 2, axis=axis, name=name)
    gate = sigmoid(b, name=name)
    out = paddle.multiply(a, gate, name=name)
    return out
1452 1453 1454 1455 1456 1457 1458 1459 1460 1461 1462 1463 1464 1465 1466 1467 1468 1469 1470 1471 1472 1473 1474 1475 1476 1477 1478 1479 1480 1481 1482 1483 1484 1485 1486 1487 1488 1489 1490 1491 1492 1493 1494 1495 1496 1497 1498 1499 1500 1501 1502 1503 1504 1505 1506 1507 1508 1509 1510 1511


def gumbel_softmax(x, temperature=1.0, hard=False, axis=-1, name=None):
    r"""
    Samples from the Gumbel-Softmax distribution and optionally discretizes.
    temperature is denoted by t. The calculation process is as follows:

    First, generate gumbel noise:

    .. math::

        G_i = -log(-log(U_i)), U_i \sim U(0,1)

    Second, add noise to ``x``:

    .. math::

        v = [x_1 + G_1,...,x_n + G_n]

    Finally, calculate gumbel_softmax and generate samples:

    .. math::
        gumbel\_softmax(v_i)=\frac{e^{v_i/t}}{\sum_{j=1}^n{e^{v_j/t}}},i=1,2,3...n

    Parameters:
        x (Tensor): An N-D Tensor, the first N - 1 dimensions index into a batch 
            of independent distributions and the last dimension represents 
            a vector of probabilities with datatype float32, float64.
        temperature (float, optional): non-negative scalar temperature.
            Default is 1.0.
        hard (bool, optional): if True, the returned samples will be discretized as 
            one-hot vectors, but will be differentiated as if it is the soft sample 
            in autograd. Default is False.
        axis (int, optional): The axis along will be calculated softmax value. 
            Default is -1.
        name (str, optional): Name for the operation (optional, default is None).
            For more information, please refer to :ref:`api_guide_Name`.
    
    Returns:
        Sampled tensor of same shape as ``x`` from the Gumbel-Softmax distribution. 
        If ``hard = True``, the returned samples will be one-hot, otherwise they will be 
        probability distributions that sum to 1 across ``axis``.
    
    Examples:
        .. code-block:: python

            import paddle
            import paddle.nn.functional as F

            logits = paddle.randn([4, 6])
            temperature = 0.01
            gumbel_softmax = F.gumbel_softmax(logits, temperature)
            print(gumbel_softmax)
            # out's value is as follows:
            # [[0.00000001, 1.        , 0.00000000, 0.00000000, 0.00000006, 0.00000000],
            # [0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000, 1.        ],
            # [0.00000062, 0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.99999940],
            # [0.00000000, 0.00000000, 0.00000000, 0.00001258, 0.99998736, 0.00000000]]
        
    """
Z
zhiboniu 已提交
1512
    if in_dynamic_mode():
1513 1514 1515 1516 1517 1518 1519 1520 1521 1522 1523 1524 1525 1526
        return _C_ops.gumbel_softmax(x, 'temperature', temperature, 'hard',
                                     hard, 'axis', axis)

    helper = LayerHelper("gumbel_softmax", **locals())
    check_variable_and_dtype(x, 'x', ['float32', 'float64'], 'gumbel_softmax')
    out = helper.create_variable_for_type_inference(x.dtype)
    helper.append_op(
        type='gumbel_softmax',
        inputs={'X': x},
        outputs={'Out': out},
        attrs={'temperature': temperature,
               'hard': hard,
               'axis': axis})
    return out