未验证 提交 08c5b1d1 编写于 作者: S shangliang Xu 提交者: GitHub

fix bug for num_iters in fit/evaluate (#34059)

上级 edb9aff5
...@@ -1707,7 +1707,8 @@ class Model(object): ...@@ -1707,7 +1707,8 @@ class Model(object):
steps = self._len_data_loader(train_loader) steps = self._len_data_loader(train_loader)
self.num_iters = num_iters self.num_iters = num_iters
if num_iters is not None and isinstance(num_iters, int): if num_iters is not None and isinstance(num_iters, int) and isinstance(
steps, int):
assert num_iters > 0, "num_iters must be greater than 0!" assert num_iters > 0, "num_iters must be greater than 0!"
epochs = (num_iters // steps) + 1 epochs = (num_iters // steps) + 1
steps = min(num_iters, steps) steps = min(num_iters, steps)
...@@ -1830,7 +1831,8 @@ class Model(object): ...@@ -1830,7 +1831,8 @@ class Model(object):
eval_steps = self._len_data_loader(eval_loader) eval_steps = self._len_data_loader(eval_loader)
self.num_iters = num_iters self.num_iters = num_iters
if num_iters is not None and isinstance(num_iters, int): if num_iters is not None and isinstance(num_iters, int) and isinstance(
eval_steps, int):
assert num_iters > 0, "num_iters must be greater than 0!" assert num_iters > 0, "num_iters must be greater than 0!"
eval_steps = min(num_iters, eval_steps) eval_steps = min(num_iters, eval_steps)
self.num_iters = eval_steps self.num_iters = eval_steps
...@@ -2092,7 +2094,9 @@ class Model(object): ...@@ -2092,7 +2094,9 @@ class Model(object):
callbacks.on_batch_end(mode, step, logs) callbacks.on_batch_end(mode, step, logs)
if hasattr(self, 'num_iters') and self.num_iters is not None: if hasattr(self, 'num_iters') and self.num_iters is not None:
self.num_iters -= 1 self.num_iters -= 1
if self.num_iters == 0: if self.num_iters <= 0:
self.stop_training = True
del self.num_iters
break break
self._reset_metrics() self._reset_metrics()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册