From c731d54e18d544cd14e4a92f462722f1976ef402 Mon Sep 17 00:00:00 2001 From: xiaoting <31891223+tink2123@users.noreply.github.com> Date: Fri, 19 Aug 2022 14:49:35 +0800 Subject: [PATCH] fix sr typo (#7262) * fix sr typo * rm eval visual --- tools/infer_sr.py | 10 +++++----- tools/program.py | 23 ++--------------------- 2 files changed, 7 insertions(+), 26 deletions(-) diff --git a/tools/infer_sr.py b/tools/infer_sr.py index 0bc2f6aa..df4334f3 100755 --- a/tools/infer_sr.py +++ b/tools/infer_sr.py @@ -63,14 +63,14 @@ def main(): elif op_name in ['SRResize']: op[op_name]['infer_mode'] = True elif op_name == 'KeepKeys': - op[op_name]['keep_keys'] = ['imge_lr'] + op[op_name]['keep_keys'] = ['img_lr'] transforms.append(op) global_config['infer_mode'] = True ops = create_operators(transforms, global_config) - save_res_path = config['Global'].get('save_res_path', "./infer_result") - if not os.path.exists(os.path.dirname(save_res_path)): - os.makedirs(os.path.dirname(save_res_path)) + save_visual_path = config['Global'].get('save_visual', "infer_result/") + if not os.path.exists(os.path.dirname(save_visual_path)): + os.makedirs(os.path.dirname(save_visual_path)) model.eval() for file in get_image_file_list(config['Global']['infer_img']): @@ -87,7 +87,7 @@ def main(): fm_sr = (sr_img.numpy() * 255).transpose(1, 2, 0).astype(np.uint8) fm_lr = (lr_img.numpy() * 255).transpose(1, 2, 0).astype(np.uint8) img_name_pure = os.path.split(file)[-1] - cv2.imwrite("infer_result/sr_{}".format(img_name_pure), + cv2.imwrite("{}/sr_{}".format(save_visual_path, img_name_pure), fm_sr[:, :, ::-1]) logger.info("The visualized image saved in infer_result/sr_{}".format( img_name_pure)) diff --git a/tools/program.py b/tools/program.py index 012a2c61..8de15ee0 100755 --- a/tools/program.py +++ b/tools/program.py @@ -231,7 +231,8 @@ def train(config, use_srn = config['Architecture']['algorithm'] == "SRN" extra_input_models = [ - "SRN", "NRTR", "SAR", "SEED", "SVTR", "SPIN", "VisionLAN", "RobustScanner" + "SRN", "NRTR", "SAR", "SEED", "SVTR", "SPIN", "VisionLAN", + "RobustScanner" ] extra_input = False if config['Architecture']['algorithm'] == 'Distillation': @@ -503,16 +504,6 @@ def eval(model, preds = model(batch) sr_img = preds["sr_img"] lr_img = preds["lr_img"] - - for i in (range(sr_img.shape[0])): - fm_sr = (sr_img[i].numpy() * 255).transpose( - 1, 2, 0).astype(np.uint8) - fm_lr = (lr_img[i].numpy() * 255).transpose( - 1, 2, 0).astype(np.uint8) - cv2.imwrite("output/images/{}_{}_sr.jpg".format( - sum_images, i), fm_sr) - cv2.imwrite("output/images/{}_{}_lr.jpg".format( - sum_images, i), fm_lr) else: preds = model(images) preds = to_float32(preds) @@ -525,16 +516,6 @@ def eval(model, preds = model(batch) sr_img = preds["sr_img"] lr_img = preds["lr_img"] - - for i in (range(sr_img.shape[0])): - fm_sr = (sr_img[i].numpy() * 255).transpose( - 1, 2, 0).astype(np.uint8) - fm_lr = (lr_img[i].numpy() * 255).transpose( - 1, 2, 0).astype(np.uint8) - cv2.imwrite("output/images/{}_{}_sr.jpg".format( - sum_images, i), fm_sr) - cv2.imwrite("output/images/{}_{}_lr.jpg".format( - sum_images, i), fm_lr) else: preds = model(images) -- GitLab