random.py 45.0 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
#   Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15
# TODO: define random functions
S
silingtong123 已提交
16

17 18
import paddle
from paddle import _C_ops, _legacy_C_ops
19
from paddle.common_ops_import import Variable
20 21
from paddle.fluid.framework import _current_expected_place
from paddle.framework import in_dynamic_mode
22

23 24 25
from ..fluid.data_feeder import (
    check_dtype,
    check_shape,
26 27
    check_type,
    check_variable_and_dtype,
28
)
29 30 31 32 33
from ..framework import (
    LayerHelper,
    convert_np_dtype_to_dtype_,
    core,
    dygraph_only,
34
)
S
silingtong123 已提交
35

36 37
__all__ = []

S
silingtong123 已提交
38

L
Leo Chen 已提交
39
def bernoulli(x, name=None):
40
    r"""
L
Leo Chen 已提交
41

42
    For each element :math:`x_i` in input ``x``, take a sample from the Bernoulli distribution, also called two-point distribution, with success probability :math:`x_i`. The Bernoulli distribution with success probability :math:`x_i` is a discrete probability distribution with probability mass function
L
Leo Chen 已提交
43

44
    .. math::
45 46
        p(y)=\begin{cases}
            x_i,&y=1\\
47 48
            1-x_i,&y=0
        \end{cases}.
L
Leo Chen 已提交
49 50

    Args:
51 52 53
        x (Tensor): The input Tensor, it's data type should be float32, float64.
        name (str, optional): For details, please refer to :ref:`api_guide_Name`. Generally, no setting is required. Default: None.

54
    Returns:
55
        Tensor: A Tensor filled samples from Bernoulli distribution, whose shape and dtype are same as ``x``.
L
Leo Chen 已提交
56 57 58 59

    Examples:
        .. code-block:: python

60
            import paddle
L
Leo Chen 已提交
61

L
Leo Chen 已提交
62
            paddle.set_device('cpu')  # on CPU device
63
            paddle.seed(100)
L
Leo Chen 已提交
64

65
            x = paddle.rand([2,3])
L
Leo Chen 已提交
66 67 68
            print(x)
            # [[0.55355281, 0.20714243, 0.01162981],
            #  [0.51577556, 0.36369765, 0.26091650]]
L
Leo Chen 已提交
69

70
            out = paddle.bernoulli(x)
L
Leo Chen 已提交
71 72 73
            print(out)
            # [[1., 0., 1.],
            #  [0., 1., 0.]]
L
Leo Chen 已提交
74 75 76

    """

77
    if in_dynamic_mode():
78
        return _C_ops.bernoulli(x)
79
    else:
80 81 82
        check_variable_and_dtype(
            x, "x", ["float32", "float64", "float16", "uint16"], "bernoulli"
        )
83 84 85 86 87 88 89 90 91 92

        helper = LayerHelper("randint", **locals())
        out = helper.create_variable_for_type_inference(
            dtype=x.dtype
        )  # maybe set out to int32 ?
        helper.append_op(
            type='bernoulli', inputs={"X": x}, outputs={'Out': out}, attrs={}
        )
        out.stop_gradient = True
        return out
L
Leo Chen 已提交
93 94


95
def poisson(x, name=None):
96
    r"""
97
    Returns a tensor filled with random number from a Poisson Distribution.
98 99 100

    .. math::

101
        out_i \sim Poisson (lambda = x_i)
102 103

    Args:
104
        x(Tensor):  A tensor with rate parameter of poisson Distribution. The data type
105 106 107 108
            should be float32, float64.
        name(str, optional): The default value is None. Normally there is no
            need for user to set this property. For more information, please
            refer to :ref:`api_guide_Name`.
109
    Returns:
110 111 112 113 114 115
        Tensor: A Tensor filled with random number with the same shape and dtype as ``x``.

    Examples:
        .. code-block:: python

            import paddle
116
            paddle.set_device('cpu')
117
            paddle.seed(100)
118 119 120

            x = paddle.uniform([2,3], min=1.0, max=5.0)
            out = paddle.poisson(x)
121 122
            #[[2., 5., 0.],
            # [5., 1., 3.]]
123 124

    """
125
    if in_dynamic_mode():
126
        return _C_ops.poisson(x)
127 128
    else:
        check_variable_and_dtype(x, "x", ["float32", "float64"], "poisson")
129

130 131 132 133 134 135
        helper = LayerHelper("poisson", **locals())
        out = helper.create_variable_for_type_inference(dtype=x.dtype)
        helper.append_op(
            type='poisson', inputs={'X': x}, outputs={'Out': out}, attrs={}
        )
        return out
136 137


P
pangyoki 已提交
138 139
def multinomial(x, num_samples=1, replacement=False, name=None):
    """
140
    Returns a Tensor filled with random values sampled from a Multinomical
P
pangyoki 已提交
141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159
    distribution. The input ``x`` is a tensor with probabilities for generating the
    random number. Each element in ``x`` should be larger or equal to 0, but not all
    0. ``replacement`` indicates whether it is a replaceable sample. If ``replacement``
    is True, a category can be sampled more than once.

    Args:
        x(Tensor):  A tensor with probabilities for generating the random number. The data type
            should be float32, float64.
        num_samples(int, optional): Number of samples, default is 1.
        replacement(bool, optional): Whether it is a replaceable sample, default is False.
        name(str, optional): The default value is None. Normally there is no
            need for user to set this property. For more information, please
            refer to :ref:`api_guide_Name`.
    Returns:
        Tensor: A Tensor filled with sampled category index after ``num_samples`` times samples.

    Examples:
        .. code-block:: python

160 161
            import paddle

C
cnn 已提交
162
            paddle.seed(100) # on CPU device
163
            x = paddle.rand([2,4])
164
            print(x)
165 166 167
            # [[0.5535528  0.20714243 0.01162981 0.51577556]
            # [0.36369765 0.2609165  0.18905126 0.5621971 ]]

C
cnn 已提交
168
            paddle.seed(200) # on CPU device
169
            out1 = paddle.multinomial(x, num_samples=5, replacement=True)
170
            print(out1)
171 172 173 174 175 176 177
            # [[3 3 0 0 0]
            # [3 3 3 1 0]]

            # out2 = paddle.multinomial(x, num_samples=5)
            # InvalidArgumentError: When replacement is False, number of samples
            #  should be less than non-zero categories

C
cnn 已提交
178
            paddle.seed(300) # on CPU device
179
            out3 = paddle.multinomial(x, num_samples=3)
180
            print(out3)
181 182
            # [[3 0 1]
            # [3 1 0]]
P
pangyoki 已提交
183 184 185

    """

