rng.cpp 1.7 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59
/**
 * \file dnn/test/cuda/rng.cpp
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
 *
 * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 */
#include "megdnn/oprs.h"
#include "test/cuda/fixture.h"
#include "test/naive/rng.h"
#include "test/common/tensor.h"

namespace megdnn {

namespace test {

TEST_F(CUDA, UNIFORM_RNG_F32) {
    auto opr = handle_cuda()->create_operator<UniformRNG>();
    SyncedTensor<> t(handle_cuda(), {TensorShape{200000}, dtype::Float32()});
    opr->exec(t.tensornd_dev(), {});

    assert_uniform_correct(t.ptr_mutable_host(),
            t.layout().total_nr_elems());
}

TEST_F(CUDA, GAUSSIAN_RNG_F32) {
    auto opr = handle_cuda()->create_operator<GaussianRNG>();
    opr->param().mean = 0.8;
    opr->param().std = 2.3;
    for (size_t size: {1, 200000, 200001}) {
        TensorLayout ly{{size}, dtype::Float32()};
        Tensor<dt_byte> workspace(handle_cuda(),
                {TensorShape{opr->get_workspace_in_bytes(ly)},
                dtype::Byte()});
        SyncedTensor<> t(handle_cuda(), ly);
        opr->exec(t.tensornd_dev(),
                {workspace.ptr(), workspace.layout().total_nr_elems()});

        auto ptr = t.ptr_mutable_host();
        ASSERT_LE(std::abs(ptr[0] - 0.8), 2.3);

        if (size >= 1000) {
            auto stat = get_mean_var(ptr, size, 0.8f);
            ASSERT_LE(std::abs(stat.first - 0.8), 5e-3);
            ASSERT_LE(std::abs(stat.second - 2.3 * 2.3), 5e-2);
        }
    }
}

} // namespace test
} // namespace megdnn

// vim: syntax=cpp.doxygen