norm.py 2.6 KB
Newer Older
M
MegEngine Team 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73
# -*- coding: utf-8 -*-
# Copyright 2019 - present, Facebook, Inc
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ---------------------------------------------------------------------
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#
# This file has been modified by Megvii ("Megvii Modifications").
# All Megvii Modifications are Copyright (C) 2014-2019 Megvii Inc. All rights reserved.
# ---------------------------------------------------------------------
import megengine.module as M
import numpy as np

from megengine.core import Buffer


class FrozenBatchNorm2d(M.Module):
    """
    BatchNorm2d, which the weight, bias, running_mean, running_var
    are immutable.
    """

    def __init__(self, num_features, eps=1e-5):
        super().__init__()
        self.num_features = num_features
        self.eps = eps

        self.weight = Buffer(np.ones(num_features, dtype=np.float32))
        self.bias = Buffer(np.zeros(num_features, dtype=np.float32))

        self.running_mean = Buffer(np.zeros((1, num_features, 1, 1), dtype=np.float32))
        self.running_var = Buffer(np.ones((1, num_features, 1, 1), dtype=np.float32))

    def forward(self, x):
        scale = self.weight.reshape(1, -1, 1, 1) * (
            1.0 / (self.running_var + self.eps).sqrt()
        )
        bias = self.bias.reshape(1, -1, 1, 1) - self.running_mean * scale
        return x * scale + bias


def get_norm(norm, out_channels=None):
    """
    Args:
        norm (str): currently support "BN" and "FrozenBN"

    Returns:
        M.Module or None: the normalization layer
    """
    if isinstance(norm, str):
        if len(norm) == 0:
            return None
        norm = {"BN": M.BatchNorm2d, "FrozenBN": FrozenBatchNorm2d}[norm]
    if out_channels is not None:
        return norm(out_channels)
    else:
        return norm