提交 ad15a645 编写于 作者: T tink2123

polish code for srn eval

上级 a6146ffc
......@@ -177,6 +177,8 @@ def train(config,
model_average = False
model.train()
use_srn = config['Architecture']['algorithm'] == "SRN"
if 'start_epoch' in best_model_dict:
start_epoch = best_model_dict['start_epoch']
else:
......@@ -195,7 +197,7 @@ def train(config,
break
lr = optimizer.get_lr()
images = batch[0]
if config['Architecture']['algorithm'] == "SRN":
if use_srn:
others = batch[-4:]
preds = model(images, others)
model_average = True
......@@ -251,8 +253,12 @@ def train(config,
min_average_window=10000,
max_average_window=15625)
Model_Average.apply()
cur_metric = eval(model, valid_dataloader, post_process_class,
eval_class)
cur_metric = eval(
model,
valid_dataloader,
post_process_class,
eval_class,
use_srn=use_srn)
cur_metric_str = 'cur metric, {}'.format(', '.join(
['{}: {}'.format(k, v) for k, v in cur_metric.items()]))
logger.info(cur_metric_str)
......@@ -316,7 +322,8 @@ def train(config,
return
def eval(model, valid_dataloader, post_process_class, eval_class):
def eval(model, valid_dataloader, post_process_class, eval_class,
use_srn=False):
model.eval()
with paddle.no_grad():
total_frame = 0.0
......@@ -327,7 +334,8 @@ def eval(model, valid_dataloader, post_process_class, eval_class):
break
images = batch[0]
start = time.time()
if "SRN" in str(model.head):
if use_srn:
others = batch[-4:]
preds = model(images, others)
else:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册