test_formatted_tensor.py 13.2 KB
Newer Older
1 2 3 4 5
import numpy as np
import pytest

import megengine as mge
import megengine.functional as F
6
import megengine.module as M
7 8
from megengine import tensor
from megengine.autodiff import GradManager
9
from megengine.core._trace_option import use_symbolic_shape
10
from megengine.jit import trace
11 12 13


def test_basic():
14 15 16
    data = np.arange(0, 24).reshape((1, 2, 3, 4))
    # init from numpy
    a = tensor(data, format="nhwc")
17
    assert a.format == "nhwc"
18 19

    # init from tensor
20 21
    b = tensor(a)
    assert b.format == "nhwc"
22

23 24 25 26
    b = tensor(data, format="nchw")
    result = F.pad(b, ((0, 0), (0, 0), (1, 1), (1, 1)), mode="reflect")
    assert result.format == "default"

27
    # TODO: init from tensor with new format
28 29 30
    # c = tensor(a, format="nchw")
    # assert c.format == "nchw"

31 32 33 34 35 36 37 38
    # TODO: reset from numpy
    # b[...] = data
    # assert b.format == "nhwc"

    # reset from tensor
    b[...] = tensor(data, format="nchw")
    assert b.format == "nchw"

39 40 41
    # set tensor's format
    b.format = "nhwc"
    assert b.format == "nhwc"
42

43

44
def _compare_nchw_nhwc(data, func, is_symbolic=None):
45
    x1 = tensor(data)
46
    x2 = tensor(data, format="nhwc")
47 48
    if is_symbolic is not None:
        func = trace(func, symbolic=is_symbolic)
49
    out1 = func(x1)
50
    out2 = func(x2)
51
    np.testing.assert_almost_equal(out1, out2, decimal=5)
52 53


54 55
@pytest.mark.parametrize("is_symbolic", [None])
def test_dimshuffle(is_symbolic):
56 57 58 59 60 61
    def func(x):
        out = F.transpose(x, [2, 3, 0, 1])
        assert out.format == "default"
        return out.numpy()

    data = np.arange(0, 24).reshape((1, 2, 3, 4))
62
    _compare_nchw_nhwc(data, func, is_symbolic)
63 64


65 66
@pytest.mark.parametrize("is_symbolic", [None])
def test_reshape(is_symbolic):
67 68 69
    # maintain NHWC format
    def func(x):
        out = F.reshape(x, (1, 2, 6, 2))
70
        assert out.format == x.format
71 72 73
        return out.numpy()

    data = np.arange(0, 24).reshape((1, 2, 3, 4))
74
    _compare_nchw_nhwc(data, func, is_symbolic)
75 76 77 78 79 80 81

    # not maintain NHWC format
    def func2(x):
        out = F.reshape(x, (1, 24))
        assert out.format == "default"
        return out.numpy()

82
    _compare_nchw_nhwc(data, func2, is_symbolic)
83 84


85 86
@pytest.mark.parametrize("is_symbolic", [None])
def test_flatten(is_symbolic):
87 88 89 90
    def func(x):
        return F.flatten(x).numpy()

    data = np.arange(0, 24).reshape((1, 2, 3, 4))
91
    _compare_nchw_nhwc(data, func, is_symbolic)
92 93


94 95
@pytest.mark.parametrize("is_symbolic", [None])
def test_broadcast(is_symbolic):
96 97 98
    # maintain NHWC format
    def func(x):
        out = F.broadcast_to(x, (4, 3, 2, 3))
99
        assert out.format == x.format
100 101 102
        return out.numpy()

    data = np.arange(0, 24).reshape((4, 3, 2, 1))
103
    _compare_nchw_nhwc(data, func, is_symbolic)
104 105 106 107 108 109 110

    # not maintain NHWC format
    def func2(x):
        out = F.broadcast_to(x, (3, 4, 3, 2, 1))
        assert out.format == "default"
        return out.numpy()

111
    _compare_nchw_nhwc(data, func2, is_symbolic)
112 113 114


@pytest.mark.skip("repeat cannot maintain format yet")
115 116
@pytest.mark.parametrize("is_symbolic", [None])
def test_repeat(is_symbolic):
117 118 119 120 121 122
    def func(x):
        rst = F.repeat(x, 3, axis=1)
        assert rst.format == x.format
        return rst.numpy()

    data = np.arange(0, 24).reshape((1, 2, 3, 4))
