norm.py 1.6 KB
Newer Older
L
LielinJiang 已提交
1 2 3
import paddle
import functools
import paddle.nn as nn
4
from .nn import Spectralnorm
L
LielinJiang 已提交
5 6


L
fix nan  
LielinJiang 已提交
7
class Identity(nn.Layer):
L
LielinJiang 已提交
8 9 10 11 12 13 14 15 16 17 18 19 20 21
    def forward(self, x):
        return x


def build_norm_layer(norm_type='instance'):
    """Return a normalization layer

    Args:
        norm_type (str) -- the name of the normalization layer: batch | instance | none

    For BatchNorm, we use learnable affine parameters and track running statistics (mean/stddev).
    For InstanceNorm, we do not use learnable affine parameters. We do not track running statistics.
    """
    if norm_type == 'batch':
L
fix nan  
LielinJiang 已提交
22 23 24 25 26 27 28
        norm_layer = functools.partial(
            nn.BatchNorm,
            param_attr=paddle.ParamAttr(
                initializer=nn.initializer.Normal(1.0, 0.02)),
            bias_attr=paddle.ParamAttr(
                initializer=nn.initializer.Constant(0.0)),
            trainable_statistics=True)
L
LielinJiang 已提交
29
    elif norm_type == 'instance':
L
fix nan  
LielinJiang 已提交
30 31 32 33 34 35 36 37 38
        norm_layer = functools.partial(
            nn.InstanceNorm,
            param_attr=paddle.ParamAttr(
                initializer=nn.initializer.Constant(1.0),
                learning_rate=0.0,
                trainable=False),
            bias_attr=paddle.ParamAttr(initializer=nn.initializer.Constant(0.0),
                                       learning_rate=0.0,
                                       trainable=False))
39 40
    elif norm_type == 'spectral':
        norm_layer = functools.partial(Spectralnorm)
L
LielinJiang 已提交
41
    elif norm_type == 'none':
L
fix nan  
LielinJiang 已提交
42 43 44

        def norm_layer(x):
            return Identity()
L
LielinJiang 已提交
45
    else:
L
fix nan  
LielinJiang 已提交
46 47 48
        raise NotImplementedError('normalization layer [%s] is not found' %
                                  norm_type)
    return norm_layer