distribution.py 2.6 KB
Newer Older
1 2 3
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
4
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
5 6 7 8 9 10 11 12
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from typing import Iterable, Optional

from .. import Tensor
from ..core._imperative_rt import invoke_op
13
from ..core._imperative_rt.core2 import apply
14 15 16 17
from ..core.ops.builtin import GaussianRNG, UniformRNG
from ..core.tensor import utils
from .rng import _random_seed_generator

18
__all__ = ["normal", "uniform"]
19 20


21 22 23
def normal(
    mean: float = 0, std: float = 1, size: Optional[Iterable[int]] = None
) -> Tensor:
24 25
    r"""
    Random variable with Gaussian distribution :math:`N(\mu, \sigma)`.
26

M
Megvii Engine Team 已提交
27 28 29 30
    :param size: output tensor size.
    :param mean: the mean or expectation of the distribution.
    :param std: the standard deviation of the distribution (variance = :math:`\sigma ^ 2`).
    :return: the output tensor.
31 32 33 34 35 36 37 38

    Examples:

    .. testcode::

        import megengine as mge
        import megengine.random as rand

39
        x = rand.normal(mean=0, std=1, size=(2, 2))
40
        print(x.numpy())
M
Megvii Engine Team 已提交
41 42 43
    
    Outputs:
    
44 45 46 47 48 49 50
    .. testoutput::
        :options: +SKIP

        [[-0.20235455 -0.6959438 ]
         [-1.4939808  -1.5824696 ]]

    """
51 52
    if size is None:
        size = (1,)
53
    op = GaussianRNG(mean, std)
54
    _ref = Tensor([], dtype="int32")
55 56 57
    shape = utils.astensor1d(size, _ref, dtype="int32")
    shape = Tensor(shape, dtype="int32")
    (output,) = apply(op, shape)
58 59 60
    return output


61 62 63
def uniform(
    low: float = 0, high: float = 1, size: Optional[Iterable[int]] = None
) -> Tensor:
64 65
    r"""
    Random variable with uniform distribution $U(0, 1)$.
66

M
Megvii Engine Team 已提交
67 68 69 70
    :param size: output tensor size.
    :param low: lower range.
    :param high: upper range.
    :return: the output tensor.
71 72 73 74 75 76 77 78

    Examples:

    .. testcode::

        import megengine as mge
        import megengine.random as rand

79
        x = rand.uniform(size=(2, 2))
80
        print(x.numpy())
M
Megvii Engine Team 已提交
81 82 83
    
    Outputs:
    
84 85 86 87 88 89 90 91 92
    .. testoutput::
        :options: +SKIP

        [[0.76901674 0.70496535]
         [0.09365904 0.62957656]]

    """
    assert low < high, "Uniform is not defined when low >= high"

93 94
    if size is None:
        size = (1,)
95
    op = UniformRNG()
96
    _ref = Tensor([], dtype="int32")
97 98 99
    shape = utils.astensor1d(size, _ref, dtype="int32")
    shape = Tensor(shape, dtype="int32")
    (output,) = apply(op, shape)
100 101

    return low + (high - low) * output