diff --git a/tools/program.py b/tools/program.py index 694d64152f05ffd5e9329885149891f75a98ed84..f3ba49450a21f600589b6888710a2420ccdaa321 100755 --- a/tools/program.py +++ b/tools/program.py @@ -326,9 +326,12 @@ def eval(model, valid_dataloader, post_process_class, eval_class): if idx >= len(valid_dataloader): break images = batch[0] - others = batch[-4:] start = time.time() - preds = model(images, others) + if "SRN" in str(model.head): + others = batch[-4:] + preds = model(images, others) + else: + preds = model(images) batch = [item.numpy() for item in batch] # Obtain usable results from post-processing methods