提交 97a66874 编写于 作者: M MissPenguin

refine

上级 053cc43d
...@@ -187,7 +187,7 @@ def train(config, ...@@ -187,7 +187,7 @@ def train(config,
use_srn = config['Architecture']['algorithm'] == "SRN" use_srn = config['Architecture']['algorithm'] == "SRN"
model_type = config['Architecture']['model_type'] model_type = config['Architecture']['model_type']
if 'start_epoch' in best_model_dict: if 'start_epoch' in best_model_dict:
start_epoch = best_model_dict['start_epoch'] start_epoch = best_model_dict['start_epoch']
else: else:
...@@ -338,8 +338,12 @@ def train(config, ...@@ -338,8 +338,12 @@ def train(config,
return return
def eval(model, valid_dataloader, post_process_class, eval_class, def eval(model,
model_type, use_srn=False): valid_dataloader,
post_process_class,
eval_class,
model_type,
use_srn=False):
model.eval() model.eval()
with paddle.no_grad(): with paddle.no_grad():
total_frame = 0.0 total_frame = 0.0
...@@ -352,7 +356,7 @@ def eval(model, valid_dataloader, post_process_class, eval_class, ...@@ -352,7 +356,7 @@ def eval(model, valid_dataloader, post_process_class, eval_class,
break break
images = batch[0] images = batch[0]
start = time.time() start = time.time()
preds = model(images, data=batch[1:]) preds = model(images, data=batch[1:])
batch = [item.numpy() for item in batch] batch = [item.numpy() for item in batch]
# Obtain usable results from post-processing methods # Obtain usable results from post-processing methods
total_time += time.time() - start total_time += time.time() - start
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册