stat.py 11.9 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
#   Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# TODO: define statistical functions of a tensor  
16 17
from ..fluid.layers import reduce_mean  #DEFINE_ALIAS

18
__all__ = ['mean', 'reduce_mean', 'std', 'var', 'numel']
19

20
import numpy as np
21
from ..fluid.framework import Variable
22
from ..fluid.layer_helper import LayerHelper
23
from ..fluid.framework import core, in_dygraph_mode
24 25
from ..fluid import layers
from .search import where
L
Liufang Sang 已提交
26
from ..fluid.data_feeder import convert_dtype, check_variable_and_dtype, check_type, check_dtype
27 28 29 30 31 32 33 34
import paddle


def mean(x, axis=None, keepdim=False, name=None):
    """
    Computes the mean of the input tensor's elements along ``axis``.

    Args:
35
        x (Tensor): The input Tensor with data type float32, float64.
36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98
        axis (int|list|tuple, optional): The axis along which to perform mean
            calculations. ``axis`` should be int, list(int) or tuple(int). If
            ``axis`` is a list/tuple of dimension(s), mean is calculated along
            all element(s) of ``axis`` . ``axis`` or element(s) of ``axis``
            should be in range [-D, D), where D is the dimensions of ``x`` . If
            ``axis`` or element(s) of ``axis`` is less than 0, it works the
            same way as :math:`axis + D` . If ``axis`` is None, mean is
            calculated along all elements of ``x``. Default is None.
        keepdim (bool, optional): Whether to reserve the reduced dimension(s)
            in the output Tensor. If ``keep_dim`` is True, the dimensions of
            the output Tensor is the same as ``x`` except in the reduced
            dimensions(it is of size 1 in this case). Otherwise, the shape of
            the output Tensor is squeezed in ``axis`` . Default is False.
        name (str, optional): Name for the operation (optional, default is None).
            For more information, please refer to :ref:`api_guide_Name`.

    Returns:
        Tensor, results of average along ``axis`` of ``x``, with the same data
        type as ``x``.

    Examples:
        .. code-block:: python

            import paddle
            import numpy as np

            paddle.disable_static()

            x = np.array([[[1, 2, 3, 4],
                           [5, 6, 7, 8],
                           [9, 10, 11, 12]],
                          [[13, 14, 15, 16],
                           [17, 18, 19, 20],
                           [21, 22, 23, 24]]], 'float32')
            x = paddle.to_variable(x)
            out1 = paddle.mean(x)
            # [12.5]
            out2 = paddle.mean(x, axis=-1)
            # [[ 2.5  6.5 10.5]
            #  [14.5 18.5 22.5]]
            out3 = paddle.mean(x, axis=-1, keepdim=True)
            # [[[ 2.5]
            #   [ 6.5]
            #   [10.5]]
            #  [[14.5]
            #   [18.5]
            #   [22.5]]]
            out4 = paddle.mean(x, axis=[0, 2])
            # [ 8.5 12.5 16.5]
    """

    if isinstance(axis, int):
        axis = [axis]
    reduce_all = True if axis is None \
        or len(axis)==0 \
        or len(axis) == len(x.shape) else False
    if axis is None or len(axis) == 0:
        axis = [0]

    if in_dygraph_mode():
        return core.ops.reduce_mean(x, 'dim', axis, 'keep_dim', keepdim,
                                    'reduce_all', reduce_all)

99
    check_variable_and_dtype(x, 'x/input', ['float32', 'float64'],
100
                             'mean/reduce_mean')
101 102 103 104
    check_type(axis, 'axis/dim', (int, list, tuple), 'mean/reduce_mean')
    if isinstance(axis, (list, tuple)):
        for item in axis:
            check_type(item, 'elements of axis/dim', (int), 'mean/reduce_mean')
