det_db_head.py 5.4 KB
Newer Older
W
WenmuZhou 已提交
1
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
L
LDOUBLEV 已提交
2
#
W
WenmuZhou 已提交
3 4 5
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
L
LDOUBLEV 已提交
6 7 8
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
W
WenmuZhou 已提交
9 10 11 12 13
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
L
LDOUBLEV 已提交
14 15 16 17 18 19

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import math
W
WenmuZhou 已提交
20 21 22 23
import paddle
from paddle import nn
import paddle.nn.functional as F
from paddle import ParamAttr
H
huangjun12 已提交
24
from ppocr.modeling.backbones.det_mobilenet_v3 import ConvBNLayer
L
LDOUBLEV 已提交
25 26


littletomatodonkey's avatar
littletomatodonkey 已提交
27
def get_bias_attr(k):
W
WenmuZhou 已提交
28 29
    stdv = 1.0 / math.sqrt(k * 1.0)
    initializer = paddle.nn.initializer.Uniform(-stdv, stdv)
littletomatodonkey's avatar
littletomatodonkey 已提交
30
    bias_attr = ParamAttr(initializer=initializer)
W
WenmuZhou 已提交
31
    return bias_attr
L
LDOUBLEV 已提交
32 33


W
WenmuZhou 已提交
34
class Head(nn.Layer):
35
    def __init__(self, in_channels, kernel_list=[3, 2, 2], **kwargs):
W
WenmuZhou 已提交
36
        super(Head, self).__init__()
L
LDOUBLEV 已提交
37

