From dc7d4b00cd9f1c3c6d7d30f89a14258a26f29438 Mon Sep 17 00:00:00 2001 From: LDOUBLEV Date: Sat, 18 Dec 2021 08:04:10 +0000 Subject: [PATCH] fix train --- tools/program.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/tools/program.py b/tools/program.py index c1547efb..538d0f62 100755 --- a/tools/program.py +++ b/tools/program.py @@ -227,10 +227,6 @@ def train(config, images = batch[0] if use_srn: model_average = True - if model_type == 'table' or extra_input: - preds = model(images, data=batch[1:]) - if model_type == "kie": - preds = model(batch) train_start = time.time() # use amp @@ -243,6 +239,8 @@ def train(config, else: if model_type == 'table' or extra_input: preds = model(images, data=batch[1:]) + if model_type == "kie": + preds = model(batch) else: preds = model(images) loss = loss_class(preds, batch) -- GitLab