From 7591718d206aaff96802458611ddaf85f0d4dfff Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Tue, 9 Nov 2021 11:37:10 +0800 Subject: [PATCH] feat(mge): add functional test GitOrigin-RevId: aa0be626862782cf877adc4321f4eeda63d249f6 --- .../test/unit/functional/test_functional.py | 32 +++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/imperative/python/test/unit/functional/test_functional.py b/imperative/python/test/unit/functional/test_functional.py index 364cffd29..119a75d6e 100644 --- a/imperative/python/test/unit/functional/test_functional.py +++ b/imperative/python/test/unit/functional/test_functional.py @@ -15,6 +15,7 @@ import pytest from utils import opr_test import megengine.amp as amp +import megengine.config as config import megengine.core.ops.builtin as builtin import megengine.core.tensor.dtype as dtype import megengine.functional as F @@ -1258,3 +1259,34 @@ def test_pixel_shuffle_symbolic(is_symbolic): np.testing.assert_equal(out.numpy(), golden) if is_symbolic is None: break + + +def test_set_conv2d_config(): + """check setting config by contextmanager is equal to manually converted result""" + config._compute_mode = "float32" + inp = tensor(np.random.randn(1, 3, 224, 224), dtype=np.float16) + weight = tensor(np.random.randn(64, 3, 7, 7), dtype=np.float16) + config_out = F.conv2d(inp, weight, None, (2, 2), (3, 3), (1, 1), 1) + config._compute_mode = "default" + with config._override(compute_mode="float32"): + context_out = F.conv2d(inp, weight, None, (2, 2), (3, 3), (1, 1), 1) + expected = F.conv2d( + inp, weight, None, (2, 2), (3, 3), (1, 1), 1, compute_mode="float32", + ) + np.testing.assert_allclose(config_out.numpy(), expected.numpy()) + np.testing.assert_allclose(context_out.numpy(), expected.numpy()) + + +def test_set_warp_perspective_config(): + config._conv_format = "NHWC" + inp_shape = (1, 1, 4, 4) + inp = Tensor(np.arange(16, dtype=np.float32).reshape(inp_shape)) + M_shape = (1, 3, 3) + M = Tensor(np.random.randn(3, 3), dtype=np.float32).reshape(M_shape) + config_out = F.vision.warp_perspective(inp, M, (2, 2)) + config._conv_format = "default" + with config._override(conv_format="NHWC"): + context_out = F.vision.warp_perspective(inp, M, (2, 2)) + expected = F.vision.warp_perspective(inp, M, (2, 2), format="NHWC") + np.testing.assert_allclose(config_out.numpy(), expected.numpy()) + np.testing.assert_allclose(context_out.numpy(), expected.numpy()) -- GitLab