# -*- coding: utf-8 -*- # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") # # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. import numpy as np from megengine import tensor from megengine.core._imperative_rt import CompNode from megengine.core._imperative_rt.ops import delete_rng_handle, new_rng_handle from megengine.core.ops.builtin import GaussianRNG, UniformRNG from megengine.core.tensor.core import apply def test_gaussian_rng(): shape = ( 8, 9, 11, 12, ) shape = tensor(shape, dtype="int32") op = GaussianRNG(1.0, 3.0) (output,) = apply(op, shape) assert np.fabs(output.numpy().mean() - 1.0) < 1e-1 assert np.sqrt(output.numpy().var()) - 3.0 < 1e-1 assert str(output.device) == str(CompNode("xpux")) cn = CompNode("xpu1") op = GaussianRNG(-1.0, 2.0, cn) (output,) = apply(op, shape) assert np.fabs(output.numpy().mean() - (-1.0)) < 1e-1 assert np.sqrt(output.numpy().var()) - 2.0 < 1e-1 assert str(output.device) == str(cn) cn = CompNode("xpu2") seed = 233333 h = new_rng_handle(cn, seed) op = GaussianRNG(3.0, 1.0, h) (output,) = apply(op, shape) delete_rng_handle(h) assert np.fabs(output.numpy().mean() - 3.0) < 1e-1 assert np.sqrt(output.numpy().var()) - 1.0 < 1e-1 assert str(output.device) == str(cn) def test_uniform_rng(): shape = ( 8, 9, 11, 12, ) shape = tensor(shape, dtype="int32") op = UniformRNG() (output,) = apply(op, shape) assert np.fabs(output.numpy().mean() - 0.5) < 1e-1 assert str(output.device) == str(CompNode("xpux")) cn = CompNode("xpu1") op = UniformRNG(cn) (output,) = apply(op, shape) assert np.fabs(output.numpy().mean() - 0.5) < 1e-1 assert str(output.device) == str(cn) cn = CompNode("xpu2") seed = 233333 h = new_rng_handle(cn, seed) op = UniformRNG(h) (output,) = apply(op, shape) delete_rng_handle(h) assert np.fabs(output.numpy().mean() - 0.5) < 1e-1 assert str(output.device) == str(cn)