test_bn.py 2.7 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import numpy as np
import pytest

import megengine
13
import megengine.autodiff as ad
14 15 16 17 18 19 20 21 22 23 24 25 26 27
import megengine.optimizer as optimizer
from megengine import Parameter, tensor
from megengine.module import BatchNorm2d


def test_frozen_bn():
    nchannel = 3
    m = BatchNorm2d(nchannel, freeze=True)

    saved_var = m.running_var.numpy()
    saved_mean = m.running_mean.numpy()
    saved_wt = m.weight.numpy()
    saved_bias = m.bias.numpy()

M
Megvii Engine Team 已提交
28
    gm = ad.GradManager().attach(m.parameters())
29
    optim = optimizer.SGD(m.parameters(), lr=1.0)
30
    optim.clear_grad()
31 32

    data = np.random.random((6, nchannel, 2, 2)).astype("float32")
M
Megvii Engine Team 已提交
33
    with gm:
34
        loss = m(data).mean()
35
        gm.backward(loss)
36 37 38 39 40 41 42 43 44 45 46 47 48
    optim.step()

    np.testing.assert_equal(m.running_var.numpy(), saved_var)
    np.testing.assert_equal(m.running_mean.numpy(), saved_mean)
    np.testing.assert_equal(m.weight.numpy(), saved_wt)
    np.testing.assert_equal(m.bias.numpy(), saved_bias)
    np.testing.assert_almost_equal(loss.numpy(), data.mean(), 5)


def test_bn_no_track_stat():
    nchannel = 3
    m = BatchNorm2d(nchannel, track_running_stats=False)

M
Megvii Engine Team 已提交
49
    gm = ad.GradManager().attach(m.parameters())
50
    optim = optimizer.SGD(m.parameters(), lr=1.0)
51
    optim.clear_grad()
52 53

    data = np.random.random((6, nchannel, 2, 2)).astype("float32")
M
Megvii Engine Team 已提交
54
    with gm:
55
        loss = m(data).sum()
56
        gm.backward(loss)
57 58 59 60 61 62 63 64 65 66 67 68 69 70
    optim.step()


def test_bn_no_track_stat2():
    nchannel = 3
    m = BatchNorm2d(nchannel)  # Init with track_running_stat = True
    m.track_running_stats = False

    # m.running_var and m.running_mean created during init time
    saved_var = m.running_var.numpy()
    assert saved_var is not None
    saved_mean = m.running_mean.numpy()
    assert saved_mean is not None

M
Megvii Engine Team 已提交
71
    gm = ad.GradManager().attach(m.parameters())
72
    optim = optimizer.SGD(m.parameters(), lr=1.0)
73
    optim.clear_grad()
74 75

    data = np.random.random((6, nchannel, 2, 2)).astype("float32")
M
Megvii Engine Team 已提交
76
    with gm:
77
        loss = m(data).sum()
78
        gm.backward(loss)
79 80 81 82 83 84 85 86 87 88 89 90 91
    optim.step()

    np.testing.assert_equal(m.running_var.numpy(), saved_var)
    np.testing.assert_equal(m.running_mean.numpy(), saved_mean)


def test_bn_no_track_stat3():
    nchannel = 3
    m = BatchNorm2d(nchannel, track_running_stats=False)
    m.track_running_stats = True
    data = np.random.random((6, nchannel, 2, 2)).astype("float32")
    with pytest.raises(Exception):
        m(data)