"""shufflenetv2 in pytorch [1] Ningning Ma, Xiangyu Zhang, Hai-Tao Zheng, Jian Sun ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design https://arxiv.org/abs/1807.11164 """ import torch import torch.nn as nn import torch.nn.functional as F def channel_split(x, split): """split a tensor into two pieces along channel dimension Args: x: input tensor split:(int) channel size for each pieces """ assert x.size(1) == split * 2 return torch.split(x, split, dim=1) def channel_shuffle(x, groups): """channel shuffle operation Args: x: input tensor groups: input branch number """ batch_size, channels, height, width = x.size() channels_per_group = int(channels // groups) x = x.view(batch_size, groups, channels_per_group, height, width) x = x.transpose(1, 2).contiguous() x = x.view(batch_size, -1, height, width) return x class ShuffleUnit(nn.Module): def __init__(self, in_channels, out_channels, stride): super().__init__() self.stride = stride self.in_channels = in_channels self.out_channels = out_channels if stride != 1 or in_channels != out_channels: self.residual = nn.Sequential( nn.Conv2d(in_channels, in_channels, 1), nn.BatchNorm2d(in_channels), nn.ReLU(inplace=True), nn.Conv2d(in_channels, in_channels, 3, stride=stride, padding=1, groups=in_channels), nn.BatchNorm2d(in_channels), nn.Conv2d(in_channels, int(out_channels / 2), 1), nn.BatchNorm2d(int(out_channels / 2)), nn.ReLU(inplace=True) ) self.shortcut = nn.Sequential( nn.Conv2d(in_channels, in_channels, 3, stride=stride, padding=1, groups=in_channels), nn.BatchNorm2d(in_channels), nn.Conv2d(in_channels, int(out_channels / 2), 1), nn.BatchNorm2d(int(out_channels / 2)), nn.ReLU(inplace=True) ) else: self.shortcut = nn.Sequential() in_channels = int(in_channels / 2) self.residual = nn.Sequential( nn.Conv2d(in_channels, in_channels, 1), nn.BatchNorm2d(in_channels), nn.ReLU(inplace=True), nn.Conv2d(in_channels, in_channels, 3, stride=stride, padding=1, groups=in_channels), nn.BatchNorm2d(in_channels), nn.Conv2d(in_channels, in_channels, 1), nn.BatchNorm2d(in_channels), nn.ReLU(inplace=True) ) def forward(self, x): if self.stride == 1 and self.out_channels == self.in_channels: shortcut, residual = channel_split(x, int(self.in_channels / 2)) else: shortcut = x residual = x shortcut = self.shortcut(shortcut) residual = self.residual(residual) x = torch.cat([shortcut, residual], dim=1) x = channel_shuffle(x, 2) return x class ShuffleNetV2(nn.Module): def __init__(self, ratio=1., class_num=100, dropout_factor = 1.0): super().__init__() if ratio == 0.5: out_channels = [48, 96, 192, 1024] elif ratio == 1: out_channels = [116, 232, 464, 1024] elif ratio == 1.5: out_channels = [176, 352, 704, 1024] elif ratio == 2: out_channels = [244, 488, 976, 2048] else: ValueError('unsupported ratio number') self.pre = nn.Sequential( nn.Conv2d(3, 24, 3, padding=1), nn.BatchNorm2d(24) ) self.stage2 = self._make_stage(24, out_channels[0], 3) self.stage3 = self._make_stage(out_channels[0], out_channels[1], 7) self.stage4 = self._make_stage(out_channels[1], out_channels[2], 3) self.conv5 = nn.Sequential( nn.Conv2d(out_channels[2], out_channels[3], 1), nn.BatchNorm2d(out_channels[3]), nn.ReLU(inplace=True) ) self.fc = nn.Linear(out_channels[3], class_num) self.dropout = nn.Dropout(dropout_factor) def forward(self, x): x = self.pre(x) x = self.stage2(x) x = self.stage3(x) x = self.stage4(x) x = self.conv5(x) x = F.adaptive_avg_pool2d(x, 1) x = x.view(x.size(0), -1) x = self.dropout(x) x = self.fc(x) return x def _make_stage(self, in_channels, out_channels, repeat): layers = [] layers.append(ShuffleUnit(in_channels, out_channels, 2)) while repeat: layers.append(ShuffleUnit(out_channels, out_channels, 1)) repeat -= 1 return nn.Sequential(*layers) def shufflenetv2(): return ShuffleNetV2()