未验证 提交 65c776de 编写于 作者: G Guanghua Yu 提交者: GitHub

add Structural Re-parameterization implementation (#1608)

上级 44e3306b
# 重参数化
本示例介绍如何对动态图模型进行重参数化训练,示例以常用的MobileNetV1模型为例,介绍如何对其进行DBB重参数化实验,DBB参考自[论文](https://arxiv.org/abs/2103.13425)
## 分类模型的重参数化训练流程
### 准备数据
在当前目录下创建``data``文件夹,将``ImageNet``数据集解压在``data``文件夹下,解压后``data/ILSVRC2012``文件夹下应包含以下文件:
- ``'train'``文件夹,训练图片
- ``'train_list.txt'``文件
- ``'val'``文件夹,验证图片
- ``'val_list.txt'``文件
### 准备需要重参数化的模型
- 对于paddle vision支持的[模型](https://github.com/PaddlePaddle/Paddle/tree/develop/python/paddle/vision/models)`[lenet, mobilenetv1, mobilenetv2, resnet, vgg]`可以直接使用vision内置的模型定义和ImageNet预训练权重
### 训练命令
- MobileNetV1
启动命令如下:
```bash
# 单卡训练
python train.py --model=mobilenet_v1
# 多卡训练,以0到3号卡为例
python -m paddle.distributed.launch --gpus="0,1,2,3" train.py
```
### 重参数化结果
| 模型 | FP32模型准确率(Top1) | 重参数化方法 | 重参数化模型准确率(Top1) |
| ----------- | --------------------------- | ------------ | --------------------------- |
| MobileNetV1 | 70.99 | DBB | 72.01 |
# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import math
import paddle
def piecewise_decay(net, device_num, args):
step = int(
math.ceil(float(args.total_images) / (args.batch_size * device_num)))
bd = [step * e for e in args.step_epochs]
lr = [args.lr * (0.1**i) for i in range(len(bd) + 1)]
learning_rate = paddle.optimizer.lr.PiecewiseDecay(
boundaries=bd, values=lr, verbose=False)
optimizer = paddle.optimizer.Momentum(
parameters=net.parameters(),
learning_rate=learning_rate,
momentum=args.momentum_rate,
weight_decay=paddle.regularizer.L2Decay(args.l2_decay))
return optimizer, learning_rate
def cosine_decay(net, device_num, args):
step = int(
math.ceil(float(args.total_images) / (args.batch_size * device_num)))
learning_rate = paddle.optimizer.lr.CosineAnnealingDecay(
learning_rate=args.lr, T_max=step * args.num_epochs, verbose=False)
optimizer = paddle.optimizer.Momentum(
parameters=net.parameters(),
learning_rate=learning_rate,
momentum=args.momentum_rate,
weight_decay=paddle.regularizer.L2Decay(args.l2_decay))
return optimizer, learning_rate
def create_optimizer(net, device_num, args):
if args.lr_strategy == "piecewise_decay":
return piecewise_decay(net, device_num, args)
elif args.lr_strategy == "cosine_decay":
return cosine_decay(net, device_num, args)
# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import sys
import logging
import paddle
import argparse
import functools
import math
import time
import random
import numpy as np
from paddle.distributed import ParallelEnv
from paddle.static import load_program_state
from paddle.vision.models import mobilenet_v1
import paddle.vision.transforms as T
from paddleslim.common import get_logger
from paddleslim.dygraph.rep import Reparameter, DBBRepConfig, ACBRepConfig
sys.path.append(os.path.join(os.path.dirname("__file__")))
from optimizer import create_optimizer
sys.path.append(
os.path.join(os.path.dirname("__file__"), os.path.pardir, os.path.pardir))
from utility import add_arguments, print_arguments
_logger = get_logger(__name__, level=logging.INFO)
parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
# yapf: disable
add_arg('batch_size', int, 64, "Single Card Minibatch size.")
add_arg('use_gpu', bool, True, "Whether to use GPU or not.")
add_arg('lr', float, 0.1, "The learning rate used to fine-tune pruned model.")
add_arg('lr_strategy', str, "piecewise_decay", "The learning rate decay strategy.")
add_arg('l2_decay', float, 0.00003, "The l2_decay parameter.")
add_arg('ls_epsilon', float, 0.0, "Label smooth epsilon.")
add_arg('use_pact', bool, False, "Whether to use PACT method.")
add_arg('ce_test', bool, False, "Whether to CE test.")
add_arg('momentum_rate', float, 0.9, "The value of momentum_rate.")
add_arg('num_epochs', int, 120, "The number of total epochs.")
add_arg('total_images', int, 1281167, "The number of total training images.")
add_arg('data', str, "imagenet", "Which data to use. 'cifar10' or 'imagenet'")
add_arg('log_period', int, 10, "Log period in batches.")
add_arg('model_save_dir', str, "./output_models", "model save directory.")
parser.add_argument('--step_epochs', nargs='+', type=int, default=[30, 60, 90], help="piecewise decay step")
# yapf: enable
def load_dygraph_pretrain(model, path=None, load_static_weights=False):
if not (os.path.isdir(path) or os.path.exists(path + '.pdparams')):
raise ValueError("Model pretrain path {} does not "
"exists.".format(path))
if load_static_weights:
pre_state_dict = load_program_state(path)
param_state_dict = {}
model_dict = model.state_dict()
for key in model_dict.keys():
weight_name = model_dict[key].name
if weight_name in pre_state_dict.keys():
print('Load weight: {}, shape: {}'.format(
weight_name, pre_state_dict[weight_name].shape))
param_state_dict[key] = pre_state_dict[weight_name]
else:
param_state_dict[key] = model_dict[key]
model.set_dict(param_state_dict)
return
param_state_dict = paddle.load(path + ".pdparams")
model.set_dict(param_state_dict)
return
def train(args):
num_workers = 4
shuffle = True
if args.ce_test:
# set seed
seed = 111
paddle.seed(seed)
np.random.seed(seed)
random.seed(seed)
num_workers = 0
shuffle = False
if args.data == "cifar10":
transform = T.Compose([T.Transpose(), T.Normalize([127.5], [127.5])])
train_dataset = paddle.vision.datasets.Cifar10(
mode="train", backend="cv2", transform=transform)
val_dataset = paddle.vision.datasets.Cifar10(
mode="test", backend="cv2", transform=transform)
class_dim = 10
image_shape = [3, 32, 32]
pretrain = False
args.total_images = 50000
elif args.data == "imagenet":
import imagenet_reader as reader
train_dataset = reader.ImageNetDataset(mode='train')
val_dataset = reader.ImageNetDataset(mode='val')
class_dim = 1000
image_shape = "3,224,224"
else:
raise ValueError("{} is not supported.".format(args.data))
trainer_num = paddle.distributed.get_world_size()
use_data_parallel = trainer_num != 1
place = paddle.set_device('gpu' if args.use_gpu else 'cpu')
# model definition
if use_data_parallel:
paddle.distributed.init_parallel_env()
pretrain = True if args.data == "imagenet" else False
net = mobilenet_v1(pretrained=pretrain, num_classes=class_dim)
rep_config = DBBRepConfig()
reper = Reparameter(rep_config)
reper.prepare(net)
paddle.summary(net, (1, 3, 224, 224))
opt, lr = create_optimizer(net, trainer_num, args)
if use_data_parallel:
net = paddle.DataParallel(net)
train_batch_sampler = paddle.io.DistributedBatchSampler(
train_dataset,
batch_size=args.batch_size,
shuffle=shuffle,
drop_last=True)
train_loader = paddle.io.DataLoader(
train_dataset,
batch_sampler=train_batch_sampler,
places=place,
return_list=True,
num_workers=num_workers)
valid_loader = paddle.io.DataLoader(
val_dataset,
places=place,
batch_size=args.batch_size,
shuffle=False,
drop_last=False,
return_list=True,
num_workers=num_workers)
@paddle.no_grad()
def test(epoch, net):
net.eval()
batch_id = 0
acc_top1_ns = []
acc_top5_ns = []
eval_reader_cost = 0.0
eval_run_cost = 0.0
total_samples = 0
reader_start = time.time()
for data in valid_loader():
eval_reader_cost += time.time() - reader_start
image = data[0]
label = data[1]
if args.data == "cifar10":
label = paddle.reshape(label, [-1, 1])
eval_start = time.time()
out = net(image)
acc_top1 = paddle.metric.accuracy(input=out, label=label, k=1)
acc_top5 = paddle.metric.accuracy(input=out, label=label, k=5)
eval_run_cost += time.time() - eval_start
batch_size = image.shape[0]
total_samples += batch_size
if batch_id % args.log_period == 0:
log_period = 1 if batch_id == 0 else args.log_period
_logger.info(
"Eval epoch[{}] batch[{}] - top1: {:.6f}; top5: {:.6f}; avg_reader_cost: {:.6f} s, avg_batch_cost: {:.6f} s, avg_samples: {}, avg_ips: {:.3f} images/s".
format(epoch, batch_id,
np.mean(acc_top1.numpy()),
np.mean(acc_top5.numpy()), eval_reader_cost /
log_period, (eval_reader_cost + eval_run_cost) /
log_period, total_samples / log_period, total_samples
/ (eval_reader_cost + eval_run_cost)))
eval_reader_cost = 0.0
eval_run_cost = 0.0
total_samples = 0
acc_top1_ns.append(np.mean(acc_top1.numpy()))
acc_top5_ns.append(np.mean(acc_top5.numpy()))
batch_id += 1
reader_start = time.time()
_logger.info(
"Final eval epoch[{}] - acc_top1: {:.6f}; acc_top5: {:.6f}".format(
epoch,
np.mean(np.array(acc_top1_ns)), np.mean(np.array(acc_top5_ns))))
return np.mean(np.array(acc_top1_ns))
def cross_entropy(input, target, ls_epsilon):
if ls_epsilon > 0:
if target.shape[-1] != class_dim:
target = paddle.nn.functional.one_hot(target, class_dim)
target = paddle.nn.functional.label_smooth(
target, epsilon=ls_epsilon)
target = paddle.reshape(target, shape=[-1, class_dim])
input = -paddle.nn.functional.log_softmax(input, axis=-1)
cost = paddle.sum(target * input, axis=-1)
else:
cost = paddle.nn.functional.cross_entropy(input=input, label=target)
avg_cost = paddle.mean(cost)
return avg_cost
def train(epoch, net):
net.train()
batch_id = 0
train_reader_cost = 0.0
train_run_cost = 0.0
total_samples = 0
reader_start = time.time()
for data in train_loader():
train_reader_cost += time.time() - reader_start
image = data[0]
label = data[1]
if args.data == "cifar10":
label = paddle.reshape(label, [-1, 1])
train_start = time.time()
out = net(image)
avg_cost = cross_entropy(out, label, args.ls_epsilon)
acc_top1 = paddle.metric.accuracy(input=out, label=label, k=1)
acc_top5 = paddle.metric.accuracy(input=out, label=label, k=5)
avg_cost.backward()
opt.step()
opt.clear_grad()
lr.step()
loss_n = np.mean(avg_cost.numpy())
acc_top1_n = np.mean(acc_top1.numpy())
acc_top5_n = np.mean(acc_top5.numpy())
train_run_cost += time.time() - train_start
batch_size = image.shape[0]
total_samples += batch_size
if batch_id % args.log_period == 0:
log_period = 1 if batch_id == 0 else args.log_period
_logger.info(
"epoch[{}]-batch[{}] lr: {:.6f} - loss: {:.6f}; top1: {:.6f}; top5: {:.6f}; avg_reader_cost: {:.6f} s, avg_batch_cost: {:.6f} s, avg_samples: {}, avg_ips: {:.3f} images/s".
format(epoch, batch_id,
lr.get_lr(), loss_n, acc_top1_n, acc_top5_n,
train_reader_cost / log_period, (
train_reader_cost + train_run_cost) / log_period,
total_samples / log_period, total_samples / (
train_reader_cost + train_run_cost)))
train_reader_cost = 0.0
train_run_cost = 0.0
total_samples = 0
batch_id += 1
reader_start = time.time()
# train loop
best_acc1 = 0.0
best_epoch = 0
for i in range(args.num_epochs):
train(i, net)
acc1 = test(i, net)
if paddle.distributed.get_rank() == 0:
model_prefix = os.path.join(args.model_save_dir, "epoch_" + str(i))
paddle.save(net.state_dict(), model_prefix + ".pdparams")
paddle.save(opt.state_dict(), model_prefix + ".pdopt")
if acc1 > best_acc1:
best_acc1 = acc1
best_epoch = i
if paddle.distributed.get_rank() == 0:
model_prefix = os.path.join(args.model_save_dir, "best_model")
paddle.save(net.state_dict(), model_prefix + ".pdparams")
paddle.save(opt.state_dict(), model_prefix + ".pdopt")
# Save model
reper.convert(net)
if paddle.distributed.get_rank() == 0:
# load best model
load_dygraph_pretrain(net,
os.path.join(args.model_save_dir, "best_model"))
path = os.path.join(args.model_save_dir, "inference_model", 'rep_model')
paddle.jit.save(
net,
path,
input_spec=[
paddle.static.InputSpec(
shape=[None, 3, 224, 224], dtype='float32')
])
def main():
args = parser.parse_args()
print_arguments(args)
train(args)
if __name__ == '__main__':
main()
......@@ -5,3 +5,5 @@ from .prune import *
__all__ += prune.__all__
from .dist import *
__all__ += dist.__all__
from .rep import *
__all__ += rep.__all__
\ No newline at end of file
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from . import rep
from . import config
from . import reper
from .rep import Reparameter
from .config import *
from .reper import *
__all__ = []
__all__ += rep.__all__
__all__ += config.__all__
__all__ += reper.__all__
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Dict, Union
import paddle.nn as nn
from .reper import DiverseBranchBlock, ACBlock, RepVGGBlock, SlimRepBlock
SUPPORT_REP_TYPE_LAYERS = [nn.Conv2D, nn.Linear]
__all__ = [
"BaseRepConfig", "DBBRepConfig", "ACBRepConfig", "RepVGGConfig",
"SlimRepConfig"
]
class BaseRepConfig:
"""
Basic reparameterization configuration class.
Args:
type_config(dict): Set the reper by the type of layer. The key of `type_config`
should be subclass of `paddle.nn.Layer`. Its priority is lower than `layer_config`.
Default is: `{nn.Conv2D: ACBlock}`.
layer_config(dict): Set the reper by layer. It has the highest priority among
all the setting methods. Such as: `{model.conv1: ACBlock}`. Default is None.
"""
def __init__(
self,
type_config: Dict={nn.Conv2D: ACBlock},
layer_config: Dict=None, ):
self._type_config = self._set_type_config(type_config)
self._layer_config = self._set_layer_config(layer_config)
def add_config(
self,
type_config: Dict=None,
layer_config: Dict=None, ):
self._type_config.update(self._set_type_config(type_config))
self._layer_config.update(self._set_layer_config(layer_config))
@property
def all_config(self):
return {
'type_config': self._type_config,
'layer_config': self._layer_config,
}
def _set_type_config(self, type_config):
_type_config = {}
if type_config:
for _layer in type_config:
assert isinstance(_layer, type) and issubclass(
_layer, nn.Layer
), "Expect to get subclasses under nn.Layer, but got {}.".format(
_layer)
assert _layer in SUPPORT_REP_TYPE_LAYERS, "Expect to get one of `{}`, but got {}.".format(
SUPPORT_REP_TYPE_LAYERS, _layer)
_type_config[_layer] = type_config[_layer]
return _type_config
def _set_layer_config(self, layer_config):
_layer_config = {}
if layer_config:
for _layer in layer_config:
is_support = False
for support_type in SUPPORT_REP_TYPE_LAYERS:
if isinstance(_layer, support_type):
is_support = True
assert is_support, "Expect layer to get one of `{}`.".format(
SUPPORT_REP_LAYERS)
_layer_config[_layer.full_name()] = layer_config[_layer]
return _layer_config
def __str__(self):
result = ""
if len(self._type_config) > 0:
result += f"Type config:\n{self._type_config}\n"
if len(self._layer_config) > 0:
result += f"Layer config: \n{self._layer_config}\n"
return result
class DBBRepConfig(BaseRepConfig):
"""
DBB reparameterization configuration class.
Args:
type_config(dict): Set the reper by the type of layer. The key of `type_config`
should be subclass of `paddle.nn.Layer`. Its priority is lower than `layer_config`.
Default is: `{nn.Conv2D: ACBlock}`.
layer_config(dict): Set the reper by layer. It has the highest priority among
all the setting methods. Such as: `{model.conv1: ACBlock}`. Default is None.
"""
def __init__(
self,
type_config: Dict={nn.Conv2D: DiverseBranchBlock},
layer_config: Dict=None, ):
self._type_config = self._set_type_config(type_config)
self._layer_config = self._set_layer_config(layer_config)
class ACBRepConfig(BaseRepConfig):
"""
ACBlock reparameterization configuration class.
Args:
type_config(dict): Set the reper by the type of layer. The key of `type_config`
should be subclass of `paddle.nn.Layer`. Its priority is lower than `layer_config`.
Default is: `{nn.Conv2D: ACBlock}`.
layer_config(dict): Set the reper by layer. It has the highest priority among
all the setting methods. Such as: `{model.conv1: ACBlock}`. Default is None.
"""
def __init__(
self,
type_config: Dict={nn.Conv2D: ACBlock},
layer_config: Dict=None, ):
self._type_config = self._set_type_config(type_config)
self._layer_config = self._set_layer_config(layer_config)
class RepVGGConfig(BaseRepConfig):
"""
RepVGG reparameterization configuration class.
Args:
type_config(dict): Set the reper by the type of layer. The key of `type_config`
should be subclass of `paddle.nn.Layer`. Its priority is lower than `layer_config`.
Default is: `{nn.Conv2D: ACBlock}`.
layer_config(dict): Set the reper by layer. It has the highest priority among
all the setting methods. Such as: `{model.conv1: ACBlock}`. Default is None.
"""
def __init__(
self,
type_config: Dict={nn.Conv2D: RepVGGBlock},
layer_config: Dict=None, ):
self._type_config = self._set_type_config(type_config)
self._layer_config = self._set_layer_config(layer_config)
class SlimRepConfig(BaseRepConfig):
"""
SlimRepBlock reparameterization configuration class.
Args:
type_config(dict): Set the reper by the type of layer. The key of `type_config`
should be subclass of `paddle.nn.Layer`. Its priority is lower than `layer_config`.
Default is: `{nn.Conv2D: ACBlock}`.
layer_config(dict): Set the reper by layer. It has the highest priority among
all the setting methods. Such as: `{model.conv1: ACBlock}`. Default is None.
"""
def __init__(
self,
type_config: Dict={nn.Conv2D: SlimRepBlock},
layer_config: Dict=None, ):
self._type_config = self._set_type_config(type_config)
self._layer_config = self._set_layer_config(layer_config)
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import paddle.nn as nn
from ...common import get_logger
from .config import BaseRepConfig, SUPPORT_REP_TYPE_LAYERS
_logger = get_logger(__name__, level=logging.INFO)
__all__ = ["Reparameter"]
class Reparameter:
"""
Re-parameterization interface of dygraph model.
Args:
model(nn.Layer): Model of networks.
config(instance): Reparameterization config, default is `BaseRepConfig`.
"""
def __init__(self, config=BaseRepConfig):
assert config != None, "config cannot be None."
self._config = config.all_config
self._layer2reper_config = {}
def prepare(self, model):
"""
Re-parameterization prepare model callback interface.
Args:
model(nn.Layer): The model to be reparameterized.
"""
self._layer2reper_config = self._parser_rep_config(model)
# Conv2D
if "Conv2D" in self._layer2reper_config:
conv2d2reper_config = self._layer2reper_config["Conv2D"]
conv_bn_pairs = self._get_conv_bn_pair(model)
if not conv_bn_pairs:
_logger.info(
"No conv-bn layer found, so skip the reparameterization.")
return model
for layer_name in conv2d2reper_config:
if layer_name in list(conv_bn_pairs.keys()):
per_conv_bn_pair = [layer_name, conv_bn_pairs[layer_name]]
self._replace_conv_bn_with_reper(
model, conv2d2reper_config[layer_name],
per_conv_bn_pair)
return model
def convert(self, model):
"""
Re-parameterization export interface, it will run fusion operation.
Args:
model(nn.Layer): The model that has been reparameterized.
"""
for layer in model.sublayers():
if hasattr(layer, 'convert_to_deploy'):
layer.convert_to_deploy()
def _parser_rep_config(self, model):
_layer2reper_config = {}
for name, layer in model.named_sublayers():
support_type_layers = list(self._config['type_config'].keys())
refine_layer_full_names = list(self._config['layer_config'].keys())
cur_layer_reper = None
# Firstly, parser type layer in model.
for layer_type in support_type_layers:
if isinstance(layer, layer_type):
cur_layer_reper = self._config['type_config'][layer_type]
# Secondly, parser layer full name in model.
if name in refine_layer_full_names:
cur_layer_reper = self._config['layer_config'][name]
# Conv2d
if cur_layer_reper and isinstance(layer, nn.Conv2D):
if "Conv2D" in _layer2reper_config:
_layer2reper_config["Conv2D"].update({
name: cur_layer_reper
})
else:
_layer2reper_config["Conv2D"] = {name: cur_layer_reper}
# Linear
elif cur_layer_reper and isinstance(layer, nn.Linear):
if "Linear" in _layer2reper_config:
_layer2reper_config["Linear"].update({
name: cur_layer_reper
})
else:
_layer2reper_config["Linear"] = {name: cur_layer_reper}
elif cur_layer_reper:
_logger.info(
"{} not support reparameterization, please choose one of {}".
format(name, SUPPORT_REP_TYPE_LAYERS))
return _layer2reper_config
def _get_conv_bn_pair(self, model):
"""
Get the combination of Conv2D and BatchNorm2D.
Args:
model(nn.Layer): The model that has been reparameterized.
"""
conv_bn_pairs = {}
tmp_pair = [None, None]
for name, layer in model.named_sublayers():
if isinstance(layer, nn.Conv2D):
tmp_pair[0] = name
if isinstance(layer, nn.BatchNorm2D) or isinstance(
layer, nn.BatchNorm):
tmp_pair[1] = name
if tmp_pair[0] and tmp_pair[1] and len(tmp_pair) == 2:
conv_bn_pairs[tmp_pair[0]] = tmp_pair[1]
tmp_pair = [None, None]
return conv_bn_pairs
def _replace_conv_bn_with_reper(self, model, reper, conv_bn_pair):
"""
Replace Conv2D and BatchNorm2D with reper.
Args:
model(nn.Layer): The model that has been reparameterized.
reper(nn.Layer): The reper used by the current layer.
conv_bn_pairs(list[str, str]): List of combination of Conv2D and BatchNorm2D.
"""
for layer_name in conv_bn_pair:
parent_layer, sub_name = self._find_parent_layer_and_sub_name(
model, layer_name)
module = getattr(parent_layer, sub_name)
if isinstance(module, nn.Conv2D):
new_layer = reper(
in_channels=module._in_channels,
out_channels=module._out_channels,
kernel_size=module._kernel_size[0],
stride=module._stride[0],
groups=module._groups,
padding=module._padding)
setattr(parent_layer, sub_name, new_layer)
if isinstance(module, nn.BatchNorm2D) or isinstance(
module, nn.BatchNorm):
new_layer = nn.Identity()
setattr(parent_layer, sub_name, new_layer)
def _find_parent_layer_and_sub_name(self, model, name):
"""
Given the model and the name of a layer, find the parent layer and
the sub_name of the layer.
For example, if name is 'block_1/convbn_1/conv_1', the parent layer is
'block_1/convbn_1' and the sub_name is `conv_1`.
Args:
model(paddle.nn.Layer): the model to be reparameterized.
name(string): the name of a layer.
Returns:
parent_layer, subname
"""
assert isinstance(model, nn.Layer), \
"The model must be the instance of paddle.nn.Layer."
assert len(name) > 0, "The input (name) should not be empty."
last_idx = 0
idx = 0
parent_layer = model
while idx < len(name):
if name[idx] == '.':
sub_name = name[last_idx:idx]
if hasattr(parent_layer, sub_name):
parent_layer = getattr(parent_layer, sub_name)
last_idx = idx + 1
idx += 1
sub_name = name[last_idx:idx]
return parent_layer, sub_name
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from . import diversebranchblock
from . import acblock
from . import repvgg
from . import slimrep
from . import base
from .diversebranchblock import DiverseBranchBlock
from .acblock import ACBlock
from .repvgg import RepVGGBlock
from .slimrep import SlimRepBlock
__all__ = []
__all__ += diversebranchblock.__all__
__all__ += acblock.__all__
__all__ += repvgg.__all__
__all__ += slimrep.__all__
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import paddle
import paddle.nn as nn
from .base import BaseConv2DReper, ConvBNLayer
__all__ = ["ACBlock"]
class ACBlock(BaseConv2DReper):
"""
An instance of the ACBlock module, which replaces the conv-bn layer in the network.
Refer from Paper: https://arxiv.org/abs/1908.03930.
"""
def __init__(self,
in_channels,
out_channels,
kernel_size,
stride=1,
groups=1,
padding=None):
super(ACBlock, self).__init__(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
groups=groups,
padding=padding)
if self.padding - self.kernel_size // 2 >= 0:
self.crop = 0
# Compared to the KxK layer, the padding of the 1xK layer and Kx1 layer should be adjust to align the sliding windows (Fig 2 in the paper)
hor_padding = [self.padding - self.kernel_size // 2, self.padding]
ver_padding = [self.padding, self.padding - self.kernel_size // 2]
else:
# A negative "padding" (padding - kernel_size//2 < 0, which is not a common use case) is cropping.
# Since nn.Conv2D does not support negative padding, we implement it manually
self.crop = self.kernel_size // 2 - self.padding
hor_padding = [0, self.padding]
ver_padding = [self.padding, 0]
# kxk square branch
self.square_branch = ConvBNLayer(
self.in_channels,
self.out_channels,
self.kernel_size,
self.stride,
groups=self.groups,
padding=self.padding)
# kx1 vertical branch
self.ver_branch = ConvBNLayer(
self.in_channels,
self.out_channels, (self.kernel_size, 1),
self.stride,
groups=self.groups,
padding=ver_padding)
# 1xk horizontal branch
self.hor_branch = ConvBNLayer(
self.in_channels,
self.out_channels, (1, self.kernel_size),
self.stride,
groups=self.groups,
padding=hor_padding)
def _add_to_square_kernel(self, square_kernel, asym_kernel):
asym_h = asym_kernel.shape[2]
asym_w = asym_kernel.shape[3]
square_h = square_kernel.shape[2]
square_w = square_kernel.shape[3]
square_kernel[:, :, square_h // 2 - asym_h // 2:square_h // 2 -
asym_h // 2 + asym_h, square_w // 2 - asym_w // 2:
square_w // 2 - asym_w // 2 + asym_w] += asym_kernel
def _fuse_bn(self, kernel, bn):
running_mean = bn._mean
running_var = bn._variance
gamma = bn.weight
beta = bn.bias
eps = bn._epsilon
std = (running_var + eps).sqrt()
t = (gamma / std).reshape((-1, 1, 1, 1))
return kernel * t, beta - running_mean * gamma / std
def _get_equivalent_kernel_bias(self):
hor_k, hor_b = self._fuse_bn(self.hor_branch.conv.weight,
self.hor_branch.bn)
ver_k, ver_b = self._fuse_bn(self.ver_branch.conv.weight,
self.ver_branch.bn)
square_k, square_b = self._fuse_bn(self.square_branch.conv.weight,
self.square_branch.bn)
self._add_to_square_kernel(square_k, hor_k)
self._add_to_square_kernel(square_k, ver_k)
return square_k, hor_b + ver_b + square_b
def convert_to_deploy(self):
if hasattr(self, 'fused_branch'):
return
kernel, bias = self._get_equivalent_kernel_bias()
self.fused_branch = nn.Conv2D(
in_channels=self.in_channels,
out_channels=self.out_channels,
kernel_size=self.kernel_size,
stride=self.stride,
padding=self.padding,
groups=self.groups,
bias_attr=True)
self.fused_branch.weight.set_value(kernel)
self.fused_branch.bias.set_value(bias)
self.__delattr__('ver_branch')
self.__delattr__('hor_branch')
self.__delattr__('square_branch')
def forward(self, input):
if hasattr(self, 'fused_branch'):
return self.fused_branch(input)
out = self.square_branch(input)
if self.crop > 0:
ver_input = input[:, :, :, self.crop:-self.crop]
hor_input = input[:, :, self.crop:-self.crop, :]
else:
ver_input = input
hor_input = input
out += self.ver_branch(ver_input)
out += self.hor_branch(hor_input)
return out
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle.nn as nn
from paddle import ParamAttr
from paddle.regularizer import L2Decay
from paddle.nn.initializer import KaimingNormal
class BaseConv2DReper(nn.Layer):
"""
An Base instance of the Reparameterization module based on Conv2D.
"""
def __init__(self,
in_channels,
out_channels,
kernel_size,
stride=1,
groups=1,
padding=None):
super(BaseConv2DReper, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = kernel_size
self.stride = stride
self.groups = groups
self.padding = padding
def convert_to_deploy(self):
pass
def forward(self, input):
pass
class ConvBNLayer(nn.Layer):
def __init__(self,
in_channels,
out_channels,
filter_size,
stride,
groups=1,
padding=None):
super().__init__()
if not padding:
padding = filter_size // 2
self.conv = nn.Conv2D(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=filter_size,
stride=stride,
padding=padding,
groups=groups,
weight_attr=ParamAttr(initializer=KaimingNormal()),
bias_attr=False)
self.bn = nn.BatchNorm2D(
out_channels,
weight_attr=ParamAttr(regularizer=L2Decay(0.0)),
bias_attr=ParamAttr(regularizer=L2Decay(0.0)))
def forward(self, x):
x = self.conv(x)
x = self.bn(x)
return x
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This code is referenced from: https://github.com/DingXiaoH/DiverseBranchBlock/blob/main/diversebranchblock.py
import numpy as np
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from paddle import ParamAttr
from .base import BaseConv2DReper, ConvBNLayer
__all__ = ["DiverseBranchBlock"]
class IdentityBasedConv1x1(nn.Conv2D):
def __init__(self,
channels,
groups=1,
weight_attr=ParamAttr(
initializer=nn.initializer.Constant(0.0))):
super(IdentityBasedConv1x1, self).__init__(
in_channels=channels,
out_channels=channels,
kernel_size=1,
stride=1,
padding=0,
groups=groups,
weight_attr=weight_attr,
bias_attr=False)
assert channels % groups == 0
input_dim = channels // groups
id_value = np.zeros((channels, input_dim, 1, 1))
for i in range(channels):
id_value[i, i % input_dim, 0, 0] = 1
self.id_tensor = paddle.to_tensor(id_value)
self.groups = groups
def forward(self, input):
kernel = self.weight + self.id_tensor
result = F.conv2d(
input, kernel, None, stride=1, padding=0, groups=self.groups)
return result
def get_actual_kernel(self):
return self.weight + self.id_tensor
class BNAndPadLayer(nn.Layer):
def __init__(self, pad_pixels, num_features, eps=1e-5, momentum=0.1):
super(BNAndPadLayer, self).__init__()
self.bn = nn.BatchNorm2D(num_features, momentum, eps)
self.pad_pixels = pad_pixels
def forward(self, input):
output = self.bn(input)
if self.pad_pixels > 0:
pad_values = self.bn.bias - self.bn._mean * self.bn.weight / paddle.sqrt(
self.bn._variance + self.bn._epsilon)
output = F.pad(output, [self.pad_pixels] * 4)
pad_values = pad_values.reshape((1, -1, 1, 1))
output[:, :, 0:self.pad_pixels, :] = pad_values
output[:, :, -self.pad_pixels:, :] = pad_values
output[:, :, :, 0:self.pad_pixels] = pad_values
output[:, :, :, -self.pad_pixels:] = pad_values
return output
@property
def weight(self):
return self.bn.weight
@property
def bias(self):
return self.bn.bias
@property
def _mean(self):
return self.bn._mean
@property
def _variance(self):
return self.bn._variance
@property
def _epsilon(self):
return self.bn._epsilon
class DiverseBranchBlock(BaseConv2DReper):
"""
An instance of the DBB module, which replaces the conv-bn layer in the network.
"""
def __init__(self,
in_channels,
out_channels,
kernel_size,
stride=1,
groups=1,
padding=None,
internal_channels_1x1_3x3=None):
super(DiverseBranchBlock, self).__init__(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
groups=groups,
padding=padding)
# kxk branch
self.dbb_origin = ConvBNLayer(
in_channels, out_channels, kernel_size, stride, groups=groups)
# 1x1-avg branch
self.dbb_avg = nn.Sequential()
if groups < out_channels:
self.dbb_avg.add_sublayer('conv',
nn.Conv2D(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=1,
stride=1,
padding=0,
groups=groups,
bias_attr=False))
self.dbb_avg.add_sublayer('bn',
BNAndPadLayer(
pad_pixels=self.padding,
num_features=out_channels))
self.dbb_avg.add_sublayer('avg',
nn.AvgPool2D(
kernel_size=kernel_size,
stride=stride,
padding=0))
else:
self.dbb_avg.add_sublayer('avg',
nn.AvgPool2D(
kernel_size=kernel_size,
stride=stride,
padding=self.padding))
self.dbb_avg.add_sublayer('avgbn', nn.BatchNorm2D(out_channels))
# 1x1 branch
if groups < out_channels:
self.dbb_1x1 = ConvBNLayer(
in_channels, out_channels, 1, stride, groups=groups)
# 1x1-kxk branch
if internal_channels_1x1_3x3 is None:
# For mobilenet, it is better to have 2X internal channels
internal_channels_1x1_3x3 = in_channels if groups < out_channels else 2 * in_channels
self.dbb_1x1_kxk = nn.Sequential()
if internal_channels_1x1_3x3 == in_channels:
self.dbb_1x1_kxk.add_sublayer('idconv1',
IdentityBasedConv1x1(
channels=in_channels,
groups=groups))
else:
self.dbb_1x1_kxk.add_sublayer(
'conv1',
nn.Conv2D(
in_channels=in_channels,
out_channels=internal_channels_1x1_3x3,
kernel_size=1,
stride=1,
padding=0,
groups=groups,
bias_attr=False))
self.dbb_1x1_kxk.add_sublayer(
'bn1',
BNAndPadLayer(
pad_pixels=self.padding,
num_features=internal_channels_1x1_3x3))
self.dbb_1x1_kxk.add_sublayer('conv2',
nn.Conv2D(
in_channels=internal_channels_1x1_3x3,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=0,
groups=groups,
bias_attr=False))
self.dbb_1x1_kxk.add_sublayer('bn2', nn.BatchNorm2D(out_channels))
def _fuse_bn(self, kernel, bn):
running_mean = bn._mean
running_var = bn._variance
gamma = bn.weight
beta = bn.bias
eps = bn._epsilon
std = (running_var + eps).sqrt()
t = (gamma / std).reshape((-1, 1, 1, 1))
return kernel * t, beta - running_mean * gamma / std
def _fuse_1x1_kxk(self, k1, b1, k2, b2, groups):
if groups == 1:
k = F.conv2d(k2, k1.transpose((1, 0, 2, 3)))
b_hat = (k2 * b1.reshape((1, -1, 1, 1))).sum((1, 2, 3))
else:
k_slices = []
b_slices = []
k1_T = k1.transpose((1, 0, 2, 3))
k1_group_width = k1.shape[0] // groups
k2_group_width = k2.shape[0] // groups
for g in range(groups):
k1_T_slice = k1_T[:, g * k1_group_width:(
g + 1) * k1_group_width, :, :]
k2_slice = k2[g * k2_group_width:(g + 1
) * k2_group_width, :, :, :]
k_slices.append(F.conv2d(k2_slice, k1_T_slice))
b_slices.append(
(k2_slice *
b1[g * k1_group_width:(g + 1) * k1_group_width].reshape(
(1, -1, 1, 1))).sum((1, 2, 3)))
k = paddle.concat(k_slices)
b_hat = paddle.concat(b_slices)
return k, b_hat + b2
def _fuse_avg(self, channels, kernel_size, groups):
input_dim = channels // groups
k = paddle.zeros((channels, input_dim, kernel_size, kernel_size))
k[np.arange(channels),
np.tile(np.arange(input_dim), groups), :, :] = 1.0 / kernel_size**2
return k
# This has not been tested with non-square kernels (kernel.size(2) != kernel.size(3)) nor even-size kernels
def _fuse_multiscale(self, kernel, target_kernel_size):
H_pixels_to_pad = (target_kernel_size - kernel.shape[2]) // 2
W_pixels_to_pad = (target_kernel_size - kernel.shape[3]) // 2
return F.pad(kernel, [
H_pixels_to_pad, H_pixels_to_pad, W_pixels_to_pad, W_pixels_to_pad
])
def _get_equivalent_kernel_bias(self):
k_origin, b_origin = self._fuse_bn(self.dbb_origin.conv.weight,
self.dbb_origin.bn)
if hasattr(self, 'dbb_1x1'):
k_1x1, b_1x1 = self._fuse_bn(self.dbb_1x1.conv.weight,
self.dbb_1x1.bn)
k_1x1 = self._fuse_multiscale(k_1x1, self.kernel_size)
else:
k_1x1, b_1x1 = 0, 0
if hasattr(self.dbb_1x1_kxk, 'idconv1'):
k_1x1_kxk_first = self.dbb_1x1_kxk.idconv1.get_actual_kernel()
else:
k_1x1_kxk_first = self.dbb_1x1_kxk.conv1.weight
k_1x1_kxk_first, b_1x1_kxk_first = self._fuse_bn(
k_1x1_kxk_first, self.dbb_1x1_kxk.bn1)
k_1x1_kxk_second, b_1x1_kxk_second = self._fuse_bn(
self.dbb_1x1_kxk.conv2.weight, self.dbb_1x1_kxk.bn2)
k_1x1_kxk_merged, b_1x1_kxk_merged = self._fuse_1x1_kxk(
k_1x1_kxk_first,
b_1x1_kxk_first,
k_1x1_kxk_second,
b_1x1_kxk_second,
groups=self.groups)
k_avg = self._fuse_avg(self.out_channels, self.kernel_size, self.groups)
k_1x1_avg_second, b_1x1_avg_second = self._fuse_bn(
k_avg, self.dbb_avg.avgbn)
if hasattr(self.dbb_avg, 'conv'):
k_1x1_avg_first, b_1x1_avg_first = self._fuse_bn(
self.dbb_avg.conv.weight, self.dbb_avg.bn)
k_1x1_avg_merged, b_1x1_avg_merged = self._fuse_1x1_kxk(
k_1x1_avg_first,
b_1x1_avg_first,
k_1x1_avg_second,
b_1x1_avg_second,
groups=self.groups)
else:
k_1x1_avg_merged, b_1x1_avg_merged = k_1x1_avg_second, b_1x1_avg_second
return sum([k_origin, k_1x1, k_1x1_kxk_merged, k_1x1_avg_merged]), sum(
[b_origin, b_1x1, b_1x1_kxk_merged, b_1x1_avg_merged])
def convert_to_deploy(self):
if hasattr(self, 'dbb_reparam'):
return
kernel, bias = self._get_equivalent_kernel_bias()
self.dbb_reparam = nn.Conv2D(
in_channels=self.in_channels,
out_channels=self.out_channels,
kernel_size=self.kernel_size,
stride=self.stride,
padding=self.padding,
groups=self.groups,
bias_attr=True)
self.dbb_reparam.weight.set_value(kernel)
self.dbb_reparam.bias.set_value(bias)
self.__delattr__('dbb_origin')
self.__delattr__('dbb_avg')
if hasattr(self, 'dbb_1x1'):
self.__delattr__('dbb_1x1')
self.__delattr__('dbb_1x1_kxk')
def forward(self, inputs):
if hasattr(self, 'dbb_reparam'):
return self.dbb_reparam(inputs)
out = self.dbb_origin(inputs)
if hasattr(self, 'dbb_1x1'):
out += self.dbb_1x1(inputs)
out += self.dbb_avg(inputs)
out += self.dbb_1x1_kxk(inputs)
return out
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import paddle
from paddle import ParamAttr
from paddle.regularizer import L2Decay
import paddle.nn as nn
from .base import BaseConv2DReper, ConvBNLayer
__all__ = ["RepVGGBlock"]
class RepVGGBlock(BaseConv2DReper):
"""
An instance of the RepVGGBlock module, which replaces the conv-bn layer in the network.
Refer from Paper: https://arxiv.org/abs/2101.03697.
"""
def __init__(self,
in_channels,
out_channels,
kernel_size,
stride=1,
groups=1,
padding=None):
super(RepVGGBlock, self).__init__(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
groups=groups,
padding=padding)
# Re-parameterizable skip connection
self.rbr_skip = nn.BatchNorm2D(
num_features=in_channels,
weight_attr=ParamAttr(regularizer=L2Decay(0.0)),
bias_attr=ParamAttr(regularizer=L2Decay(0.0))
) if in_channels == out_channels and self.stride == 1 else None
# Re-parameterizable conv branches
self.rbr_conv = ConvBNLayer(
self.in_channels,
self.out_channels,
self.kernel_size,
stride=self.stride,
groups=self.groups)
# Re-parameterizable scale branch
self.rbr_scale = None
if kernel_size > 1:
self.rbr_scale = ConvBNLayer(
self.in_channels,
self.out_channels,
1,
stride=self.stride,
groups=self.groups)
def forward(self, x):
# Inference mode forward pass.
if hasattr(self, "reparam_conv"):
return self.reparam_conv(x)
# Multi-branched train-time forward pass.
# Skip branch output
identity_out = 0
if self.rbr_skip is not None:
identity_out = self.rbr_skip(x)
# Scale branch output
scale_out = 0
if self.rbr_scale is not None:
scale_out = self.rbr_scale(x)
# Other branches
out = scale_out + identity_out
out += self.rbr_conv(x)
return out
def convert_to_deploy(self):
"""
Re-parameterize multi-branched architecture used at training
time to obtain a plain CNN-like structure for inference.
"""
if hasattr(self, 'reparam_conv'):
return
kernel, bias = self._get_kernel_bias()
self.reparam_conv = nn.Conv2D(
in_channels=self.in_channels,
out_channels=self.out_channels,
kernel_size=self.kernel_size,
stride=self.stride,
padding=(self.kernel_size - 1) // 2,
groups=self.groups)
self.reparam_conv.weight.set_value(kernel)
self.reparam_conv.bias.set_value(bias)
# Delete un-used branches
self.__delattr__('rbr_conv')
if hasattr(self, 'rbr_scale'):
self.__delattr__('rbr_scale')
if hasattr(self, 'rbr_skip'):
self.__delattr__('rbr_skip')
def _get_kernel_bias(self):
"""
Method to obtain re-parameterized kernel and bias.
"""
# get weights and bias of scale branch
kernel_scale = 0
bias_scale = 0
if self.rbr_scale is not None:
kernel_scale, bias_scale = self._fuse_bn_tensor(self.rbr_scale)
# Pad scale branch kernel to match conv branch kernel size. 1x1->3x3
padding_size = self.kernel_size // 2
kernel_scale = paddle.nn.functional.pad(kernel_scale, [
padding_size, padding_size, padding_size, padding_size
])
# get weights and bias of skip branch
kernel_identity = 0
bias_identity = 0
if self.rbr_skip is not None:
kernel_identity, bias_identity = self._fuse_bn_tensor(self.rbr_skip)
# get weights and bias of conv branches
kernel_conv, bias_conv = self._fuse_bn_tensor(self.rbr_conv)
kernel_final = kernel_conv + kernel_scale + kernel_identity
bias_final = bias_conv + bias_scale + bias_identity
return kernel_final, bias_final
def _fuse_bn_tensor(self, branch):
if branch is None:
return 0, 0
if isinstance(branch, nn.LayerList):
fused_kernels = []
fused_bias = []
for block in branch:
kernel = block.conv.weight
running_mean = block.bn._mean
running_var = block.bn._variance
gamma = block.bn.weight
beta = block.bn.bias
eps = block.bn._epsilon
std = (running_var + eps).sqrt()
t = (gamma / std).reshape((-1, 1, 1, 1))
fused_kernels.append(kernel * t)
fused_bias.append(beta - running_mean * gamma / std)
return sum(fused_kernels), sum(fused_bias)
elif isinstance(branch, ConvBNLayer):
kernel = branch.conv.weight
running_mean = branch.bn._mean
running_var = branch.bn._variance
gamma = branch.bn.weight
beta = branch.bn.bias
eps = branch.bn._epsilon
else:
assert isinstance(branch, nn.BatchNorm2D)
input_dim = self.in_channels if self.kernel_size == 1 else 1
kernel_value = paddle.zeros(
shape=[
self.in_channels, input_dim, self.kernel_size,
self.kernel_size
],
dtype='float32')
if self.kernel_size > 1:
for i in range(self.in_channels):
kernel_value[i, i % input_dim, (self.kernel_size - 1) // 2,
(self.kernel_size - 1) // 2] = 1
elif self.kernel_size == 1:
for i in range(self.in_channels):
kernel_value[i, i % input_dim, 0, 0] = 1
else:
raise ValueError("Invalid kernel size recieved!")
kernel = paddle.to_tensor(kernel_value, place=branch.weight.place)
running_mean = branch._mean
running_var = branch._variance
gamma = branch.weight
beta = branch.bias
eps = branch._epsilon
std = (running_var + eps).sqrt()
t = (gamma / std).reshape((-1, 1, 1, 1))
return kernel * t, beta - running_mean * gamma / std
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import paddle
from paddle import ParamAttr
from paddle.regularizer import L2Decay
import paddle.nn as nn
from .base import BaseConv2DReper, ConvBNLayer
__all__ = ["SlimRepBlock"]
class SlimRepBlock(BaseConv2DReper):
"""
An instance of the SlimRepBlock module, which replaces the conv-bn layer in the network.
"""
def __init__(self,
in_channels,
out_channels,
kernel_size,
stride=1,
groups=1,
padding=None):
super(SlimRepBlock, self).__init__(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
groups=groups,
padding=padding)
self.num_conv_branches = 1
if not self.padding:
self.padding = self.kernel_size // 2
if self.padding - self.kernel_size // 2 >= 0:
self.crop = 0
# Compared to the KxK layer, the padding of the 1xK layer and Kx1 layer should be adjust to align the sliding windows (Fig 2 in the paper)
hor_padding = [self.padding - self.kernel_size // 2, self.padding]
ver_padding = [self.padding, self.padding - self.kernel_size // 2]
else:
# A negative "padding" (padding - kernel_size//2 < 0, which is not a common use case) is cropping.
# Since nn.Conv2D does not support negative padding, we implement it manually
self.crop = self.kernel_size // 2 - self.padding
hor_padding = [0, self.padding]
ver_padding = [self.padding, 0]
# Re-parameterizable skip connection
self.rbr_skip = nn.BatchNorm2D(
num_features=in_channels,
weight_attr=ParamAttr(regularizer=L2Decay(0.0)),
bias_attr=ParamAttr(regularizer=L2Decay(0.0))
) if in_channels == out_channels and self.stride == 1 else None
# Re-parameterizable conv branches
self.rbr_conv = nn.LayerList()
for _ in range(self.num_conv_branches):
for kernel_size in range(self.kernel_size, 0, -2):
self.rbr_conv.append(
ConvBNLayer(
self.in_channels,
self.out_channels,
kernel_size,
stride=self.stride,
groups=self.groups))
# kx1 vertical branch
self.ver_branch = ConvBNLayer(
self.in_channels,
self.out_channels, (self.kernel_size, 1),
self.stride,
groups=self.groups,
padding=ver_padding)
# 1xk horizontal branch
self.hor_branch = ConvBNLayer(
self.in_channels,
self.out_channels, (1, self.kernel_size),
self.stride,
groups=self.groups,
padding=hor_padding)
def forward(self, x):
# Inference mode forward pass.
if hasattr(self, "reparam_conv"):
return self.reparam_conv(x)
# Multi-branched train-time forward pass.
out = 0
for rbr_conv in self.rbr_conv:
out += rbr_conv(x)
# Skip branch output
if self.rbr_skip is not None:
out += self.rbr_skip(x)
if self.crop > 0:
ver_input = x[:, :, :, self.crop:-self.crop]
hor_input = x[:, :, self.crop:-self.crop, :]
else:
ver_input = x
hor_input = x
out += self.ver_branch(ver_input)
out += self.hor_branch(hor_input)
return out
def convert_to_deploy(self):
"""
Re-parameterize multi-branched architecture used at training
time to obtain a plain CNN-like structure for inference.
"""
if hasattr(self, 'reparam_conv'):
return
kernel, bias = self._get_kernel_bias()
self.reparam_conv = nn.Conv2D(
in_channels=self.in_channels,
out_channels=self.out_channels,
kernel_size=self.kernel_size,
stride=self.stride,
padding=(self.kernel_size - 1) // 2,
groups=self.groups)
self.reparam_conv.weight.set_value(kernel)
self.reparam_conv.bias.set_value(bias)
# Delete un-used branches
self.__delattr__('rbr_conv')
if hasattr(self, 'rbr_skip'):
self.__delattr__('rbr_skip')
self.__delattr__('ver_branch')
self.__delattr__('hor_branch')
def _get_kernel_bias(self):
"""
Method to obtain re-parameterized kernel and bias.
"""
# get weights and bias of conv branches
kernel_conv = 0
bias_conv = 0
for ix in range(self.num_conv_branches):
_kernel, _bias = self._fuse_bn_tensor(self.rbr_conv[ix])
_kernel = self._pad_tensor(_kernel, to_size=self.kernel_size)
kernel_conv += _kernel
bias_conv += _bias
# get weights and bias of skip branch
kernel_identity = 0
bias_identity = 0
if self.rbr_skip is not None:
kernel_identity, bias_identity = self._fuse_bn_tensor(self.rbr_skip)
kernel_final = kernel_conv + kernel_identity
bias_final = bias_conv + bias_identity
# get kx1 1xk branch
hor_k, hor_b = self._fuse_bn_tensor(self.hor_branch)
ver_k, ver_b = self._fuse_bn_tensor(self.ver_branch)
self._add_to_square_kernel(kernel_final, hor_k)
self._add_to_square_kernel(kernel_final, ver_k)
bias_final += hor_b + ver_b
return kernel_final, bias_final
def _add_to_square_kernel(self, square_kernel, asym_kernel):
asym_h = asym_kernel.shape[2]
asym_w = asym_kernel.shape[3]
square_h = square_kernel.shape[2]
square_w = square_kernel.shape[3]
square_kernel[:, :, square_h // 2 - asym_h // 2:square_h // 2 -
asym_h // 2 + asym_h, square_w // 2 - asym_w // 2:
square_w // 2 - asym_w // 2 + asym_w] += asym_kernel
def _pad_tensor(self, tensor, to_size):
from_size = tensor.shape[-1]
if from_size == to_size:
return tensor
pad = (to_size - from_size) // 2
return paddle.nn.functional.pad(tensor, [pad, pad, pad, pad])
def _fuse_bn_tensor(self, branch):
if branch is None:
return 0, 0
if isinstance(branch, nn.LayerList):
fused_kernels = []
fused_bias = []
for block in branch:
kernel = block.conv.weight
running_mean = block.bn._mean
running_var = block.bn._variance
gamma = block.bn.weight
beta = block.bn.bias
eps = block.bn._epsilon
std = (running_var + eps).sqrt()
t = (gamma / std).reshape((-1, 1, 1, 1))
fused_kernels.append(kernel * t)
fused_bias.append(beta - running_mean * gamma / std)
return sum(fused_kernels), sum(fused_bias)
elif isinstance(branch, ConvBNLayer):
kernel = branch.conv.weight
running_mean = branch.bn._mean
running_var = branch.bn._variance
gamma = branch.bn.weight
beta = branch.bn.bias
eps = branch.bn._epsilon
else:
assert isinstance(branch, nn.BatchNorm2D)
input_dim = self.in_channels if self.kernel_size == 1 else 1
kernel_value = paddle.zeros(
shape=[
self.in_channels, input_dim, self.kernel_size,
self.kernel_size
],
dtype='float32')
if self.kernel_size > 1:
for i in range(self.in_channels):
kernel_value[i, i % input_dim, (self.kernel_size - 1) // 2,
(self.kernel_size - 1) // 2] = 1
elif self.kernel_size == 1:
for i in range(self.in_channels):
kernel_value[i, i % input_dim, 0, 0] = 1
else:
raise ValueError("Invalid kernel size recieved!")
kernel = paddle.to_tensor(kernel_value, place=branch.weight.place)
running_mean = branch._mean
running_var = branch._variance
gamma = branch.weight
beta = branch.bias
eps = branch._epsilon
std = (running_var + eps).sqrt()
t = (gamma / std).reshape((-1, 1, 1, 1))
return kernel * t, beta - running_mean * gamma / std
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import sys
sys.path.append("../../")
import unittest
import logging
import paddle
from paddleslim.common import get_logger
from paddleslim.dygraph.rep import Reparameter, DBBRepConfig, ACBRepConfig
_logger = get_logger(__name__, level=logging.INFO)
class ImperativeLenet(paddle.nn.Layer):
def __init__(self, num_classes=10):
super(ImperativeLenet, self).__init__()
self.features = paddle.nn.Sequential(
paddle.nn.Conv2D(
in_channels=1,
out_channels=6,
kernel_size=3,
stride=1,
padding=1,
bias_attr=False),
paddle.nn.BatchNorm2D(6),
paddle.nn.ReLU(),
paddle.nn.MaxPool2D(kernel_size=2, stride=2),
paddle.nn.Conv2D(
in_channels=6,
out_channels=16,
kernel_size=5,
stride=1,
padding=2,
bias_attr=False),
paddle.nn.BatchNorm2D(16),
paddle.nn.PReLU(), paddle.nn.MaxPool2D(kernel_size=2, stride=2))
self.fc = paddle.nn.Sequential(
paddle.nn.Linear(in_features=784, out_features=120),
paddle.nn.LeakyReLU(),
paddle.nn.Linear(in_features=120, out_features=84),
paddle.nn.Sigmoid(),
paddle.nn.Linear(in_features=84, out_features=num_classes),
paddle.nn.Softmax())
def forward(self, inputs):
x = self.features(inputs)
x = paddle.flatten(x, 1)
x = self.fc(x)
return x
class TestRep(unittest.TestCase):
"""
Test dygraph reparameterization.
"""
def model_test(self, model, test_reader):
model.eval()
avg_acc = [[], []]
for batch_id, data in enumerate(test_reader):
img = paddle.to_tensor(data[0])
img = paddle.reshape(img, [-1, 1, 28, 28])
label = paddle.to_tensor(data[1])
label = paddle.reshape(label, [-1, 1])
out = model(img)
acc_top1 = paddle.metric.accuracy(input=out, label=label, k=1)
acc_top5 = paddle.metric.accuracy(input=out, label=label, k=5)
avg_acc[0].append(acc_top1.numpy())
avg_acc[1].append(acc_top5.numpy())
if batch_id % 100 == 0:
_logger.info("Test | step {}: acc1 = {:}, acc5 = {:}".format(
batch_id, acc_top1.numpy(), acc_top5.numpy()))
_logger.info("Test |Average: acc_top1 {}, acc_top5 {}".format(
np.mean(avg_acc[0]), np.mean(avg_acc[1])))
return np.mean(avg_acc[0]), np.mean(avg_acc[1])
def model_train(self, model, train_reader):
adam = paddle.optimizer.Adam(
learning_rate=0.0001, parameters=model.parameters())
epoch_num = 1
for epoch in range(epoch_num):
model.train()
for batch_id, data in enumerate(train_reader):
img = paddle.to_tensor(data[0])
label = paddle.to_tensor(data[1])
img = paddle.reshape(img, [-1, 1, 28, 28])
label = paddle.reshape(label, [-1, 1])
out = model(img)
acc = paddle.metric.accuracy(out, label)
loss = paddle.nn.functional.loss.cross_entropy(out, label)
avg_loss = paddle.mean(loss)
avg_loss.backward()
adam.minimize(avg_loss)
model.clear_gradients()
if batch_id % 100 == 0:
_logger.info(
"Train | At epoch {} step {}: loss = {:}, acc= {:}".
format(epoch, batch_id, avg_loss.numpy(), acc.numpy()))
def test_dbb(self):
seed = 1
np.random.seed(seed)
paddle.static.default_main_program().random_seed = seed
paddle.static.default_startup_program().random_seed = seed
_logger.info("create the fp32 model")
fp32_lenet = ImperativeLenet()
_logger.info("prepare data")
batch_size = 64
transform = paddle.vision.transforms.Compose([
paddle.vision.transforms.Transpose(),
paddle.vision.transforms.Normalize([127.5], [127.5])
])
train_dataset = paddle.vision.datasets.MNIST(
mode='train', backend='cv2', transform=transform)
val_dataset = paddle.vision.datasets.MNIST(
mode='test', backend='cv2', transform=transform)
place = paddle.CUDAPlace(0) \
if paddle.is_compiled_with_cuda() else paddle.CPUPlace()
train_reader = paddle.io.DataLoader(
train_dataset,
drop_last=True,
places=place,
batch_size=batch_size,
return_list=True)
test_reader = paddle.io.DataLoader(
val_dataset, places=place, batch_size=batch_size, return_list=True)
_logger.info("train the fp32 model")
self.model_train(fp32_lenet, train_reader)
_logger.info("test fp32 model")
fp32_top1, fp32_top5 = self.model_test(fp32_lenet, test_reader)
rep_config = DBBRepConfig()
reper = Reparameter(rep_config)
reper.prepare(fp32_lenet)
_logger.info("train the DBB reparameterization model")
self.model_train(fp32_lenet, train_reader)
rep_top1, rep_top5 = self.model_test(fp32_lenet, test_reader)
_logger.info("save and test the DBB reparameterization model")
reper.convert(fp32_lenet)
save_path = "./tmp/model"
input_spec = paddle.static.InputSpec(
shape=[None, 1, 28, 28], dtype='float32')
paddle.jit.save(fp32_lenet, save_path, input_spec=[input_spec])
_logger.info(
"FP32 acc: top1: {}, top5: {}".format(fp32_top1, fp32_top5))
_logger.info("Int acc: top1: {}, top5: {}".format(rep_top1, rep_top5))
diff = 0.005
self.assertTrue(
fp32_top1 - rep_top1 < diff,
msg="The acc of rep model is too lower than fp32 model")
def test_acb(self):
seed = 1
np.random.seed(seed)
paddle.static.default_main_program().random_seed = seed
paddle.static.default_startup_program().random_seed = seed
_logger.info("create the fp32 model")
fp32_lenet = ImperativeLenet()
_logger.info("prepare data")
batch_size = 64
transform = paddle.vision.transforms.Compose([
paddle.vision.transforms.Transpose(),
paddle.vision.transforms.Normalize([127.5], [127.5])
])
train_dataset = paddle.vision.datasets.MNIST(
mode='train', backend='cv2', transform=transform)
val_dataset = paddle.vision.datasets.MNIST(
mode='test', backend='cv2', transform=transform)
place = paddle.CUDAPlace(0) \
if paddle.is_compiled_with_cuda() else paddle.CPUPlace()
train_reader = paddle.io.DataLoader(
train_dataset,
drop_last=True,
places=place,
batch_size=batch_size,
return_list=True)
test_reader = paddle.io.DataLoader(
val_dataset, places=place, batch_size=batch_size, return_list=True)
_logger.info("train the fp32 model")
self.model_train(fp32_lenet, train_reader)
_logger.info("test fp32 model")
fp32_top1, fp32_top5 = self.model_test(fp32_lenet, test_reader)
rep_config = ACBRepConfig()
reper = Reparameter(rep_config)
reper.prepare(fp32_lenet)
_logger.info("train the ACB reparameterization model")
self.model_train(fp32_lenet, train_reader)
rep_top1, rep_top5 = self.model_test(fp32_lenet, test_reader)
_logger.info("save and test the ACB reparameterization model")
reper.convert(fp32_lenet)
save_path = "./tmp/model"
input_spec = paddle.static.InputSpec(
shape=[None, 1, 28, 28], dtype='float32')
paddle.jit.save(fp32_lenet, save_path, input_spec=[input_spec])
_logger.info(
"FP32 acc: top1: {}, top5: {}".format(fp32_top1, fp32_top5))
_logger.info("Int acc: top1: {}, top5: {}".format(rep_top1, rep_top5))
diff = 0.005
self.assertTrue(
fp32_top1 - rep_top1 < diff,
msg="The acc of rep model is too lower than fp32 model")
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册