From 97a668747cd38ce2708f5e745588e3fc3de62bf5 Mon Sep 17 00:00:00 2001 From: MissPenguin Date: Tue, 22 Jun 2021 04:24:14 +0000 Subject: [PATCH] refine --- tools/program.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/tools/program.py b/tools/program.py index bd17db4a..2bb34835 100755 --- a/tools/program.py +++ b/tools/program.py @@ -187,7 +187,7 @@ def train(config, use_srn = config['Architecture']['algorithm'] == "SRN" model_type = config['Architecture']['model_type'] - + if 'start_epoch' in best_model_dict: start_epoch = best_model_dict['start_epoch'] else: @@ -338,8 +338,12 @@ def train(config, return -def eval(model, valid_dataloader, post_process_class, eval_class, - model_type, use_srn=False): +def eval(model, + valid_dataloader, + post_process_class, + eval_class, + model_type, + use_srn=False): model.eval() with paddle.no_grad(): total_frame = 0.0 @@ -352,7 +356,7 @@ def eval(model, valid_dataloader, post_process_class, eval_class, break images = batch[0] start = time.time() - preds = model(images, data=batch[1:]) + preds = model(images, data=batch[1:]) batch = [item.numpy() for item in batch] # Obtain usable results from post-processing methods total_time += time.time() - start -- GitLab