diff --git a/ppgan/engine/trainer.py b/ppgan/engine/trainer.py index 417db0a37fe2d1035802ab5d54b7b31deba63c6e..fa2dfa0d645d95999038f61f025ad540fcaed478 100755 --- a/ppgan/engine/trainer.py +++ b/ppgan/engine/trainer.py @@ -13,6 +13,7 @@ # limitations under the License. import os +import sys import time import copy @@ -35,6 +36,7 @@ class IterLoader: self._dataloader = dataloader self.iter_loader = iter(self._dataloader) self._epoch = 1 + self._inner_iter = 0 @property def epoch(self): @@ -42,12 +44,17 @@ class IterLoader: def __next__(self): try: + if sys.platform == "Windows" and self._inner_iter == len( + self._dataloader) - 1: + self._inner_iter = 0 + raise StopIteration data = next(self.iter_loader) except StopIteration: self._epoch += 1 self.iter_loader = iter(self._dataloader) data = next(self.iter_loader) + self._inner_iter += 1 return data def __len__(self):