norm.py 2.2 KB
Newer Older
Q
qingqing01 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
#   Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

L
LielinJiang 已提交
15 16 17
import paddle
import functools
import paddle.nn as nn
18
from .nn import Spectralnorm
L
LielinJiang 已提交
19 20


L
fix nan  
LielinJiang 已提交
21
class Identity(nn.Layer):
L
LielinJiang 已提交
22 23 24 25 26 27 28 29 30 31 32 33 34 35
    def forward(self, x):
        return x


def build_norm_layer(norm_type='instance'):
    """Return a normalization layer

    Args:
        norm_type (str) -- the name of the normalization layer: batch | instance | none

    For BatchNorm, we use learnable affine parameters and track running statistics (mean/stddev).
    For InstanceNorm, we do not use learnable affine parameters. We do not track running statistics.
    """
    if norm_type == 'batch':
L
fix nan  
LielinJiang 已提交
36 37
        norm_layer = functools.partial(
            nn.BatchNorm,
L
LielinJiang 已提交
38
            param_attr=paddle.ParamAttr(
L
fix nan  
LielinJiang 已提交
39 40 41 42
                initializer=nn.initializer.Normal(1.0, 0.02)),
            bias_attr=paddle.ParamAttr(
                initializer=nn.initializer.Constant(0.0)),
            trainable_statistics=True)
L
LielinJiang 已提交
43
    elif norm_type == 'instance':
L
fix nan  
LielinJiang 已提交
44
        norm_layer = functools.partial(
L
LielinJiang 已提交
45
            nn.InstanceNorm2D,
littletomatodonkey's avatar
littletomatodonkey 已提交
46
            weight_attr=paddle.ParamAttr(
L
fix nan  
LielinJiang 已提交
47 48 49
                initializer=nn.initializer.Constant(1.0),
                learning_rate=0.0,
                trainable=False),
L
LielinJiang 已提交
50 51 52
            bias_attr=paddle.ParamAttr(initializer=nn.initializer.Constant(0.0),
                                       learning_rate=0.0,
                                       trainable=False))
53 54
    elif norm_type == 'spectral':
        norm_layer = functools.partial(Spectralnorm)
L
LielinJiang 已提交
55
    elif norm_type == 'none':
L
fix nan  
LielinJiang 已提交
56 57 58

        def norm_layer(x):
            return Identity()
L
LielinJiang 已提交
59
    else:
L
LielinJiang 已提交
60 61
        raise NotImplementedError('normalization layer [%s] is not found' %
                                  norm_type)
L
fix nan  
LielinJiang 已提交
62
    return norm_layer