diff --git a/tools/program.py b/tools/program.py index aa0d2698cf66c928f87217996c31c042e1c8aa02..9a90eda2f8855804795e6e6f33fa94f6cf87180b 100755 --- a/tools/program.py +++ b/tools/program.py @@ -255,6 +255,8 @@ def train(config, with paddle.amp.auto_cast(): if model_type == 'table' or extra_input: preds = model(images, data=batch[1:]) + elif model_type in ["kie", 'vqa']: + preds = model(batch) else: preds = model(images) else: @@ -307,7 +309,8 @@ def train(config, train_stats.update(stats) if log_writer is not None and dist.get_rank() == 0: - log_writer.log_metrics(metrics=train_stats.get(), prefix="TRAIN", step=global_step) + log_writer.log_metrics( + metrics=train_stats.get(), prefix="TRAIN", step=global_step) if dist.get_rank() == 0 and ( (global_step > 0 and global_step % print_batch_step == 0) or @@ -354,7 +357,8 @@ def train(config, # logger metric if log_writer is not None: - log_writer.log_metrics(metrics=cur_metric, prefix="EVAL", step=global_step) + log_writer.log_metrics( + metrics=cur_metric, prefix="EVAL", step=global_step) if cur_metric[main_indicator] >= best_model_dict[ main_indicator]: @@ -377,11 +381,18 @@ def train(config, logger.info(best_str) # logger best metric if log_writer is not None: - log_writer.log_metrics(metrics={ - "best_{}".format(main_indicator): best_model_dict[main_indicator] - }, prefix="EVAL", step=global_step) - - log_writer.log_model(is_best=True, prefix="best_accuracy", metadata=best_model_dict) + log_writer.log_metrics( + metrics={ + "best_{}".format(main_indicator): + best_model_dict[main_indicator] + }, + prefix="EVAL", + step=global_step) + + log_writer.log_model( + is_best=True, + prefix="best_accuracy", + metadata=best_model_dict) reader_start = time.time() if dist.get_rank() == 0: @@ -413,7 +424,8 @@ def train(config, epoch=epoch, global_step=global_step) if log_writer is not None: - log_writer.log_model(is_best=False, prefix='iter_epoch_{}'.format(epoch)) + log_writer.log_model( + is_best=False, prefix='iter_epoch_{}'.format(epoch)) best_str = 'best metric, {}'.format(', '.join( ['{}: {}'.format(k, v) for k, v in best_model_dict.items()])) @@ -585,7 +597,8 @@ def preprocess(is_train=False): vdl_writer_path = '{}/vdl/'.format(save_model_dir) log_writer = VDLLogger(save_model_dir) loggers.append(log_writer) - if ('use_wandb' in config['Global'] and config['Global']['use_wandb']) or 'wandb' in config: + if ('use_wandb' in config['Global'] and + config['Global']['use_wandb']) or 'wandb' in config: save_dir = config['Global']['save_model_dir'] wandb_writer_path = "{}/wandb".format(save_dir) if "wandb" in config: