提交 a68d90c5 编写于 作者: W wqz960

fix interpolation

上级 5c20b55e
......@@ -47,7 +47,7 @@ def forward(self, inputs):
python tools/feature_maps_visualization/fm_vis.py -i the image you want to test \
-c channel_num -p pretrained model \
--show whether to show \
--save whether to save \
--interpolation interpolation method\
--save_path where to save \
--use_gpu whether to use gpu
```
......@@ -56,6 +56,7 @@ python tools/feature_maps_visualization/fm_vis.py -i the image you want to test
+ `-c`:特征图维度,如 `./resnet50_vd/model`
+ `-p`:权重文件路径,如 `./ResNet50_pretrained/`
+ `--show`:是否展示图片,默认值 False
+ `--interpolation`: 图像插值方式, 默认值 1
+ `--save_path`:保存路径,如:`./tools/`
+ `--use_gpu`:是否使用 GPU 预测,默认值:True
......
......@@ -28,19 +28,20 @@ def parse_args():
parser.add_argument("-c", "--channel_num", type=int)
parser.add_argument("-p", "--pretrained_model", type=str)
parser.add_argument("--show", type=str2bool, default=False)
parser.add_argument("--interpolation", type=int, default=1)
parser.add_argument("--save_path", type=str)
parser.add_argument("--use_gpu", type=str2bool, default=True)
return parser.parse_args()
def create_operators():
def create_operators(interpolation=1):
size = 224
img_mean = [0.485, 0.456, 0.406]
img_std = [0.229, 0.224, 0.225]
img_scale = 1.0 / 255.0
decode_op = utils.DecodeImage()
resize_op = utils.ResizeImage(resize_short=256)
resize_op = utils.ResizeImage(resize_short=256, interpolation=interpolation)
crop_op = utils.CropImage(size=(size, size))
normalize_op = utils.NormalizeImage(
scale=img_scale, mean=img_mean, std=img_std)
......@@ -58,7 +59,7 @@ def preprocess(fname, ops):
def main():
args = parse_args()
operators = create_operators()
operators = create_operators(args.interpolation)
# assign the place
if args.use_gpu:
gpu_id = fluid.dygraph.parallel.Env().dev_id
......@@ -66,7 +67,7 @@ def main():
else:
place = fluid.CPUPlace()
pre_weights_dict = fluid.load_program_state(args.pretrained_model)
#pre_weights_dict = fluid.load_program_state(args.pretrained_model)
with fluid.dygraph.guard(place):
net = ResNet50()
data = preprocess(args.image_file, operators)
......
......@@ -32,15 +32,16 @@ class DecodeImage(object):
class ResizeImage(object):
def __init__(self, resize_short=None):
def __init__(self, resize_short=None, interpolation=1):
self.resize_short = resize_short
self.interpolation = interpolation
def __call__(self, img):
img_h, img_w = img.shape[:2]
percent = float(self.resize_short) / min(img_w, img_h)
w = int(round(img_w * percent))
h = int(round(img_h * percent))
return cv2.resize(img, (w, h))
return cv2.resize(img, (w, h), interpolation=self.interpolation)
class CropImage(object):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册