normalize.py 1.1 KB
Newer Older
S
SunAhong1993 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
from .register import register
from x2paddle.core.util import *


def normalize_shape(input_shape):
    return input_shape


def normalize_layer(inputs,
                    across_spatial=None,
                    channel_shared=None,
                    input_shape=None,
                    name=None):
    assert across_spatial == False, "Only support across_spatial == False for Normalize"
    input = inputs[0]
    l2_norm = fluid.layers.l2_normalize(input, axis=1, name=name + '_l2')
    scale_param = fluid.layers.create_parameter(
S
SunAhong1993 已提交
18
        shape=[1] if channel_shared else [1, 1, 1, input_shape[0][1]],
S
SunAhong1993 已提交
19 20
        dtype=input.dtype,
        attr=name + '_scale')
S
SunAhong1993 已提交
21
    scale_param = fluid.layers.reshape(x=scale_param, \
S
SunAhong1993 已提交
22
                  shape=[1] if channel_shared else [input_shape[0][1]])
J
jiangjiajun 已提交
23 24
    out = fluid.layers.elementwise_mul(
        x=l2_norm, y=scale_param, axis=-1 if channel_shared else 1)
S
SunAhong1993 已提交
25 26
    return out

S
SunAhong1993 已提交
27

S
SunAhong1993 已提交
28 29 30 31 32
def normalize_weights(name, data=None):
    weights_name = [name + '_scale']
    return weights_name


J
jiangjiajun 已提交
33 34 35 36 37
register(
    kind='Normalize',
    shape=normalize_shape,
    layer=normalize_layer,
    weights=normalize_weights)