186
    assert (
187
        not core.is_compiled_with_rocm()
188
    ), "multinomial op is not supported on ROCM yet."
189

190
    if in_dynamic_mode():
191
        return _C_ops.multinomial(x, num_samples, replacement)
192
    else:
193 194 195
        check_variable_and_dtype(
            x, "x", ["uint16", "float16", "float32", "float64"], "multinomial"
        )
H
hong 已提交
196

197 198 199
        helper = LayerHelper("multinomial", **locals())
        out = helper.create_variable_for_type_inference(
            dtype=convert_np_dtype_to_dtype_('int64')
200
        )
201 202 203 204 205 206 207 208
        helper.append_op(
            type='multinomial',
            inputs={"X": x},
            outputs={'Out': out},
            attrs={'num_samples': num_samples, 'replacement': replacement},
        )
        out.stop_gradient = True
        return out
P
pangyoki 已提交
209 210


211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265
def uniform_random_batch_size_like(
    input,
    shape,
    dtype='float32',
    input_dim_idx=0,
    output_dim_idx=0,
    min=-1.0,
    max=1.0,
    seed=0,
):
    """
    This OP initializes a variable with random values sampled from a
    uniform distribution in the range [min, max). The input_dim_idx used to get the input dimension value which will be used to resize the output dimension.
    .. code-block:: text
        *Case 1:
            Given:
                input =[[0.946741  , 0.1357001 , 0.38086128]]    # input.shape=[1,3]
                shape=[2,4]
            result.shape[output_dim_idx] = input.shape[input_dim_idx],
            output_dim_idx = 0,
            input_dim_idx = 0,
            result.shape[0] = input.shape[0],
            then:
                result=[[ 0.3443427 , -0.23056602,  0.3477049 ,  0.06139076]]    # result.shape=[1,4]
       *Case 2:
           Given:
               input =[[0.946741  , 0.1357001 , 0.38086128]]     # input.shape=[1,3]
               shape=[2,4]
               input_dim_idx=1
               output_dim_idx=1
           result.shape[output_dim_idx] = input.shape[input_dim_idx],
           output_dim_idx = 1,
           input_dim_idx = 1,
           result.shape[1] = input.shape[1],
           then:
               result=[[-0.23133647, -0.84195036,  0.21441269],
                       [-0.08774924,  0.25605237, -0.09403259]]    # result.shape=[2,3]
    Args:
        input (Variable): A Tensor. Supported data types: float32, float64.
        shape (tuple|list): A python list or python tuple. The shape of the output Tensor, the data type is int.
        input_dim_idx (int, optional): An index used to get the input dimension value which will be used to resize the output dimension. Default  0.
        output_dim_idx (int, optional): An index used to indicate the specific dimension that will be replaced by corresponding input dimension value. Default 0.
        min (float, optional): The lower bound on the range of random values to generate, the min is included in the range. Default -1.0.
        max (float, optional): The upper bound on the range of random values to generate, the max is excluded in the range. Default 1.0.
        seed (int, optional):  Random seed used for generating samples. 0 means use a seed generated by the system.Note that if seed is not 0, this operator will always generate the same random numbers every time.
        dtype(np.dtype|core.VarDesc.VarType|str, optional): The data type of output Tensor. Supported data types: float32, float64. Default float32.
    Returns:
        Variable: A Tensor of the specified shape filled with uniform_random values. The shape of the Tensor is determined by the shape parameter and the specified dimension of the input Tensor.
    Examples:
        .. code-block:: python
            import paddle
            import paddle.fluid as fluid
            from paddle.tensor import random
            paddle.enable_static()
            # example 1:
266
            input = paddle.static.data(name="input", shape=[1, 3], dtype='float32')
267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305
            out_1 = random.uniform_random_batch_size_like(input, [2, 4]) # out_1.shape=[1, 4]
            # example 2:
            out_2 = random.uniform_random_batch_size_like(input, [2, 4], input_dim_idx=1, output_dim_idx=1) # out_2.shape=[2, 3]
    """
    check_variable_and_dtype(
        input,
        'Input',
        ("float32", 'float64', "uint16"),
        'uniform_random_batch_size_like',
    )
    check_type(shape, 'shape', (list, tuple), 'uniform_random_batch_size_like')
    check_dtype(
        dtype,
        'dtype',
        ('float32', 'float64', "uint16"),
        'uniform_random_batch_size_like',
    )

    helper = LayerHelper('uniform_random_batch_size_like', **locals())
    out = helper.create_variable_for_type_inference(dtype)
    c_dtype = convert_np_dtype_to_dtype_(dtype)
    helper.append_op(
        type='uniform_random_batch_size_like',
        inputs={'Input': input},
        outputs={'Out': out},
        attrs={
            'shape': shape,
            'input_dim_idx': input_dim_idx,
            'output_dim_idx': output_dim_idx,
            'min': min,
            'max': max,
            'seed': seed,
            'dtype': c_dtype,
        },
    )

    return out


306
def gaussian(shape, mean=0.0, std=1.0, seed=0, dtype=None, name=None):
307
    """
308
    Returns a Tensor filled with random values sampled from a Gaussian
309 310 311
    distribution, with ``shape`` and ``dtype``.

    Args:
312 313 314
        shape (tuple|list|Tensor): Shape of the Tensor to be created. The data type is ``int32`` or ``int64`` .
            If ``shape`` is a list or tuple, each element of it should be integer or 0-D Tensor with shape [].
            If ``shape`` is an Tensor, it should be an 1-D Tensor which represents a list.
315 316
        mean (float|int, optional): Mean of the output tensor, default is 0.0.
        std (float|int, optional): Standard deviation of the output tensor, default
317
            is 1.0.
318 319
        seed (int, optional): Random seed of generator.
        dtype (str|np.dtype, optional): The data type of the output Tensor.
320 321 322
            Supported data types: float32, float64.
            Default is None, use global default dtype (see ``get_default_dtype``
            for details).
323
        name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.
324 325 326

    Returns:
        Tensor: A Tensor filled with random values sampled from a Gaussian
327
        distribution, with ``shape`` and ``dtype``.
328
    """
