gc_block.py 4.3 KB
Newer Older
littletomatodonkey's avatar
littletomatodonkey 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals

import paddle.fluid as fluid
from paddle.fluid import ParamAttr
from paddle.fluid.initializer import ConstantInitializer


def spatial_pool(x, pooling_type, name):
    _, channel, height, width = x.shape
    if pooling_type == 'att':
        input_x = x
        # [N, 1, C, H * W]
        input_x = fluid.layers.reshape(input_x, shape=(0, 1, channel, -1))
        context_mask = fluid.layers.conv2d(
            input=x,
            num_filters=1,
            filter_size=1,
            stride=1,
            padding=0,
            param_attr=ParamAttr(name=name + "_weights"),
            bias_attr=ParamAttr(name=name + "_bias"))
        # [N, 1, H * W]
        context_mask = fluid.layers.reshape(context_mask, shape=(0, 0, -1))
        # [N, 1, H * W]
        context_mask = fluid.layers.softmax(context_mask, axis=2)
        # [N, 1, H * W, 1]
        context_mask = fluid.layers.reshape(context_mask, shape=(0, 0, -1, 1))
        # [N, 1, C, 1]
        context = fluid.layers.matmul(input_x, context_mask)
        # [N, C, 1, 1]
        context = fluid.layers.reshape(context, shape=(0, channel, 1, 1))
    else:
        # [N, C, 1, 1]
        context = fluid.layers.pool2d(
            input=x, pool_type='avg', global_pooling=True)
    return context


def channel_conv(input, inner_ch, out_ch, name):
    conv = fluid.layers.conv2d(
        input=input,
        num_filters=inner_ch,
        filter_size=1,
        stride=1,
        padding=0,
        param_attr=ParamAttr(name=name + "_conv1_weights"),
        bias_attr=ParamAttr(name=name + "_conv1_bias"),
        name=name + "_conv1", )
    conv = fluid.layers.layer_norm(
        conv,
        begin_norm_axis=1,
        param_attr=ParamAttr(name=name + "_ln_weights"),
        bias_attr=ParamAttr(name=name + "_ln_bias"),
        act="relu",
        name=name + "_ln")

    conv = fluid.layers.conv2d(
        input=conv,
        num_filters=out_ch,
        filter_size=1,
        stride=1,
        padding=0,
        param_attr=ParamAttr(
            name=name + "_conv2_weights",
            initializer=ConstantInitializer(value=0.0), ),
        bias_attr=ParamAttr(
            name=name + "_conv2_bias",
            initializer=ConstantInitializer(value=0.0), ),
        name=name + "_conv2")
    return conv


def add_gc_block(x,
                 ratio=1.0 / 16,
                 pooling_type='att',
                 fusion_types=['channel_add'],
                 name=None):
    '''
    GCNet: Non-local Networks Meet Squeeze-Excitation Networks and Beyond, see https://arxiv.org/abs/1904.11492
    Args:
        ratio (float): channel reduction ratio
        pooling_type (str): pooling type, support att and avg
        fusion_types (list): fusion types, support channel_add and channel_mul
        name (str): prefix name of gc block
    '''
    assert pooling_type in ['avg', 'att']
    assert isinstance(fusion_types, (list, tuple))
    valid_fusion_types = ['channel_add', 'channel_mul']
    assert all([f in valid_fusion_types for f in fusion_types])
    assert len(fusion_types) > 0, 'at least one fusion should be used'

    inner_ch = int(ratio * x.shape[1])
    out_ch = x.shape[1]
    context = spatial_pool(x, pooling_type, name + "_spatial_pool")
    out = x
    if 'channel_mul' in fusion_types:
        inner_out = channel_conv(context, inner_ch, out_ch, name + "_mul")
        channel_mul_term = fluid.layers.sigmoid(inner_out)
        out = out * channel_mul_term

    if 'channel_add' in fusion_types:
        channel_add_term = channel_conv(context, inner_ch, out_ch,
                                        name + "_add")
        out = out + channel_add_term

    return out