提交 963dcb9c 编写于 作者: G gaotingquan 提交者: Tingquan Gao

support dbb module for ResNet

1. add DiverseBranchBlock module;
2. ResNet support dbb version using DiverseBranchBlock by setting micro_block="DiverseBranchBlock";
3. ResNet support official vb version by setting use_first_short_conv=False;
4. add ResNet18_dbb training config ResNet18_dbb.yaml.
上级 4292c1a9
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
import numpy as np
def conv_bn(in_channels,
out_channels,
kernel_size,
stride=1,
padding=0,
dilation=1,
groups=1,
padding_mode='zeros'):
conv_layer = nn.Conv2D(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
bias_attr=False,
padding_mode=padding_mode)
bn_layer = nn.BatchNorm2D(num_features=out_channels)
se = nn.Sequential()
se.add_sublayer('conv', conv_layer)
se.add_sublayer('bn', bn_layer)
return se
class IdentityBasedConv1x1(nn.Conv2D):
def __init__(self, channels, groups=1):
super(IdentityBasedConv1x1, self).__init__(
in_channels=channels,
out_channels=channels,
kernel_size=1,
stride=1,
padding=0,
groups=groups,
bias_attr=False)
assert channels % groups == 0
input_dim = channels // groups
id_value = np.zeros((channels, input_dim, 1, 1))
for i in range(channels):
id_value[i, i % input_dim, 0, 0] = 1
self.id_tensor = paddle.to_tensor(id_value)
# nn.init.zeros_(self.weight)
self.weight.set_value(paddle.zeros_like(self.weight))
def forward(self, input):
kernel = self.weight + self.id_tensor
result = F.conv2d(
input,
kernel,
None,
stride=1,
padding=0,
dilation=self._dilation,
groups=self._groups)
return result
def get_actual_kernel(self):
return self.weight + self.id_tensor
class BNAndPad(nn.Layer):
def __init__(self,
pad_pixels,
num_features,
epsilon=1e-5,
momentum=0.1,
last_conv_bias=None,
bn=nn.BatchNorm2D):
super().__init__()
self.bn = bn(num_features, momentum=momentum, epsilon=epsilon)
self.pad_pixels = pad_pixels
self.last_conv_bias = last_conv_bias
def forward(self, input):
output = self.bn(input)
if self.pad_pixels > 0:
bias = -self.bn._mean
if self.last_conv_bias is not None:
bias += self.last_conv_bias
pad_values = self.bn.bias + self.bn.weight * (
bias / paddle.sqrt(self.bn._variance + self.bn._epsilon))
''' pad '''
# TODO: n,h,w,c format is not supported yet
n, c, h, w = output.shape
values = pad_values.reshape([1, -1, 1, 1])
w_values = values.expand([n, -1, self.pad_pixels, w])
x = paddle.concat([w_values, output, w_values], axis=2)
h = h + self.pad_pixels * 2
h_values = values.expand([n, -1, h, self.pad_pixels])
x = paddle.concat([h_values, x, h_values], axis=3)
output = x
return output
@property
def weight(self):
return self.bn.weight
@property
def bias(self):
return self.bn.bias
@property
def _mean(self):
return self.bn._mean
@property
def _variance(self):
return self.bn._variance
@property
def _epsilon(self):
return self.bn._epsilon
class DiverseBranchBlock(nn.Layer):
def __init__(self,
num_channels,
num_filters,
filter_size,
stride=1,
groups=1,
act=None,
**kwargs):
super().__init__()
padding = (filter_size - 1) // 2
dilation = 1
deploy = False
single_init = False
in_channels = num_channels
out_channels = num_filters
kernel_size = filter_size
internal_channels_1x1_3x3 = None
nonlinear = act
self.deploy = deploy
if nonlinear is None:
self.nonlinear = nn.Identity()
else:
self.nonlinear = nn.ReLU()
self.kernel_size = kernel_size
self.out_channels = out_channels
self.groups = groups
assert padding == kernel_size // 2
if deploy:
self.dbb_reparam = nn.Conv2D(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
bias_attr=True)
else:
self.dbb_origin = conv_bn(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups)
self.dbb_avg = nn.Sequential()
if groups < out_channels:
self.dbb_avg.add_sublayer(
'conv',
nn.Conv2D(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=1,
stride=1,
padding=0,
groups=groups,
bias_attr=False))
self.dbb_avg.add_sublayer(
'bn',
BNAndPad(
pad_pixels=padding, num_features=out_channels))
self.dbb_avg.add_sublayer(
'avg',
nn.AvgPool2D(
kernel_size=kernel_size, stride=stride, padding=0))
self.dbb_1x1 = conv_bn(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=1,
stride=stride,
padding=0,
groups=groups)
else:
self.dbb_avg.add_sublayer(
'avg',
nn.AvgPool2D(
kernel_size=kernel_size,
stride=stride,
padding=padding))
self.dbb_avg.add_sublayer('avgbn', nn.BatchNorm2D(out_channels))
if internal_channels_1x1_3x3 is None:
internal_channels_1x1_3x3 = in_channels if groups < out_channels else 2 * in_channels # For mobilenet, it is better to have 2X internal channels
self.dbb_1x1_kxk = nn.Sequential()
if internal_channels_1x1_3x3 == in_channels:
self.dbb_1x1_kxk.add_sublayer(
'idconv1',
IdentityBasedConv1x1(
channels=in_channels, groups=groups))
else:
self.dbb_1x1_kxk.add_sublayer(
'conv1',
nn.Conv2D(
in_channels=in_channels,
out_channels=internal_channels_1x1_3x3,
kernel_size=1,
stride=1,
padding=0,
groups=groups,
bias_attr=False))
self.dbb_1x1_kxk.add_sublayer(
'bn1',
BNAndPad(
pad_pixels=padding,
num_features=internal_channels_1x1_3x3))
self.dbb_1x1_kxk.add_sublayer(
'conv2',
nn.Conv2D(
in_channels=internal_channels_1x1_3x3,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=0,
groups=groups,
bias_attr=False))
self.dbb_1x1_kxk.add_sublayer('bn2', nn.BatchNorm2D(out_channels))
# The experiments reported in the paper used the default initialization of bn.weight (all as 1). But changing the initialization may be useful in some cases.
if single_init:
# Initialize the bn.weight of dbb_origin as 1 and others as 0. This is not the default setting.
self.single_init()
def forward(self, inputs):
if hasattr(self, 'dbb_reparam'):
return self.nonlinear(self.dbb_reparam(inputs))
out = self.dbb_origin(inputs)
if hasattr(self, 'dbb_1x1'):
out += self.dbb_1x1(inputs)
out += self.dbb_avg(inputs)
out += self.dbb_1x1_kxk(inputs)
return self.nonlinear(out)
def init_gamma(self, gamma_value):
if hasattr(self, "dbb_origin"):
paddle.nn.init.constant_(self.dbb_origin.bn.weight, gamma_value)
if hasattr(self, "dbb_1x1"):
paddle.nn.init.constant_(self.dbb_1x1.bn.weight, gamma_value)
if hasattr(self, "dbb_avg"):
paddle.nn.init.constant_(self.dbb_avg.avgbn.weight, gamma_value)
if hasattr(self, "dbb_1x1_kxk"):
paddle.nn.init.constant_(self.dbb_1x1_kxk.bn2.weight, gamma_value)
def single_init(self):
self.init_gamma(0.0)
if hasattr(self, "dbb_origin"):
paddle.nn.init.constant_(self.dbb_origin.bn.weight, 1.0)
......@@ -28,6 +28,7 @@ import math
from ....utils import logger
from ..base.theseus_layer import TheseusLayer
from ..base.dbb_block import DiverseBranchBlock
from ....utils.save_load import load_dygraph_pretrain, load_dygraph_pretrain_from_url
MODEL_URLS = {
......@@ -163,17 +164,18 @@ class BottleneckBlock(TheseusLayer):
stride,
shortcut=True,
if_first=False,
micro_block=ConvBNLayer,
lr_mult=1.0,
data_format="NCHW"):
super().__init__()
self.conv0 = ConvBNLayer(
self.conv0 = micro_block(
num_channels=num_channels,
num_filters=num_filters,
filter_size=1,
act="relu",
lr_mult=lr_mult,
data_format=data_format)
self.conv1 = ConvBNLayer(
self.conv1 = micro_block(
num_channels=num_filters,
num_filters=num_filters,
filter_size=3,
......@@ -181,7 +183,7 @@ class BottleneckBlock(TheseusLayer):
act="relu",
lr_mult=lr_mult,
data_format=data_format)
self.conv2 = ConvBNLayer(
self.conv2 = micro_block(
num_channels=num_filters,
num_filters=num_filters * 4,
filter_size=1,
......@@ -224,12 +226,13 @@ class BasicBlock(TheseusLayer):
stride,
shortcut=True,
if_first=False,
micro_block=ConvBNLayer,
lr_mult=1.0,
data_format="NCHW"):
super().__init__()
self.stride = stride
self.conv0 = ConvBNLayer(
self.conv0 = micro_block(
num_channels=num_channels,
num_filters=num_filters,
filter_size=3,
......@@ -237,7 +240,7 @@ class BasicBlock(TheseusLayer):
act="relu",
lr_mult=lr_mult,
data_format=data_format)
self.conv1 = ConvBNLayer(
self.conv1 = micro_block(
num_channels=num_filters,
num_filters=num_filters,
filter_size=3,
......@@ -274,7 +277,7 @@ class ResNet(TheseusLayer):
ResNet
Args:
config: dict. config of ResNet.
version: str="vb". Different version of ResNet, version vd can perform better.
version: str="vb". Different version of ResNet, version vd can perform better.
class_num: int=1000. The number of classes.
lr_mult_list: list. Control the learning rate of different stages.
Returns:
......@@ -293,6 +296,8 @@ class ResNet(TheseusLayer):
input_image_channel=3,
return_patterns=None,
return_stages=None,
micro_block="ConvBNLayer",
use_first_short_conv=True,
**kargs):
super().__init__()
......@@ -307,6 +312,13 @@ class ResNet(TheseusLayer):
self.num_channels = self.cfg["num_channels"]
self.channels_mult = 1 if self.num_channels[-1] == 256 else 4
if micro_block == "ConvBNLayer":
micro_block = ConvBNLayer
elif micro_block == "DiverseBranchBlock":
micro_block = DiverseBranchBlock
else:
raise Exception()
assert isinstance(self.lr_mult_list, (
list, tuple
)), "lr_mult_list should be in (list, tuple) but got {}".format(
......@@ -351,7 +363,11 @@ class ResNet(TheseusLayer):
data_format=data_format)
block_list = []
for block_idx in range(len(self.block_depth)):
# paddleclas' special improvement version
shortcut = False
# official resnet_vb version
if not use_first_short_conv and block_idx == 0:
shortcut = True
for i in range(self.block_depth[block_idx]):
block_list.append(globals()[self.block_type](
num_channels=self.num_channels[block_idx] if i == 0 else
......@@ -361,6 +377,7 @@ class ResNet(TheseusLayer):
if i == 0 and block_idx != 0 else 1,
shortcut=shortcut,
if_first=block_idx == i == 0 if version == "vd" else True,
micro_block=micro_block,
lr_mult=self.lr_mult_list[block_idx + 1],
data_format=data_format))
shortcut = True
......
# global configs
Global:
checkpoints: null
pretrained_model: null
output_dir: ./output/
device: gpu
save_interval: 1
eval_during_train: True
eval_interval: 1
epochs: 120
print_batch_step: 10
use_visualdl: False
# used for static mode and model export
image_shape: [3, 224, 224]
save_inference_dir: ./inference
# model architecture
Arch:
name: ResNet18
class_num: 1000
micro_block: DiverseBranchBlock
use_first_short_conv: False
# loss function config for traing/eval process
Loss:
Train:
- CELoss:
weight: 1.0
Eval:
- CELoss:
weight: 1.0
Optimizer:
name: Momentum
momentum: 0.9
lr:
name: Cosine
learning_rate: 0.1
regularizer:
name: 'L2'
coeff: 0.0001
# data loader for train and eval
DataLoader:
Train:
dataset:
name: ImageNetDataset
image_root: ./dataset/ILSVRC2012/
cls_label_path: ./dataset/ILSVRC2012/train_list.txt
transform_ops:
- DecodeImage:
to_rgb: True
channel_first: False
- RandCropImage:
size: 224
backend: pil
interpolation: bilinear
- RandFlipImage:
flip_code: 1
- ColorJitter:
brightness: 0.4
saturation: 0.4
hue: 0.4
- PCALighting:
- NormalizeImage:
scale: 1.0/255.0
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
sampler:
name: DistributedBatchSampler
batch_size: 64
drop_last: False
shuffle: True
loader:
num_workers: 4
use_shared_memory: True
Eval:
dataset:
name: ImageNetDataset
image_root: ./dataset/ILSVRC2012/
cls_label_path: ./dataset/ILSVRC2012/val_list.txt
transform_ops:
- DecodeImage:
to_rgb: True
channel_first: False
- ResizeImage:
resize_short: 256
backend: pil
interpolation: bilinear
- CropImage:
size: 224
- NormalizeImage:
scale: 1.0/255.0
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
sampler:
name: DistributedBatchSampler
batch_size: 64
drop_last: False
shuffle: False
loader:
num_workers: 4
use_shared_memory: True
Infer:
infer_imgs: docs/images/inference_deployment/whl_demo.jpg
batch_size: 10
transforms:
- DecodeImage:
to_rgb: True
channel_first: False
- ResizeImage:
resize_short: 256
- CropImage:
size: 224
- NormalizeImage:
scale: 1.0/255.0
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
- ToCHWImage:
PostProcess:
name: Topk
topk: 5
class_id_map_file: ppcls/utils/imagenet1k_label_list.txt
Metric:
Train:
- TopkAcc:
topk: [1, 5]
Eval:
- TopkAcc:
topk: [1, 5]
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册