未验证 提交 f921b9bd 编写于 作者: L littletomatodonkey 提交者: GitHub

fix feature map visualization (#377)

fix feature map visualization
上级 c9f8e8c6
...@@ -55,16 +55,30 @@ python tools/feature_maps_visualization/fm_vis.py -i the image you want to test ...@@ -55,16 +55,30 @@ python tools/feature_maps_visualization/fm_vis.py -i the image you want to test
+ `-i`:待预测的图片文件路径,如 `./test.jpeg` + `-i`:待预测的图片文件路径,如 `./test.jpeg`
+ `-c`:特征图维度,如 `./resnet50_vd/model` + `-c`:特征图维度,如 `./resnet50_vd/model`
+ `-p`:权重文件路径,如 `./ResNet50_pretrained/` + `-p`:权重文件路径,如 `./ResNet50_pretrained/`
+ `--show`:是否展示图片,默认值 False
+ `--interpolation`: 图像插值方式, 默认值 1 + `--interpolation`: 图像插值方式, 默认值 1
+ `--save_path`:保存路径,如:`./tools/` + `--save_path`:保存路径,如:`./tools/`
+ `--use_gpu`:是否使用 GPU 预测,默认值:True + `--use_gpu`:是否使用 GPU 预测,默认值:True
## 四、结果 ## 四、结果
输入图片:
![](../../../tools/feature_maps_visualization/test.jpg) * 输入图片:
输出特征图: ![](../../../docs/images/feature_maps/feature_visualization_input.jpg)
![](../../../tools/feature_maps_visualization/fm.jpg) * 运行下面的特征图可视化脚本
```
python tools/feature_maps_visualization/fm_vis.py \
-i ./docs/images/feature_maps/feature_visualization_input.jpg \
-c 5 \
-p pretrained/ResNet50_pretrained/ \
--show=True \
--interpolation=1 \
--save_path="./output.png" \
--use_gpu=False \
--load_static_weights=True
```
* 输出特征图保存为`output.png`,如下所示。
![](../../../docs/images/feature_maps/feature_visualization_output.jpg)
...@@ -11,84 +11,92 @@ ...@@ -11,84 +11,92 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import numpy as np
from resnet import ResNet50 import cv2
import paddle.fluid as fluid
import numpy as np
import cv2
import utils import utils
import argparse import argparse
import os
import sys
__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(__dir__)
sys.path.append(os.path.abspath(os.path.join(__dir__, '../..')))
import paddle
from paddle.distributed import ParallelEnv
from resnet import ResNet50
from ppcls.utils.save_load import load_dygraph_pretrain
def parse_args(): def parse_args():
def str2bool(v): def str2bool(v):
return v.lower() in ("true", "t", "1") return v.lower() in ("true", "t", "1")
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("-i", "--image_file", type=str) parser.add_argument("-i", "--image_file", type=str)
parser.add_argument("-c", "--channel_num", type=int) parser.add_argument("-c", "--channel_num", type=int)
parser.add_argument("-p", "--pretrained_model", type=str) parser.add_argument("-p", "--pretrained_model", type=str)
parser.add_argument("--show", type=str2bool, default=False) parser.add_argument("--show", type=str2bool, default=False)
parser.add_argument("--interpolation", type=int, default=1) parser.add_argument("--interpolation", type=int, default=1)
parser.add_argument("--save_path", type=str) parser.add_argument("--save_path", type=str, default=None)
parser.add_argument("--use_gpu", type=str2bool, default=True) parser.add_argument("--use_gpu", type=str2bool, default=True)
parser.add_argument(
"--load_static_weights",
type=str2bool,
default=False,
help='Whether to load the pretrained weights saved in static mode')
return parser.parse_args() return parser.parse_args()
def create_operators(interpolation=1): def create_operators(interpolation=1):
size = 224 size = 224
img_mean = [0.485, 0.456, 0.406] img_mean = [0.485, 0.456, 0.406]
img_std = [0.229, 0.224, 0.225] img_std = [0.229, 0.224, 0.225]
img_scale = 1.0 / 255.0 img_scale = 1.0 / 255.0
decode_op = utils.DecodeImage() resize_op = utils.ResizeImage(
resize_op = utils.ResizeImage(resize_short=256, interpolation=interpolation) resize_short=256, interpolation=interpolation)
crop_op = utils.CropImage(size=(size, size)) crop_op = utils.CropImage(size=(size, size))
normalize_op = utils.NormalizeImage( normalize_op = utils.NormalizeImage(
scale=img_scale, mean=img_mean, std=img_std) scale=img_scale, mean=img_mean, std=img_std)
totensor_op = utils.ToTensor() totensor_op = utils.ToTensor()
return [decode_op, resize_op, crop_op, normalize_op, totensor_op] return [resize_op, crop_op, normalize_op, totensor_op]
def preprocess(fname, ops): def preprocess(data, ops):
data = open(fname, 'rb').read()
for op in ops: for op in ops:
data = op(data) data = op(data)
return data return data
def main(): def main():
args = parse_args() args = parse_args()
operators = create_operators(args.interpolation) operators = create_operators(args.interpolation)
# assign the place # assign the place
if args.use_gpu: place = 'gpu:{}'.format(ParallelEnv().dev_id) if args.use_gpu else 'cpu'
gpu_id = fluid.dygraph.parallel.Env().dev_id place = paddle.set_device(place)
place = fluid.CUDAPlace(gpu_id)
else: net = ResNet50()
place = fluid.CPUPlace() load_dygraph_pretrain(net, args.pretrained_model, args.load_static_weights)
#pre_weights_dict = fluid.load_program_state(args.pretrained_model) img = cv2.imread(args.image_file, cv2.IMREAD_COLOR)
with fluid.dygraph.guard(place): data = preprocess(img, operators)
net = ResNet50() data = np.expand_dims(data, axis=0)
data = preprocess(args.image_file, operators) data = paddle.to_tensor(data)
data = np.expand_dims(data, axis=0) net.eval()
data = fluid.dygraph.to_variable(data) _, fm = net(data)
dy_weights_dict = net.state_dict() assert args.channel_num >= 0 and args.channel_num <= fm.shape[
pre_weights_dict_new = {} 1], "the channel is out of the range, should be in {} but got {}".format(
for key in dy_weights_dict: [0, fm.shape[1]], args.channel_num)
weights_name = dy_weights_dict[key].name
pre_weights_dict_new[key] = pre_weights_dict[weights_name] fm = (np.squeeze(fm[0][args.channel_num].numpy()) * 255).astype(np.uint8)
net.set_dict(pre_weights_dict_new) fm = cv2.resize(fm, (img.shape[1], img.shape[0]))
net.eval() if args.save_path is not None:
_, fm = net(data) print("the feature map is saved in path: {}".format(args.save_path))
assert args.channel_num >= 0 and args.channel_num <= fm.shape[1], "the channel is out of the range, should be in {} but got {}".format([0, fm.shape[1]], args.channel_num) cv2.imwrite(args.save_path, fm)
fm = (np.squeeze(fm[0][args.channel_num].numpy())*255).astype(np.uint8)
if fm is not None:
if args.save:
cv2.imwrite(args.save_path, fm)
if args.show:
cv2.show(fm)
cv2.waitKey(0)
if __name__ == "__main__": if __name__ == "__main__":
main() main()
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np import numpy as np
import argparse
import ast
import paddle import paddle
import paddle.fluid as fluid from paddle import ParamAttr
from paddle.fluid.param_attr import ParamAttr import paddle.nn as nn
from paddle.fluid.layer_helper import LayerHelper import paddle.nn.functional as F
from paddle.fluid.dygraph.nn import Conv2D, Pool2D, BatchNorm, Linear from paddle.nn import Conv2D, BatchNorm, Linear, Dropout
from paddle.fluid.dygraph.base import to_variable from paddle.nn import AdaptiveAvgPool2D, MaxPool2D, AvgPool2D
from paddle.nn.initializer import Uniform
from paddle.fluid import framework
import math import math
import sys
import time
class ConvBNLayer(fluid.dygraph.Layer): __all__ = ["ResNet18", "ResNet34", "ResNet50", "ResNet101", "ResNet152"]
class ConvBNLayer(nn.Layer):
def __init__(self, def __init__(self,
num_channels, num_channels,
num_filters, num_filters,
...@@ -26,25 +42,25 @@ class ConvBNLayer(fluid.dygraph.Layer): ...@@ -26,25 +42,25 @@ class ConvBNLayer(fluid.dygraph.Layer):
super(ConvBNLayer, self).__init__() super(ConvBNLayer, self).__init__()
self._conv = Conv2D( self._conv = Conv2D(
num_channels=num_channels, in_channels=num_channels,
num_filters=num_filters, out_channels=num_filters,
filter_size=filter_size, kernel_size=filter_size,
stride=stride, stride=stride,
padding=(filter_size - 1) // 2, padding=(filter_size - 1) // 2,
groups=groups, groups=groups,
act=None, weight_attr=ParamAttr(name=name + "_weights"),
param_attr=ParamAttr(name=name + "_weights"),
bias_attr=False) bias_attr=False)
if name == "conv1": if name == "conv1":
bn_name = "bn_" + name bn_name = "bn_" + name
else: else:
bn_name = "bn" + name[3:] bn_name = "bn" + name[3:]
self._batch_norm = BatchNorm(num_filters, self._batch_norm = BatchNorm(
act=act, num_filters,
param_attr=ParamAttr(name=bn_name + '_scale'), act=act,
bias_attr=ParamAttr(bn_name + '_offset'), param_attr=ParamAttr(name=bn_name + "_scale"),
moving_mean_name=bn_name + '_mean', bias_attr=ParamAttr(bn_name + "_offset"),
moving_variance_name=bn_name + '_variance') moving_mean_name=bn_name + "_mean",
moving_variance_name=bn_name + "_variance")
def forward(self, inputs): def forward(self, inputs):
y = self._conv(inputs) y = self._conv(inputs)
...@@ -52,7 +68,7 @@ class ConvBNLayer(fluid.dygraph.Layer): ...@@ -52,7 +68,7 @@ class ConvBNLayer(fluid.dygraph.Layer):
return y return y
class BottleneckBlock(fluid.dygraph.Layer): class BottleneckBlock(nn.Layer):
def __init__(self, def __init__(self,
num_channels, num_channels,
num_filters, num_filters,
...@@ -65,21 +81,21 @@ class BottleneckBlock(fluid.dygraph.Layer): ...@@ -65,21 +81,21 @@ class BottleneckBlock(fluid.dygraph.Layer):
num_channels=num_channels, num_channels=num_channels,
num_filters=num_filters, num_filters=num_filters,
filter_size=1, filter_size=1,
act='relu', act="relu",
name=name+"_branch2a") name=name + "_branch2a")
self.conv1 = ConvBNLayer( self.conv1 = ConvBNLayer(
num_channels=num_filters, num_channels=num_filters,
num_filters=num_filters, num_filters=num_filters,
filter_size=3, filter_size=3,
stride=stride, stride=stride,
act='relu', act="relu",
name=name+"_branch2b") name=name + "_branch2b")
self.conv2 = ConvBNLayer( self.conv2 = ConvBNLayer(
num_channels=num_filters, num_channels=num_filters,
num_filters=num_filters * 4, num_filters=num_filters * 4,
filter_size=1, filter_size=1,
act=None, act=None,
name=name+"_branch2c") name=name + "_branch2c")
if not shortcut: if not shortcut:
self.short = ConvBNLayer( self.short = ConvBNLayer(
...@@ -103,90 +119,163 @@ class BottleneckBlock(fluid.dygraph.Layer): ...@@ -103,90 +119,163 @@ class BottleneckBlock(fluid.dygraph.Layer):
else: else:
short = self.short(inputs) short = self.short(inputs)
y = fluid.layers.elementwise_add(x=short, y=conv2) y = paddle.add(x=short, y=conv2)
y = F.relu(y)
return y
class BasicBlock(nn.Layer):
def __init__(self,
num_channels,
num_filters,
stride,
shortcut=True,
name=None):
super(BasicBlock, self).__init__()
self.stride = stride
self.conv0 = ConvBNLayer(
num_channels=num_channels,
num_filters=num_filters,
filter_size=3,
stride=stride,
act="relu",
name=name + "_branch2a")
self.conv1 = ConvBNLayer(
num_channels=num_filters,
num_filters=num_filters,
filter_size=3,
act=None,
name=name + "_branch2b")
if not shortcut:
self.short = ConvBNLayer(
num_channels=num_channels,
num_filters=num_filters,
filter_size=1,
stride=stride,
name=name + "_branch1")
layer_helper = LayerHelper(self.full_name(), act='relu') self.shortcut = shortcut
return layer_helper.append_activation(y)
def forward(self, inputs):
y = self.conv0(inputs)
conv1 = self.conv1(y)
class ResNet(fluid.dygraph.Layer): if self.shortcut:
short = inputs
else:
short = self.short(inputs)
y = paddle.add(x=short, y=conv1)
y = F.relu(y)
return y
class ResNet(nn.Layer):
def __init__(self, layers=50, class_dim=1000): def __init__(self, layers=50, class_dim=1000):
super(ResNet, self).__init__() super(ResNet, self).__init__()
self.layers = layers self.layers = layers
supported_layers = [50, 101, 152] supported_layers = [18, 34, 50, 101, 152]
assert layers in supported_layers, \ assert layers in supported_layers, \
"supported layers are {} but input layer is {}".format(supported_layers, layers) "supported layers are {} but input layer is {}".format(
self.fm = None supported_layers, layers)
if layers == 50: if layers == 18:
depth = [2, 2, 2, 2]
elif layers == 34 or layers == 50:
depth = [3, 4, 6, 3] depth = [3, 4, 6, 3]
elif layers == 101: elif layers == 101:
depth = [3, 4, 23, 3] depth = [3, 4, 23, 3]
elif layers == 152: elif layers == 152:
depth = [3, 8, 36, 3] depth = [3, 8, 36, 3]
num_channels = [64, 256, 512, 1024] num_channels = [64, 256, 512,
1024] if layers >= 50 else [64, 64, 128, 256]
num_filters = [64, 128, 256, 512] num_filters = [64, 128, 256, 512]
self.feature_map = None
self.conv = ConvBNLayer( self.conv = ConvBNLayer(
num_channels=3, num_channels=3,
num_filters=64, num_filters=64,
filter_size=7, filter_size=7,
stride=2, stride=2,
act='relu', act="relu",
name="conv1") name="conv1")
self.pool2d_max = Pool2D( self.pool2d_max = MaxPool2D(kernel_size=3, stride=2, padding=1)
pool_size=3,
pool_stride=2, self.block_list = []
pool_padding=1, if layers >= 50:
pool_type='max') for block in range(len(depth)):
shortcut = False
self.bottleneck_block_list = [] for i in range(depth[block]):
for block in range(len(depth)): if layers in [101, 152] and block == 2:
shortcut = False if i == 0:
for i in range(depth[block]): conv_name = "res" + str(block + 2) + "a"
if layers in [101, 152] and block == 2: else:
if i == 0: conv_name = "res" + str(block + 2) + "b" + str(i)
conv_name="res"+str(block+2)+"a"
else: else:
conv_name="res"+str(block+2)+"b"+str(i) conv_name = "res" + str(block + 2) + chr(97 + i)
else: bottleneck_block = self.add_sublayer(
conv_name="res"+str(block+2)+chr(97+i) conv_name,
bottleneck_block = self.add_sublayer( BottleneckBlock(
'bb_%d_%d' % (block, i), num_channels=num_channels[block]
BottleneckBlock( if i == 0 else num_filters[block] * 4,
num_channels=num_channels[block] num_filters=num_filters[block],
if i == 0 else num_filters[block] * 4, stride=2 if i == 0 and block != 0 else 1,
num_filters=num_filters[block], shortcut=shortcut,
stride=2 if i == 0 and block != 0 else 1, name=conv_name))
shortcut=shortcut, self.block_list.append(bottleneck_block)
name=conv_name)) shortcut = True
self.bottleneck_block_list.append(bottleneck_block) else:
shortcut = True for block in range(len(depth)):
shortcut = False
self.pool2d_avg = Pool2D( for i in range(depth[block]):
pool_size=7, pool_type='avg', global_pooling=True) conv_name = "res" + str(block + 2) + chr(97 + i)
basic_block = self.add_sublayer(
self.pool2d_avg_output = num_filters[len(num_filters) - 1] * 4 * 1 * 1 conv_name,
BasicBlock(
stdv = 1.0 / math.sqrt(2048 * 1.0) num_channels=num_channels[block]
if i == 0 else num_filters[block],
self.out = Linear(self.pool2d_avg_output, num_filters=num_filters[block],
class_dim, stride=2 if i == 0 and block != 0 else 1,
param_attr=ParamAttr( shortcut=shortcut,
initializer=fluid.initializer.Uniform(-stdv, stdv), name="fc_0.w_0"), name=conv_name))
bias_attr=ParamAttr(name="fc_0.b_0")) self.block_list.append(basic_block)
shortcut = True
self.pool2d_avg = AdaptiveAvgPool2D(1)
self.pool2d_avg_channels = num_channels[-1] * 2
stdv = 1.0 / math.sqrt(self.pool2d_avg_channels * 1.0)
self.out = Linear(
self.pool2d_avg_channels,
class_dim,
weight_attr=ParamAttr(
initializer=Uniform(-stdv, stdv), name="fc_0.w_0"),
bias_attr=ParamAttr(name="fc_0.b_0"))
def forward(self, inputs): def forward(self, inputs):
y = self.conv(inputs) y = self.conv(inputs)
y = self.pool2d_max(y) y = self.pool2d_max(y)
self.fm = y self.feature_map = y
for bottleneck_block in self.bottleneck_block_list: for block in self.block_list:
y = bottleneck_block(y) y = block(y)
y = self.pool2d_avg(y) y = self.pool2d_avg(y)
y = fluid.layers.reshape(y, shape=[-1, self.pool2d_avg_output]) y = paddle.reshape(y, shape=[-1, self.pool2d_avg_channels])
y = self.out(y) y = self.out(y)
return y, self.fm return y, self.feature_map
def ResNet18(**args):
model = ResNet(layers=18, **args)
return model
def ResNet34(**args):
model = ResNet(layers=34, **args)
return model
def ResNet50(**args): def ResNet50(**args):
...@@ -202,14 +291,3 @@ def ResNet101(**args): ...@@ -202,14 +291,3 @@ def ResNet101(**args):
def ResNet152(**args): def ResNet152(**args):
model = ResNet(layers=152, **args) model = ResNet(layers=152, **args)
return model return model
if __name__ == "__main__":
import numpy as np
place = fluid.CPUPlace()
with fluid.dygraph.guard(place):
model = ResNet50()
img = np.random.uniform(0, 255, [1, 3, 224, 224]).astype('float32')
img = fluid.dygraph.to_variable(img)
res = model(img)
print(res.shape)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册