105 106 107 108 109 110 111

    helper = LayerHelper('mean', **locals())
    attrs = {'dim': axis, 'keep_dim': keepdim, 'reduce_all': reduce_all}
    out = helper.create_variable_for_type_inference(x.dtype)
    helper.append_op(
        type='reduce_mean', inputs={'X': x}, outputs={'Out': out}, attrs=attrs)
    return out
112 113 114 115


def var(input, axis=None, keepdim=False, unbiased=True, out=None, name=None):
    """
116 117
	:alias_main: paddle.var
	:alias: paddle.var,paddle.tensor.var,paddle.tensor.stat.var
S
swtkiwi 已提交
118

119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190
    Computes the variance of the input Variable's elements along the specified 
    axis.

    Args:
        input (Variable): The input Variable to be computed variance, with data 
            type float32 and float64 supported.
        axis (list|int, optional): The axis along which the variance is computed. 
            If `None`, compute the variance over all elements of :attr:`input`
            and return a Variable with a single element, otherwise it must be in 
            the range :math:`[-rank(input), rank(input))`. If :math:`axis[i] < 0`, 
            the axis to compute is :math:`rank(input) + axis[i]`.
        keepdim (bool, optional): Whether to reserve the reduced dimensions in 
            the output Variable. The dimensions in :attr:`axis` will be squeezed 
            and the result Variable will have :attr:`len(axis)` fewer dimensions 
            than the :attr:`input` unless :attr:`keepdim` is true, default False.
        unbiased (bool, optional): Whether to compute variance via the unbiased 
            estimator, in which the divisor used in the computation is 
            :math:`N - 1`, where :math:`N` represents the number of elements 
            along :attr:`axis`, otherwise the divisor is :math:`N`. Default True.
        out (Variable, optional): Alternate output Variable to store the result
            variance. Default None.
        name (str, optional): The name for this layer. Normally there is no 
            need for user to set this property.  For more information, please 
            refer to :ref:`api_guide_Name`. Default None.

    Returns:
        Variable: The result variance with the same dtype as :attr:`input`. 
            If :attr:`out = None`, returns a new Variable containing the 
            variance, otherwise returns a reference to the output Variable.

    Examples:
        .. code-block:: python

            import numpy as np
            import paddle
            import paddle.fluid.dygraph as dg

            a = np.array([[1.0, 2.0], [3.0, 4.0]]).astype("float32")
            with dg.guard():
                data = dg.to_variable(a)
                variance = paddle.var(data, axis=[1])
                print(variance.numpy())   
                # [0.5 0.5]
    """
    dtype = convert_dtype(input.dtype)
    if dtype not in ["float32", "float64"]:
        raise ValueError("Layer tensor.var() only supports floating-point "
                         "dtypes, but received {}.".format(dtype))
    rank = len(input.shape)
    axes = axis if axis != None and axis != [] else range(rank)
    axes = [e if e >= 0 else e + rank for e in axes]
    inp_shape = input.shape if in_dygraph_mode() else layers.shape(input)
    mean = layers.reduce_mean(input, dim=axis, keep_dim=True, name=name)
    tmp = layers.reduce_mean(
        (input - mean)**2, dim=axis, keep_dim=keepdim, name=name)

    if unbiased:
        n = 1
        for i in axes:
            n *= inp_shape[i]
        if not in_dygraph_mode():
            n = layers.cast(n, dtype)
            zero_const = layers.fill_constant(shape=[1], dtype=dtype, value=0.0)
            factor = where(n > 1.0, n / (n - 1.0), zero_const)
        else:
            factor = n / (n - 1.0) if n > 1.0 else 0.0
        tmp *= factor
    if out:
        layers.assign(input=tmp, output=out)
        return out
    else:
        return tmp
L
Liufang Sang 已提交
191 192 193 194


