未验证 提交 c731d54e 编写于 作者: X xiaoting 提交者: GitHub

fix sr typo (#7262)

* fix sr typo

* rm eval visual
上级 d26da4b6
......@@ -63,14 +63,14 @@ def main():
elif op_name in ['SRResize']:
op[op_name]['infer_mode'] = True
elif op_name == 'KeepKeys':
op[op_name]['keep_keys'] = ['imge_lr']
op[op_name]['keep_keys'] = ['img_lr']
transforms.append(op)
global_config['infer_mode'] = True
ops = create_operators(transforms, global_config)
save_res_path = config['Global'].get('save_res_path', "./infer_result")
if not os.path.exists(os.path.dirname(save_res_path)):
os.makedirs(os.path.dirname(save_res_path))
save_visual_path = config['Global'].get('save_visual', "infer_result/")
if not os.path.exists(os.path.dirname(save_visual_path)):
os.makedirs(os.path.dirname(save_visual_path))
model.eval()
for file in get_image_file_list(config['Global']['infer_img']):
......@@ -87,7 +87,7 @@ def main():
fm_sr = (sr_img.numpy() * 255).transpose(1, 2, 0).astype(np.uint8)
fm_lr = (lr_img.numpy() * 255).transpose(1, 2, 0).astype(np.uint8)
img_name_pure = os.path.split(file)[-1]
cv2.imwrite("infer_result/sr_{}".format(img_name_pure),
cv2.imwrite("{}/sr_{}".format(save_visual_path, img_name_pure),
fm_sr[:, :, ::-1])
logger.info("The visualized image saved in infer_result/sr_{}".format(
img_name_pure))
......
......@@ -231,7 +231,8 @@ def train(config,
use_srn = config['Architecture']['algorithm'] == "SRN"
extra_input_models = [
"SRN", "NRTR", "SAR", "SEED", "SVTR", "SPIN", "VisionLAN", "RobustScanner"
"SRN", "NRTR", "SAR", "SEED", "SVTR", "SPIN", "VisionLAN",
"RobustScanner"
]
extra_input = False
if config['Architecture']['algorithm'] == 'Distillation':
......@@ -503,16 +504,6 @@ def eval(model,
preds = model(batch)
sr_img = preds["sr_img"]
lr_img = preds["lr_img"]
for i in (range(sr_img.shape[0])):
fm_sr = (sr_img[i].numpy() * 255).transpose(
1, 2, 0).astype(np.uint8)
fm_lr = (lr_img[i].numpy() * 255).transpose(
1, 2, 0).astype(np.uint8)
cv2.imwrite("output/images/{}_{}_sr.jpg".format(
sum_images, i), fm_sr)
cv2.imwrite("output/images/{}_{}_lr.jpg".format(
sum_images, i), fm_lr)
else:
preds = model(images)
preds = to_float32(preds)
......@@ -525,16 +516,6 @@ def eval(model,
preds = model(batch)
sr_img = preds["sr_img"]
lr_img = preds["lr_img"]
for i in (range(sr_img.shape[0])):
fm_sr = (sr_img[i].numpy() * 255).transpose(
1, 2, 0).astype(np.uint8)
fm_lr = (lr_img[i].numpy() * 255).transpose(
1, 2, 0).astype(np.uint8)
cv2.imwrite("output/images/{}_{}_sr.jpg".format(
sum_images, i), fm_sr)
cv2.imwrite("output/images/{}_{}_lr.jpg".format(
sum_images, i), fm_lr)
else:
preds = model(images)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册