From 4fb3ab78820965c1e0c0d265f8dc255ffd9de51a Mon Sep 17 00:00:00 2001 From: chengxianbin Date: Fri, 29 May 2020 00:18:02 +0800 Subject: [PATCH] modify ssd script for merging backbone --- mindspore/model_zoo/ssd.py | 146 +++++++++++++++++++++++++++++++++---- 1 file changed, 133 insertions(+), 13 deletions(-) diff --git a/mindspore/model_zoo/ssd.py b/mindspore/model_zoo/ssd.py index b92e8457d..b69942cd5 100644 --- a/mindspore/model_zoo/ssd.py +++ b/mindspore/model_zoo/ssd.py @@ -24,7 +24,8 @@ from mindspore.ops import operations as P from mindspore.ops import functional as F from mindspore.ops import composite as C from mindspore.common.initializer import initializer -from .mobilenet import InvertedResidual, ConvBNReLU +from mindspore.ops.operations import TensorAdd +from mindspore import Parameter def _conv2d(in_channel, out_channel, kernel_size=3, stride=1, pad_mod='same'): @@ -45,6 +46,129 @@ def _make_divisible(v, divisor, min_value=None): return new_v +class DepthwiseConv(nn.Cell): + """ + Depthwise Convolution warpper definition. + + Args: + in_planes (int): Input channel. + kernel_size (int): Input kernel size. + stride (int): Stride size. + pad_mode (str): pad mode in (pad, same, valid) + channel_multiplier (int): Output channel multiplier + has_bias (bool): has bias or not + + Returns: + Tensor, output tensor. + + Examples: + >>> DepthwiseConv(16, 3, 1, 'pad', 1, channel_multiplier=1) + """ + def __init__(self, in_planes, kernel_size, stride, pad_mode, pad, channel_multiplier=1, has_bias=False): + super(DepthwiseConv, self).__init__() + self.has_bias = has_bias + self.in_channels = in_planes + self.channel_multiplier = channel_multiplier + self.out_channels = in_planes * channel_multiplier + self.kernel_size = (kernel_size, kernel_size) + self.depthwise_conv = P.DepthwiseConv2dNative(channel_multiplier=channel_multiplier, + kernel_size=self.kernel_size, + stride=stride, pad_mode=pad_mode, pad=pad) + self.bias_add = P.BiasAdd() + weight_shape = [channel_multiplier, in_planes, *self.kernel_size] + self.weight = Parameter(initializer('ones', weight_shape), name='weight') + + if has_bias: + bias_shape = [channel_multiplier * in_planes] + self.bias = Parameter(initializer('zeros', bias_shape), name='bias') + else: + self.bias = None + + def construct(self, x): + output = self.depthwise_conv(x, self.weight) + if self.has_bias: + output = self.bias_add(output, self.bias) + return output + + +class ConvBNReLU(nn.Cell): + """ + Convolution/Depthwise fused with Batchnorm and ReLU block definition. + + Args: + in_planes (int): Input channel. + out_planes (int): Output channel. + kernel_size (int): Input kernel size. + stride (int): Stride size for the first convolutional layer. Default: 1. + groups (int): channel group. Convolution is 1 while Depthiwse is input channel. Default: 1. + + Returns: + Tensor, output tensor. + + Examples: + >>> ConvBNReLU(16, 256, kernel_size=1, stride=1, groups=1) + """ + def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1): + super(ConvBNReLU, self).__init__() + padding = (kernel_size - 1) // 2 + if groups == 1: + conv = nn.Conv2d(in_planes, out_planes, kernel_size, stride, pad_mode='pad', + padding=padding) + else: + conv = DepthwiseConv(in_planes, kernel_size, stride, pad_mode='pad', pad=padding) + layers = [conv, nn.BatchNorm2d(out_planes), nn.ReLU6()] + self.features = nn.SequentialCell(layers) + + def construct(self, x): + output = self.features(x) + return output + + +class InvertedResidual(nn.Cell): + """ + Mobilenetv2 residual block definition. + + Args: + inp (int): Input channel. + oup (int): Output channel. + stride (int): Stride size for the first convolutional layer. Default: 1. + expand_ratio (int): expand ration of input channel + + Returns: + Tensor, output tensor. + + Examples: + >>> ResidualBlock(3, 256, 1, 1) + """ + def __init__(self, inp, oup, stride, expand_ratio): + super(InvertedResidual, self).__init__() + assert stride in [1, 2] + + hidden_dim = int(round(inp * expand_ratio)) + self.use_res_connect = stride == 1 and inp == oup + + layers = [] + if expand_ratio != 1: + layers.append(ConvBNReLU(inp, hidden_dim, kernel_size=1)) + layers.extend([ + # dw + ConvBNReLU(hidden_dim, hidden_dim, stride=stride, groups=hidden_dim), + # pw-linear + nn.Conv2d(hidden_dim, oup, kernel_size=1, stride=1, has_bias=False), + nn.BatchNorm2d(oup), + ]) + self.conv = nn.SequentialCell(layers) + self.add = TensorAdd() + self.cast = P.Cast() + + def construct(self, x): + identity = x + x = self.conv(x) + if self.use_res_connect: + return self.add(identity, x) + return x + + class FlattenConcat(nn.Cell): """ Concatenate predictions into a single tensor. @@ -57,20 +181,17 @@ class FlattenConcat(nn.Cell): """ def __init__(self, config): super(FlattenConcat, self).__init__() - self.sizes = config.FEATURE_SIZE - self.length = len(self.sizes) - self.num_default = config.NUM_DEFAULT - self.concat = P.Concat(axis=-1) + self.num_ssd_boxes = config.NUM_SSD_BOXES + self.concat = P.Concat(axis=1) self.transpose = P.Transpose() - def construct(self, x): + def construct(self, inputs): output = () - for i in range(self.length): - shape = F.shape(x[i]) - mid_shape = (shape[0], -1, self.num_default[i], self.sizes[i], self.sizes[i]) - final_shape = (shape[0], -1, self.num_default[i] * self.sizes[i] * self.sizes[i]) - output += (F.reshape(F.reshape(x[i], mid_shape), final_shape),) + batch_size = F.shape(inputs[0])[0] + for x in inputs: + x = self.transpose(x, (0, 2, 3, 1)) + output += (F.reshape(x, (batch_size, -1)),) res = self.concat(output) - return self.transpose(res, (0, 2, 1)) + return F.reshape(res, (batch_size, self.num_ssd_boxes, -1)) class MultiBox(nn.Cell): @@ -145,7 +266,6 @@ class SSD300(nn.Cell): if not is_training: self.softmax = P.Softmax() - def construct(self, x): layer_out_13, output = self.backbone(x) multi_feature = (layer_out_13, output) -- GitLab