DBHead.py 4.8 KB
Newer Older
Z
zhoujun 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138
# -*- coding: utf-8 -*-
# @Time    : 2019/12/4 14:54
# @Author  : zhoujun
import paddle
from paddle import nn, ParamAttr


class DBHead(nn.Layer):
    def __init__(self, in_channels, out_channels, k=50):
        super().__init__()
        self.k = k
        self.binarize = nn.Sequential(
            nn.Conv2D(
                in_channels,
                in_channels // 4,
                3,
                padding=1,
                weight_attr=ParamAttr(
                    initializer=nn.initializer.KaimingNormal())),
            nn.BatchNorm2D(
                in_channels // 4,
                weight_attr=ParamAttr(initializer=nn.initializer.Constant(1)),
                bias_attr=ParamAttr(initializer=nn.initializer.Constant(1e-4))),
            nn.ReLU(),
            nn.Conv2DTranspose(
                in_channels // 4,
                in_channels // 4,
                2,
                2,
                weight_attr=ParamAttr(
                    initializer=nn.initializer.KaimingNormal())),
            nn.BatchNorm2D(
                in_channels // 4,
                weight_attr=ParamAttr(initializer=nn.initializer.Constant(1)),
                bias_attr=ParamAttr(initializer=nn.initializer.Constant(1e-4))),
            nn.ReLU(),
            nn.Conv2DTranspose(
                in_channels // 4,
                1,
                2,
                2,
                weight_attr=nn.initializer.KaimingNormal()),
            nn.Sigmoid())

        self.thresh = self._init_thresh(in_channels)

    def forward(self, x):
        shrink_maps = self.binarize(x)
        threshold_maps = self.thresh(x)
        if self.training:
            binary_maps = self.step_function(shrink_maps, threshold_maps)
            y = paddle.concat(
                (shrink_maps, threshold_maps, binary_maps), axis=1)
        else:
            y = paddle.concat((shrink_maps, threshold_maps), axis=1)
        return y

    def _init_thresh(self,
                     inner_channels,
                     serial=False,
                     smooth=False,
                     bias=False):
        in_channels = inner_channels
        if serial:
            in_channels += 1
        self.thresh = nn.Sequential(
            nn.Conv2D(
                in_channels,
                inner_channels // 4,
                3,
                padding=1,
                bias_attr=bias,
                weight_attr=ParamAttr(
                    initializer=nn.initializer.KaimingNormal())),
            nn.BatchNorm2D(
                inner_channels // 4,
                weight_attr=ParamAttr(initializer=nn.initializer.Constant(1)),
                bias_attr=ParamAttr(initializer=nn.initializer.Constant(1e-4))),
            nn.ReLU(),
            self._init_upsample(
                inner_channels // 4,
                inner_channels // 4,
                smooth=smooth,
                bias=bias),
            nn.BatchNorm2D(
                inner_channels // 4,
                weight_attr=ParamAttr(initializer=nn.initializer.Constant(1)),
                bias_attr=ParamAttr(initializer=nn.initializer.Constant(1e-4))),
            nn.ReLU(),
            self._init_upsample(
                inner_channels // 4, 1, smooth=smooth, bias=bias),
            nn.Sigmoid())
        return self.thresh

    def _init_upsample(self,
                       in_channels,
                       out_channels,
                       smooth=False,
                       bias=False):
        if smooth:
            inter_out_channels = out_channels
            if out_channels == 1:
                inter_out_channels = in_channels
            module_list = [
                nn.Upsample(
                    scale_factor=2, mode='nearest'), nn.Conv2D(
                        in_channels,
                        inter_out_channels,
                        3,
                        1,
                        1,
                        bias_attr=bias,
                        weight_attr=ParamAttr(
                            initializer=nn.initializer.KaimingNormal()))
            ]
            if out_channels == 1:
                module_list.append(
                    nn.Conv2D(
                        in_channels,
                        out_channels,
                        kernel_size=1,
                        stride=1,
                        padding=1,
                        bias_attr=True,
                        weight_attr=ParamAttr(
                            initializer=nn.initializer.KaimingNormal())))
            return nn.Sequential(module_list)
        else:
            return nn.Conv2DTranspose(
                in_channels,
                out_channels,
                2,
                2,
                weight_attr=ParamAttr(
                    initializer=nn.initializer.KaimingNormal()))

    def step_function(self, x, y):
        return paddle.reciprocal(1 + paddle.exp(-self.k * (x - y)))