softmax_cn.rst 4.6 KB
Newer Older
S
swtkiwi 已提交
1 2 3 4
.. _cn_api_nn_cn_softmax:

softmax
-------------------------------
5
.. py:function:: paddle.nn.functional.softmax(x, axis=-1, dtype=None, name=None)
6

S
swtkiwi 已提交
7

8
该OP实现了softmax层。OP的计算过程如下:
S
swtkiwi 已提交
9

10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25
步骤1:输入 ``x`` 的 ``axis`` 维会被置换到最后一维;

步骤2:将输入 ``x`` 在逻辑上变换为二维矩阵。二维矩阵第一维(列长度)是输入除最后一维之外的其他维度值的乘积,第二维(行长度)和输入 ``axis`` 维的长度相同;对于矩阵的每一行,softmax操作对其进行重新缩放,使得该行的每个元素在 \[0,1\] 范围内,并且总和为1;

步骤3:softmax操作执行完成后,执行步骤1和步骤2的逆运算,将二维矩阵恢复至和输入 ``x`` 相同的维度。

上述步骤2中softmax操作计算过程如下:

    - 对于二维矩阵的每一行,计算K维向量(K是输入第 ``axis`` 维的长度)中指定位置的指数值和全部位置指数值的和。

    - 指定位置指数值与全部位置指数值之和的比值就是softmax操作的输出。

对于二维矩阵中的第i行和第j列有:

.. math::

26
    softmax[i, j] = \frac{\exp(x[i, j])}{\sum_j(exp(x[i, j])}
27 28 29

- 示例1(矩阵一共有三维。axis = -1,表示沿着最后一维(即第三维)做softmax操作)

30
.. code-block:: text
31

32
  # input
33 34 35 36 37 38 39 40 41 42 43 44

    x.shape = [2, 3, 4] 

    x.data = [[[2.0, 3.0, 4.0, 5.0],
               [3.0, 4.0, 5.0, 6.0],
               [7.0, 8.0, 8.0, 9.0]],
              [[1.0, 2.0, 3.0, 4.0],
               [5.0, 6.0, 7.0, 8.0],
               [6.0, 7.0, 8.0, 9.0]]]

    axis = -1

45
  # output
46 47 48 49 50 51 52 53 54 55 56 57

    out.shape = [2, 3, 4]

    out.data = [[[0.0320586 , 0.08714432, 0.23688282, 0.64391426],
                 [0.0320586 , 0.08714432, 0.23688282, 0.64391426],
                 [0.07232949, 0.19661193, 0.19661193, 0.53444665]],
                [[0.0320586 , 0.08714432, 0.23688282, 0.64391426],
                 [0.0320586 , 0.08714432, 0.23688282, 0.64391426],
                 [0.0320586 , 0.08714432, 0.23688282, 0.64391426]]]

- 示例2(矩阵一共有三维。axis = 1,表示沿着第二维做softmax操作)

58
.. code-block:: text
59

60
  # input
61 62 63 64 65 66 67 68 69 70 71 72

    x.shape = [2, 3, 4] 

    x.data = [[[2.0, 3.0, 4.0, 5.0],
               [3.0, 4.0, 5.0, 6.0],
               [7.0, 8.0, 8.0, 9.0]],
              [[1.0, 2.0, 3.0, 4.0],
               [5.0, 6.0, 7.0, 8.0],
               [6.0, 7.0, 8.0, 9.0]]]

    axis = 1

73
  # output
74 75 76 77 78 79 80 81 82 83 84 85 86

    out.shape = [2, 3, 4]

    out.data = [[[0.00657326, 0.00657326, 0.01714783, 0.01714783],
                 [0.01786798, 0.01786798, 0.04661262, 0.04661262],
                 [0.97555875, 0.97555875, 0.93623955, 0.93623955]],
                [[0.00490169, 0.00490169, 0.00490169, 0.00490169],
                 [0.26762315, 0.26762315, 0.26762315, 0.26762315],
                 [0.72747516, 0.72747516, 0.72747516, 0.72747516]]] 


参数
::::::::::
87
    - x (Tensor) - 输入的 ``Tensor`` ,数据类型为:float32、float64。
88
    - axis (int, 可选) - 指定对输入 ``x`` 进行运算的轴。``axis`` 的有效范围是[-D, D),D是输入 ``x`` 的维度, ``axis`` 为负值时与 :math:`axis + D` 等价。默认值为-1。
89
    - dtype (str|np.dtype|core.VarDesc.VarType, 可选) - 输入Tensor的数据类型。如果指定了 ``dtype`` ,则输入Tensor的数据类型会在计算前转换到 ``dtype`` 。``dtype``可以用来避免数据溢出。如果 ``dtype`` 为None,则输出Tensor的数据类型和 ``x`` 相同。默认值为None。
90 91 92 93
    - name (str, 可选) - 操作的名称(可选,默认值为None)。更多信息请参见 :ref:`api_guide_Name`。

返回
::::::::::
94
    ``Tensor`` ,形状和 ``x`` 相同,数据类型为 ``dtype`` 或者和 ``x`` 相同。
95 96 97 98 99 100 101 102 103 104

代码示例
::::::::::

.. code-block:: python

    import paddle
    import paddle.nn.functional as F
    import numpy as np

105
    paddle.disable_static()
106 107 108 109 110 111 112

    x = np.array([[[2.0, 3.0, 4.0, 5.0],
                    [3.0, 4.0, 5.0, 6.0],
                    [7.0, 8.0, 8.0, 9.0]],
                    [[1.0, 2.0, 3.0, 4.0],
                    [5.0, 6.0, 7.0, 8.0],
                    [6.0, 7.0, 8.0, 9.0]]], 'float32')
113 114 115 116 117
    x = paddle.to_tensor(x)
    out1 = F.softmax(x)
    out2 = F.softmax(x, dtype='float64')
    # out1's data type is float32; out2's data type is float64
    # out1 and out2's value is as follows:
118 119 120 121 122 123
    # [[[0.0320586 , 0.08714432, 0.23688282, 0.64391426],
    #   [0.0320586 , 0.08714432, 0.23688282, 0.64391426],
    #   [0.07232949, 0.19661193, 0.19661193, 0.53444665]],
    # [[0.0320586 , 0.08714432, 0.23688282, 0.64391426],
    #   [0.0320586 , 0.08714432, 0.23688282, 0.64391426],
    #   [0.0320586 , 0.08714432, 0.23688282, 0.64391426]]]