normalize.py 1.2 KB
Newer Older
S
SunAhong1993 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
from .register import register
from x2paddle.core.util import *


def normalize_shape(input_shape):
    return input_shape


def normalize_layer(inputs,
                    across_spatial=None,
                    channel_shared=None,
                    input_shape=None,
                    name=None):
    assert across_spatial == False, "Only support across_spatial == False for Normalize"
S
SunAhong1993 已提交
15

S
SunAhong1993 已提交
16 17 18
    input = inputs[0]
    l2_norm = fluid.layers.l2_normalize(input, axis=1, name=name + '_l2')
    scale_param = fluid.layers.create_parameter(
S
SunAhong1993 已提交
19
        shape=[1]
S
SunAhong1993 已提交
20
        if channel_shared else [input_shape[0][1]],
S
SunAhong1993 已提交
21 22
        dtype=input.dtype,
        attr=name + '_scale')
S
SunAhong1993 已提交
23
    scale_param = fluid.layers.reshape(x=scale_param, \
S
SunAhong1993 已提交
24
                  shape=[1] if channel_shared else [input_shape[0][1]])
S
SunAhong1993 已提交
25 26 27 28 29
    out = fluid.layers.elementwise_mul(x=l2_norm,
                                       y=scale_param,
                                       axis=-1 if channel_shared else 1)
    return out

S
SunAhong1993 已提交
30

S
SunAhong1993 已提交
31 32 33 34 35 36 37 38 39
def normalize_weights(name, data=None):
    weights_name = [name + '_scale']
    return weights_name


register(kind='Normalize',
         shape=normalize_shape,
         layer=normalize_layer,
         weights=normalize_weights)