diff --git a/python/paddle/distributed/auto_parallel/engine.py b/python/paddle/distributed/auto_parallel/engine.py index 76e9863bfe27265c70f0235dfab133614f6a5378..60ee7d0ba3e1421bed26650a7aee70c5f7b65429 100644 --- a/python/paddle/distributed/auto_parallel/engine.py +++ b/python/paddle/distributed/auto_parallel/engine.py @@ -535,8 +535,8 @@ class Engine: outputs = [] losses = [] metrics = [] - inputs = self._inputs - labels = self._labels + inputs = self._inputs if self._inputs else [] + labels = self._labels if self._labels else [] serial_main_prog = self._orig_main_prog.clone() serial_startup_prog = self._orig_startup_prog.clone() if not self._skip_build: @@ -848,12 +848,12 @@ class Engine: history = self._prepare_history(outs, fetch_indices, self._mode) - # if valid_data and epoch % valid_freq == 0: - # self.evaluate(valid_data, valid_sample_split, batch_size, - # valid_steps, collate_fn, callbacks) - # self._switch_mode("train") - # else: - # self._reset_metrics() + if valid_data and epoch % valid_freq == 0: + self.evaluate(valid_data, valid_sample_split, batch_size, + valid_steps, collate_fn, callbacks) + self._switch_mode("train") + else: + self._reset_metrics() return history def evaluate(self,