329
    op_type_for_check = 'gaussian/standard_normal/randn/normal'
330
    supported_dtypes = ['float32', 'float64', 'float16', 'uint16']
331

332 333
    if dtype is None:
        dtype = paddle.framework.get_default_dtype()
334
        if dtype not in supported_dtypes:
335
            raise TypeError(
336 337
                "{} only supports {}, but the default dtype is {}".format(
                    op_type_for_check, supported_dtypes, dtype
338 339
                )
            )
340 341 342
    if not isinstance(dtype, core.VarDesc.VarType):
        dtype = convert_np_dtype_to_dtype_(dtype)

343
    if in_dynamic_mode():
344
        shape = paddle.utils.convert_shape_to_list(shape)
345
        place = _current_expected_place()
346
        return _C_ops.gaussian(
347 348
            shape, float(mean), float(std), seed, dtype, place
        )
349 350
    else:
        check_shape(shape, op_type_for_check)
351
        check_dtype(dtype, 'dtype', supported_dtypes, op_type_for_check)
352

353 354 355 356 357 358 359 360
        inputs = {}
        attrs = {
            'mean': mean,
            'std': std,
            'seed': seed,
            'dtype': dtype,
            'use_mkldnn': False,
        }
361
        paddle.utils.get_shape_tensor_inputs(
362
            inputs=inputs, attrs=attrs, shape=shape, op_type=op_type_for_check
363
        )
364

365 366 367 368 369 370 371 372 373 374
        helper = LayerHelper('gaussian', **locals())
        out = helper.create_variable_for_type_inference(dtype)
        helper.append_op(
            type='gaussian_random',
            inputs=inputs,
            outputs={'Out': out},
            attrs=attrs,
        )
        out.stop_gradient = True
        return out
375 376 377 378


def standard_normal(shape, dtype=None, name=None):
    """
379
    Returns a Tensor filled with random values sampled from a standard
380 381 382 383
    normal distribution with mean 0 and standard deviation 1, with ``shape``
    and ``dtype``.

    Args:
384 385 386
        shape (tuple|list|Tensor): Shape of the Tensor to be created. The data type is ``int32`` or ``int64`` .
            If ``shape`` is a list or tuple, each element of it should be integer or 0-D Tensor with shape [].
            If ``shape`` is an Tensor, it should be an 1-D Tensor which represents a list.
387
        dtype (str|np.dtype, optional): The data type of the output Tensor.
388 389 390
            Supported data types: float32, float64.
            Default is None, use global default dtype (see ``get_default_dtype``
            for details).
391 392 393 394 395 396 397 398 399 400 401 402 403 404
        name (str, optional): Name for the operation (optional, default is None).
            For more information, please refer to :ref:`api_guide_Name`.

    Returns:
        Tensor: A Tensor filled with random values sampled from a standard
        normal distribution with mean 0 and standard deviation 1, with
        ``shape`` and ``dtype``.

    Examples:
        .. code-block:: python

            import paddle

            # example 1: attr shape is a list which doesn't contain Tensor.
405
            out1 = paddle.standard_normal(shape=[2, 3])
406 407 408 409
            # [[-2.923464  ,  0.11934398, -0.51249987],  # random
            #  [ 0.39632758,  0.08177969,  0.2692008 ]]  # random

            # example 2: attr shape is a list which contains Tensor.
410 411
            dim1 = paddle.to_tensor(2, 'int64')
            dim2 = paddle.to_tensor(3, 'int32')
412
            out2 = paddle.standard_normal(shape=[dim1, dim2, 2])
413 414 415 416 417 418 419 420
            # [[[-2.8852394 , -0.25898588],  # random
            #   [-0.47420555,  0.17683524],  # random
            #   [-0.7989969 ,  0.00754541]],  # random
            #  [[ 0.85201347,  0.32320443],  # random
            #   [ 1.1399018 ,  0.48336947],  # random
            #   [ 0.8086993 ,  0.6868893 ]]]  # random

            # example 3: attr shape is a Tensor, the data type must be int64 or int32.
421
            shape_tensor = paddle.to_tensor([2, 3])
Z
zhupengyang 已提交
422
            out3 = paddle.standard_normal(shape_tensor)
423 424 425 426
            # [[-2.878077 ,  0.17099959,  0.05111201]  # random
            #  [-0.3761474, -1.044801  ,  1.1870178 ]]  # random

    """
427
    return gaussian(shape=shape, mean=0.0, std=1.0, dtype=dtype, name=name)
428 429


Z
zhupengyang 已提交
430 431
def randn(shape, dtype=None, name=None):
    """
432
    Returns a Tensor filled with random values sampled from a standard
Z
zhupengyang 已提交
433 434 435 436
    normal distribution with mean 0 and standard deviation 1, with ``shape``
    and ``dtype``.

    Args:
437 438 439
        shape (tuple|list|Tensor): Shape of the Tensor to be created. The data type is ``int32`` or ``int64`` .
            If ``shape`` is a list or tuple, each element of it should be integer or 0-D Tensor with shape [].
            If ``shape`` is an Tensor, it should be an 1-D Tensor which represents a list.
Z
zhupengyang 已提交
440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462
        dtype (str|np.dtype, optional): The data type of the output Tensor.
            Supported data types: float32, float64.
            Default is None, use global default dtype (see ``get_default_dtype``
            for details).
        name (str, optional): Name for the operation (optional, default is None).
            For more information, please refer to :ref:`api_guide_Name`.

    Returns:
        Tensor: A Tensor filled with random values sampled from a standard
        normal distribution with mean 0 and standard deviation 1, with
        ``shape`` and ``dtype``.

    Examples:
        .. code-block:: python

            import paddle

            # example 1: attr shape is a list which doesn't contain Tensor.
            out1 = paddle.randn(shape=[2, 3])
            # [[-2.923464  ,  0.11934398, -0.51249987],  # random
            #  [ 0.39632758,  0.08177969,  0.2692008 ]]  # random

            # example 2: attr shape is a list which contains Tensor.
463 464
            dim1 = paddle.to_tensor(2, 'int64')
            dim2 = paddle.to_tensor(3, 'int32')
Z
zhupengyang 已提交
465 466 467 468 469 470 471 472 473 474 475 476 477 478 479
            out2 = paddle.randn(shape=[dim1, dim2, 2])
            # [[[-2.8852394 , -0.25898588],  # random
            #   [-0.47420555,  0.17683524],  # random
            #   [-0.7989969 ,  0.00754541]],  # random
            #  [[ 0.85201347,  0.32320443],  # random
            #   [ 1.1399018 ,  0.48336947],  # random
            #   [ 0.8086993 ,  0.6868893 ]]]  # random

            # example 3: attr shape is a Tensor, the data type must be int64 or int32.
            shape_tensor = paddle.to_tensor([2, 3])
            out3 = paddle.randn(shape_tensor)
            # [[-2.878077 ,  0.17099959,  0.05111201]  # random
            #  [-0.3761474, -1.044801  ,  1.1870178 ]]  # random
    """
    return standard_normal(shape, dtype, name)
