diff --git a/tools/program.py b/tools/program.py index 99a374326ceb71a21001f67534990c4f37effeac..3dc8550094fbc79397396d9f6d552ffa05630300 100755 --- a/tools/program.py +++ b/tools/program.py @@ -177,6 +177,8 @@ def train(config, model_average = False model.train() + use_srn = config['Architecture']['algorithm'] == "SRN" + if 'start_epoch' in best_model_dict: start_epoch = best_model_dict['start_epoch'] else: @@ -195,7 +197,7 @@ def train(config, break lr = optimizer.get_lr() images = batch[0] - if config['Architecture']['algorithm'] == "SRN": + if use_srn: others = batch[-4:] preds = model(images, others) model_average = True @@ -251,8 +253,12 @@ def train(config, min_average_window=10000, max_average_window=15625) Model_Average.apply() - cur_metric = eval(model, valid_dataloader, post_process_class, - eval_class) + cur_metric = eval( + model, + valid_dataloader, + post_process_class, + eval_class, + use_srn=use_srn) cur_metric_str = 'cur metric, {}'.format(', '.join( ['{}: {}'.format(k, v) for k, v in cur_metric.items()])) logger.info(cur_metric_str) @@ -316,7 +322,8 @@ def train(config, return -def eval(model, valid_dataloader, post_process_class, eval_class): +def eval(model, valid_dataloader, post_process_class, eval_class, + use_srn=False): model.eval() with paddle.no_grad(): total_frame = 0.0 @@ -327,7 +334,8 @@ def eval(model, valid_dataloader, post_process_class, eval_class): break images = batch[0] start = time.time() - if "SRN" in str(model.head): + + if use_srn: others = batch[-4:] preds = model(images, others) else: