diff --git a/ppdet/optimizer/ema.py b/ppdet/optimizer/ema.py index 9cd9dca637998f4701bfe77fe317240fee26fe71..70d006b8fe30b6c4895a4a1c5aeee29c04550636 100644 --- a/ppdet/optimizer/ema.py +++ b/ppdet/optimizer/ema.py @@ -60,6 +60,12 @@ class ModelEMA(object): self.cycle_epoch = cycle_epoch self.ema_black_list = self._match_ema_black_list( model.state_dict().keys(), ema_black_list) + bn_states_names = get_bn_running_state_names(model) + if ema_filter_no_grad: + for n, p in model.named_parameters(): + if p.stop_gradient and n not in bn_states_names: + self.ema_black_list.add(n) + self.state_dict = dict() for k, v in model.state_dict().items(): if k in self.ema_black_list: @@ -67,12 +73,6 @@ class ModelEMA(object): else: self.state_dict[k] = paddle.zeros_like(v) - bn_states_names = get_bn_running_state_names(model) - if ema_filter_no_grad: - for n, p in model.named_parameters(): - if p.stop_gradient == True and n not in bn_states_names: - self.ema_black_list.append(n) - self._model_state = { k: weakref.ref(p) for k, p in model.state_dict().items()