diff --git a/tools/program.py b/tools/program.py index 333e8ed9770cad08ba5e9aa47edec850a74a1808..1dfd06af8dcc7d260cdd42d95d7ca0a747703cfb 100755 --- a/tools/program.py +++ b/tools/program.py @@ -266,7 +266,7 @@ def train(config, stats['lr'] = lr train_stats.update(stats) - if cal_metric_during_train: # only rec and cls need + if cal_metric_during_train and model_type is not "det": # only rec and cls need batch = [item.numpy() for item in batch] if model_type in ['table', 'kie']: eval_class(preds, batch)