test_rng.py 2.3 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import numpy as np

from megengine import tensor
from megengine.core._imperative_rt import CompNode
from megengine.core._imperative_rt.ops import delete_rng_handle, new_rng_handle
from megengine.core.ops.builtin import GaussianRNG, UniformRNG
from megengine.core.tensor.core import apply


def test_gaussian_rng():
    shape = (
        8,
        9,
        11,
        12,
    )
    shape = tensor(shape, dtype="int32")
    op = GaussianRNG(1.0, 3.0)
    (output,) = apply(op, shape)
    assert np.fabs(output.numpy().mean() - 1.0) < 1e-1
    assert np.sqrt(output.numpy().var()) - 3.0 < 1e-1
    assert str(output.device) == str(CompNode("xpux"))

    cn = CompNode("xpu1")
    op = GaussianRNG(-1.0, 2.0, cn)
    (output,) = apply(op, shape)
    assert np.fabs(output.numpy().mean() - (-1.0)) < 1e-1
    assert np.sqrt(output.numpy().var()) - 2.0 < 1e-1
    assert str(output.device) == str(cn)

    cn = CompNode("xpu2")
    seed = 233333
    h = new_rng_handle(cn, seed)
    op = GaussianRNG(3.0, 1.0, h)
    (output,) = apply(op, shape)
    delete_rng_handle(h)
    assert np.fabs(output.numpy().mean() - 3.0) < 1e-1
    assert np.sqrt(output.numpy().var()) - 1.0 < 1e-1
    assert str(output.device) == str(cn)


def test_uniform_rng():
    shape = (
        8,
        9,
        11,
        12,
    )
    shape = tensor(shape, dtype="int32")
    op = UniformRNG()
    (output,) = apply(op, shape)
    assert np.fabs(output.numpy().mean() - 0.5) < 1e-1
    assert str(output.device) == str(CompNode("xpux"))

    cn = CompNode("xpu1")
    op = UniformRNG(cn)
    (output,) = apply(op, shape)
    assert np.fabs(output.numpy().mean() - 0.5) < 1e-1
    assert str(output.device) == str(cn)

    cn = CompNode("xpu2")
    seed = 233333
    h = new_rng_handle(cn, seed)
    op = UniformRNG(h)
    (output,) = apply(op, shape)
    delete_rng_handle(h)
    assert np.fabs(output.numpy().mean() - 0.5) < 1e-1
    assert str(output.device) == str(cn)