未验证 提交 c731d54e 编写于 作者: X xiaoting 提交者: GitHub

fix sr typo (#7262)

* fix sr typo

* rm eval visual
上级 d26da4b6
...@@ -63,14 +63,14 @@ def main(): ...@@ -63,14 +63,14 @@ def main():
elif op_name in ['SRResize']: elif op_name in ['SRResize']:
op[op_name]['infer_mode'] = True op[op_name]['infer_mode'] = True
elif op_name == 'KeepKeys': elif op_name == 'KeepKeys':
op[op_name]['keep_keys'] = ['imge_lr'] op[op_name]['keep_keys'] = ['img_lr']
transforms.append(op) transforms.append(op)
global_config['infer_mode'] = True global_config['infer_mode'] = True
ops = create_operators(transforms, global_config) ops = create_operators(transforms, global_config)
save_res_path = config['Global'].get('save_res_path', "./infer_result") save_visual_path = config['Global'].get('save_visual', "infer_result/")
if not os.path.exists(os.path.dirname(save_res_path)): if not os.path.exists(os.path.dirname(save_visual_path)):
os.makedirs(os.path.dirname(save_res_path)) os.makedirs(os.path.dirname(save_visual_path))
model.eval() model.eval()
for file in get_image_file_list(config['Global']['infer_img']): for file in get_image_file_list(config['Global']['infer_img']):
...@@ -87,7 +87,7 @@ def main(): ...@@ -87,7 +87,7 @@ def main():
fm_sr = (sr_img.numpy() * 255).transpose(1, 2, 0).astype(np.uint8) fm_sr = (sr_img.numpy() * 255).transpose(1, 2, 0).astype(np.uint8)
fm_lr = (lr_img.numpy() * 255).transpose(1, 2, 0).astype(np.uint8) fm_lr = (lr_img.numpy() * 255).transpose(1, 2, 0).astype(np.uint8)
img_name_pure = os.path.split(file)[-1] img_name_pure = os.path.split(file)[-1]
cv2.imwrite("infer_result/sr_{}".format(img_name_pure), cv2.imwrite("{}/sr_{}".format(save_visual_path, img_name_pure),
fm_sr[:, :, ::-1]) fm_sr[:, :, ::-1])
logger.info("The visualized image saved in infer_result/sr_{}".format( logger.info("The visualized image saved in infer_result/sr_{}".format(
img_name_pure)) img_name_pure))
......
...@@ -231,7 +231,8 @@ def train(config, ...@@ -231,7 +231,8 @@ def train(config,
use_srn = config['Architecture']['algorithm'] == "SRN" use_srn = config['Architecture']['algorithm'] == "SRN"
extra_input_models = [ extra_input_models = [
"SRN", "NRTR", "SAR", "SEED", "SVTR", "SPIN", "VisionLAN", "RobustScanner" "SRN", "NRTR", "SAR", "SEED", "SVTR", "SPIN", "VisionLAN",
"RobustScanner"
] ]
extra_input = False extra_input = False
if config['Architecture']['algorithm'] == 'Distillation': if config['Architecture']['algorithm'] == 'Distillation':
...@@ -503,16 +504,6 @@ def eval(model, ...@@ -503,16 +504,6 @@ def eval(model,
preds = model(batch) preds = model(batch)
sr_img = preds["sr_img"] sr_img = preds["sr_img"]
lr_img = preds["lr_img"] lr_img = preds["lr_img"]
for i in (range(sr_img.shape[0])):
fm_sr = (sr_img[i].numpy() * 255).transpose(
1, 2, 0).astype(np.uint8)
fm_lr = (lr_img[i].numpy() * 255).transpose(
1, 2, 0).astype(np.uint8)
cv2.imwrite("output/images/{}_{}_sr.jpg".format(
sum_images, i), fm_sr)
cv2.imwrite("output/images/{}_{}_lr.jpg".format(
sum_images, i), fm_lr)
else: else:
preds = model(images) preds = model(images)
preds = to_float32(preds) preds = to_float32(preds)
...@@ -525,16 +516,6 @@ def eval(model, ...@@ -525,16 +516,6 @@ def eval(model,
preds = model(batch) preds = model(batch)
sr_img = preds["sr_img"] sr_img = preds["sr_img"]
lr_img = preds["lr_img"] lr_img = preds["lr_img"]
for i in (range(sr_img.shape[0])):
fm_sr = (sr_img[i].numpy() * 255).transpose(
1, 2, 0).astype(np.uint8)
fm_lr = (lr_img[i].numpy() * 255).transpose(
1, 2, 0).astype(np.uint8)
cv2.imwrite("output/images/{}_{}_sr.jpg".format(
sum_images, i), fm_sr)
cv2.imwrite("output/images/{}_{}_lr.jpg".format(
sum_images, i), fm_lr)
else: else:
preds = model(images) preds = model(images)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册