未验证 提交 974e98bc 编写于 作者: Y Yulong Ao 提交者: GitHub

[Auto Parallel] Fix the bug for None labels (#46987)

上级 48bb2c0a
...@@ -535,8 +535,8 @@ class Engine: ...@@ -535,8 +535,8 @@ class Engine:
outputs = [] outputs = []
losses = [] losses = []
metrics = [] metrics = []
inputs = self._inputs inputs = self._inputs if self._inputs else []
labels = self._labels labels = self._labels if self._labels else []
serial_main_prog = self._orig_main_prog.clone() serial_main_prog = self._orig_main_prog.clone()
serial_startup_prog = self._orig_startup_prog.clone() serial_startup_prog = self._orig_startup_prog.clone()
if not self._skip_build: if not self._skip_build:
...@@ -848,12 +848,12 @@ class Engine: ...@@ -848,12 +848,12 @@ class Engine:
history = self._prepare_history(outs, fetch_indices, history = self._prepare_history(outs, fetch_indices,
self._mode) self._mode)
# if valid_data and epoch % valid_freq == 0: if valid_data and epoch % valid_freq == 0:
# self.evaluate(valid_data, valid_sample_split, batch_size, self.evaluate(valid_data, valid_sample_split, batch_size,
# valid_steps, collate_fn, callbacks) valid_steps, collate_fn, callbacks)
# self._switch_mode("train") self._switch_mode("train")
# else: else:
# self._reset_metrics() self._reset_metrics()
return history return history
def evaluate(self, def evaluate(self,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册