123
    _compare_nchw_nhwc(data, func, is_symbolic)
124 125


126 127
@pytest.mark.parametrize("is_symbolic", [None])
def test_getshape(is_symbolic):
128
    def func(x):
129 130 131 132
        if use_symbolic_shape():
            return x.shape.numpy()
        else:
            return x.shape
133 134

    data = np.arange(0, 24).reshape((1, 2, 3, 4))
135
    _compare_nchw_nhwc(data, func, is_symbolic)
136 137 138


@pytest.mark.skip("symbolic shape is not supported yet")
139
def test_get_symbolic_shape(is_symbolic):
140 141 142 143 144 145 146 147
    from megengine.core._trace_option import set_symbolic_shape

    origin_opt = set_symbolic_shape(True)

    def func(x):
        return x.shape.numpy()

    data = np.arange(0, 24).reshape((1, 2, 3, 4))
148
    _compare_nchw_nhwc(data, func, is_symbolic)
149 150 151
    set_symbolic_shape(origin_opt)


152 153
@pytest.mark.parametrize("is_symbolic", [None])
def test_getvalue(is_symbolic):
154 155 156 157
    def func(x):
        return x.numpy()

    data = np.arange(0, 24).reshape((1, 2, 3, 4))
158
    _compare_nchw_nhwc(data, func, is_symbolic)
159 160


161 162
@pytest.mark.parametrize("is_symbolic", [None])
def test_get_set_subtensor(is_symbolic):
163 164 165 166
    def get_subtensor(x):
        return x[:, :1, :2, :3].numpy()

    data = np.arange(0, 24).reshape((1, 2, 3, 4))
167
    _compare_nchw_nhwc(data, get_subtensor, is_symbolic)
168 169 170 171 172

    def set_subtensor(x):
        x[:, :1, :2, :3] = 0
        return x.numpy()

173
    _compare_nchw_nhwc(data, set_subtensor, is_symbolic)
174 175


176 177
@pytest.mark.parametrize("is_symbolic", [None])
def test_get_set_advanced_indexing(is_symbolic):
178 179 180 181 182
    def get_advanced_indexing(x):
        x = x[:, : mge.tensor(2), : mge.tensor(2), [1, 2]].numpy()
        return x

    data = np.arange(0, 24).reshape((1, 2, 3, 4))
183
    _compare_nchw_nhwc(data, get_advanced_indexing, is_symbolic)
184 185 186 187 188

    def set_advanced_indexing(x):
        x[:, : mge.tensor(2), : mge.tensor([2]), [1,]] = 0
        return x.numpy()

189
    _compare_nchw_nhwc(data, set_advanced_indexing, is_symbolic)
190 191


192 193
@pytest.mark.parametrize("is_symbolic", [None])
def test_typecvt(is_symbolic):
194 195 196 197
    def typecvt(x):
        return x.astype("float16").numpy()

    data = np.arange(0, 24).reshape((1, 2, 3, 4))
198
    _compare_nchw_nhwc(data, typecvt, is_symbolic)
199 200


201 202
@pytest.mark.parametrize("is_symbolic", [None])
def test_elemwise(is_symbolic):
203
    def elemwise(x):
204 205 206 207
        tmp = F.ones((1, 2, 3, 4))
        oup = x * tmp + x / 2
        assert oup.format == x.format
        return oup.numpy()
208 209

    data = np.arange(0, 24).reshape((1, 2, 3, 4))
210
    _compare_nchw_nhwc(data, elemwise, is_symbolic)
211 212


213 214
@pytest.mark.parametrize("is_symbolic", [None])
def test_concat(is_symbolic):
215
    def func(x):
216 217
        tmp = F.ones((1, 2, 3, 4))
        rst = F.concat([x / 2, tmp], axis=1)
218 219 220 221
        assert rst.format == x.format
        return rst.numpy()

    data = np.arange(0, 24).reshape((1, 2, 3, 4))
222
    _compare_nchw_nhwc(data, func, is_symbolic)
223 224 225 226 227


@pytest.mark.parametrize(
    "mode", ["bilinear", "nearest"],
)
228 229
@pytest.mark.parametrize("is_symbolic", [None])
def test_interpolate(mode, is_symbolic):
230
    def func(x):
231 232 233
        rst = F.vision.interpolate(x, scale_factor=3, mode=mode)
        assert rst.format == x.format
        return rst.numpy()
