未验证 提交 0910e988 编写于 作者: P pk_hk 提交者: GitHub

add use_checkpoint and use_alpha for cspresnet (#6428)

上级 63dc4c4a
_BASE_: [
'../datasets/visdrone_detection.yml',
'../runtime.yml',
'../ppyoloe/_base_/optimizer_300e.yml',
'../ppyoloe/_base_/ppyoloe_crn.yml',
'../ppyoloe/_base_/ppyoloe_reader.yml',
]
log_iter: 100
snapshot_epoch: 10
weights: output/ppyoloe_crn_s_80e_visdrone_use_checkpoint/model_final
pretrain_weights: https://paddledet.bj.bcebos.com/models/ppyoloe_crn_s_300e_coco.pdparams
depth_mult: 0.33
width_mult: 0.50
TrainReader:
batch_size: 8
epoch: 80
LearningRate:
base_lr: 0.01
schedulers:
- !CosineDecay
max_epochs: 96
- !LinearWarmup
start_factor: 0.
epochs: 1
CSPResNet:
use_checkpoint: True
use_alpha: True
# when use_checkpoint
use_fused_allreduce_gradients: True
PPYOLOEHead:
static_assigner_epoch: -1
nms:
name: MultiClassNMS
nms_top_k: 10000
keep_top_k: 500
score_threshold: 0.01
nms_threshold: 0.6
......@@ -49,6 +49,8 @@ from ppdet.utils import profiler
from .callbacks import Callback, ComposeCallback, LogPrinter, Checkpointer, WiferFaceEval, VisualDLWriter, SniperProposalsGenerator, WandbCallback
from .export_utils import _dump_infer_config, _prune_input_spec
from paddle.distributed.fleet.utils.hybrid_parallel_util import fused_allreduce_gradients
from ppdet.utils.logger import setup_logger
logger = setup_logger('ppdet.engine')
......@@ -152,7 +154,6 @@ class Trainer(object):
if self.cfg.get('unstructured_prune'):
self.pruner = create('UnstructuredPruner')(self.model,
steps_per_epoch)
if self.use_amp and self.amp_level == 'O2':
self.model = paddle.amp.decorate(
models=self.model, level=self.amp_level)
......@@ -426,6 +427,9 @@ class Trainer(object):
self._compose_callback.on_train_begin(self.status)
use_fused_allreduce_gradients = self.cfg[
'use_fused_allreduce_gradients'] if 'use_fused_allreduce_gradients' in self.cfg else False
for epoch_id in range(self.start_epoch, self.cfg.epoch):
self.status['mode'] = 'train'
self.status['epoch_id'] = epoch_id
......@@ -441,22 +445,51 @@ class Trainer(object):
data['epoch_id'] = epoch_id
if self.use_amp:
with paddle.amp.auto_cast(
enable=self.cfg.use_gpu, level=self.amp_level):
# model forward
outputs = model(data)
loss = outputs['loss']
# model backward
scaled_loss = scaler.scale(loss)
scaled_loss.backward()
if isinstance(
model, paddle.
DataParallel) and use_fused_allreduce_gradients:
with model.no_sync():
with amp.auto_cast(
enable=self.cfg.use_gpus,
level=self.amp_level):
# model forward
outputs = model(data)
loss = outputs['loss']
# model backward
scaled_loss = scaler.scale(loss)
scaled_loss.backward()
fused_allreduce_gradients(
list(model.parameters()), None)
else:
with amp.auto_cast(
enable=self.cfg.use_gpu, level=self.amp_level):
# model forward
outputs = model(data)
loss = outputs['loss']
# model backward
scaled_loss = scaler.scale(loss)
scaled_loss.backward()
# in dygraph mode, optimizer.minimize is equal to optimizer.step
scaler.minimize(self.optimizer, scaled_loss)
else:
# model forward
outputs = model(data)
loss = outputs['loss']
# model backward
loss.backward()
if isinstance(
model, paddle.
DataParallel) and use_fused_allreduce_gradients:
with model.no_sync():
# model forward
outputs = model(data)
loss = outputs['loss']
# model backward
loss.backward()
fused_allreduce_gradients(
list(model.parameters()), None)
else:
# model forward
outputs = model(data)
loss = outputs['loss']
# model backward
loss.backward()
self.optimizer.step()
curr_lr = self.optimizer.get_lr()
self.lr.step()
......
......@@ -21,6 +21,7 @@ import paddle.nn as nn
import paddle.nn.functional as F
from paddle import ParamAttr
from paddle.regularizer import L2Decay
from paddle.nn.initializer import Constant
from ppdet.modeling.ops import get_act_fn
from ppdet.core.workspace import register, serializable
......@@ -65,7 +66,7 @@ class ConvBNLayer(nn.Layer):
class RepVggBlock(nn.Layer):
def __init__(self, ch_in, ch_out, act='relu'):
def __init__(self, ch_in, ch_out, act='relu', alpha=False):
super(RepVggBlock, self).__init__()
self.ch_in = ch_in
self.ch_out = ch_out
......@@ -75,12 +76,22 @@ class RepVggBlock(nn.Layer):
ch_in, ch_out, 1, stride=1, padding=0, act=None)
self.act = get_act_fn(act) if act is None or isinstance(act, (
str, dict)) else act
if alpha:
self.alpha = self.create_parameter(
shape=[1],
attr=ParamAttr(initializer=Constant(value=1.)),
dtype="float32")
else:
self.alpha = None
def forward(self, x):
if hasattr(self, 'conv'):
y = self.conv(x)
else:
y = self.conv1(x) + self.conv2(x)
if self.alpha:
y = self.conv1(x) + self.alpha * self.conv2(x)
else:
y = self.conv1(x) + self.conv2(x)
y = self.act(y)
return y
......@@ -102,8 +113,12 @@ class RepVggBlock(nn.Layer):
def get_equivalent_kernel_bias(self):
kernel3x3, bias3x3 = self._fuse_bn_tensor(self.conv1)
kernel1x1, bias1x1 = self._fuse_bn_tensor(self.conv2)
return kernel3x3 + self._pad_1x1_to_3x3_tensor(
kernel1x1), bias3x3 + bias1x1
if self.alpha:
return kernel3x3 + self.alpha * self._pad_1x1_to_3x3_tensor(
kernel1x1), bias3x3 + self.alpha * bias1x1
else:
return kernel3x3 + self._pad_1x1_to_3x3_tensor(
kernel1x1), bias3x3 + bias1x1
def _pad_1x1_to_3x3_tensor(self, kernel1x1):
if kernel1x1 is None:
......@@ -126,11 +141,16 @@ class RepVggBlock(nn.Layer):
class BasicBlock(nn.Layer):
def __init__(self, ch_in, ch_out, act='relu', shortcut=True):
def __init__(self,
ch_in,
ch_out,
act='relu',
shortcut=True,
use_alpha=False):
super(BasicBlock, self).__init__()
assert ch_in == ch_out
self.conv1 = ConvBNLayer(ch_in, ch_out, 3, stride=1, padding=1, act=act)
self.conv2 = RepVggBlock(ch_out, ch_out, act=act)
self.conv2 = RepVggBlock(ch_out, ch_out, act=act, alpha=use_alpha)
self.shortcut = shortcut
def forward(self, x):
......@@ -167,7 +187,8 @@ class CSPResStage(nn.Layer):
n,
stride,
act='relu',
attn='eca'):
attn='eca',
use_alpha=False):
super(CSPResStage, self).__init__()
ch_mid = (ch_in + ch_out) // 2
......@@ -180,8 +201,11 @@ class CSPResStage(nn.Layer):
self.conv2 = ConvBNLayer(ch_mid, ch_mid // 2, 1, act=act)
self.blocks = nn.Sequential(*[
block_fn(
ch_mid // 2, ch_mid // 2, act=act, shortcut=True)
for i in range(n)
ch_mid // 2,
ch_mid // 2,
act=act,
shortcut=True,
use_alpha=use_alpha) for i in range(n)
])
if attn:
self.attn = EffectiveSELayer(ch_mid, act='hardsigmoid')
......@@ -216,8 +240,12 @@ class CSPResNet(nn.Layer):
use_large_stem=False,
width_mult=1.0,
depth_mult=1.0,
trt=False):
trt=False,
use_checkpoint=False,
use_alpha=False,
**args):
super(CSPResNet, self).__init__()
self.use_checkpoint = use_checkpoint
channels = [max(round(c * width_mult), 1) for c in channels]
layers = [max(round(l * depth_mult), 1) for l in layers]
act = get_act_fn(
......@@ -255,19 +283,30 @@ class CSPResNet(nn.Layer):
n = len(channels) - 1
self.stages = nn.Sequential(*[(str(i), CSPResStage(
BasicBlock, channels[i], channels[i + 1], layers[i], 2, act=act))
for i in range(n)])
BasicBlock,
channels[i],
channels[i + 1],
layers[i],
2,
act=act,
use_alpha=use_alpha)) for i in range(n)])
self._out_channels = channels[1:]
self._out_strides = [4, 8, 16, 32]
self._out_strides = [4 * 2**i for i in range(n)]
self.return_idx = return_idx
if use_checkpoint:
paddle.seed(0)
def forward(self, inputs):
x = inputs['image']
x = self.stem(x)
outs = []
for idx, stage in enumerate(self.stages):
x = stage(x)
if self.use_checkpoint and self.training:
x = paddle.distributed.fleet.utils.recompute(
stage, x, **{"preserve_rng_state": True})
else:
x = stage(x)
if idx in self.return_idx:
outs.append(x)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册