From 2f343f5aa332b75a15c38f024cf90ac83c413b0b Mon Sep 17 00:00:00 2001 From: shangliang Xu Date: Tue, 21 Mar 2023 15:04:02 +0800 Subject: [PATCH] fix ema_filter_no_grad (#7974) --- ppdet/optimizer/ema.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/ppdet/optimizer/ema.py b/ppdet/optimizer/ema.py index 9cd9dca63..70d006b8f 100644 --- a/ppdet/optimizer/ema.py +++ b/ppdet/optimizer/ema.py @@ -60,6 +60,12 @@ class ModelEMA(object): self.cycle_epoch = cycle_epoch self.ema_black_list = self._match_ema_black_list( model.state_dict().keys(), ema_black_list) + bn_states_names = get_bn_running_state_names(model) + if ema_filter_no_grad: + for n, p in model.named_parameters(): + if p.stop_gradient and n not in bn_states_names: + self.ema_black_list.add(n) + self.state_dict = dict() for k, v in model.state_dict().items(): if k in self.ema_black_list: @@ -67,12 +73,6 @@ class ModelEMA(object): else: self.state_dict[k] = paddle.zeros_like(v) - bn_states_names = get_bn_running_state_names(model) - if ema_filter_no_grad: - for n, p in model.named_parameters(): - if p.stop_gradient == True and n not in bn_states_names: - self.ema_black_list.append(n) - self._model_state = { k: weakref.ref(p) for k, p in model.state_dict().items() -- GitLab