234 235

    data = np.arange(0, 48).reshape((1, 3, 4, 4)).astype("float32")
236
    _compare_nchw_nhwc(data, func, is_symbolic)
237 238


239 240 241 242 243 244 245 246 247 248 249 250 251
@pytest.mark.skip("not implemented")
@pytest.mark.parametrize("is_symbolic", [None])
def test_warp_perspective(is_symbolic):
    def func(x):
        m_shape = (1, 3, 3)
        m = tensor(np.random.randn(3, 3), dtype=np.float32).reshape(m_shape)
        rst = F.vision.warp_perspective(x, m, (2, 2), format="NHWC")
        return rst.numpy()

    data = np.arange(0, 48).reshape((1, 3, 4, 4)).astype("float32")
    _compare_nchw_nhwc(data, func, is_symbolic)


252 253
@pytest.mark.parametrize("is_symbolic", [None])
def test_conv2d(is_symbolic):
254 255
    def conv2d(x):
        if x.format == "nhwc":
256 257
            x = F.conv2d(
                x,
258 259
                weight=mge.tensor(np.ones((3, 2, 1, 1)), format="nhwc"),
                bias=mge.tensor(np.ones((1, 3, 1, 1)), format="nhwc"),
260 261 262
            )
            assert x.format == "nhwc"
            return x.numpy()
263 264 265 266
        else:
            return F.conv2d(x, F.ones((3, 2, 1, 1)), F.ones((1, 3, 1, 1))).numpy()

    data = np.arange(0, 24).reshape((1, 2, 3, 4))
267
    _compare_nchw_nhwc(data, conv2d, is_symbolic)
268 269


270 271
@pytest.mark.parametrize("is_symbolic", [None])
def test_group_conv2d(is_symbolic):
272 273
    def conv2d(x):
        if x.format == "nhwc":
274 275
            x = F.conv2d(
                x,
276 277
                weight=mge.tensor(np.ones((2, 2, 2, 1, 1)), format="nhwc"),
                bias=mge.tensor(np.ones((1, 4, 1, 1)), format="nhwc"),
278 279 280 281
                groups=2,
            )
            assert x.format == "nhwc"
            return x.numpy()
282 283 284 285 286 287
        else:
            return F.conv2d(
                x, F.ones((2, 2, 2, 1, 1)), F.ones((1, 4, 1, 1)), groups=2
            ).numpy()

    data = np.arange(0, 48).reshape((1, 4, 3, 4))
288
    _compare_nchw_nhwc(data, conv2d, is_symbolic)
289 290


291 292
@pytest.mark.parametrize("is_symbolic", [None])
def test_bn(is_symbolic):
293 294
    def func(x):
        if x.format == "nhwc":
295 296
            oups = F.batch_norm(
                x.astype("float32"),
297 298 299 300
                running_mean=mge.tensor(np.ones((1, 2, 1, 1)), format="nhwc"),
                running_var=mge.tensor(np.ones((1, 2, 1, 1)), format="nhwc"),
                weight=mge.tensor(np.ones((1, 2, 1, 1)), format="nhwc"),
                bias=mge.tensor(np.ones((1, 2, 1, 1)), format="nhwc"),
301 302 303 304 305 306 307
                training=True,
                inplace=False,
            )
            assert oups[0].format == "nhwc", "y's format is wrong"
            assert oups[1].format == "nhwc", "running_mean's format is wrong"
            assert oups[2].format == "nhwc", "running_var's format is wrong"
            return oups[0].numpy()
308 309 310 311 312 313 314 315 316 317 318 319
        else:
            return F.batch_norm(
                x.astype("float32"),
                running_mean=mge.tensor(np.ones((1, 2, 1, 1))),
                running_var=mge.tensor(np.ones((1, 2, 1, 1))),
                weight=mge.tensor(np.ones((1, 2, 1, 1))),
                bias=mge.tensor(np.ones((1, 2, 1, 1))),
                training=True,
                inplace=False,
            )[0].numpy()

    data = np.arange(0, 24).reshape((1, 2, 3, 4))
320
    _compare_nchw_nhwc(data, func, is_symbolic)
321 322 323 324 325 326


@pytest.mark.parametrize(
    "pooling",
    [F.max_pool2d, F.avg_pool2d, F.adaptive_avg_pool2d, F.adaptive_max_pool2d],
)
327 328
@pytest.mark.parametrize("is_symbolic", [None])
def test_pooling2d(pooling, is_symbolic):
329 330
    def func(x):
        if x.format == "nhwc":
