diff --git a/tools/program.py b/tools/program.py index 2bb34835269d913b0ef773d9233a65b6ccb9f2d5..2d99f2968a3f0c8acc359ed0fbb199650bd7010c 100755 --- a/tools/program.py +++ b/tools/program.py @@ -210,7 +210,10 @@ def train(config, images = batch[0] if use_srn: model_average = True - preds = model(images, data=batch[1:]) + if use_srn or model_type == 'table': + preds = model(images, data=batch[1:]) + else: + preds = model(images) loss = loss_class(preds, batch) avg_loss = loss['loss'] avg_loss.backward() @@ -356,7 +359,10 @@ def eval(model, break images = batch[0] start = time.time() - preds = model(images, data=batch[1:]) + if use_srn or model_type == 'table': + preds = model(images, data=batch[1:]) + else: + preds = model(images) batch = [item.numpy() for item in batch] # Obtain usable results from post-processing methods total_time += time.time() - start