diff --git a/ppdet/engine/trainer.py b/ppdet/engine/trainer.py index 3e83ca6cd6ab58c108605daa7f8f14858d27a6d0..98da6c47772c4edff38224afbc6a2faea0e09f7c 100644 --- a/ppdet/engine/trainer.py +++ b/ppdet/engine/trainer.py @@ -187,12 +187,14 @@ class Trainer(object): ema_decay_type = self.cfg.get('ema_decay_type', 'threshold') cycle_epoch = self.cfg.get('cycle_epoch', -1) ema_black_list = self.cfg.get('ema_black_list', None) + ema_filter_no_grad = self.cfg.get('ema_filter_no_grad', False) self.ema = ModelEMA( self.model, decay=ema_decay, ema_decay_type=ema_decay_type, cycle_epoch=cycle_epoch, - ema_black_list=ema_black_list) + ema_black_list=ema_black_list, + ema_filter_no_grad=ema_filter_no_grad) self._nranks = dist.get_world_size() self._local_rank = dist.get_rank() @@ -1040,6 +1042,7 @@ class Trainer(object): start = end return results + def _get_save_image_name(self, output_dir, image_path): """ Get save image name from source image path. diff --git a/ppdet/optimizer/ema.py b/ppdet/optimizer/ema.py index 2fade4dcf63e96b551ee1c06c5ccb338a659d3fa..9cd9dca637998f4701bfe77fe317240fee26fe71 100644 --- a/ppdet/optimizer/ema.py +++ b/ppdet/optimizer/ema.py @@ -21,6 +21,8 @@ import paddle import weakref from copy import deepcopy +from .utils import get_bn_running_state_names + __all__ = ['ModelEMA', 'SimpleModelEMA'] @@ -49,7 +51,8 @@ class ModelEMA(object): decay=0.9998, ema_decay_type='threshold', cycle_epoch=-1, - ema_black_list=None): + ema_black_list=None, + ema_filter_no_grad=False): self.step = 0 self.epoch = 0 self.decay = decay @@ -64,6 +67,12 @@ class ModelEMA(object): else: self.state_dict[k] = paddle.zeros_like(v) + bn_states_names = get_bn_running_state_names(model) + if ema_filter_no_grad: + for n, p in model.named_parameters(): + if p.stop_gradient == True and n not in bn_states_names: + self.ema_black_list.append(n) + self._model_state = { k: weakref.ref(p) for k, p in model.state_dict().items() diff --git a/ppdet/optimizer/utils.py b/ppdet/optimizer/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..ce2de49bf5973ee0b69a9ecc62028cca67f4d1e0 --- /dev/null +++ b/ppdet/optimizer/utils.py @@ -0,0 +1,37 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import paddle +import paddle.nn as nn + +from typing import List + + +def get_bn_running_state_names(model: nn.Layer) -> List[str]: + """Get all bn state full names including running mean and variance + """ + names = [] + for n, m in model.named_sublayers(): + if isinstance(m, (nn.BatchNorm2D, nn.SyncBatchNorm)): + assert hasattr(m, '_mean'), f'assert {m} has _mean' + assert hasattr(m, '_variance'), f'assert {m} has _variance' + running_mean = f'{n}._mean' + running_var = f'{n}._variance' + names.extend([running_mean, running_var]) + + return names