未验证 提交 4af9f510 编写于 作者: littletomatodonkey's avatar littletomatodonkey 提交者: GitHub

fix windows training (#1038)

* fix windows training

* fix typo
上级 1837078b
......@@ -21,6 +21,7 @@ __dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(os.path.abspath(os.path.join(__dir__, '../../')))
import time
import platform
import datetime
import argparse
import paddle
......@@ -152,10 +153,14 @@ class Trainer(object):
best_metric.update(metric_info)
tic = time.time()
max_iter = len(self.train_dataloader) - 1 if platform.system(
) == "Windows" else len(self.train_dataloader)
for epoch_id in range(best_metric["epoch"] + 1,
self.config["Global"]["epochs"] + 1):
acc = 0.0
for iter_id, batch in enumerate(self.train_dataloader()):
if iter_id >= max_iter:
break
if iter_id == 5:
for key in time_info:
time_info[key].reset()
......@@ -349,7 +354,11 @@ class Trainer(object):
metric_key = None
tic = time.time()
max_iter = len(self.eval_dataloader) - 1 if platform.system(
) == "Windows" else len(self.eval_dataloader)
for iter_id, batch in enumerate(self.eval_dataloader()):
if iter_id >= max_iter:
break
if iter_id == 5:
for key in time_info:
time_info[key].reset()
......@@ -498,8 +507,12 @@ class Trainer(object):
raise RuntimeError("Only support gallery or query dataset")
has_unique_id = False
max_iter = len(dataloader) - 1 if platform.system(
) == "Windows" else len(dataloader)
for idx, batch in enumerate(dataloader(
)): # load is very time-consuming
if idx >= max_iter:
break
if idx % self.config["Global"]["print_batch_step"] == 0:
logger.info(
f"{name} feature calculation process: [{idx}/{len(dataloader)}]"
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册