480 481 482 483


def normal(mean=0.0, std=1.0, shape=None, name=None):
    """
484
    Returns a Tensor filled with random values sampled from a normal
485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501
    distribution with ``mean`` and ``std`` (standard deviation) .

    If ``mean`` is a Tensor, the output Tensor has the same shape and data type as ``mean``.
    If ``mean`` is not a Tensor and ``std`` is a Tensor, the output Tensor has the same shape and data type as ``std``.
    If ``mean`` and ``std`` are not a Tensor, the output Tensor has the same shape as ``shape``, with data type float32.

    If ``mean`` and ``std`` are Tensor, the num of elements of ``mean`` and ``std`` should be the same.

    Args:
        mean (float|Tensor, optional): The mean of the output Tensor's normal distribution.
            If ``mean`` is float, all elements of the output Tensor shared the same mean.
            If ``mean`` is a Tensor(data type supports float32, float64), it has per-element means.
            Default is 0.0
        std (float|Tensor, optional): The  standard deviation of the output Tensor's normal distribution.
            If ``std`` is float, all elements of the output Tensor shared the same standard deviation.
            If ``std`` is a Tensor(data type supports float32, float64), it has per-element standard deviations.
            Defaule is 1.0
502 503 504 505
        shape (tuple|list|Tensor): Shape of the Tensor to be created. The data type is ``int32`` or ``int64`` .
            If ``shape`` is a list or tuple, each element of it should be integer or 0-D Tensor with shape [].
            If ``shape`` is an Tensor, it should be an 1-D Tensor which represents a list. If ``mean`` or ``std``
            is a Tensor, the shape of the output Tensor is the same as ``mean`` or ``std`` , attr ``shape`` is ignored.
506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521
            Default is None
        name (str, optional): Name for the operation (optional, default is None).
            For more information, please refer to :ref:`api_guide_Name`.

    Returns:
        A Tensor filled with random values sampled from a normal distribution with ``mean`` and ``std`` .

    Examples:
        .. code-block:: python

            import paddle

            out1 = paddle.normal(shape=[2, 3])
            # [[ 0.17501129  0.32364586  1.561118  ]  # random
            #  [-1.7232178   1.1545963  -0.76156676]]  # random

522
            mean_tensor = paddle.to_tensor([1.0, 2.0, 3.0])
523 524 525
            out2 = paddle.normal(mean=mean_tensor)
            # [ 0.18644847 -1.19434458  3.93694787]  # random

526
            std_tensor = paddle.to_tensor([1.0, 2.0, 3.0])
527 528 529 530
            out3 = paddle.normal(mean=mean_tensor, std=std_tensor)
            # [1.00780561 3.78457445 5.81058198]  # random

    """
531
    if not in_dynamic_mode():
532 533 534 535
        check_type(mean, 'mean', (int, float, Variable), 'normal')
        check_type(std, 'std', (int, float, Variable), 'normal')
        if isinstance(mean, Variable):
            check_dtype(
536 537 538 539 540
                mean.dtype,
                'mean',
                ['float32', 'float64'],
                'normal',
                "If mean is Tensor, it's data type only support float32, float64.",
541 542 543
            )
        if isinstance(std, Variable):
            check_dtype(
544 545 546 547 548
                std.dtype,
                'std',
                ['float32', 'float64'],
                'normal',
                "If std is Tensor, it's data type only support float32, float64.",
549 550
            )
        if shape is not None:
551
            check_shape(shape, 'normal')
552 553 554 555 556 557 558 559 560 561 562 563 564 565

    if isinstance(mean, Variable):
        if isinstance(std, Variable):
            if std.dtype != mean.dtype:
                std = paddle.cast(std, mean.dtype)
            mean_shape = paddle.shape(mean)
            std = paddle.reshape(std, mean_shape)
        else:
            std = float(std)
        out = standard_normal(paddle.shape(mean), mean.dtype, name)
    elif isinstance(std, Variable):
        mean = float(mean)
        out = standard_normal(paddle.shape(std), std.dtype, name)
    else:
566
        return gaussian(shape=shape, mean=mean, std=std, name=name)
567 568

    out = out * std + mean
569
    if not in_dynamic_mode():
570 571 572 573
        out.stop_grediant = True
    return out


