未验证 提交 974e98bc 编写于 作者: Y Yulong Ao 提交者: GitHub

[Auto Parallel] Fix the bug for None labels (#46987)

上级 48bb2c0a
......@@ -535,8 +535,8 @@ class Engine:
outputs = []
losses = []
metrics = []
inputs = self._inputs
labels = self._labels
inputs = self._inputs if self._inputs else []
labels = self._labels if self._labels else []
serial_main_prog = self._orig_main_prog.clone()
serial_startup_prog = self._orig_startup_prog.clone()
if not self._skip_build:
......@@ -848,12 +848,12 @@ class Engine:
history = self._prepare_history(outs, fetch_indices,
self._mode)
# if valid_data and epoch % valid_freq == 0:
# self.evaluate(valid_data, valid_sample_split, batch_size,
# valid_steps, collate_fn, callbacks)
# self._switch_mode("train")
# else:
# self._reset_metrics()
if valid_data and epoch % valid_freq == 0:
self.evaluate(valid_data, valid_sample_split, batch_size,
valid_steps, collate_fn, callbacks)
self._switch_mode("train")
else:
self._reset_metrics()
return history
def evaluate(self,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册