提交 2aa92e6a 编写于 作者: W WenmuZhou

fix bug

上级 9467b754
...@@ -152,7 +152,6 @@ def train(config, ...@@ -152,7 +152,6 @@ def train(config,
pre_best_model_dict, pre_best_model_dict,
logger, logger,
vdl_writer=None): vdl_writer=None):
cal_metric_during_train = config['Global'].get('cal_metric_during_train', cal_metric_during_train = config['Global'].get('cal_metric_during_train',
False) False)
log_smooth_window = config['Global']['log_smooth_window'] log_smooth_window = config['Global']['log_smooth_window']
...@@ -185,14 +184,13 @@ def train(config, ...@@ -185,14 +184,13 @@ def train(config,
for epoch in range(start_epoch, epoch_num): for epoch in range(start_epoch, epoch_num):
if epoch > 0: if epoch > 0:
train_loader = build_dataloader(config, 'Train', device) train_dataloader = build_dataloader(config, 'Train', device, logger)
for idx, batch in enumerate(train_dataloader): for idx, batch in enumerate(train_dataloader):
if idx >= len(train_dataloader): if idx >= len(train_dataloader):
break break
lr = optimizer.get_lr() lr = optimizer.get_lr()
t1 = time.time() t1 = time.time()
batch = [paddle.to_tensor(x) for x in batch]
images = batch[0] images = batch[0]
preds = model(images) preds = model(images)
loss = loss_class(preds, batch) loss = loss_class(preds, batch)
...@@ -301,11 +299,11 @@ def eval(model, valid_dataloader, post_process_class, eval_class, logger, ...@@ -301,11 +299,11 @@ def eval(model, valid_dataloader, post_process_class, eval_class, logger,
with paddle.no_grad(): with paddle.no_grad():
total_frame = 0.0 total_frame = 0.0
total_time = 0.0 total_time = 0.0
# pbar = tqdm(total=len(valid_dataloader), desc='eval model:') pbar = tqdm(total=len(valid_dataloader), desc='eval model:')
for idx, batch in enumerate(valid_dataloader): for idx, batch in enumerate(valid_dataloader):
if idx >= len(valid_dataloader): if idx >= len(valid_dataloader):
break break
images = paddle.to_tensor(batch[0]) images = batch[0]
start = time.time() start = time.time()
preds = model(images) preds = model(images)
...@@ -315,15 +313,15 @@ def eval(model, valid_dataloader, post_process_class, eval_class, logger, ...@@ -315,15 +313,15 @@ def eval(model, valid_dataloader, post_process_class, eval_class, logger,
total_time += time.time() - start total_time += time.time() - start
# Evaluate the results of the current batch # Evaluate the results of the current batch
eval_class(post_result, batch) eval_class(post_result, batch)
# pbar.update(1) pbar.update(1)
total_frame += len(images) total_frame += len(images)
if idx % print_batch_step == 0 and dist.get_rank() == 0: # if idx % print_batch_step == 0 and dist.get_rank() == 0:
logger.info('tackling images for eval: {}/{}'.format( # logger.info('tackling images for eval: {}/{}'.format(
idx, len(valid_dataloader))) # idx, len(valid_dataloader)))
# Get final metirc,eg. acc or hmean # Get final metirc,eg. acc or hmean
metirc = eval_class.get_metric() metirc = eval_class.get_metric()
# pbar.close() pbar.close()
model.train() model.train()
metirc['fps'] = total_frame / total_time metirc['fps'] = total_frame / total_time
return metirc return metirc
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册