574
def uniform(shape, dtype=None, min=-1.0, max=1.0, seed=0, name=None):
P
pangyoki 已提交
575
    """
576
    Returns a Tensor filled with random values sampled from a uniform
P
pangyoki 已提交
577 578 579
    distribution in the range [``min``, ``max``), with ``shape`` and ``dtype``.

    Examples:
李灿 已提交
580

Z
zhupengyang 已提交
581
    .. code-block:: text
李灿 已提交
582

P
pangyoki 已提交
583 584 585 586 587 588
        Input:
          shape = [1, 2]
        Output:
          result=[[0.8505902, 0.8397286]]

    Args:
589 590 591
        shape (tuple|list|Tensor): Shape of the Tensor to be created. The data type is ``int32`` or ``int64`` .
            If ``shape`` is a list or tuple, each element of it should be integer or 0-D Tensor with shape [].
            If ``shape`` is an Tensor, it should be an 1-D Tensor which represents a list.
592 593 594 595
        dtype(str|np.dtype, optional): The data type of the output Tensor.
            Supported data types: float32, float64.
            Default is None, use global default dtype (see ``get_default_dtype``
            for details).
P
pangyoki 已提交
596 597 598 599
        min(float|int, optional): The lower bound on the range of random values
            to generate, ``min`` is included in the range. Default is -1.0.
        max(float|int, optional): The upper bound on the range of random values
            to generate, ``max`` is excluded in the range. Default is 1.0.
J
JYChen 已提交
600
        seed(int, optional): Random seed used for generating samples. If seed is 0,
601
            it will use the seed of the global default generator (which can be set by paddle.seed).
J
JYChen 已提交
602
            Note that if seed is not 0, this operator will always generate the same random numbers every
P
pangyoki 已提交
603
            time. Default is 0.
604 605
        name(str, optional): Name for the operation (optional, default is None).
            For more information, please refer to :ref:`api_guide_Name`.
P
pangyoki 已提交
606 607 608 609 610 611 612

    Returns:
        Tensor: A Tensor filled with random values sampled from a uniform
        distribution in the range [``min``, ``max``), with ``shape`` and ``dtype``.

    Examples:
        .. code-block:: python
613
          :name: code-example1
614

P
pangyoki 已提交
615 616 617 618
            import paddle

            # example 1:
            # attr shape is a list which doesn't contain Tensor.
Z
zhupengyang 已提交
619 620 621 622
            out1 = paddle.uniform(shape=[3, 4])
            # [[ 0.84524226,  0.6921872,   0.56528175,  0.71690357], # random
            #  [-0.34646994, -0.45116323, -0.09902662, -0.11397249], # random
            #  [ 0.433519,    0.39483607, -0.8660099,   0.83664286]] # random
P
pangyoki 已提交
623 624 625

            # example 2:
            # attr shape is a list which contains Tensor.
626 627
            dim1 = paddle.to_tensor(2, 'int64')
            dim2 = paddle.to_tensor(3, 'int32')
Z
zhupengyang 已提交
628 629 630
            out2 = paddle.uniform(shape=[dim1, dim2])
            # [[-0.9951253,   0.30757582, 0.9899647 ], # random
            #  [ 0.5864527,   0.6607096,  -0.8886161]] # random
P
pangyoki 已提交
631 632 633

            # example 3:
            # attr shape is a Tensor, the data type must be int64 or int32.
634
            shape_tensor = paddle.to_tensor([2, 3])
Z
zhupengyang 已提交
635 636 637
            out3 = paddle.uniform(shape_tensor)
            # [[-0.8517412,  -0.4006908,   0.2551912 ], # random
            #  [ 0.3364414,   0.36278176, -0.16085452]] # random
P
pangyoki 已提交
638
    """
639
    supported_dtypes = ['float32', 'float64', 'float16', 'uint16']
640 641
    if dtype is None:
        dtype = paddle.framework.get_default_dtype()
642
        if dtype not in supported_dtypes:
643
            raise TypeError(
644 645
                "uniform/rand only supports {}, but the default dtype is {}".format(
                    supported_dtypes, dtype
646 647
                )
            )
648

P
pangyoki 已提交
649 650 651
    if not isinstance(dtype, core.VarDesc.VarType):
        dtype = convert_np_dtype_to_dtype_(dtype)

652
    if in_dynamic_mode():
653
        shape = paddle.utils.convert_shape_to_list(shape)
654
        return _C_ops.uniform(
655 656 657 658 659 660 661
            shape,
            dtype,
            float(min),
            float(max),
            seed,
            _current_expected_place(),
        )
662 663
    else:
        check_type(shape, 'shape', (list, tuple, Variable), 'uniform/rand')
664
        check_dtype(dtype, 'dtype', supported_dtypes, 'uniform/rand')
665 666 667
        check_type(min, 'min', (float, int, Variable), 'uniform/rand')
        check_type(max, 'max', (float, int, Variable), 'uniform/rand')

668
        inputs = {}
669
        attrs = {'seed': seed, 'min': min, 'max': max, 'dtype': dtype}
670
        paddle.utils.get_shape_tensor_inputs(
671
            inputs=inputs, attrs=attrs, shape=shape, op_type='uniform/rand'
672
        )
P
pangyoki 已提交
673

674 675 676 677 678 679 680 681 682 683
        helper = LayerHelper("uniform", **locals())
        out = helper.create_variable_for_type_inference(dtype)
        helper.append_op(
            type="uniform_random",
            inputs=inputs,
            attrs=attrs,
            outputs={"Out": out},
        )
        out.stop_gradient = True
        return out
P
pangyoki 已提交
684 685


J
JYChen 已提交
686 687 688
@dygraph_only
def uniform_(x, min=-1.0, max=1.0, seed=0, name=None):
    """
689
    This is the inplace version of OP ``uniform``, which returns a Tensor filled
J
JYChen 已提交
690 691
    with random values sampled from a uniform distribution. The output Tensor will
    be inplaced with input ``x``. Please refer to :ref:`api_tensor_uniform`.
692

J
JYChen 已提交
693 694 695 696 697 698
    Args:
        x(Tensor): The input tensor to be filled with random values.
        min(float|int, optional): The lower bound on the range of random values
            to generate, ``min`` is included in the range. Default is -1.0.
        max(float|int, optional): The upper bound on the range of random values
            to generate, ``max`` is excluded in the range. Default is 1.0.
699 700
        seed(int, optional): Random seed used for generating samples. If seed is 0,
            it will use the seed of the global default generator (which can be set by paddle.seed).
J
JYChen 已提交
701 702 703 704 705 706 707 708 709 710
            Note that if seed is not 0, this operator will always generate the same random numbers every
            time. Default is 0.
        name(str, optional): The default value is None. Normally there is no
            need for user to set this property. For more information, please
            refer to :ref:`api_guide_Name`.
    Returns:
        Tensor: The input tensor x filled with random values sampled from a uniform
        distribution in the range [``min``, ``max``).
    Examples:
        .. code-block:: python
711

J
JYChen 已提交
712 713 714 715 716 717 718 719 720
            import paddle
            # example:
            x = paddle.ones(shape=[3, 4])
            x.uniform_()
            print(x)
            # [[ 0.84524226,  0.6921872,   0.56528175,  0.71690357], # random
            #  [-0.34646994, -0.45116323, -0.09902662, -0.11397249], # random
            #  [ 0.433519,    0.39483607, -0.8660099,   0.83664286]] # random
    """
