diff --git a/ppocr/modeling/necks/db_fpn.py b/ppocr/modeling/necks/db_fpn.py index 47a258e2a821366ef977251d47eb05bfeb96b45e..1c81be4a6246287ea8d03acb6efe75b71a73f466 100644 --- a/ppocr/modeling/necks/db_fpn.py +++ b/ppocr/modeling/necks/db_fpn.py @@ -39,7 +39,8 @@ class DSConv(nn.Layer): stride=1, groups=None, if_act=True, - act="relu"): + act="relu", + **kwargs): super(DSConv, self).__init__() if groups == None: groups = in_channels @@ -263,7 +264,7 @@ class CAFPN(nn.Layer): class FEPAN(nn.Layer): - def __init__(self, in_channels, out_channels, **kwargs): + def __init__(self, in_channels, out_channels, mode='large', **kwargs): super(FEPAN, self).__init__() self.out_channels = out_channels weight_attr = paddle.nn.initializer.KaimingUniform() @@ -274,6 +275,15 @@ class FEPAN(nn.Layer): self.pan_head_conv = nn.LayerList() self.pan_lat_conv = nn.LayerList() + if mode.lower() == 'lite': + p_layer = DSConv + elif mode.lower() == 'large': + p_layer = nn.Conv2D + else: + raise ValueError( + "mode can only be one of ['lite', 'large'], but received {}". + format(mode)) + for i in range(len(in_channels)): self.ins_conv.append( nn.Conv2D( @@ -284,7 +294,7 @@ class FEPAN(nn.Layer): bias_attr=False)) self.inp_conv.append( - nn.Conv2D( + p_layer( in_channels=self.out_channels, out_channels=self.out_channels // 4, kernel_size=9, @@ -303,7 +313,7 @@ class FEPAN(nn.Layer): weight_attr=ParamAttr(initializer=weight_attr), bias_attr=False)) self.pan_lat_conv.append( - nn.Conv2D( + p_layer( in_channels=self.out_channels // 4, out_channels=self.out_channels // 4, kernel_size=9, @@ -346,86 +356,3 @@ class FEPAN(nn.Layer): fuse = paddle.concat([p5, p4, p3, p2], axis=1) return fuse - - -class FEPANLite(nn.Layer): - def __init__(self, in_channels, out_channels, **kwargs): - super(FEPANLite, self).__init__() - self.out_channels = out_channels - weight_attr = paddle.nn.initializer.KaimingUniform() - - self.ins_conv = nn.LayerList() - self.inp_conv = nn.LayerList() - # pan head - self.pan_head_conv = nn.LayerList() - self.pan_lat_conv = nn.LayerList() - - for i in range(len(in_channels)): - self.ins_conv.append( - nn.Conv2D( - in_channels=in_channels[i], - out_channels=self.out_channels, - kernel_size=1, - weight_attr=ParamAttr(initializer=weight_attr), - bias_attr=False)) - - self.inp_conv.append( - DSConv( - in_channels=self.out_channels, - out_channels=self.out_channels // 4, - kernel_size=9, - padding=4)) - - if i > 0: - self.pan_head_conv.append( - nn.Conv2D( - in_channels=self.out_channels // 4, - out_channels=self.out_channels // 4, - kernel_size=3, - padding=1, - stride=2, - weight_attr=ParamAttr(initializer=weight_attr), - bias_attr=False)) - - self.pan_lat_conv.append( - DSConv( - in_channels=self.out_channels // 4, - out_channels=self.out_channels // 4, - kernel_size=9, - padding=4)) - - def forward(self, x): - c2, c3, c4, c5 = x - - in5 = self.ins_conv[3](c5) - in4 = self.ins_conv[2](c4) - in3 = self.ins_conv[1](c3) - in2 = self.ins_conv[0](c2) - - out4 = in4 + F.upsample( - in5, scale_factor=2, mode="nearest", align_mode=1) # 1/16 - out3 = in3 + F.upsample( - out4, scale_factor=2, mode="nearest", align_mode=1) # 1/8 - out2 = in2 + F.upsample( - out3, scale_factor=2, mode="nearest", align_mode=1) # 1/4 - - f5 = self.inp_conv[3](in5) - f4 = self.inp_conv[2](out4) - f3 = self.inp_conv[1](out3) - f2 = self.inp_conv[0](out2) - - pan3 = f3 + self.pan_head_conv[0](f2) - pan4 = f4 + self.pan_head_conv[1](pan3) - pan5 = f5 + self.pan_head_conv[2](pan4) - - p2 = self.pan_lat_conv[0](f2) - p3 = self.pan_lat_conv[1](pan3) - p4 = self.pan_lat_conv[2](pan4) - p5 = self.pan_lat_conv[3](pan5) - - p5 = F.upsample(p5, scale_factor=8, mode="nearest", align_mode=1) - p4 = F.upsample(p4, scale_factor=4, mode="nearest", align_mode=1) - p3 = F.upsample(p3, scale_factor=2, mode="nearest", align_mode=1) - - fuse = paddle.concat([p5, p4, p3, p2], axis=1) - return fuse