未验证 提交 4af9f510 编写于 作者: littletomatodonkey's avatar littletomatodonkey 提交者: GitHub

fix windows training (#1038)

* fix windows training

* fix typo
上级 1837078b
...@@ -21,6 +21,7 @@ __dir__ = os.path.dirname(os.path.abspath(__file__)) ...@@ -21,6 +21,7 @@ __dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(os.path.abspath(os.path.join(__dir__, '../../'))) sys.path.append(os.path.abspath(os.path.join(__dir__, '../../')))
import time import time
import platform
import datetime import datetime
import argparse import argparse
import paddle import paddle
...@@ -152,10 +153,14 @@ class Trainer(object): ...@@ -152,10 +153,14 @@ class Trainer(object):
best_metric.update(metric_info) best_metric.update(metric_info)
tic = time.time() tic = time.time()
max_iter = len(self.train_dataloader) - 1 if platform.system(
) == "Windows" else len(self.train_dataloader)
for epoch_id in range(best_metric["epoch"] + 1, for epoch_id in range(best_metric["epoch"] + 1,
self.config["Global"]["epochs"] + 1): self.config["Global"]["epochs"] + 1):
acc = 0.0 acc = 0.0
for iter_id, batch in enumerate(self.train_dataloader()): for iter_id, batch in enumerate(self.train_dataloader()):
if iter_id >= max_iter:
break
if iter_id == 5: if iter_id == 5:
for key in time_info: for key in time_info:
time_info[key].reset() time_info[key].reset()
...@@ -349,7 +354,11 @@ class Trainer(object): ...@@ -349,7 +354,11 @@ class Trainer(object):
metric_key = None metric_key = None
tic = time.time() tic = time.time()
max_iter = len(self.eval_dataloader) - 1 if platform.system(
) == "Windows" else len(self.eval_dataloader)
for iter_id, batch in enumerate(self.eval_dataloader()): for iter_id, batch in enumerate(self.eval_dataloader()):
if iter_id >= max_iter:
break
if iter_id == 5: if iter_id == 5:
for key in time_info: for key in time_info:
time_info[key].reset() time_info[key].reset()
...@@ -498,8 +507,12 @@ class Trainer(object): ...@@ -498,8 +507,12 @@ class Trainer(object):
raise RuntimeError("Only support gallery or query dataset") raise RuntimeError("Only support gallery or query dataset")
has_unique_id = False has_unique_id = False
max_iter = len(dataloader) - 1 if platform.system(
) == "Windows" else len(dataloader)
for idx, batch in enumerate(dataloader( for idx, batch in enumerate(dataloader(
)): # load is very time-consuming )): # load is very time-consuming
if idx >= max_iter:
break
if idx % self.config["Global"]["print_batch_step"] == 0: if idx % self.config["Global"]["print_batch_step"] == 0:
logger.info( logger.info(
f"{name} feature calculation process: [{idx}/{len(dataloader)}]" f"{name} feature calculation process: [{idx}/{len(dataloader)}]"
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册