diff --git a/tools/program.py b/tools/program.py index 696700c031eda416212e3fd96747badd24cf2c1d..33bf1adc98fc48289e9873d073501d76547c011d 100755 --- a/tools/program.py +++ b/tools/program.py @@ -152,7 +152,6 @@ def train(config, pre_best_model_dict, logger, vdl_writer=None): - cal_metric_during_train = config['Global'].get('cal_metric_during_train', False) log_smooth_window = config['Global']['log_smooth_window'] @@ -185,14 +184,13 @@ def train(config, for epoch in range(start_epoch, epoch_num): if epoch > 0: - train_loader = build_dataloader(config, 'Train', device) + train_dataloader = build_dataloader(config, 'Train', device, logger) for idx, batch in enumerate(train_dataloader): if idx >= len(train_dataloader): break lr = optimizer.get_lr() t1 = time.time() - batch = [paddle.to_tensor(x) for x in batch] images = batch[0] preds = model(images) loss = loss_class(preds, batch) @@ -301,11 +299,11 @@ def eval(model, valid_dataloader, post_process_class, eval_class, logger, with paddle.no_grad(): total_frame = 0.0 total_time = 0.0 - # pbar = tqdm(total=len(valid_dataloader), desc='eval model:') + pbar = tqdm(total=len(valid_dataloader), desc='eval model:') for idx, batch in enumerate(valid_dataloader): if idx >= len(valid_dataloader): break - images = paddle.to_tensor(batch[0]) + images = batch[0] start = time.time() preds = model(images) @@ -315,15 +313,15 @@ def eval(model, valid_dataloader, post_process_class, eval_class, logger, total_time += time.time() - start # Evaluate the results of the current batch eval_class(post_result, batch) - # pbar.update(1) + pbar.update(1) total_frame += len(images) - if idx % print_batch_step == 0 and dist.get_rank() == 0: - logger.info('tackling images for eval: {}/{}'.format( - idx, len(valid_dataloader))) + # if idx % print_batch_step == 0 and dist.get_rank() == 0: + # logger.info('tackling images for eval: {}/{}'.format( + # idx, len(valid_dataloader))) # Get final metirc,eg. acc or hmean metirc = eval_class.get_metric() -# pbar.close() + pbar.close() model.train() metirc['fps'] = total_frame / total_time return metirc