721
    return _C_ops.uniform_inplace_(x, min, max, seed, 0, 0, 1.0)
J
JYChen 已提交
722 723


724
def randint(low=0, high=None, shape=[1], dtype=None, name=None):
S
silingtong123 已提交
725
    """
726
    Returns a Tensor filled with random integers from a discrete uniform
727 728
    distribution in the range [``low``, ``high``), with ``shape`` and ``dtype``.
    If ``high`` is None (the default), the range is [0, ``low``).
S
silingtong123 已提交
729 730

    Args:
731
        low (int, optional): The lower bound on the range of random values to generate.
732 733
            The ``low`` is included in the range. If ``high`` is None, the
            range is [0, ``low``). Default is 0.
734
        high (int, optional): The upper bound on the range of random values to
735 736
            generate, the ``high`` is excluded in the range. Default is None
            (see above for behavior if high = None). Default is None.
737 738 739
        shape (tuple|list|Tensor): Shape of the Tensor to be created. The data type is ``int32`` or ``int64`` .
            If ``shape`` is a list or tuple, each element of it should be integer or 0-D Tensor with shape [].
            If ``shape`` is an Tensor, it should be an 1-D Tensor which represents a list. Default is [1].
740
        dtype (str|np.dtype, optional): The data type of the
741 742
            output tensor. Supported data types: int32, int64. If ``dytpe``
            is None, the data type is int64. Default is None.
743
        name (str, optional): The default value is None.  Normally there is no
744 745
            need for user to set this property.  For more information, please
            refer to :ref:`api_guide_Name`.
S
silingtong123 已提交
746

747
    Returns:
748 749
        Tensor: A Tensor filled with random integers from a discrete uniform
        distribution in the range [``low``, ``high``), with ``shape`` and ``dtype``.
S
silingtong123 已提交
750 751 752

    Examples:
        .. code-block:: python
753

754
            import paddle
755

756 757
            # example 1:
            # attr shape is a list which doesn't contain Tensor.
758
            out1 = paddle.randint(low=-5, high=5, shape=[2, 3])
759 760 761 762
            # [0, -3, 2]  # random

            # example 2:
            # attr shape is a list which contains Tensor.
763 764
            dim1 = paddle.to_tensor(2, 'int64')
            dim2 = paddle.to_tensor(3, 'int32')
Z
zhupengyang 已提交
765
            out2 = paddle.randint(low=-5, high=5, shape=[dim1, dim2])
766 767 768 769 770
            # [[0, -1, -3],  # random
            #  [4, -2,  0]]  # random

            # example 3:
            # attr shape is a Tensor
771
            shape_tensor = paddle.to_tensor([2, 3])
Z
zhupengyang 已提交
772
            out3 = paddle.randint(low=-5, high=5, shape=shape_tensor)
773 774
            # [[ 2, -3, -1],    # random
            #  [-3, -2,  1]])   # random
775 776 777

            # example 4:
            # data type is int32
778
            out4 = paddle.randint(low=-5, high=5, shape=[3], dtype='int32')
779 780 781 782 783
            # [-5, 4, -4]  # random

            # example 5:
            # Input only one parameter
            # low=0, high=10, shape=[1], dtype='int64'
784
            out5 = paddle.randint(10)
785
            # [7]  # random
S
silingtong123 已提交
786

787 788
    """
    if high is None:
789 790
        if low <= 0:
            raise ValueError(
791
                "If high is None, low must be greater than 0, but received low = {}.".format(
792 793 794
                    low
                )
            )
795 796
        high = low
        low = 0
S
silingtong123 已提交
797
    if dtype is None:
W
Weilong Wu 已提交
798 799
        dtype = core.VarDesc.VarType.INT64
    elif not isinstance(dtype, core.VarDesc.VarType):
800
        dtype = convert_np_dtype_to_dtype_(dtype)
S
silingtong123 已提交
801

802
    if in_dynamic_mode():
803
        shape = paddle.utils.convert_shape_to_list(shape)
F
From00 已提交
804
        place = _current_expected_place()
805
        return _C_ops.randint(low, high, shape, dtype, place)
806 807 808 809 810
    else:
        check_shape(shape, 'randint')
        check_dtype(dtype, 'dtype', ['int32', 'int64'], 'randint')
        if low >= high:
            raise ValueError(
811 812
                f"randint's low must less then high, but received low = {low}, "
                f"high = {high}"
813
            )
S
silingtong123 已提交
814

815
        inputs = {}
816
        attrs = {'low': low, 'high': high, 'seed': 0, 'dtype': dtype}
817
        paddle.utils.get_shape_tensor_inputs(
818
            inputs=inputs, attrs=attrs, shape=shape, op_type='randint'
819
        )
S
silingtong123 已提交
820

821 822 823 824 825 826 827
        helper = LayerHelper("randint", **locals())
        out = helper.create_variable_for_type_inference(dtype=dtype)
        helper.append_op(
            type='randint', inputs=inputs, outputs={'Out': out}, attrs=attrs
        )
        out.stop_gradient = True
        return out
C
cc 已提交
828 829


