From a68d90c53bbad1614696bd8d5441fb26c2768095 Mon Sep 17 00:00:00 2001 From: wqz960 <362379625@qq.com> Date: Mon, 20 Jul 2020 04:16:02 +0000 Subject: [PATCH] fix interpolation --- docs/zh_CN/feature_visiualization/get_started.md | 3 ++- tools/feature_maps_visualization/fm_vis.py | 9 +++++---- tools/feature_maps_visualization/utils.py | 5 +++-- 3 files changed, 10 insertions(+), 7 deletions(-) diff --git a/docs/zh_CN/feature_visiualization/get_started.md b/docs/zh_CN/feature_visiualization/get_started.md index a287aad8..f80a2f84 100644 --- a/docs/zh_CN/feature_visiualization/get_started.md +++ b/docs/zh_CN/feature_visiualization/get_started.md @@ -47,7 +47,7 @@ def forward(self, inputs): python tools/feature_maps_visualization/fm_vis.py -i the image you want to test \ -c channel_num -p pretrained model \ --show whether to show \ - --save whether to save \ + --interpolation interpolation method\ --save_path where to save \ --use_gpu whether to use gpu ``` @@ -56,6 +56,7 @@ python tools/feature_maps_visualization/fm_vis.py -i the image you want to test + `-c`:特征图维度,如 `./resnet50_vd/model` + `-p`:权重文件路径,如 `./ResNet50_pretrained/` + `--show`:是否展示图片,默认值 False ++ `--interpolation`: 图像插值方式, 默认值 1 + `--save_path`:保存路径,如:`./tools/` + `--use_gpu`:是否使用 GPU 预测,默认值:True diff --git a/tools/feature_maps_visualization/fm_vis.py b/tools/feature_maps_visualization/fm_vis.py index 1731313b..b389d833 100644 --- a/tools/feature_maps_visualization/fm_vis.py +++ b/tools/feature_maps_visualization/fm_vis.py @@ -28,19 +28,20 @@ def parse_args(): parser.add_argument("-c", "--channel_num", type=int) parser.add_argument("-p", "--pretrained_model", type=str) parser.add_argument("--show", type=str2bool, default=False) + parser.add_argument("--interpolation", type=int, default=1) parser.add_argument("--save_path", type=str) parser.add_argument("--use_gpu", type=str2bool, default=True) return parser.parse_args() -def create_operators(): +def create_operators(interpolation=1): size = 224 img_mean = [0.485, 0.456, 0.406] img_std = [0.229, 0.224, 0.225] img_scale = 1.0 / 255.0 decode_op = utils.DecodeImage() - resize_op = utils.ResizeImage(resize_short=256) + resize_op = utils.ResizeImage(resize_short=256, interpolation=interpolation) crop_op = utils.CropImage(size=(size, size)) normalize_op = utils.NormalizeImage( scale=img_scale, mean=img_mean, std=img_std) @@ -58,7 +59,7 @@ def preprocess(fname, ops): def main(): args = parse_args() - operators = create_operators() + operators = create_operators(args.interpolation) # assign the place if args.use_gpu: gpu_id = fluid.dygraph.parallel.Env().dev_id @@ -66,7 +67,7 @@ def main(): else: place = fluid.CPUPlace() - pre_weights_dict = fluid.load_program_state(args.pretrained_model) + #pre_weights_dict = fluid.load_program_state(args.pretrained_model) with fluid.dygraph.guard(place): net = ResNet50() data = preprocess(args.image_file, operators) diff --git a/tools/feature_maps_visualization/utils.py b/tools/feature_maps_visualization/utils.py index 6c4a75e1..7c701493 100644 --- a/tools/feature_maps_visualization/utils.py +++ b/tools/feature_maps_visualization/utils.py @@ -32,15 +32,16 @@ class DecodeImage(object): class ResizeImage(object): - def __init__(self, resize_short=None): + def __init__(self, resize_short=None, interpolation=1): self.resize_short = resize_short + self.interpolation = interpolation def __call__(self, img): img_h, img_w = img.shape[:2] percent = float(self.resize_short) / min(img_w, img_h) w = int(round(img_w * percent)) h = int(round(img_h * percent)) - return cv2.resize(img, (w, h)) + return cv2.resize(img, (w, h), interpolation=self.interpolation) class CropImage(object): -- GitLab