未验证 提交 6c59641e 编写于 作者: S shangliang Xu 提交者: GitHub

[dev] add amp eval (#6400)

cast dtype in load_pretrain_weight
上级 a04d0d22
...@@ -65,6 +65,8 @@ class Trainer(object): ...@@ -65,6 +65,8 @@ class Trainer(object):
self.mode = mode.lower() self.mode = mode.lower()
self.optimizer = None self.optimizer = None
self.is_loaded_weights = False self.is_loaded_weights = False
self.use_amp = self.cfg.get('amp', False)
self.amp_level = self.cfg.get('amp_level', 'O1')
# build data loader # build data loader
capital_mode = self.mode.capitalize() capital_mode = self.mode.capitalize()
...@@ -124,17 +126,6 @@ class Trainer(object): ...@@ -124,17 +126,6 @@ class Trainer(object):
else: else:
self.model.load_meanstd(cfg['TestReader']['sample_transforms']) self.model.load_meanstd(cfg['TestReader']['sample_transforms'])
self.use_ema = ('use_ema' in cfg and cfg['use_ema'])
if self.use_ema:
ema_decay = self.cfg.get('ema_decay', 0.9998)
cycle_epoch = self.cfg.get('cycle_epoch', -1)
ema_decay_type = self.cfg.get('ema_decay_type', 'threshold')
self.ema = ModelEMA(
self.model,
decay=ema_decay,
ema_decay_type=ema_decay_type,
cycle_epoch=cycle_epoch)
# EvalDataset build with BatchSampler to evaluate in single device # EvalDataset build with BatchSampler to evaluate in single device
# TODO: multi-device evaluate # TODO: multi-device evaluate
if self.mode == 'eval': if self.mode == 'eval':
...@@ -162,6 +153,20 @@ class Trainer(object): ...@@ -162,6 +153,20 @@ class Trainer(object):
self.pruner = create('UnstructuredPruner')(self.model, self.pruner = create('UnstructuredPruner')(self.model,
steps_per_epoch) steps_per_epoch)
if self.use_amp and self.amp_level == 'O2':
self.model = paddle.amp.decorate(
models=self.model, level=self.amp_level)
self.use_ema = ('use_ema' in cfg and cfg['use_ema'])
if self.use_ema:
ema_decay = self.cfg.get('ema_decay', 0.9998)
cycle_epoch = self.cfg.get('cycle_epoch', -1)
ema_decay_type = self.cfg.get('ema_decay_type', 'threshold')
self.ema = ModelEMA(
self.model,
decay=ema_decay,
ema_decay_type=ema_decay_type,
cycle_epoch=cycle_epoch)
self._nranks = dist.get_world_size() self._nranks = dist.get_world_size()
self._local_rank = dist.get_rank() self._local_rank = dist.get_rank()
...@@ -387,13 +392,10 @@ class Trainer(object): ...@@ -387,13 +392,10 @@ class Trainer(object):
model = paddle.nn.SyncBatchNorm.convert_sync_batchnorm(model) model = paddle.nn.SyncBatchNorm.convert_sync_batchnorm(model)
# enabel auto mixed precision mode # enabel auto mixed precision mode
use_amp = self.cfg.get('amp', False) if self.use_amp:
amp_level = self.cfg.get('amp_level', 'O1')
if use_amp:
scaler = paddle.amp.GradScaler( scaler = paddle.amp.GradScaler(
enable=self.cfg.use_gpu or self.cfg.use_npu, enable=self.cfg.use_gpu or self.cfg.use_npu,
init_loss_scaling=self.cfg.get('init_loss_scaling', 1024)) init_loss_scaling=self.cfg.get('init_loss_scaling', 1024))
model = paddle.amp.decorate(models=model, level=amp_level)
# get distributed model # get distributed model
if self.cfg.get('fleet', False): if self.cfg.get('fleet', False):
model = fleet.distributed_model(model) model = fleet.distributed_model(model)
...@@ -438,9 +440,9 @@ class Trainer(object): ...@@ -438,9 +440,9 @@ class Trainer(object):
self._compose_callback.on_step_begin(self.status) self._compose_callback.on_step_begin(self.status)
data['epoch_id'] = epoch_id data['epoch_id'] = epoch_id
if use_amp: if self.use_amp:
with paddle.amp.auto_cast( with paddle.amp.auto_cast(
enable=self.cfg.use_gpu, level=amp_level): enable=self.cfg.use_gpu, level=self.amp_level):
# model forward # model forward
outputs = model(data) outputs = model(data)
loss = outputs['loss'] loss = outputs['loss']
...@@ -532,7 +534,12 @@ class Trainer(object): ...@@ -532,7 +534,12 @@ class Trainer(object):
self.status['step_id'] = step_id self.status['step_id'] = step_id
self._compose_callback.on_step_begin(self.status) self._compose_callback.on_step_begin(self.status)
# forward # forward
outs = self.model(data) if self.use_amp:
with paddle.amp.auto_cast(
enable=self.cfg.use_gpu, level=self.amp_level):
outs = self.model(data)
else:
outs = self.model(data)
# update metrics # update metrics
for metric in self._metrics: for metric in self._metrics:
......
...@@ -176,7 +176,8 @@ def compute_max_iou_gt(ious): ...@@ -176,7 +176,8 @@ def compute_max_iou_gt(ious):
def generate_anchors_for_grid_cell(feats, def generate_anchors_for_grid_cell(feats,
fpn_strides, fpn_strides,
grid_cell_size=5.0, grid_cell_size=5.0,
grid_cell_offset=0.5): grid_cell_offset=0.5,
dtype='float32'):
r""" r"""
Like ATSS, generate anchors based on grid size. Like ATSS, generate anchors based on grid size.
Args: Args:
...@@ -206,16 +207,15 @@ def generate_anchors_for_grid_cell(feats, ...@@ -206,16 +207,15 @@ def generate_anchors_for_grid_cell(feats,
shift_x - cell_half_size, shift_y - cell_half_size, shift_x - cell_half_size, shift_y - cell_half_size,
shift_x + cell_half_size, shift_y + cell_half_size shift_x + cell_half_size, shift_y + cell_half_size
], ],
axis=-1).astype(feat.dtype) axis=-1).astype(dtype)
anchor_point = paddle.stack( anchor_point = paddle.stack([shift_x, shift_y], axis=-1).astype(dtype)
[shift_x, shift_y], axis=-1).astype(feat.dtype)
anchors.append(anchor.reshape([-1, 4])) anchors.append(anchor.reshape([-1, 4]))
anchor_points.append(anchor_point.reshape([-1, 2])) anchor_points.append(anchor_point.reshape([-1, 2]))
num_anchors_list.append(len(anchors[-1])) num_anchors_list.append(len(anchors[-1]))
stride_tensor.append( stride_tensor.append(
paddle.full( paddle.full(
[num_anchors_list[-1], 1], stride, dtype=feat.dtype)) [num_anchors_list[-1], 1], stride, dtype=dtype))
anchors = paddle.concat(anchors) anchors = paddle.concat(anchors)
anchors.stop_gradient = True anchors.stop_gradient = True
anchor_points = paddle.concat(anchor_points) anchor_points = paddle.concat(anchor_points)
......
...@@ -160,7 +160,7 @@ class PPYOLOEHead(nn.Layer): ...@@ -160,7 +160,7 @@ class PPYOLOEHead(nn.Layer):
num_anchors_list, stride_tensor num_anchors_list, stride_tensor
], targets) ], targets)
def _generate_anchors(self, feats=None): def _generate_anchors(self, feats=None, dtype='float32'):
# just use in eval time # just use in eval time
anchor_points = [] anchor_points = []
stride_tensor = [] stride_tensor = []
...@@ -175,11 +175,9 @@ class PPYOLOEHead(nn.Layer): ...@@ -175,11 +175,9 @@ class PPYOLOEHead(nn.Layer):
shift_y, shift_x = paddle.meshgrid(shift_y, shift_x) shift_y, shift_x = paddle.meshgrid(shift_y, shift_x)
anchor_point = paddle.cast( anchor_point = paddle.cast(
paddle.stack( paddle.stack(
[shift_x, shift_y], axis=-1), dtype='float32') [shift_x, shift_y], axis=-1), dtype=dtype)
anchor_points.append(anchor_point.reshape([-1, 2])) anchor_points.append(anchor_point.reshape([-1, 2]))
stride_tensor.append( stride_tensor.append(paddle.full([h * w, 1], stride, dtype=dtype))
paddle.full(
[h * w, 1], stride, dtype='float32'))
anchor_points = paddle.concat(anchor_points) anchor_points = paddle.concat(anchor_points)
stride_tensor = paddle.concat(stride_tensor) stride_tensor = paddle.concat(stride_tensor)
return anchor_points, stride_tensor return anchor_points, stride_tensor
......
...@@ -77,6 +77,12 @@ def parse_args(): ...@@ -77,6 +77,12 @@ def parse_args():
default=False, default=False,
help='Whether to save the evaluation results only') help='Whether to save the evaluation results only')
parser.add_argument(
"--amp",
action='store_true',
default=False,
help="Enable auto mixed precision eval.")
args = parser.parse_args() args = parser.parse_args()
return args return args
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册