D
dyning 已提交
38
        self.conv1 = nn.Conv2D(
W
WenmuZhou 已提交
39 40
            in_channels=in_channels,
            out_channels=in_channels // 4,
L
fix  
LDOUBLEV 已提交
41 42
            kernel_size=kernel_list[0],
            padding=int(kernel_list[0] // 2),
littletomatodonkey's avatar
littletomatodonkey 已提交
43
            weight_attr=ParamAttr(),
L
LDOUBLEV 已提交
44
            bias_attr=False)
W
WenmuZhou 已提交
45 46 47 48 49 50 51
        self.conv_bn1 = nn.BatchNorm(
            num_channels=in_channels // 4,
            param_attr=ParamAttr(
                initializer=paddle.nn.initializer.Constant(value=1.0)),
            bias_attr=ParamAttr(
                initializer=paddle.nn.initializer.Constant(value=1e-4)),
            act='relu')
H
huangjun12 已提交
52

D
dyning 已提交
53
        self.conv2 = nn.Conv2DTranspose(
W
WenmuZhou 已提交
54 55
            in_channels=in_channels // 4,
            out_channels=in_channels // 4,
L
fix  
LDOUBLEV 已提交
56
            kernel_size=kernel_list[1],
L
LDOUBLEV 已提交
57
            stride=2,
W
WenmuZhou 已提交
58
            weight_attr=ParamAttr(
W
WenmuZhou 已提交
59
                initializer=paddle.nn.initializer.KaimingUniform()),
littletomatodonkey's avatar
littletomatodonkey 已提交
60
            bias_attr=get_bias_attr(in_channels // 4))
W
WenmuZhou 已提交
61 62 63 64 65 66
        self.conv_bn2 = nn.BatchNorm(
            num_channels=in_channels // 4,
            param_attr=ParamAttr(
                initializer=paddle.nn.initializer.Constant(value=1.0)),
            bias_attr=ParamAttr(
                initializer=paddle.nn.initializer.Constant(value=1e-4)),
L
LDOUBLEV 已提交
67
            act="relu")
D
dyning 已提交
68
        self.conv3 = nn.Conv2DTranspose(
W
WenmuZhou 已提交
69 70
            in_channels=in_channels // 4,
            out_channels=1,
L
fix  
LDOUBLEV 已提交
71
            kernel_size=kernel_list[2],
L
LDOUBLEV 已提交
72
            stride=2,
W
WenmuZhou 已提交
73
            weight_attr=ParamAttr(
W
WenmuZhou 已提交
74
                initializer=paddle.nn.initializer.KaimingUniform()),
littletomatodonkey's avatar
littletomatodonkey 已提交
75
            bias_attr=get_bias_attr(in_channels // 4), )
L
LDOUBLEV 已提交
76

H
huangjun12 已提交
77
    def forward(self, x, return_f=False):
W
WenmuZhou 已提交
78 79 80 81
        x = self.conv1(x)
        x = self.conv_bn1(x)
        x = self.conv2(x)
        x = self.conv_bn2(x)
H
huangjun12 已提交
82 83
        if return_f is True:
            f = x
W
WenmuZhou 已提交
84 85
        x = self.conv3(x)
        x = F.sigmoid(x)
H
huangjun12 已提交
86 87
        if return_f is True:
            return x, f
W
WenmuZhou 已提交
88
        return x
L
LDOUBLEV 已提交
89 90


W
WenmuZhou 已提交
91 92 93 94 95 96 97
class DBHead(nn.Layer):
    """
    Differentiable Binarization (DB) for text detection:
        see https://arxiv.org/abs/1911.08947
    args:
        params(dict): super parameters for build DB network
    """
L
LDOUBLEV 已提交
98

W
WenmuZhou 已提交
99 100 101
    def __init__(self, in_channels, k=50, **kwargs):
        super(DBHead, self).__init__()
        self.k = k
102 103
        self.binarize = Head(in_channels, **kwargs)
        self.thresh = Head(in_channels, **kwargs)
L
LDOUBLEV 已提交
104

W
WenmuZhou 已提交
105 106
    def step_function(self, x, y):
        return paddle.reciprocal(1 + paddle.exp(-self.k * (x - y)))
L
LDOUBLEV 已提交
107

M
refine  
MissPenguin 已提交
108
    def forward(self, x, targets=None):
W
WenmuZhou 已提交
109 110
        shrink_maps = self.binarize(x)
        if not self.training:
W
WenmuZhou 已提交
111
            return {'maps': shrink_maps}
L
LDOUBLEV 已提交
112

W
WenmuZhou 已提交
113
        threshold_maps = self.thresh(x)
L
LDOUBLEV 已提交
114
        binary_maps = self.step_function(shrink_maps, threshold_maps)
W
WenmuZhou 已提交
115
        y = paddle.concat([shrink_maps, threshold_maps, binary_maps], axis=1)
W
WenmuZhou 已提交
116
        return {'maps': y}
H
huangjun12 已提交
117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154


class LocalModule(nn.Layer):
    def __init__(self, in_c, mid_c, use_distance=True):
        super(self.__class__, self).__init__()
        self.last_3 = ConvBNLayer(in_c + 1, mid_c, 3, 1, 1, act='relu')
        self.last_1 = nn.Conv2D(mid_c, 1, 1, 1, 0)

    def forward(self, x, init_map, distance_map):
        outf = paddle.concat([init_map, x], axis=1)
        # last Conv
        out = self.last_1(self.last_3(outf))
        return out


class CBNHeadLocal(DBHead):
    def __init__(self, in_channels, k=50, mode='small', **kwargs):
        super(CBNHeadLocal, self).__init__(in_channels, k, **kwargs)
        self.mode = mode

        self.up_conv = nn.Upsample(scale_factor=2, mode="nearest", align_mode=1)
        if self.mode == 'large':
            self.cbn_layer = LocalModule(in_channels // 4, in_channels // 4)
        elif self.mode == 'small':
            self.cbn_layer = LocalModule(in_channels // 4, in_channels // 8)

    def forward(self, x, targets=None):
        shrink_maps, f = self.binarize(x, return_f=True)
        base_maps = shrink_maps
        cbn_maps = self.cbn_layer(self.up_conv(f), shrink_maps, None)
        cbn_maps = F.sigmoid(cbn_maps)
        if not self.training:
            return {'maps': 0.5 * (base_maps + cbn_maps), 'cbn_maps': cbn_maps}

        threshold_maps = self.thresh(x)
        binary_maps = self.step_function(shrink_maps, threshold_maps)
        y = paddle.concat([cbn_maps, threshold_maps, binary_maps], axis=1)
        return {'maps': y, 'distance_maps': cbn_maps, 'cbn_maps': binary_maps}