830 831
def randint_like(x, low=0, high=None, dtype=None, name=None):
    """
832
    Returns a Tensor filled with random integers from a discrete uniform
833
    distribution in the range [``low``, ``high``), with the same shape as ``x``.
834
    (use ``dtype`` if ``dtype`` is not None)
835 836 837
    If ``high`` is None (the default), the range is [0, ``low``).

    Args:
838
        x (Tensor): The input multi-dimensional tensor which specifies shape. The dtype of ``x``
839
            can be bool, int32, int64, float16, float32, float64.
840
        low (int, optional): The lower bound on the range of random values to generate.
841 842 843
            The ``low`` is included in the range. If ``high`` is None, the
            range is [0, ``low``). Default is 0.
        high (int, optional): The upper bound on the range of random values to
844 845
            generate, the ``high`` is excluded in the range. Default is None.
            If ``high`` is None, the range is [0, ``low``).
846
        dtype (str|np.dtype, optional): The data type of the
847
            output tensor. Supported data types: bool, int32, int64, float16,
848 849 850 851 852 853
            float32, float64. If ``dytpe`` is None, the data type is the
            same as x's data type. Default is None.
        name (str, optional): The default value is None.  Normally there is no
            need for user to set this property.  For more information, please
            refer to :ref:`api_guide_Name`.

854
    Returns:
855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956
        Tensor: A Tensor filled with random integers from a discrete uniform
        distribution in the range [``low``, ``high``), with ``shape`` and ``dtype``.

    Examples:
        .. code-block:: python

            import paddle

            # example 1:
            # dtype is None and the dtype of x is float16
            x = paddle.zeros((1,2)).astype("float16")
            out1 = paddle.randint_like(x, low=-5, high=5)
            print(out1)
            print(out1.dtype)
            # [[0, -3]]  # random
            # paddle.float16

            # example 2:
            # dtype is None and the dtype of x is float32
            x = paddle.zeros((1,2)).astype("float32")
            out2 = paddle.randint_like(x, low=-5, high=5)
            print(out2)
            print(out2.dtype)
            # [[0, -3]]  # random
            # paddle.float32

            # example 3:
            # dtype is None and the dtype of x is float64
            x = paddle.zeros((1,2)).astype("float64")
            out3 = paddle.randint_like(x, low=-5, high=5)
            print(out3)
            print(out3.dtype)
            # [[0, -3]]  # random
            # paddle.float64

            # example 4:
            # dtype is None and the dtype of x is int32
            x = paddle.zeros((1,2)).astype("int32")
            out4 = paddle.randint_like(x, low=-5, high=5)
            print(out4)
            print(out4.dtype)
            # [[0, -3]]  # random
            # paddle.int32

            # example 5:
            # dtype is None and the dtype of x is int64
            x = paddle.zeros((1,2)).astype("int64")
            out5 = paddle.randint_like(x, low=-5, high=5)
            print(out5)
            print(out5.dtype)
            # [[0, -3]]  # random
            # paddle.int64

            # example 6:
            # dtype is float64 and the dtype of x is float32
            x = paddle.zeros((1,2)).astype("float32")
            out6 = paddle.randint_like(x, low=-5, high=5, dtype="float64")
            print(out6)
            print(out6.dtype)
            # [[0, -1]]  # random
            # paddle.float64

            # example 7:
            # dtype is bool and the dtype of x is float32
            x = paddle.zeros((1,2)).astype("float32")
            out7 = paddle.randint_like(x, low=-5, high=5, dtype="bool")
            print(out7)
            print(out7.dtype)
            # [[0, -1]]  # random
            # paddle.bool

            # example 8:
            # dtype is int32 and the dtype of x is float32
            x = paddle.zeros((1,2)).astype("float32")
            out8 = paddle.randint_like(x, low=-5, high=5, dtype="int32")
            print(out8)
            print(out8.dtype)
            # [[0, -1]]  # random
            # paddle.int32

            # example 9:
            # dtype is int64 and the dtype of x is float32
            x = paddle.zeros((1,2)).astype("float32")
            out9 = paddle.randint_like(x, low=-5, high=5, dtype="int64")
            print(out9)
            print(out9.dtype)
            # [[0, -1]]  # random
            # paddle.int64

            # example 10:
            # dtype is int64 and the dtype of x is bool
            x = paddle.zeros((1,2)).astype("bool")
            out10 = paddle.randint_like(x, low=-5, high=5, dtype="int64")
            print(out10)
            print(out10.dtype)
            # [[0, -1]]  # random
            # paddle.int64

    """
    if high is None:
        if low <= 0:
            raise ValueError(
957
                "If high is None, low must be greater than 0, but received low = {}.".format(
958 959 960
                    low
                )
            )
961 962 963 964 965 966
        high = low
        low = 0
    if dtype is None:
        dtype = x.dtype
    if not isinstance(dtype, core.VarDesc.VarType):
        dtype = convert_np_dtype_to_dtype_(dtype)
967
    shape = paddle.shape(x)
968 969 970

    if low >= high:
        raise ValueError(
971 972
            f"randint_like's low must less then high, but received low = {low}, "
            f"high = {high}"
973
        )
974

975
    if in_dynamic_mode():
976
        shape = paddle.utils.convert_shape_to_list(shape)
977 978 979 980 981 982 983 984 985 986 987 988
        out = _legacy_C_ops.randint(
            'shape',
            shape,
            'low',
            low,
            'high',
            high,
            'seed',
            0,
            'dtype',
            core.VarDesc.VarType.INT64,
        )
989 990
        out = paddle.cast(out, dtype)
        return out
991 992 993 994 995 996 997 998
    else:
        check_shape(shape, 'randint_like')
        check_dtype(
            dtype,
            'dtype',
            ['bool', 'float16', 'float32', 'float64', 'int32', 'int64'],
            'randint_like',
        )
999

1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017
        inputs = {"ShapeTensor": shape}
        attrs = {
            'low': low,
            'high': high,
            'seed': 0,
            'dtype': core.VarDesc.VarType.INT64,
        }

        helper = LayerHelper("randint", **locals())
        out = helper.create_variable_for_type_inference(
            dtype=core.VarDesc.VarType.INT64
        )
        helper.append_op(
            type='randint', inputs=inputs, outputs={'Out': out}, attrs=attrs
        )
        out.stop_gradient = True
        out = paddle.cast(out, dtype)
        return out
1018 1019


1020
def randperm(n, dtype="int64", name=None):
C
cc 已提交
1021
    """
1022
    Returns a 1-D Tensor filled with random permutation values from 0
1023
    to n-1, with ``dtype``.
C
cc 已提交
1024 1025

    Args:
1026 1027
        n (int): The upper bound (exclusive), and it should be greater than 0.
        dtype (str|np.dtype, optional): The data type of
1028 1029
            the output Tensor. Supported data types: int32, int64, float32,
            float64. Default is int64.
1030
        name (str, optional): The default value is None. Normally there is no
1031 1032
            need for user to set this property. For more information, please
            refer to :ref:`api_guide_Name`.
C
cc 已提交
1033 1034

    Returns:
1035 1036
        Tensor: A 1-D Tensor filled with random permutation values from 0
        to n-1, with ``dtype``.
C
cc 已提交
1037 1038 1039 1040

    Examples:
        .. code-block:: python

1041
            import paddle
C
cc 已提交
1042

1043
            out1 = paddle.randperm(5)
1044
            # [4, 1, 2, 3, 0]  # random
C
cc 已提交
1045

1046
            out2 = paddle.randperm(7, 'int32')
1047
            # [1, 6, 2, 0, 4, 3, 5]  # random
1048

C
cc 已提交
1049
    """
