From a65904d8c0110b2ff8abe12b5b0f3e1c277b901a Mon Sep 17 00:00:00 2001 From: Wenyu Date: Wed, 8 Feb 2023 15:57:31 +0800 Subject: [PATCH] add ema filter no_grad (#7691) * add ema filter no_grad * update file name and default value --- ppdet/engine/trainer.py | 5 ++++- ppdet/optimizer/ema.py | 11 ++++++++++- ppdet/optimizer/utils.py | 37 +++++++++++++++++++++++++++++++++++++ 3 files changed, 51 insertions(+), 2 deletions(-) create mode 100644 ppdet/optimizer/utils.py diff --git a/ppdet/engine/trainer.py b/ppdet/engine/trainer.py index 3e83ca6cd..98da6c477 100644 --- a/ppdet/engine/trainer.py +++ b/ppdet/engine/trainer.py @@ -187,12 +187,14 @@ class Trainer(object): ema_decay_type = self.cfg.get('ema_decay_type', 'threshold') cycle_epoch = self.cfg.get('cycle_epoch', -1) ema_black_list = self.cfg.get('ema_black_list', None) + ema_filter_no_grad = self.cfg.get('ema_filter_no_grad', False) self.ema = ModelEMA( self.model, decay=ema_decay, ema_decay_type=ema_decay_type, cycle_epoch=cycle_epoch, - ema_black_list=ema_black_list) + ema_black_list=ema_black_list, + ema_filter_no_grad=ema_filter_no_grad) self._nranks = dist.get_world_size() self._local_rank = dist.get_rank() @@ -1040,6 +1042,7 @@ class Trainer(object): start = end return results + def _get_save_image_name(self, output_dir, image_path): """ Get save image name from source image path. diff --git a/ppdet/optimizer/ema.py b/ppdet/optimizer/ema.py index 2fade4dcf..9cd9dca63 100644 --- a/ppdet/optimizer/ema.py +++ b/ppdet/optimizer/ema.py @@ -21,6 +21,8 @@ import paddle import weakref from copy import deepcopy +from .utils import get_bn_running_state_names + __all__ = ['ModelEMA', 'SimpleModelEMA'] @@ -49,7 +51,8 @@ class ModelEMA(object): decay=0.9998, ema_decay_type='threshold', cycle_epoch=-1, - ema_black_list=None): + ema_black_list=None, + ema_filter_no_grad=False): self.step = 0 self.epoch = 0 self.decay = decay @@ -64,6 +67,12 @@ class ModelEMA(object): else: self.state_dict[k] = paddle.zeros_like(v) + bn_states_names = get_bn_running_state_names(model) + if ema_filter_no_grad: + for n, p in model.named_parameters(): + if p.stop_gradient == True and n not in bn_states_names: + self.ema_black_list.append(n) + self._model_state = { k: weakref.ref(p) for k, p in model.state_dict().items() diff --git a/ppdet/optimizer/utils.py b/ppdet/optimizer/utils.py new file mode 100644 index 000000000..ce2de49bf --- /dev/null +++ b/ppdet/optimizer/utils.py @@ -0,0 +1,37 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import paddle +import paddle.nn as nn + +from typing import List + + +def get_bn_running_state_names(model: nn.Layer) -> List[str]: + """Get all bn state full names including running mean and variance + """ + names = [] + for n, m in model.named_sublayers(): + if isinstance(m, (nn.BatchNorm2D, nn.SyncBatchNorm)): + assert hasattr(m, '_mean'), f'assert {m} has _mean' + assert hasattr(m, '_variance'), f'assert {m} has _variance' + running_mean = f'{n}._mean' + running_var = f'{n}._variance' + names.extend([running_mean, running_var]) + + return names -- GitLab