diff --git a/ppcls/engine/train/classification.py b/ppcls/engine/train/classification.py index 9074a6b4913fd01a00cbb0f7727f4ca0ba9d1dae..6f84133b302cb1ea8136a25a37e9dd60936c1400 100644 --- a/ppcls/engine/train/classification.py +++ b/ppcls/engine/train/classification.py @@ -258,7 +258,7 @@ class ClassTrainer(object): return None def _build_ema_model(self): - if "EMA" in self.config and self.mode == "train": + if "EMA" in self.config: model_ema = ExponentialMovingAverage( self.model, self.config['EMA'].get("decay", 0.9999)) self.best_metric["metric_ema"] = 0