331 332 333
            x = pooling(x.astype("float32"), 2)
            assert x.format == "nhwc"
            return x.numpy()
334 335 336 337
        else:
            return pooling(x.astype("float32"), 2).numpy()

    data = np.arange(0, 24).reshape((1, 2, 3, 4))
338
    _compare_nchw_nhwc(data, func, is_symbolic)
339 340


341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369
@pytest.mark.skip("not implemented")
def test_where():
    def func(x):
        mask = tensor(
            np.array([True, False, False, True] * 6, dtype=np.bool).reshape(
                (1, 2, 3, 4)
            )
        )
        y = tensor(
            np.array([1, np.inf, np.nan, 4] * 6, dtype=np.float32).reshape((1, 2, 3, 4))
        )
        out = F.where(mask, x, y)
        assert out.format == "default"
        return out.numpy()

    data = np.arange(0, 24).reshape((1, 2, 3, 4))
    _compare_nchw_nhwc(data, func)


def test_unsupported_op():
    def func(x):
        rst = F.nn.pad(x, pad_width=((1, 1),), mode="constant")
        assert rst.format == "default"
        return rst.numpy()

    data = np.arange(0, 24).reshape((1, 2, 3, 4))
    _compare_nchw_nhwc(data, func)


370 371 372
def _compare_backward(inps, model, is_symbolic=None):
    def func(*inps):
        return model(*inps)
373

374 375
    if is_symbolic is not None:
        func = trace(func, symbolic=is_symbolic)
376

377
    gm = GradManager().attach(model.parameters())
378
    with gm:
379 380 381 382 383 384 385
        with mge.amp.autocast():
            rst = func(*inps)
            gm.backward(rst)
    expected_grads = [param.grad.numpy() for param in gm.attached_tensors()]

    for param in gm.attached_tensors():
        param.grad = None
386 387 388 389 390

    inps = [mge.amp.convert_tensor_format(inp) for inp in inps]
    model = mge.amp.convert_module_format(model)
    gm = GradManager().attach(model.parameters())
    with gm:
391 392 393 394
        with mge.amp.autocast():
            rst = func(*inps)
            gm.backward(rst)
    actual_grads = [param.grad.numpy() for param in gm.attached_tensors()]
395 396

    for expected, actual in zip(expected_grads, actual_grads):
397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414
        assert expected is not None
        assert actual is not None
        np.testing.assert_almost_equal(expected, actual, decimal=5)


@pytest.mark.parametrize("is_symbolic", [None])
def test_backward_basic(is_symbolic):
    class Net(M.Module):
        def __init__(self):
            super().__init__()
            self.w = mge.Parameter([[2.0], [4.0], [6.0]])
            self.b = mge.Parameter(-1.0)

        def forward(self, inp):
            return F.matmul(inp, self.w) + self.b

    inp = mge.tensor([1.0, 3.0, 5.0]).reshape(1, 3)
    _compare_backward([inp], Net(), is_symbolic)
415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436


@pytest.mark.parametrize("is_symbolic", [None])
def test_backward_conv2d_dimshuffle(is_symbolic):
    class Net(M.Module):
        def __init__(self):
            super().__init__()
            self.conv = M.Conv2d(2, 3, 1)

        def forward(self, inp):
            # test manually convert to NHWC, usually used in detection head
            return F.transpose(self.conv(inp), (0, 2, 3, 1)).reshape(1, 18, 2)

    inp = mge.tensor(np.arange(0, 24).reshape((1, 2, 3, 4)))
    _compare_backward([inp], Net(), is_symbolic)


@pytest.mark.parametrize("is_symbolic", [None])
def test_backward_groupconv2d_bn(is_symbolic):
    class Net(M.Module):
        def __init__(self):
            super().__init__()
437 438
            self.conv0 = M.Conv2d(32, 256, 3, groups=32, stride=2)
            self.conv1 = M.Conv2d(256, 2048, 3, groups=32, stride=2)
439
            self.bn = M.BatchNorm2d(2048)
440 441

        def forward(self, inp):
442
            return self.bn(self.conv1(self.conv0(inp)))
443

444
    inp = mge.tensor(np.ones(shape=(32, 32, 56, 56)).astype("float32"))
445
    _compare_backward([inp], Net(), is_symbolic)