From bdad0cefe69cfe41cebe040112a293f738eb526e Mon Sep 17 00:00:00 2001 From: WenmuZhou Date: Fri, 16 Oct 2020 20:19:17 +0800 Subject: [PATCH] =?UTF-8?q?=E6=A3=80=E6=B5=8B=E5=92=8C=E8=AF=86=E5=88=AB?= =?UTF-8?q?=E7=9A=84resnet=E4=BD=BF=E7=94=A8paddleclass=E9=87=8C=E7=9A=84?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- ppocr/modeling/backbones/det_resnet_vd.py | 428 ++++++++++------------ ppocr/modeling/backbones/rec_resnet_vd.py | 427 ++++++++++----------- 2 files changed, 387 insertions(+), 468 deletions(-) diff --git a/ppocr/modeling/backbones/det_resnet_vd.py b/ppocr/modeling/backbones/det_resnet_vd.py index b501bec8..6fa52716 100644 --- a/ppocr/modeling/backbones/det_resnet_vd.py +++ b/ppocr/modeling/backbones/det_resnet_vd.py @@ -16,143 +16,30 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from paddle import nn -from paddle.nn import functional as F +import paddle from paddle import ParamAttr +import paddle.nn as nn __all__ = ["ResNet"] -class ResNet(nn.Layer): - def __init__(self, in_channels=3, layers=50, **kwargs): - """ - the Resnet backbone network for detection module. - Args: - params(dict): the super parameters for network build - """ - super(ResNet, self).__init__() - supported_layers = { - 18: { - 'depth': [2, 2, 2, 2], - 'block_class': BasicBlock - }, - 34: { - 'depth': [3, 4, 6, 3], - 'block_class': BasicBlock - }, - 50: { - 'depth': [3, 4, 6, 3], - 'block_class': BottleneckBlock - }, - 101: { - 'depth': [3, 4, 23, 3], - 'block_class': BottleneckBlock - }, - 152: { - 'depth': [3, 8, 36, 3], - 'block_class': BottleneckBlock - }, - 200: { - 'depth': [3, 12, 48, 3], - 'block_class': BottleneckBlock - } - } - assert layers in supported_layers, \ - "supported layers are {} but input layer is {}".format(supported_layers.keys(), layers) - is_3x3 = True - - depth = supported_layers[layers]['depth'] - block_class = supported_layers[layers]['block_class'] - - num_filters = [64, 128, 256, 512] - - conv = [] - if is_3x3 == False: - conv.append( - ConvBNLayer( - in_channels=in_channels, - out_channels=64, - kernel_size=7, - stride=2, - act='relu')) - else: - conv.append( - ConvBNLayer( - in_channels=3, - out_channels=32, - kernel_size=3, - stride=2, - act='relu', - name='conv1_1')) - conv.append( - ConvBNLayer( - in_channels=32, - out_channels=32, - kernel_size=3, - stride=1, - act='relu', - name='conv1_2')) - conv.append( - ConvBNLayer( - in_channels=32, - out_channels=64, - kernel_size=3, - stride=1, - act='relu', - name='conv1_3')) - self.conv1 = nn.Sequential(*conv) - self.pool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) - self.stages = [] - self.out_channels = [] - in_ch = 64 - for block_index in range(len(depth)): - block_list = [] - for i in range(depth[block_index]): - if layers >= 50: - if layers in [101, 152, 200] and block_index == 2: - if i == 0: - conv_name = "res" + str(block_index + 2) + "a" - else: - conv_name = "res" + str(block_index + - 2) + "b" + str(i) - else: - conv_name = "res" + str(block_index + 2) + chr(97 + i) - else: - conv_name = "res" + str(block_index + 2) + chr(97 + i) - block_list.append( - block_class( - in_channels=in_ch, - out_channels=num_filters[block_index], - stride=2 if i == 0 and block_index != 0 else 1, - if_first=block_index == i == 0, - name=conv_name)) - in_ch = block_list[-1].out_channels - self.out_channels.append(in_ch) - self.stages.append(nn.Sequential(*block_list)) - for i, stage in enumerate(self.stages): - self.add_sublayer(sublayer=stage, name="stage{}".format(i)) - - def forward(self, x): - x = self.conv1(x) - x = self.pool(x) - out_list = [] - for stage in self.stages: - x = stage(x) - out_list.append(x) - return out_list - - class ConvBNLayer(nn.Layer): - def __init__(self, - in_channels, - out_channels, - kernel_size, - stride=1, - groups=1, - act=None, - name=None): + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride=1, + groups=1, + is_vd_mode=False, + act=None, + name=None, ): super(ConvBNLayer, self).__init__() - self.conv = nn.Conv2d( + + self.is_vd_mode = is_vd_mode + self._pool2d_avg = nn.AvgPool2d( + kernel_size=2, stride=2, padding=0, ceil_mode=True) + self._conv = nn.Conv2d( in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, @@ -165,87 +52,32 @@ class ConvBNLayer(nn.Layer): bn_name = "bn_" + name else: bn_name = "bn" + name[3:] - self.bn = nn.BatchNorm( - num_channels=out_channels, + self._batch_norm = nn.BatchNorm( + out_channels, act=act, - param_attr=ParamAttr(name=bn_name + "_scale"), - bias_attr=ParamAttr(name=bn_name + "_offset"), - moving_mean_name=bn_name + "_mean", - moving_variance_name=bn_name + "_variance") + param_attr=ParamAttr(name=bn_name + '_scale'), + bias_attr=ParamAttr(bn_name + '_offset'), + moving_mean_name=bn_name + '_mean', + moving_variance_name=bn_name + '_variance') - def __call__(self, x): - x = self.conv(x) - x = self.bn(x) - return x + def forward(self, inputs): + if self.is_vd_mode: + inputs = self._pool2d_avg(inputs) + y = self._conv(inputs) + y = self._batch_norm(y) + return y -class ConvBNLayerNew(nn.Layer): +class BottleneckBlock(nn.Layer): def __init__(self, in_channels, out_channels, - kernel_size, - stride=1, - groups=1, - act=None, + stride, + shortcut=True, + if_first=False, name=None): - super(ConvBNLayerNew, self).__init__() - self.pool = nn.AvgPool2d( - kernel_size=2, stride=2, padding=0, ceil_mode=True) - - self.conv = nn.Conv2d( - in_channels=in_channels, - out_channels=out_channels, - kernel_size=kernel_size, - stride=1, - padding=(kernel_size - 1) // 2, - groups=groups, - weight_attr=ParamAttr(name=name + "_weights"), - bias_attr=False) - if name == "conv1": - bn_name = "bn_" + name - else: - bn_name = "bn" + name[3:] - self.bn = nn.BatchNorm( - num_channels=out_channels, - act=act, - param_attr=ParamAttr(name=bn_name + "_scale"), - bias_attr=ParamAttr(name=bn_name + "_offset"), - moving_mean_name=bn_name + "_mean", - moving_variance_name=bn_name + "_variance") - - def __call__(self, x): - x = self.pool(x) - x = self.conv(x) - x = self.bn(x) - return x - - -class ShortCut(nn.Layer): - def __init__(self, in_channels, out_channels, stride, name, if_first=False): - super(ShortCut, self).__init__() - self.use_conv = True - if in_channels != out_channels or stride != 1: - if if_first: - self.conv = ConvBNLayer( - in_channels, out_channels, 1, stride, name=name) - else: - self.conv = ConvBNLayerNew( - in_channels, out_channels, 1, stride, name=name) - elif if_first: - self.conv = ConvBNLayer( - in_channels, out_channels, 1, stride, name=name) - else: - self.use_conv = False - - def forward(self, x): - if self.use_conv: - x = self.conv(x) - return x - - -class BottleneckBlock(nn.Layer): - def __init__(self, in_channels, out_channels, stride, name, if_first): super(BottleneckBlock, self).__init__() + self.conv0 = ConvBNLayer( in_channels=in_channels, out_channels=out_channels, @@ -266,32 +98,46 @@ class BottleneckBlock(nn.Layer): act=None, name=name + "_branch2c") - self.short = ShortCut( - in_channels=in_channels, - out_channels=out_channels * 4, - stride=stride, - if_first=if_first, - name=name + "_branch1") - self.out_channels = out_channels * 4 + if not shortcut: + self.short = ConvBNLayer( + in_channels=in_channels, + out_channels=out_channels * 4, + kernel_size=1, + stride=1, + is_vd_mode=False if if_first else True, + name=name + "_branch1") + + self.shortcut = shortcut + + def forward(self, inputs): + y = self.conv0(inputs) + conv1 = self.conv1(y) + conv2 = self.conv2(conv1) - def forward(self, x): - y = self.conv0(x) - y = self.conv1(y) - y = self.conv2(y) - y = y + self.short(x) - y = F.relu(y) + if self.shortcut: + short = inputs + else: + short = self.short(inputs) + y = paddle.elementwise_add(x=short, y=conv2, act='relu') return y class BasicBlock(nn.Layer): - def __init__(self, in_channels, out_channels, stride, name, if_first): + def __init__(self, + in_channels, + out_channels, + stride, + shortcut=True, + if_first=False, + name=None): super(BasicBlock, self).__init__() + self.stride = stride self.conv0 = ConvBNLayer( in_channels=in_channels, out_channels=out_channels, kernel_size=3, - act='relu', stride=stride, + act='relu', name=name + "_branch2a") self.conv1 = ConvBNLayer( in_channels=out_channels, @@ -299,31 +145,133 @@ class BasicBlock(nn.Layer): kernel_size=3, act=None, name=name + "_branch2b") - self.short = ShortCut( - in_channels=in_channels, - out_channels=out_channels, - stride=stride, - if_first=if_first, - name=name + "_branch1") - self.out_channels = out_channels - def forward(self, x): - y = self.conv0(x) - y = self.conv1(y) - y = y + self.short(x) - return F.relu(y) + if not shortcut: + self.short = ConvBNLayer( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=1, + stride=1, + is_vd_mode=False if if_first else True, + name=name + "_branch1") + + self.shortcut = shortcut + + def forward(self, inputs): + y = self.conv0(inputs) + conv1 = self.conv1(y) + if self.shortcut: + short = inputs + else: + short = self.short(inputs) + y = paddle.elementwise_add(x=short, y=conv1, act='relu') + return y -if __name__ == '__main__': - import paddle - paddle.disable_static() - x = paddle.zeros([1, 3, 640, 640]) - x = paddle.to_variable(x) - print(x.shape) - net = ResNet(layers=18) - y = net(x) +class ResNet(nn.Layer): + def __init__(self, in_channels=3, layers=50, **kwargs): + super(ResNet, self).__init__() + + self.layers = layers + supported_layers = [18, 34, 50, 101, 152, 200] + assert layers in supported_layers, \ + "supported layers are {} but input layer is {}".format( + supported_layers, layers) + + if layers == 18: + depth = [2, 2, 2, 2] + elif layers == 34 or layers == 50: + depth = [3, 4, 6, 3] + elif layers == 101: + depth = [3, 4, 23, 3] + elif layers == 152: + depth = [3, 8, 36, 3] + elif layers == 200: + depth = [3, 12, 48, 3] + num_channels = [64, 256, 512, + 1024] if layers >= 50 else [64, 64, 128, 256] + num_filters = [64, 128, 256, 512] + + self.conv1_1 = ConvBNLayer( + in_channels=in_channels, + out_channels=32, + kernel_size=3, + stride=2, + act='relu', + name="conv1_1") + self.conv1_2 = ConvBNLayer( + in_channels=32, + out_channels=32, + kernel_size=3, + stride=1, + act='relu', + name="conv1_2") + self.conv1_3 = ConvBNLayer( + in_channels=32, + out_channels=64, + kernel_size=3, + stride=1, + act='relu', + name="conv1_3") + self.pool2d_max = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + + self.stages = [] + self.out_channels = [] + if layers >= 50: + for block in range(len(depth)): + block_list = [] + shortcut = False + for i in range(depth[block]): + if layers in [101, 152] and block == 2: + if i == 0: + conv_name = "res" + str(block + 2) + "a" + else: + conv_name = "res" + str(block + 2) + "b" + str(i) + else: + conv_name = "res" + str(block + 2) + chr(97 + i) + bottleneck_block = self.add_sublayer( + 'bb_%d_%d' % (block, i), + BottleneckBlock( + in_channels=num_channels[block] + if i == 0 else num_filters[block] * 4, + out_channels=num_filters[block], + stride=2 if i == 0 and block != 0 else 1, + shortcut=shortcut, + if_first=block == i == 0, + name=conv_name)) + shortcut = True + block_list.append(bottleneck_block) + self.out_channels.append(num_filters[block] * 4) + self.stages.append(nn.Sequential(*block_list)) + else: + for block in range(len(depth)): + block_list = [] + shortcut = False + for i in range(depth[block]): + conv_name = "res" + str(block + 2) + chr(97 + i) + basic_block = self.add_sublayer( + 'bb_%d_%d' % (block, i), + BasicBlock( + in_channels=num_channels[block] + if i == 0 else num_filters[block], + out_channels=num_filters[block], + stride=2 if i == 0 and block != 0 else 1, + shortcut=shortcut, + if_first=block == i == 0, + name=conv_name)) + shortcut = True + block_list.append(basic_block) + self.out_channels.append(num_filters[block]) + self.stages.append(nn.Sequential(*block_list)) - for stage in y: - print(stage.shape) - # paddle.save(net.state_dict(),'1.pth') + def forward(self, inputs): + y = self.conv1_1(inputs) + y = self.conv1_2(y) + y = self.conv1_3(y) + y = self.pool2d_max(y) + out = [] + for block in self.stages: + y = block(y) + out.append(y) + return out diff --git a/ppocr/modeling/backbones/rec_resnet_vd.py b/ppocr/modeling/backbones/rec_resnet_vd.py index d8602498..20b03c3d 100644 --- a/ppocr/modeling/backbones/rec_resnet_vd.py +++ b/ppocr/modeling/backbones/rec_resnet_vd.py @@ -16,144 +16,34 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from paddle import nn, ParamAttr -from paddle.nn import functional as F +import paddle +from paddle import ParamAttr +import paddle.nn as nn __all__ = ["ResNet"] -class ResNet(nn.Layer): - def __init__(self, in_channels=3, layers=34): - super(ResNet, self).__init__() - supported_layers = { - 18: { - 'depth': [2, 2, 2, 2], - 'block_class': BasicBlock - }, - 34: { - 'depth': [3, 4, 6, 3], - 'block_class': BasicBlock - }, - 50: { - 'depth': [3, 4, 6, 3], - 'block_class': BottleneckBlock - }, - 101: { - 'depth': [3, 4, 23, 3], - 'block_class': BottleneckBlock - }, - 152: { - 'depth': [3, 8, 36, 3], - 'block_class': BottleneckBlock - }, - 200: { - 'depth': [3, 12, 48, 3], - 'block_class': BottleneckBlock - } - } - assert layers in supported_layers, \ - "supported layers are {} but input layer is {}".format(supported_layers.keys(), layers) - is_3x3 = True - - num_filters = [64, 128, 256, 512] - depth = supported_layers[layers]['depth'] - block_class = supported_layers[layers]['block_class'] - conv = [] - if is_3x3 == False: - conv.append( - ConvBNLayer( - in_channels=in_channels, - out_channels=64, - kernel_size=7, - stride=1, - act='relu')) - else: - conv.append( - ConvBNLayer( - in_channels=in_channels, - out_channels=32, - kernel_size=3, - stride=1, - act='relu', - name='conv1_1')) - conv.append( - ConvBNLayer( - in_channels=32, - out_channels=32, - kernel_size=3, - stride=1, - act='relu', - name='conv1_2')) - conv.append( - ConvBNLayer( - in_channels=32, - out_channels=64, - kernel_size=3, - stride=1, - act='relu', - name='conv1_3')) - self.conv1 = nn.Sequential(*conv) - - self.pool = nn.MaxPool2d( - kernel_size=3, - stride=2, - padding=1, ) - - block_list = [] - in_ch = 64 - for block_index in range(len(depth)): - for i in range(depth[block_index]): - if layers >= 50: - if layers in [101, 152, 200] and block_index == 2: - if i == 0: - conv_name = "res" + str(block_index + 2) + "a" - else: - conv_name = "res" + str(block_index + - 2) + "b" + str(i) - else: - conv_name = "res" + str(block_index + 2) + chr(97 + i) - else: - conv_name = "res" + str(block_index + 2) + chr(97 + i) - if i == 0 and block_index != 0: - stride = (2, 1) - else: - stride = (1, 1) - block_list.append( - block_class( - in_channels=in_ch, - out_channels=num_filters[block_index], - stride=stride, - if_first=block_index == i == 0, - name=conv_name)) - in_ch = block_list[-1].out_channels - self.block_list = nn.Sequential(*block_list) - self.add_sublayer(sublayer=self.block_list, name="block_list") - self.pool_out = nn.MaxPool2d(kernel_size=2, stride=2, padding=0) - self.out_channels = in_ch - - def forward(self, x): - x = self.conv1(x) - x = self.pool(x) - x = self.block_list(x) - x = self.pool_out(x) - return x - - class ConvBNLayer(nn.Layer): - def __init__(self, - in_channels, - out_channels, - kernel_size, - stride=1, - groups=1, - act=None, - name=None): + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride=1, + groups=1, + is_vd_mode=False, + act=None, + name=None, ): super(ConvBNLayer, self).__init__() - self.conv = nn.Conv2d( + + self.is_vd_mode = is_vd_mode + self._pool2d_avg = nn.AvgPool2d( + kernel_size=stride, stride=stride, padding=0, ceil_mode=True) + self._conv = nn.Conv2d( in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, - stride=stride, + stride=1 if is_vd_mode else stride, padding=(kernel_size - 1) // 2, groups=groups, weight_attr=ParamAttr(name=name + "_weights"), @@ -162,88 +52,32 @@ class ConvBNLayer(nn.Layer): bn_name = "bn_" + name else: bn_name = "bn" + name[3:] - self.bn = nn.BatchNorm( - num_channels=out_channels, + self._batch_norm = nn.BatchNorm( + out_channels, act=act, - param_attr=ParamAttr(name=bn_name + "_scale"), - bias_attr=ParamAttr(name=bn_name + "_offset"), - moving_mean_name=bn_name + "_mean", - moving_variance_name=bn_name + "_variance") + param_attr=ParamAttr(name=bn_name + '_scale'), + bias_attr=ParamAttr(bn_name + '_offset'), + moving_mean_name=bn_name + '_mean', + moving_variance_name=bn_name + '_variance') - def __call__(self, x): - x = self.conv(x) - x = self.bn(x) - return x + def forward(self, inputs): + if self.is_vd_mode: + inputs = self._pool2d_avg(inputs) + y = self._conv(inputs) + y = self._batch_norm(y) + return y -class ConvBNLayerNew(nn.Layer): +class BottleneckBlock(nn.Layer): def __init__(self, in_channels, out_channels, - kernel_size, - stride=1, - groups=1, - act=None, + stride, + shortcut=True, + if_first=False, name=None): - super(ConvBNLayerNew, self).__init__() - self.pool = nn.AvgPool2d( - kernel_size=stride, stride=stride, padding=0, ceil_mode=True) - - self.conv = nn.Conv2d( - in_channels=in_channels, - out_channels=out_channels, - kernel_size=kernel_size, - stride=1, - padding=(kernel_size - 1) // 2, - groups=groups, - weight_attr=ParamAttr(name=name + "_weights"), - bias_attr=False) - if name == "conv1": - bn_name = "bn_" + name - else: - bn_name = "bn" + name[3:] - self.bn = nn.BatchNorm( - num_channels=out_channels, - act=act, - param_attr=ParamAttr(name=bn_name + "_scale"), - bias_attr=ParamAttr(name=bn_name + "_offset"), - moving_mean_name=bn_name + "_mean", - moving_variance_name=bn_name + "_variance") - - def __call__(self, x): - x = self.pool(x) - x = self.conv(x) - x = self.bn(x) - return x - - -class ShortCut(nn.Layer): - def __init__(self, in_channels, out_channels, stride, name, if_first=False): - super(ShortCut, self).__init__() - self.use_conv = True - - if in_channels != out_channels or stride[0] != 1: - if if_first: - self.conv = ConvBNLayer( - in_channels, out_channels, 1, stride, name=name) - else: - self.conv = ConvBNLayerNew( - in_channels, out_channels, 1, stride, name=name) - elif if_first: - self.conv = ConvBNLayer( - in_channels, out_channels, 1, stride, name=name) - else: - self.use_conv = False - - def forward(self, x): - if self.use_conv: - x = self.conv(x) - return x - - -class BottleneckBlock(nn.Layer): - def __init__(self, in_channels, out_channels, stride, name, if_first): super(BottleneckBlock, self).__init__() + self.conv0 = ConvBNLayer( in_channels=in_channels, out_channels=out_channels, @@ -264,32 +98,47 @@ class BottleneckBlock(nn.Layer): act=None, name=name + "_branch2c") - self.short = ShortCut( - in_channels=in_channels, - out_channels=out_channels * 4, - stride=stride, - if_first=if_first, - name=name + "_branch1") - self.out_channels = out_channels * 4 + if not shortcut: + self.short = ConvBNLayer( + in_channels=in_channels, + out_channels=out_channels * 4, + kernel_size=1, + stride=stride, + is_vd_mode=not if_first and stride[0] != 1, + name=name + "_branch1") + + self.shortcut = shortcut + + def forward(self, inputs): + y = self.conv0(inputs) + + conv1 = self.conv1(y) + conv2 = self.conv2(conv1) - def forward(self, x): - y = self.conv0(x) - y = self.conv1(y) - y = self.conv2(y) - y = y + self.short(x) - y = F.relu(y) + if self.shortcut: + short = inputs + else: + short = self.short(inputs) + y = paddle.elementwise_add(x=short, y=conv2, act='relu') return y class BasicBlock(nn.Layer): - def __init__(self, in_channels, out_channels, stride, name, if_first): + def __init__(self, + in_channels, + out_channels, + stride, + shortcut=True, + if_first=False, + name=None): super(BasicBlock, self).__init__() + self.stride = stride self.conv0 = ConvBNLayer( in_channels=in_channels, out_channels=out_channels, kernel_size=3, - act='relu', stride=stride, + act='relu', name=name + "_branch2a") self.conv1 = ConvBNLayer( in_channels=out_channels, @@ -297,16 +146,138 @@ class BasicBlock(nn.Layer): kernel_size=3, act=None, name=name + "_branch2b") - self.short = ShortCut( + + if not shortcut: + self.short = ConvBNLayer( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=1, + stride=stride, + is_vd_mode=not if_first and stride[0] != 1, + name=name + "_branch1") + + self.shortcut = shortcut + + def forward(self, inputs): + y = self.conv0(inputs) + conv1 = self.conv1(y) + + if self.shortcut: + short = inputs + else: + short = self.short(inputs) + y = paddle.elementwise_add(x=short, y=conv1, act='relu') + return y + + +class ResNet(nn.Layer): + def __init__(self, in_channels=3, layers=50, **kwargs): + super(ResNet, self).__init__() + + self.layers = layers + supported_layers = [18, 34, 50, 101, 152, 200] + assert layers in supported_layers, \ + "supported layers are {} but input layer is {}".format( + supported_layers, layers) + + if layers == 18: + depth = [2, 2, 2, 2] + elif layers == 34 or layers == 50: + depth = [3, 4, 6, 3] + elif layers == 101: + depth = [3, 4, 23, 3] + elif layers == 152: + depth = [3, 8, 36, 3] + elif layers == 200: + depth = [3, 12, 48, 3] + num_channels = [64, 256, 512, + 1024] if layers >= 50 else [64, 64, 128, 256] + num_filters = [64, 128, 256, 512] + + self.conv1_1 = ConvBNLayer( in_channels=in_channels, - out_channels=out_channels, - stride=stride, - if_first=if_first, - name=name + "_branch1") - self.out_channels = out_channels + out_channels=32, + kernel_size=3, + stride=1, + act='relu', + name="conv1_1") + self.conv1_2 = ConvBNLayer( + in_channels=32, + out_channels=32, + kernel_size=3, + stride=1, + act='relu', + name="conv1_2") + self.conv1_3 = ConvBNLayer( + in_channels=32, + out_channels=64, + kernel_size=3, + stride=1, + act='relu', + name="conv1_3") + self.pool2d_max = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + + self.block_list = [] + if layers >= 50: + for block in range(len(depth)): + shortcut = False + for i in range(depth[block]): + if layers in [101, 152, 200] and block == 2: + if i == 0: + conv_name = "res" + str(block + 2) + "a" + else: + conv_name = "res" + str(block + 2) + "b" + str(i) + else: + conv_name = "res" + str(block + 2) + chr(97 + i) - def forward(self, x): - y = self.conv0(x) - y = self.conv1(y) - y = y + self.short(x) - return F.relu(y) + if i == 0 and block != 0: + stride = (2, 1) + else: + stride = (1, 1) + bottleneck_block = self.add_sublayer( + 'bb_%d_%d' % (block, i), + BottleneckBlock( + in_channels=num_channels[block] + if i == 0 else num_filters[block] * 4, + out_channels=num_filters[block], + stride=stride, + shortcut=shortcut, + if_first=block == i == 0, + name=conv_name)) + shortcut = True + self.block_list.append(bottleneck_block) + self.out_channels = num_filters[block] + else: + for block in range(len(depth)): + shortcut = False + for i in range(depth[block]): + conv_name = "res" + str(block + 2) + chr(97 + i) + if i == 0 and block != 0: + stride = (2, 1) + else: + stride = (1, 1) + + basic_block = self.add_sublayer( + 'bb_%d_%d' % (block, i), + BasicBlock( + in_channels=num_channels[block] + if i == 0 else num_filters[block], + out_channels=num_filters[block], + stride=stride, + shortcut=shortcut, + if_first=block == i == 0, + name=conv_name)) + shortcut = True + self.block_list.append(basic_block) + self.out_channels = num_filters[block] + self.out_pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0) + + def forward(self, inputs): + y = self.conv1_1(inputs) + y = self.conv1_2(y) + y = self.conv1_3(y) + y = self.pool2d_max(y) + for block in self.block_list: + y = block(y) + y = self.out_pool(y) + return y -- GitLab