提交 ad15a645 编写于 作者: T tink2123

polish code for srn eval

上级 a6146ffc
...@@ -177,6 +177,8 @@ def train(config, ...@@ -177,6 +177,8 @@ def train(config,
model_average = False model_average = False
model.train() model.train()
use_srn = config['Architecture']['algorithm'] == "SRN"
if 'start_epoch' in best_model_dict: if 'start_epoch' in best_model_dict:
start_epoch = best_model_dict['start_epoch'] start_epoch = best_model_dict['start_epoch']
else: else:
...@@ -195,7 +197,7 @@ def train(config, ...@@ -195,7 +197,7 @@ def train(config,
break break
lr = optimizer.get_lr() lr = optimizer.get_lr()
images = batch[0] images = batch[0]
if config['Architecture']['algorithm'] == "SRN": if use_srn:
others = batch[-4:] others = batch[-4:]
preds = model(images, others) preds = model(images, others)
model_average = True model_average = True
...@@ -251,8 +253,12 @@ def train(config, ...@@ -251,8 +253,12 @@ def train(config,
min_average_window=10000, min_average_window=10000,
max_average_window=15625) max_average_window=15625)
Model_Average.apply() Model_Average.apply()
cur_metric = eval(model, valid_dataloader, post_process_class, cur_metric = eval(
eval_class) model,
valid_dataloader,
post_process_class,
eval_class,
use_srn=use_srn)
cur_metric_str = 'cur metric, {}'.format(', '.join( cur_metric_str = 'cur metric, {}'.format(', '.join(
['{}: {}'.format(k, v) for k, v in cur_metric.items()])) ['{}: {}'.format(k, v) for k, v in cur_metric.items()]))
logger.info(cur_metric_str) logger.info(cur_metric_str)
...@@ -316,7 +322,8 @@ def train(config, ...@@ -316,7 +322,8 @@ def train(config,
return return
def eval(model, valid_dataloader, post_process_class, eval_class): def eval(model, valid_dataloader, post_process_class, eval_class,
use_srn=False):
model.eval() model.eval()
with paddle.no_grad(): with paddle.no_grad():
total_frame = 0.0 total_frame = 0.0
...@@ -327,7 +334,8 @@ def eval(model, valid_dataloader, post_process_class, eval_class): ...@@ -327,7 +334,8 @@ def eval(model, valid_dataloader, post_process_class, eval_class):
break break
images = batch[0] images = batch[0]
start = time.time() start = time.time()
if "SRN" in str(model.head):
if use_srn:
others = batch[-4:] others = batch[-4:]
preds = model(images, others) preds = model(images, others)
else: else:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册