def std(input, axis=None, keepdim=False, unbiased=True, out=None, name=None):
    """
195 196
	:alias_main: paddle.std
	:alias: paddle.std,paddle.tensor.std,paddle.tensor.stat.std
S
swtkiwi 已提交
197

L
Liufang Sang 已提交
198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244
    Computes the standard-deviation  of the input Variable's elements along the specified 
    axis.

    Args:
        input (Variable): The input Variable to be computed standard-deviation, with data 
            type float32 and float64 supported.
        axis (list|int, optional): The axis along which the standard-deviation is computed. 
            If `None`, compute the standard-deviation over all elements of :attr:`input`
            and return a Variable with a single element, otherwise it must be in 
            the range :math:`[-rank(input), rank(input))`. If :math:`axis[i] < 0`, 
            the axis to compute is :math:`rank(input) + axis[i]`.
        keepdim (bool, optional): Whether to reserve the reduced dimensions in 
            the output Variable. The dimensions in :attr:`axis` will be squeezed 
            and the result Variable will have :attr:`len(axis)` fewer dimensions 
            than the :attr:`input` unless :attr:`keepdim` is true, default False.
        unbiased (bool, optional): Whether to compute standard-deviation via the unbiased 
            estimator, in which the divisor used in the computation is 
            :math:`N - 1`, where :math:`N` represents the number of elements 
            along :attr:`axis`, otherwise the divisor is :math:`N`. Default True.
        out (Variable, optional): Alternate output Variable to store the result
            standard-deviation . Default None.
        name (str, optional): The name for this layer. Normally there is no 
            need for user to set this property.  For more information, please 
            refer to :ref:`api_guide_Name`. Default None.

    Returns:
        Variable: The result standard-deviation  with the same dtype as :attr:`input`. 
            If :attr:`out = None`, returns a new Variable containing the 
            standard-deviation , otherwise returns a reference to the output Variable.
    Examples:
        .. code-block:: python

            import paddle
            import paddle.fluid as fluid
            # x is a Tensor variable with following elements:
            #    [[0.2, 0.3, 0.5, 0.9]
            #     [0.1, 0.2, 0.6, 0.7]]
            # Each example is followed by the corresponding output tensor.
            x = fluid.data(name='x', shape=[2, 4], dtype='float32')
            paddle.std(x)  # [0.28252685] 
            paddle.std(x, axis=[0])  # [0.0707107, 0.07071075, 0.07071064, 0.1414217]
            paddle.std(x, axis=[-1])  # [0.30956957, 0.29439208] 
    """
    check_variable_and_dtype(input, 'input', ['float32', 'float64'], 'std')

    tmp = var(input, axis=axis, keepdim=keepdim, unbiased=unbiased, name=name)
    tmp = layers.sqrt(tmp)
245
    if out is not None:
L
Liufang Sang 已提交
246 247 248 249
        layers.assign(input=tmp, output=out)
        return out
    else:
        return tmp
250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287


def numel(x, name=None):
    """
    Returns the number of elements for a tensor, which is a int64 Tensor with shape [1] in static mode
    or a scalar value in imperative mode

    Args:
        x (Tensor): The input Tensor, it's data type can be bool, float16, float32, float64, int32, int64.

    Returns:
        Tensor: The number of elements for the input Tensor.
    
    Raises:
        TypeError: ``x`` must be a Tensor and the data type of ``x`` must be one of bool, float16, float32, float64, int32, int64.


    Examples:
        .. code-block:: python

        import paddle
        
        paddle.disable_static()
        x = paddle.full(shape=[4, 5, 7], fill_value=0, dtype='int32')
        numel = paddle.numel(x) # 140


    """
    if in_dygraph_mode():
        return core.ops.size(x)

    if not isinstance(x, Variable):
        raise TypeError("x must be a Tensor in numel")
    helper = LayerHelper('numel', **locals())
    out = helper.create_variable_for_type_inference(
        dtype=core.VarDesc.VarType.INT64)
    helper.append_op(type='size', inputs={'Input': x}, outputs={'Out': out})
    return out