diff --git a/tools/program.py b/tools/program.py index 538d0f62fd9fdcbf9f99bffebf477098899b9e6b..84dc6c1d02e9b6dcb569deea61953bd4ad6fbd54 100755 --- a/tools/program.py +++ b/tools/program.py @@ -239,7 +239,7 @@ def train(config, else: if model_type == 'table' or extra_input: preds = model(images, data=batch[1:]) - if model_type == "kie": + elif model_type == "kie": preds = model(batch) else: preds = model(images)