diff --git a/tools/program.py b/tools/program.py index 93c61324285f0e6128727578c18627ee4be43791..333e8ed9770cad08ba5e9aa47edec850a74a1808 100755 --- a/tools/program.py +++ b/tools/program.py @@ -227,10 +227,6 @@ def train(config, images = batch[0] if use_srn: model_average = True - if model_type == 'table' or extra_input: - preds = model(images, data=batch[1:]) - if model_type == "kie": - preds = model(batch) train_start = time.time() # use amp @@ -243,6 +239,8 @@ def train(config, else: if model_type == 'table' or extra_input: preds = model(images, data=batch[1:]) + elif model_type == "kie": + preds = model(batch) else: preds = model(images) loss = loss_class(preds, batch)