1050 1051 1052
    if not isinstance(dtype, core.VarDesc.VarType):
        dtype = convert_np_dtype_to_dtype_(dtype)

1053
    if in_dynamic_mode():
1054
        return _C_ops.randperm(n, dtype, _current_expected_place())
1055 1056 1057 1058 1059 1060 1061 1062
    else:
        if n < 1:
            raise ValueError(
                "The input n should be greater than 0 in randperm op."
            )
        check_dtype(
            dtype, 'dtype', ['int64', 'int32', 'float32', 'float64'], 'randperm'
        )
C
cc 已提交
1063

1064 1065 1066 1067 1068 1069 1070 1071
        helper = LayerHelper("randperm", **locals())
        out = helper.create_variable_for_type_inference(dtype)
        attrs = {'n': n, 'dtype': dtype, 'seed': 0}
        helper.append_op(
            type='randperm', inputs={}, outputs={'Out': out}, attrs=attrs
        )
        out.stop_gradient = True
        return out
X
Xing Wu 已提交
1072 1073


1074
def rand(shape, dtype=None, name=None):
X
Xing Wu 已提交
1075
    """
1076
    Returns a Tensor filled with random values sampled from a uniform
1077
    distribution in the range [0, 1), with ``shape`` and ``dtype``.
X
Xing Wu 已提交
1078 1079

    Args:
1080 1081 1082
        shape (tuple|list|Tensor): Shape of the Tensor to be created. The data type is ``int32`` or ``int64`` .
            If ``shape`` is a list or tuple, each element of it should be integer or 0-D Tensor with shape [].
            If ``shape`` is an Tensor, it should be an 1-D Tensor which represents a list.
1083
        dtype (str|np.dtype, optional): The data type of the output Tensor.
1084 1085 1086
            Supported data types: float32, float64.
            Default is None, use global default dtype (see ``get_default_dtype``
            for details).
1087
        name (str, optional): The default value is None. Normally there is no
1088 1089
            need for user to set this property. For more information, please
            refer to :ref:`api_guide_Name`.
1090

X
Xing Wu 已提交
1091
    Returns:
1092 1093
        Tensor: A Tensor filled with random values sampled from a uniform
        distribution in the range [0, 1), with ``shape`` and ``dtype``.
X
Xing Wu 已提交
1094 1095 1096 1097

    Examples:
        .. code-block:: python

1098
            import paddle
1099

1100
            # example 1: attr shape is a list which doesn't contain Tensor.
1101
            out1 = paddle.rand(shape=[2, 3])
1102 1103 1104 1105
            # [[0.451152  , 0.55825245, 0.403311  ],  # random
            #  [0.22550228, 0.22106001, 0.7877319 ]]  # random

            # example 2: attr shape is a list which contains Tensor.
1106 1107
            dim1 = paddle.to_tensor(2, 'int64')
            dim2 = paddle.to_tensor(3, 'int32')
1108
            out2 = paddle.rand(shape=[dim1, dim2, 2])
1109 1110 1111 1112 1113 1114 1115 1116
            # [[[0.8879919 , 0.25788337],  # random
            #   [0.28826773, 0.9712097 ],  # random
            #   [0.26438272, 0.01796806]],  # random
            #  [[0.33633623, 0.28654453],  # random
            #   [0.79109055, 0.7305809 ],  # random
            #   [0.870881  , 0.2984597 ]]]  # random

            # example 3: attr shape is a Tensor, the data type must be int64 or int32.
1117
            shape_tensor = paddle.to_tensor([2, 3])
Z
zhupengyang 已提交
1118
            out3 = paddle.rand(shape_tensor)
1119 1120
            # [[0.22920267, 0.841956  , 0.05981819],  # random
            #  [0.4836288 , 0.24573246, 0.7516129 ]]  # random
X
Xing Wu 已提交
1121
    """
1122
    return uniform(shape, dtype, min=0.0, max=1.0, name=name)
1123 1124 1125


def exponential_(x, lam=1.0, name=None):
1126
    r"""
1127 1128
    This inplace OP fill input Tensor ``x`` with random number from a Exponential Distribution.

1129 1130
    ``lam`` is :math:`\lambda` parameter of Exponential Distribution.

1131 1132 1133 1134 1135 1136
    .. math::

        f(x) = \lambda e^{-\lambda x}

    Args:
        x(Tensor):  Input tensor. The data type should be float32, float64.
1137
        lam(float, optional): :math:`\lambda` parameter of Exponential Distribution. Default, 1.0.
1138 1139 1140
        name(str, optional): The default value is None. Normally there is no
            need for user to set this property. For more information, please
            refer to :ref:`api_guide_Name`.
1141
    Returns:
1142 1143 1144 1145 1146 1147 1148 1149 1150 1151 1152 1153 1154 1155 1156
        Tensor: Input Tensor ``x``.

    Examples:
        .. code-block:: python

            import paddle
            paddle.set_device('cpu')
            paddle.seed(100)

            x = paddle.empty([2,3])
            x.exponential_()
            # [[0.80643415, 0.23211166, 0.01169797],
            #  [0.72520673, 0.45208144, 0.30234432]]

    """
1157
    if in_dynamic_mode():
1158
        return _C_ops.exponential_(x, lam)
1159
    else:
1160 1161 1162
        check_variable_and_dtype(
            x, "x", ["float16", "float32", "float64", "uint16"], "exponential"
        )
1163 1164 1165 1166 1167 1168 1169 1170 1171

        helper = LayerHelper("exponential", **locals())
        helper.append_op(
            type='exponential',
            inputs={"X": x},
            outputs={'Out': x},
            attrs={"lambda": lam},
        )
        return x