Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleClas
提交
184bdd76
P
PaddleClas
项目概览
PaddlePaddle
/
PaddleClas
大约 1 年 前同步成功
通知
115
Star
4999
Fork
1114
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
19
列表
看板
标记
里程碑
合并请求
6
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleClas
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
19
Issue
19
列表
看板
标记
里程碑
合并请求
6
合并请求
6
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
184bdd76
编写于
7月 20, 2020
作者:
L
littletomatodonkey
提交者:
GitHub
7月 20, 2020
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #198 from wqz960/PaddleClas_74
add feature maps visualization
上级
5d3fe63f
a68d90c5
变更
7
隐藏空白更改
内联
并排
Showing
7 changed file
with
466 addition
and
0 deletion
+466
-0
docs/images/feature_maps/feature_visualization_input.jpg
docs/images/feature_maps/feature_visualization_input.jpg
+0
-0
docs/images/feature_maps/feature_visualization_output.jpg
docs/images/feature_maps/feature_visualization_output.jpg
+0
-0
docs/zh_CN/feature_visiualization/get_started.md
docs/zh_CN/feature_visiualization/get_started.md
+70
-0
tools/feature_maps_visualization/download_resnet50_pretrained.sh
...eature_maps_visualization/download_resnet50_pretrained.sh
+2
-0
tools/feature_maps_visualization/fm_vis.py
tools/feature_maps_visualization/fm_vis.py
+94
-0
tools/feature_maps_visualization/resnet.py
tools/feature_maps_visualization/resnet.py
+215
-0
tools/feature_maps_visualization/utils.py
tools/feature_maps_visualization/utils.py
+85
-0
未找到文件。
docs/images/feature_maps/feature_visualization_input.jpg
0 → 100644
浏览文件 @
184bdd76
30.5 KB
docs/images/feature_maps/feature_visualization_output.jpg
0 → 100644
浏览文件 @
184bdd76
10.5 KB
docs/zh_CN/feature_visiualization/get_started.md
0 → 100644
浏览文件 @
184bdd76
# 特征图可视化指南
## 一、概述
特征图是输入图片在卷积网络中的特征表达,对特征图的研究可以有利于我们对于模型的理解与设计,所以基于动态图我们使用本工具来可视化特征图。
## 二、准备工作
首先需要选定研究的模型,本文设定ResNet50作为研究模型,将resnet.py从
[
模型库
](
../../../ppcls/modeling/architecture/
)
拷贝到当前目录下,并下载预训练模型
[
预训练模型
](
../../zh_CN/models/models_intro
)
, 复制resnet50的模型链接,使用下列命令下载并解压预训练模型。
```
bash
wget The Link
for
Pretrained Model
tar
-xf
Downloaded Pretrained Model
```
以resnet50为例:
```
bash
wget https://paddle-imagenet-models-name.bj.bcebos.com/ResNet50_pretrained.tar
tar
-xf
ResNet50_pretrained.tar
```
## 三、修改模型
找到我们所需要的特征图位置,设置self.fm将其fetch出来,本文以resnet50中的stem层之后的特征图为例。
在fm_vis.py中修改模型的名字。
在ResNet50的__init__函数中定义self.fm
```
python
self
.
fm
=
None
```
在ResNet50的forward函数中指定特征图
```
python
def
forward
(
self
,
inputs
):
y
=
self
.
conv
(
inputs
)
self
.
fm
=
y
y
=
self
.
pool2d_max
(
y
)
for
bottleneck_block
in
self
.
bottleneck_block_list
:
y
=
bottleneck_block
(
y
)
y
=
self
.
pool2d_avg
(
y
)
y
=
fluid
.
layers
.
reshape
(
y
,
shape
=
[
-
1
,
self
.
pool2d_avg_output
])
y
=
self
.
out
(
y
)
return
y
,
self
.
fm
```
执行函数
```
bash
python tools/feature_maps_visualization/fm_vis.py
-i
the image you want to
test
\
-c
channel_num
-p
pretrained model
\
--show
whether to show
\
--interpolation
interpolation method
\
--save_path
where to save
\
--use_gpu
whether to use gpu
```
参数说明:
+
`-i`
:待预测的图片文件路径,如
`./test.jpeg`
+
`-c`
:特征图维度,如
`./resnet50_vd/model`
+
`-p`
:权重文件路径,如
`./ResNet50_pretrained/`
+
`--show`
:是否展示图片,默认值 False
+
`--interpolation`
: 图像插值方式, 默认值 1
+
`--save_path`
:保存路径,如:
`./tools/`
+
`--use_gpu`
:是否使用 GPU 预测,默认值:True
## 四、结果
输入图片:
![](
../../../tools/feature_maps_visualization/test.jpg
)
输出特征图:
![](
../../../tools/feature_maps_visualization/fm.jpg
)
tools/feature_maps_visualization/download_resnet50_pretrained.sh
0 → 100644
浏览文件 @
184bdd76
wget https://paddle-imagenet-models-name.bj.bcebos.com/ResNet50_pretrained.tar
tar
-xf
ResNet50_pretrained.tar
\ No newline at end of file
tools/feature_maps_visualization/fm_vis.py
0 → 100644
浏览文件 @
184bdd76
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
resnet
import
ResNet50
import
paddle.fluid
as
fluid
import
numpy
as
np
import
cv2
import
utils
import
argparse
def
parse_args
():
def
str2bool
(
v
):
return
v
.
lower
()
in
(
"true"
,
"t"
,
"1"
)
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"-i"
,
"--image_file"
,
type
=
str
)
parser
.
add_argument
(
"-c"
,
"--channel_num"
,
type
=
int
)
parser
.
add_argument
(
"-p"
,
"--pretrained_model"
,
type
=
str
)
parser
.
add_argument
(
"--show"
,
type
=
str2bool
,
default
=
False
)
parser
.
add_argument
(
"--interpolation"
,
type
=
int
,
default
=
1
)
parser
.
add_argument
(
"--save_path"
,
type
=
str
)
parser
.
add_argument
(
"--use_gpu"
,
type
=
str2bool
,
default
=
True
)
return
parser
.
parse_args
()
def
create_operators
(
interpolation
=
1
):
size
=
224
img_mean
=
[
0.485
,
0.456
,
0.406
]
img_std
=
[
0.229
,
0.224
,
0.225
]
img_scale
=
1.0
/
255.0
decode_op
=
utils
.
DecodeImage
()
resize_op
=
utils
.
ResizeImage
(
resize_short
=
256
,
interpolation
=
interpolation
)
crop_op
=
utils
.
CropImage
(
size
=
(
size
,
size
))
normalize_op
=
utils
.
NormalizeImage
(
scale
=
img_scale
,
mean
=
img_mean
,
std
=
img_std
)
totensor_op
=
utils
.
ToTensor
()
return
[
decode_op
,
resize_op
,
crop_op
,
normalize_op
,
totensor_op
]
def
preprocess
(
fname
,
ops
):
data
=
open
(
fname
,
'rb'
).
read
()
for
op
in
ops
:
data
=
op
(
data
)
return
data
def
main
():
args
=
parse_args
()
operators
=
create_operators
(
args
.
interpolation
)
# assign the place
if
args
.
use_gpu
:
gpu_id
=
fluid
.
dygraph
.
parallel
.
Env
().
dev_id
place
=
fluid
.
CUDAPlace
(
gpu_id
)
else
:
place
=
fluid
.
CPUPlace
()
#pre_weights_dict = fluid.load_program_state(args.pretrained_model)
with
fluid
.
dygraph
.
guard
(
place
):
net
=
ResNet50
()
data
=
preprocess
(
args
.
image_file
,
operators
)
data
=
np
.
expand_dims
(
data
,
axis
=
0
)
data
=
fluid
.
dygraph
.
to_variable
(
data
)
dy_weights_dict
=
net
.
state_dict
()
pre_weights_dict_new
=
{}
for
key
in
dy_weights_dict
:
weights_name
=
dy_weights_dict
[
key
].
name
pre_weights_dict_new
[
key
]
=
pre_weights_dict
[
weights_name
]
net
.
set_dict
(
pre_weights_dict_new
)
net
.
eval
()
_
,
fm
=
net
(
data
)
assert
args
.
channel_num
>=
0
and
args
.
channel_num
<=
fm
.
shape
[
1
],
"the channel is out of the range, should be in {} but got {}"
.
format
([
0
,
fm
.
shape
[
1
]],
args
.
channel_num
)
fm
=
(
np
.
squeeze
(
fm
[
0
][
args
.
channel_num
].
numpy
())
*
255
).
astype
(
np
.
uint8
)
if
fm
is
not
None
:
if
args
.
save
:
cv2
.
imwrite
(
args
.
save_path
,
fm
)
if
args
.
show
:
cv2
.
show
(
fm
)
cv2
.
waitKey
(
0
)
if
__name__
==
"__main__"
:
main
()
tools/feature_maps_visualization/resnet.py
0 → 100644
浏览文件 @
184bdd76
import
numpy
as
np
import
argparse
import
ast
import
paddle
import
paddle.fluid
as
fluid
from
paddle.fluid.param_attr
import
ParamAttr
from
paddle.fluid.layer_helper
import
LayerHelper
from
paddle.fluid.dygraph.nn
import
Conv2D
,
Pool2D
,
BatchNorm
,
Linear
from
paddle.fluid.dygraph.base
import
to_variable
from
paddle.fluid
import
framework
import
math
import
sys
import
time
class
ConvBNLayer
(
fluid
.
dygraph
.
Layer
):
def
__init__
(
self
,
num_channels
,
num_filters
,
filter_size
,
stride
=
1
,
groups
=
1
,
act
=
None
,
name
=
None
):
super
(
ConvBNLayer
,
self
).
__init__
()
self
.
_conv
=
Conv2D
(
num_channels
=
num_channels
,
num_filters
=
num_filters
,
filter_size
=
filter_size
,
stride
=
stride
,
padding
=
(
filter_size
-
1
)
//
2
,
groups
=
groups
,
act
=
None
,
param_attr
=
ParamAttr
(
name
=
name
+
"_weights"
),
bias_attr
=
False
)
if
name
==
"conv1"
:
bn_name
=
"bn_"
+
name
else
:
bn_name
=
"bn"
+
name
[
3
:]
self
.
_batch_norm
=
BatchNorm
(
num_filters
,
act
=
act
,
param_attr
=
ParamAttr
(
name
=
bn_name
+
'_scale'
),
bias_attr
=
ParamAttr
(
bn_name
+
'_offset'
),
moving_mean_name
=
bn_name
+
'_mean'
,
moving_variance_name
=
bn_name
+
'_variance'
)
def
forward
(
self
,
inputs
):
y
=
self
.
_conv
(
inputs
)
y
=
self
.
_batch_norm
(
y
)
return
y
class
BottleneckBlock
(
fluid
.
dygraph
.
Layer
):
def
__init__
(
self
,
num_channels
,
num_filters
,
stride
,
shortcut
=
True
,
name
=
None
):
super
(
BottleneckBlock
,
self
).
__init__
()
self
.
conv0
=
ConvBNLayer
(
num_channels
=
num_channels
,
num_filters
=
num_filters
,
filter_size
=
1
,
act
=
'relu'
,
name
=
name
+
"_branch2a"
)
self
.
conv1
=
ConvBNLayer
(
num_channels
=
num_filters
,
num_filters
=
num_filters
,
filter_size
=
3
,
stride
=
stride
,
act
=
'relu'
,
name
=
name
+
"_branch2b"
)
self
.
conv2
=
ConvBNLayer
(
num_channels
=
num_filters
,
num_filters
=
num_filters
*
4
,
filter_size
=
1
,
act
=
None
,
name
=
name
+
"_branch2c"
)
if
not
shortcut
:
self
.
short
=
ConvBNLayer
(
num_channels
=
num_channels
,
num_filters
=
num_filters
*
4
,
filter_size
=
1
,
stride
=
stride
,
name
=
name
+
"_branch1"
)
self
.
shortcut
=
shortcut
self
.
_num_channels_out
=
num_filters
*
4
def
forward
(
self
,
inputs
):
y
=
self
.
conv0
(
inputs
)
conv1
=
self
.
conv1
(
y
)
conv2
=
self
.
conv2
(
conv1
)
if
self
.
shortcut
:
short
=
inputs
else
:
short
=
self
.
short
(
inputs
)
y
=
fluid
.
layers
.
elementwise_add
(
x
=
short
,
y
=
conv2
)
layer_helper
=
LayerHelper
(
self
.
full_name
(),
act
=
'relu'
)
return
layer_helper
.
append_activation
(
y
)
class
ResNet
(
fluid
.
dygraph
.
Layer
):
def
__init__
(
self
,
layers
=
50
,
class_dim
=
1000
):
super
(
ResNet
,
self
).
__init__
()
self
.
layers
=
layers
supported_layers
=
[
50
,
101
,
152
]
assert
layers
in
supported_layers
,
\
"supported layers are {} but input layer is {}"
.
format
(
supported_layers
,
layers
)
self
.
fm
=
None
if
layers
==
50
:
depth
=
[
3
,
4
,
6
,
3
]
elif
layers
==
101
:
depth
=
[
3
,
4
,
23
,
3
]
elif
layers
==
152
:
depth
=
[
3
,
8
,
36
,
3
]
num_channels
=
[
64
,
256
,
512
,
1024
]
num_filters
=
[
64
,
128
,
256
,
512
]
self
.
conv
=
ConvBNLayer
(
num_channels
=
3
,
num_filters
=
64
,
filter_size
=
7
,
stride
=
2
,
act
=
'relu'
,
name
=
"conv1"
)
self
.
pool2d_max
=
Pool2D
(
pool_size
=
3
,
pool_stride
=
2
,
pool_padding
=
1
,
pool_type
=
'max'
)
self
.
bottleneck_block_list
=
[]
for
block
in
range
(
len
(
depth
)):
shortcut
=
False
for
i
in
range
(
depth
[
block
]):
if
layers
in
[
101
,
152
]
and
block
==
2
:
if
i
==
0
:
conv_name
=
"res"
+
str
(
block
+
2
)
+
"a"
else
:
conv_name
=
"res"
+
str
(
block
+
2
)
+
"b"
+
str
(
i
)
else
:
conv_name
=
"res"
+
str
(
block
+
2
)
+
chr
(
97
+
i
)
bottleneck_block
=
self
.
add_sublayer
(
'bb_%d_%d'
%
(
block
,
i
),
BottleneckBlock
(
num_channels
=
num_channels
[
block
]
if
i
==
0
else
num_filters
[
block
]
*
4
,
num_filters
=
num_filters
[
block
],
stride
=
2
if
i
==
0
and
block
!=
0
else
1
,
shortcut
=
shortcut
,
name
=
conv_name
))
self
.
bottleneck_block_list
.
append
(
bottleneck_block
)
shortcut
=
True
self
.
pool2d_avg
=
Pool2D
(
pool_size
=
7
,
pool_type
=
'avg'
,
global_pooling
=
True
)
self
.
pool2d_avg_output
=
num_filters
[
len
(
num_filters
)
-
1
]
*
4
*
1
*
1
stdv
=
1.0
/
math
.
sqrt
(
2048
*
1.0
)
self
.
out
=
Linear
(
self
.
pool2d_avg_output
,
class_dim
,
param_attr
=
ParamAttr
(
initializer
=
fluid
.
initializer
.
Uniform
(
-
stdv
,
stdv
),
name
=
"fc_0.w_0"
),
bias_attr
=
ParamAttr
(
name
=
"fc_0.b_0"
))
def
forward
(
self
,
inputs
):
y
=
self
.
conv
(
inputs
)
y
=
self
.
pool2d_max
(
y
)
self
.
fm
=
y
for
bottleneck_block
in
self
.
bottleneck_block_list
:
y
=
bottleneck_block
(
y
)
y
=
self
.
pool2d_avg
(
y
)
y
=
fluid
.
layers
.
reshape
(
y
,
shape
=
[
-
1
,
self
.
pool2d_avg_output
])
y
=
self
.
out
(
y
)
return
y
,
self
.
fm
def
ResNet50
(
**
args
):
model
=
ResNet
(
layers
=
50
,
**
args
)
return
model
def
ResNet101
(
**
args
):
model
=
ResNet
(
layers
=
101
,
**
args
)
return
model
def
ResNet152
(
**
args
):
model
=
ResNet
(
layers
=
152
,
**
args
)
return
model
if
__name__
==
"__main__"
:
import
numpy
as
np
place
=
fluid
.
CPUPlace
()
with
fluid
.
dygraph
.
guard
(
place
):
model
=
ResNet50
()
img
=
np
.
random
.
uniform
(
0
,
255
,
[
1
,
3
,
224
,
224
]).
astype
(
'float32'
)
img
=
fluid
.
dygraph
.
to_variable
(
img
)
res
=
model
(
img
)
print
(
res
.
shape
)
tools/feature_maps_visualization/utils.py
0 → 100644
浏览文件 @
184bdd76
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
cv2
import
numpy
as
np
class
DecodeImage
(
object
):
def
__init__
(
self
,
to_rgb
=
True
):
self
.
to_rgb
=
to_rgb
def
__call__
(
self
,
img
):
data
=
np
.
frombuffer
(
img
,
dtype
=
'uint8'
)
img
=
cv2
.
imdecode
(
data
,
1
)
if
self
.
to_rgb
:
assert
img
.
shape
[
2
]
==
3
,
'invalid shape of image[%s]'
%
(
img
.
shape
)
img
=
img
[:,
:,
::
-
1
]
return
img
class
ResizeImage
(
object
):
def
__init__
(
self
,
resize_short
=
None
,
interpolation
=
1
):
self
.
resize_short
=
resize_short
self
.
interpolation
=
interpolation
def
__call__
(
self
,
img
):
img_h
,
img_w
=
img
.
shape
[:
2
]
percent
=
float
(
self
.
resize_short
)
/
min
(
img_w
,
img_h
)
w
=
int
(
round
(
img_w
*
percent
))
h
=
int
(
round
(
img_h
*
percent
))
return
cv2
.
resize
(
img
,
(
w
,
h
),
interpolation
=
self
.
interpolation
)
class
CropImage
(
object
):
def
__init__
(
self
,
size
):
if
type
(
size
)
is
int
:
self
.
size
=
(
size
,
size
)
else
:
self
.
size
=
size
def
__call__
(
self
,
img
):
w
,
h
=
self
.
size
img_h
,
img_w
=
img
.
shape
[:
2
]
w_start
=
(
img_w
-
w
)
//
2
h_start
=
(
img_h
-
h
)
//
2
w_end
=
w_start
+
w
h_end
=
h_start
+
h
return
img
[
h_start
:
h_end
,
w_start
:
w_end
,
:]
class
NormalizeImage
(
object
):
def
__init__
(
self
,
scale
=
None
,
mean
=
None
,
std
=
None
):
self
.
scale
=
np
.
float32
(
scale
if
scale
is
not
None
else
1.0
/
255.0
)
mean
=
mean
if
mean
is
not
None
else
[
0.485
,
0.456
,
0.406
]
std
=
std
if
std
is
not
None
else
[
0.229
,
0.224
,
0.225
]
shape
=
(
1
,
1
,
3
)
self
.
mean
=
np
.
array
(
mean
).
reshape
(
shape
).
astype
(
'float32'
)
self
.
std
=
np
.
array
(
std
).
reshape
(
shape
).
astype
(
'float32'
)
def
__call__
(
self
,
img
):
return
(
img
.
astype
(
'float32'
)
*
self
.
scale
-
self
.
mean
)
/
self
.
std
class
ToTensor
(
object
):
def
__init__
(
self
):
pass
def
__call__
(
self
,
img
):
img
=
img
.
transpose
((
2
,
0
,
1
))
return
img
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录