From 7bcabe0f36e5a3ce287a6a041bc81fd1aafa0175 Mon Sep 17 00:00:00 2001 From: MissPenguin Date: Tue, 22 Jun 2021 12:39:43 +0000 Subject: [PATCH] refine --- tools/program.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/tools/program.py b/tools/program.py index 2bb34835..2d99f296 100755 --- a/tools/program.py +++ b/tools/program.py @@ -210,7 +210,10 @@ def train(config, images = batch[0] if use_srn: model_average = True - preds = model(images, data=batch[1:]) + if use_srn or model_type == 'table': + preds = model(images, data=batch[1:]) + else: + preds = model(images) loss = loss_class(preds, batch) avg_loss = loss['loss'] avg_loss.backward() @@ -356,7 +359,10 @@ def eval(model, break images = batch[0] start = time.time() - preds = model(images, data=batch[1:]) + if use_srn or model_type == 'table': + preds = model(images, data=batch[1:]) + else: + preds = model(images) batch = [item.numpy() for item in batch] # Obtain usable results from post-processing methods total_time += time.time() - start -- GitLab