diff --git a/tools/program.py b/tools/program.py index 34d484d8aa6240401c5b6890a854930d03900f42..6277d7475868912ad86e1475c2ea6ac930cd6297 100755 --- a/tools/program.py +++ b/tools/program.py @@ -182,6 +182,8 @@ def train(config, model_average = False model.train() + use_srn = config['Architecture']['algorithm'] == "SRN" + if 'start_epoch' in best_model_dict: start_epoch = best_model_dict['start_epoch'] else: @@ -200,7 +202,7 @@ def train(config, break lr = optimizer.get_lr() images = batch[0] - if config['Architecture']['algorithm'] == "SRN": + if use_srn: others = batch[-4:] preds = model(images, others) model_average = True @@ -256,8 +258,12 @@ def train(config, min_average_window=10000, max_average_window=15625) Model_Average.apply() - cur_metric = eval(model, valid_dataloader, post_process_class, - eval_class) + cur_metric = eval( + model, + valid_dataloader, + post_process_class, + eval_class, + use_srn=use_srn) cur_metric_str = 'cur metric, {}'.format(', '.join( ['{}: {}'.format(k, v) for k, v in cur_metric.items()])) logger.info(cur_metric_str) @@ -321,7 +327,8 @@ def train(config, return -def eval(model, valid_dataloader, post_process_class, eval_class): +def eval(model, valid_dataloader, post_process_class, eval_class, + use_srn=False): model.eval() with paddle.no_grad(): total_frame = 0.0 @@ -332,7 +339,8 @@ def eval(model, valid_dataloader, post_process_class, eval_class): break images = batch[0] start = time.time() - if "SRN" in str(model.head): + + if use_srn: others = batch[-4:] preds = model(images, others) else: