diff --git a/tools/program.py b/tools/program.py index c1547efbcebc5ee8522aa7f190c44d602b595880..538d0f62fd9fdcbf9f99bffebf477098899b9e6b 100755 --- a/tools/program.py +++ b/tools/program.py @@ -227,10 +227,6 @@ def train(config, images = batch[0] if use_srn: model_average = True - if model_type == 'table' or extra_input: - preds = model(images, data=batch[1:]) - if model_type == "kie": - preds = model(batch) train_start = time.time() # use amp @@ -243,6 +239,8 @@ def train(config, else: if model_type == 'table' or extra_input: preds = model(images, data=batch[1:]) + if model_type == "kie": + preds = model(batch) else: preds = model(images) loss = loss_class(preds, batch)