提交 a68d90c5 编写于 作者: W wqz960

fix interpolation

上级 5c20b55e
...@@ -47,7 +47,7 @@ def forward(self, inputs): ...@@ -47,7 +47,7 @@ def forward(self, inputs):
python tools/feature_maps_visualization/fm_vis.py -i the image you want to test \ python tools/feature_maps_visualization/fm_vis.py -i the image you want to test \
-c channel_num -p pretrained model \ -c channel_num -p pretrained model \
--show whether to show \ --show whether to show \
--save whether to save \ --interpolation interpolation method\
--save_path where to save \ --save_path where to save \
--use_gpu whether to use gpu --use_gpu whether to use gpu
``` ```
...@@ -56,6 +56,7 @@ python tools/feature_maps_visualization/fm_vis.py -i the image you want to test ...@@ -56,6 +56,7 @@ python tools/feature_maps_visualization/fm_vis.py -i the image you want to test
+ `-c`:特征图维度,如 `./resnet50_vd/model` + `-c`:特征图维度,如 `./resnet50_vd/model`
+ `-p`:权重文件路径,如 `./ResNet50_pretrained/` + `-p`:权重文件路径,如 `./ResNet50_pretrained/`
+ `--show`:是否展示图片,默认值 False + `--show`:是否展示图片,默认值 False
+ `--interpolation`: 图像插值方式, 默认值 1
+ `--save_path`:保存路径,如:`./tools/` + `--save_path`:保存路径,如:`./tools/`
+ `--use_gpu`:是否使用 GPU 预测,默认值:True + `--use_gpu`:是否使用 GPU 预测,默认值:True
......
...@@ -28,19 +28,20 @@ def parse_args(): ...@@ -28,19 +28,20 @@ def parse_args():
parser.add_argument("-c", "--channel_num", type=int) parser.add_argument("-c", "--channel_num", type=int)
parser.add_argument("-p", "--pretrained_model", type=str) parser.add_argument("-p", "--pretrained_model", type=str)
parser.add_argument("--show", type=str2bool, default=False) parser.add_argument("--show", type=str2bool, default=False)
parser.add_argument("--interpolation", type=int, default=1)
parser.add_argument("--save_path", type=str) parser.add_argument("--save_path", type=str)
parser.add_argument("--use_gpu", type=str2bool, default=True) parser.add_argument("--use_gpu", type=str2bool, default=True)
return parser.parse_args() return parser.parse_args()
def create_operators(): def create_operators(interpolation=1):
size = 224 size = 224
img_mean = [0.485, 0.456, 0.406] img_mean = [0.485, 0.456, 0.406]
img_std = [0.229, 0.224, 0.225] img_std = [0.229, 0.224, 0.225]
img_scale = 1.0 / 255.0 img_scale = 1.0 / 255.0
decode_op = utils.DecodeImage() decode_op = utils.DecodeImage()
resize_op = utils.ResizeImage(resize_short=256) resize_op = utils.ResizeImage(resize_short=256, interpolation=interpolation)
crop_op = utils.CropImage(size=(size, size)) crop_op = utils.CropImage(size=(size, size))
normalize_op = utils.NormalizeImage( normalize_op = utils.NormalizeImage(
scale=img_scale, mean=img_mean, std=img_std) scale=img_scale, mean=img_mean, std=img_std)
...@@ -58,7 +59,7 @@ def preprocess(fname, ops): ...@@ -58,7 +59,7 @@ def preprocess(fname, ops):
def main(): def main():
args = parse_args() args = parse_args()
operators = create_operators() operators = create_operators(args.interpolation)
# assign the place # assign the place
if args.use_gpu: if args.use_gpu:
gpu_id = fluid.dygraph.parallel.Env().dev_id gpu_id = fluid.dygraph.parallel.Env().dev_id
...@@ -66,7 +67,7 @@ def main(): ...@@ -66,7 +67,7 @@ def main():
else: else:
place = fluid.CPUPlace() place = fluid.CPUPlace()
pre_weights_dict = fluid.load_program_state(args.pretrained_model) #pre_weights_dict = fluid.load_program_state(args.pretrained_model)
with fluid.dygraph.guard(place): with fluid.dygraph.guard(place):
net = ResNet50() net = ResNet50()
data = preprocess(args.image_file, operators) data = preprocess(args.image_file, operators)
......
...@@ -32,15 +32,16 @@ class DecodeImage(object): ...@@ -32,15 +32,16 @@ class DecodeImage(object):
class ResizeImage(object): class ResizeImage(object):
def __init__(self, resize_short=None): def __init__(self, resize_short=None, interpolation=1):
self.resize_short = resize_short self.resize_short = resize_short
self.interpolation = interpolation
def __call__(self, img): def __call__(self, img):
img_h, img_w = img.shape[:2] img_h, img_w = img.shape[:2]
percent = float(self.resize_short) / min(img_w, img_h) percent = float(self.resize_short) / min(img_w, img_h)
w = int(round(img_w * percent)) w = int(round(img_w * percent))
h = int(round(img_h * percent)) h = int(round(img_h * percent))
return cv2.resize(img, (w, h)) return cv2.resize(img, (w, h), interpolation=self.interpolation)
class CropImage(object): class CropImage(object):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册