提交 f9f8d032 编写于 作者: W wqz960

add feature maps visualization

上级 5d3fe63f
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from resnet import ResNet50
import paddle.fluid as fluid
import numpy as np
import cv2
import utils
import argparse
import matplotlib.pyplot as plt
def parse_args():
def str2bool(v):
return v.lower() in ("true", "t", "1")
parser = argparse.ArgumentParser()
parser.add_argument("-i", "--image_file", type=str)
parser.add_argument("-c", "--channel_num", type=int)
parser.add_argument("-p", "--pretrained_model", type=str)
parser.add_argument("--show", type=str2bool, default=False)
parser.add_argument("--save", type=str2bool, default=True)
parser.add_argument("--save_path", type=str)
parser.add_argument("--use_gpu", type=str2bool, default=True)
return parser.parse_args()
def create_operators():
size = 224
img_mean = [0.485, 0.456, 0.406]
img_std = [0.229, 0.224, 0.225]
img_scale = 1.0 / 255.0
decode_op = utils.DecodeImage()
resize_op = utils.ResizeImage(resize_short=256)
crop_op = utils.CropImage(size=(size, size))
normalize_op = utils.NormalizeImage(
scale=img_scale, mean=img_mean, std=img_std)
totensor_op = utils.ToTensor()
return [decode_op, resize_op, crop_op, normalize_op, totensor_op]
def preprocess(fname, ops):
data = open(fname, 'rb').read()
for op in ops:
data = op(data)
return data
def main():
args = parse_args()
operators = create_operators()
# assign the place
if args.use_gpu:
gpu_id = fluid.dygraph.parallel.Env().dev_id
place = fluid.CUDAPlace(gpu_id)
else:
place = fluid.CPUPlace()
fm = None
print(args.pretrained_model)
pre_weights_dict = fluid.load_program_state(args.pretrained_model)
with fluid.dygraph.guard(place):
net = ResNet50()
data = preprocess(args.image_file, operators)
data = np.expand_dims(data, axis=0)
data = fluid.dygraph.to_variable(data)
dy_weights_dict = net.state_dict()
pre_weights_dict_new = {}
for key in dy_weights_dict:
weights_name = dy_weights_dict[key].name
pre_weights_dict_new[key] = pre_weights_dict[weights_name]
net.set_dict(pre_weights_dict_new)
net.eval()
_, fm = net(data)
assert args.channel_num >= 0 and args.channel_num <= fm.shape[1], "the channel is out of the range"
fm = (np.squeeze(fm[0][args.channel_num].numpy())*255).astype(np.uint8)
print(fm)
if fm is not None:
if args.save:
print(args.save_path)
cv2.imwrite(args.save_path, fm)
if args.show:
cv2.show(fm)
cv2.waitKey(0)
if __name__ == "__main__":
main()
# 特征图可视化指南
## 一、概述
特征图是输入图片在卷积网络中的特征表达,对特征图的研究可以有利于我们对于模型的理解与设计,所以基于动态图我们使用本工具来可视化特征图。
## 二、准备工作
首先我们需要选定研究的模型,本文设定ResNet50作为研究模型,将resnet.py从[模型库](../../ppcls/modeling/architecture/)拷贝到当前目录下,并下载预训练模型[预训练模型](../../docs/zh_CN/models/models_intro), 复制resnet50的模型链接,使用下列命令下载并解压预训练模型。
```bash
wget The Link for Pretrained Model
tar -xf Downloaded Pretrained Model
```
以resnet50为例:
```bash
wget https://paddle-imagenet-models-name.bj.bcebos.com/ResNet50_pretrained.tar
tar -xf ResNet50_pretrained.tar
```
## 三、修改模型
找到我们所需要的特征图位置,设置self.fm将其fetch出来,本文以resnet50中的stem层之后的特征图为例。
在ResNet50的__init__函数中定义self.fm
```python
self.fm = None
```
在ResNet50的forward函数中指定特征图
```python
def forward(self, inputs):
y = self.conv(inputs)
self.fm = y
y = self.pool2d_max(y)
for bottleneck_block in self.bottleneck_block_list:
y = bottleneck_block(y)
y = self.pool2d_avg(y)
y = fluid.layers.reshape(y, shape=[-1, self.pool2d_avg_output])
y = self.out(y)
return y, self.fm
```
执行函数
```bash
python tools/feature_maps_visualization/fm_vis.py -i the image you want to test \
-c channel_num -p pretrained model \
--show whether to show \
--save whether to save \
--save_path where to save \
--use_gpu whether to use gpu
```
参数说明:
+ `-i`:待预测的图片文件路径,如 `./test.jpeg`
+ `-c`:特征图维度,如 `./resnet50-vd/model`
+ `-p`:权重文件路径,如 `./ResNet50_pretrained/`
+ `--show`:是否展示图片,默认值 False
+ `--save`:是否保存图片,默认值:True
+ `--save_path`:保存路径,如:`./tools/`
+ `--use_gpu`:是否使用 GPU 预测,默认值:True
## 四、结果
输入图片:
![](../../tools/feature_maps_visualization/test.jpg)
输出特征图:
![](../../tools/feature_maps_visualization/fm.jpg)
\ No newline at end of file
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import cv2
import numpy as np
class DecodeImage(object):
def __init__(self, to_rgb=True):
self.to_rgb = to_rgb
def __call__(self, img):
data = np.frombuffer(img, dtype='uint8')
img = cv2.imdecode(data, 1)
if self.to_rgb:
assert img.shape[2] == 3, 'invalid shape of image[%s]' % (
img.shape)
img = img[:, :, ::-1]
return img
class ResizeImage(object):
def __init__(self, resize_short=None):
self.resize_short = resize_short
def __call__(self, img):
img_h, img_w = img.shape[:2]
percent = float(self.resize_short) / min(img_w, img_h)
w = int(round(img_w * percent))
h = int(round(img_h * percent))
return cv2.resize(img, (w, h))
class CropImage(object):
def __init__(self, size):
if type(size) is int:
self.size = (size, size)
else:
self.size = size
def __call__(self, img):
w, h = self.size
img_h, img_w = img.shape[:2]
w_start = (img_w - w) // 2
h_start = (img_h - h) // 2
w_end = w_start + w
h_end = h_start + h
return img[h_start:h_end, w_start:w_end, :]
class NormalizeImage(object):
def __init__(self, scale=None, mean=None, std=None):
self.scale = np.float32(scale if scale is not None else 1.0 / 255.0)
mean = mean if mean is not None else [0.485, 0.456, 0.406]
std = std if std is not None else [0.229, 0.224, 0.225]
shape = (1, 1, 3)
self.mean = np.array(mean).reshape(shape).astype('float32')
self.std = np.array(std).reshape(shape).astype('float32')
def __call__(self, img):
return (img.astype('float32') * self.scale - self.mean) / self.std
class ToTensor(object):
def __init__(self):
pass
def __call__(self, img):
img = img.transpose((2, 0, 1))
return img
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册