diff --git a/tools/program.py b/tools/program.py index bd17db4afec459468b6428611308cd8c41920ca5..2bb34835269d913b0ef773d9233a65b6ccb9f2d5 100755 --- a/tools/program.py +++ b/tools/program.py @@ -187,7 +187,7 @@ def train(config, use_srn = config['Architecture']['algorithm'] == "SRN" model_type = config['Architecture']['model_type'] - + if 'start_epoch' in best_model_dict: start_epoch = best_model_dict['start_epoch'] else: @@ -338,8 +338,12 @@ def train(config, return -def eval(model, valid_dataloader, post_process_class, eval_class, - model_type, use_srn=False): +def eval(model, + valid_dataloader, + post_process_class, + eval_class, + model_type, + use_srn=False): model.eval() with paddle.no_grad(): total_frame = 0.0 @@ -352,7 +356,7 @@ def eval(model, valid_dataloader, post_process_class, eval_class, break images = batch[0] start = time.time() - preds = model(images, data=batch[1:]) + preds = model(images, data=batch[1:]) batch = [item.numpy() for item in batch] # Obtain usable results from post-processing methods total_time += time.time() - start