test_vision.py 1.1 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40
import time

import numpy as np
import pytest

from megengine import Tensor
from megengine.module import (
    AdditiveGaussianNoise,
    AdditiveLaplaceNoise,
    AdditivePoissonNoise,
)


@pytest.mark.parametrize(
    "cls", [AdditiveGaussianNoise, AdditiveLaplaceNoise, AdditivePoissonNoise]
)
@pytest.mark.parametrize("per_channel", [False, True])
@pytest.mark.parametrize(
    "shape, format",
    [
        ((128, 3, 160, 160), "default"),
        ((128, 160, 160, 3), "nhwc"),
        ((128, 3, 160, 160), "nchw"),
    ],
)
@pytest.mark.parametrize("seed", [1024, None])
def test_AdditiveNoise(cls, per_channel, shape, format, seed):
    if not per_channel and format == "default":
        return

    input_tensor = Tensor(
        np.random.random(shape), np.float32, device="xpux", format=format
    )

    aug = cls(per_channel=per_channel, seed=seed)
    aug_data = aug(input_tensor)
    if seed is not None:  # fix rng seed
        aug_ref = cls(per_channel=per_channel, seed=seed)
        aug_data_ref = aug_ref(input_tensor)
        np.testing.assert_allclose(aug_data, aug_data_ref)