提交 d6b920b4 编写于 作者: G gaotingquan

fix: adapt to release 2.3

上级 ea5b7254
...@@ -6,43 +6,45 @@ ...@@ -6,43 +6,45 @@
## 二、准备工作 ## 二、准备工作
首先需要选定研究的模型,本文设定ResNet50作为研究模型,将resnet.py从[模型库](../../../ppcls/arch/architecture/)拷贝到当前目录下,并下载预训练模型[预训练模型](../../zh_CN/models/models_intro), 复制resnet50的模型链接,使用下列命令下载并解压预训练模型 首先需要选定研究的模型,本文设定ResNet50作为研究模型,将模型组网代码[resnet.py](../../../ppcls/arch/backbone/legendary_models/resnet.py)拷贝到[目录](../../../ppcls/utils/feature_maps_visualization/)下,并下载[ResNet50预训练模型](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ResNet50_pretrained.pdparams),或使用以下命令下载
```bash ```bash
wget The Link for Pretrained Model wget https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ResNet50_pretrained.pdparams
tar -xf Downloaded Pretrained Model
``` ```
以resnet50为例: 其他模型网络结构代码及预训练模型请自行下载:[模型库](../../../ppcls/arch/backbone/)[预训练模型](../models/models_intro.md)
```bash
wget https://paddle-imagenet-models-name.bj.bcebos.com/ResNet50_pretrained.tar
tar -xf ResNet50_pretrained.tar
```
## 三、修改模型 ## 三、修改模型
找到我们所需要的特征图位置,设置self.fm将其fetch出来,本文以resnet50中的stem层之后的特征图为例。 找到我们所需要的特征图位置,设置self.fm将其fetch出来,本文以resnet50中的stem层之后的特征图为例。
fm_vis.py中修改模型的名字。 ResNet50的forward函数中指定要可视化的特征图
在ResNet50的__init__函数中定义self.fm
```python ```python
self.fm = None def forward(self, x):
with paddle.static.amp.fp16_guard():
if self.data_format == "NHWC":
x = paddle.transpose(x, [0, 2, 3, 1])
x.stop_gradient = True
x = self.stem(x)
fm = x
x = self.max_pool(x)
x = self.blocks(x)
x = self.avg_pool(x)
x = self.flatten(x)
x = self.fc(x)
return x, fm
``` ```
在ResNet50的forward函数中指定特征图
然后修改代码[fm_vis.py](../../../ppcls/utils/feature_maps_visualization/fm_vis.py),引入 `ResNet50`,实例化 `net` 对象:
```python ```python
def forward(self, inputs): from resnet import ResNet50
y = self.conv(inputs) net = ResNet50()
self.fm = y
y = self.pool2d_max(y)
for bottleneck_block in self.bottleneck_block_list:
y = bottleneck_block(y)
y = self.avg_pool(y)
y = fluid.layers.reshape(y, shape=[-1, self.pool2d_avg_output])
y = self.out(y)
return y, self.fm
``` ```
执行函数
最后执行函数
```bash ```bash
python tools/feature_maps_visualization/fm_vis.py -i the image you want to test \ python tools/feature_maps_visualization/fm_vis.py -i the image you want to test \
-c channel_num -p pretrained model \ -c channel_num -p pretrained model \
...@@ -51,9 +53,10 @@ python tools/feature_maps_visualization/fm_vis.py -i the image you want to test ...@@ -51,9 +53,10 @@ python tools/feature_maps_visualization/fm_vis.py -i the image you want to test
--save_path where to save \ --save_path where to save \
--use_gpu whether to use gpu --use_gpu whether to use gpu
``` ```
参数说明: 参数说明:
+ `-i`:待预测的图片文件路径,如 `./test.jpeg` + `-i`:待预测的图片文件路径,如 `./test.jpeg`
+ `-c`:特征图维度,如 `./resnet50_vd/model` + `-c`:特征图维度,如 `5`
+ `-p`:权重文件路径,如 `./ResNet50_pretrained/` + `-p`:权重文件路径,如 `./ResNet50_pretrained/`
+ `--interpolation`: 图像插值方式, 默认值 1 + `--interpolation`: 图像插值方式, 默认值 1
+ `--save_path`:保存路径,如:`./tools/` + `--save_path`:保存路径,如:`./tools/`
...@@ -63,7 +66,7 @@ python tools/feature_maps_visualization/fm_vis.py -i the image you want to test ...@@ -63,7 +66,7 @@ python tools/feature_maps_visualization/fm_vis.py -i the image you want to test
* 输入图片: * 输入图片:
![](../../../docs/images/feature_maps/feature_visualization_input.jpg) ![](../../images/feature_maps/feature_visualization_input.jpg)
* 运行下面的特征图可视化脚本 * 运行下面的特征图可视化脚本
...@@ -75,10 +78,9 @@ python tools/feature_maps_visualization/fm_vis.py \ ...@@ -75,10 +78,9 @@ python tools/feature_maps_visualization/fm_vis.py \
--show=True \ --show=True \
--interpolation=1 \ --interpolation=1 \
--save_path="./output.png" \ --save_path="./output.png" \
--use_gpu=False \ --use_gpu=False
--load_static_weights=True
``` ```
* 输出特征图保存为`output.png`,如下所示。 * 输出特征图保存为`output.png`,如下所示。
![](../../../docs/images/feature_maps/feature_visualization_output.jpg) ![](../../images/feature_maps/feature_visualization_output.jpg)
wget https://paddle-imagenet-models-name.bj.bcebos.com/ResNet50_pretrained.tar
tar -xf ResNet50_pretrained.tar
\ No newline at end of file
...@@ -19,7 +19,7 @@ import os ...@@ -19,7 +19,7 @@ import os
import sys import sys
__dir__ = os.path.dirname(os.path.abspath(__file__)) __dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(__dir__) sys.path.append(__dir__)
sys.path.append(os.path.abspath(os.path.join(__dir__, '../..'))) sys.path.append(os.path.abspath(os.path.join(__dir__, '../../..')))
import paddle import paddle
from paddle.distributed import ParallelEnv from paddle.distributed import ParallelEnv
...@@ -33,18 +33,13 @@ def parse_args(): ...@@ -33,18 +33,13 @@ def parse_args():
return v.lower() in ("true", "t", "1") return v.lower() in ("true", "t", "1")
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("-i", "--image_file", type=str) parser.add_argument("-i", "--image_file", required=True, type=str)
parser.add_argument("-c", "--channel_num", type=int) parser.add_argument("-c", "--channel_num", type=int)
parser.add_argument("-p", "--pretrained_model", type=str) parser.add_argument("-p", "--pretrained_model", type=str)
parser.add_argument("--show", type=str2bool, default=False) parser.add_argument("--show", type=str2bool, default=False)
parser.add_argument("--interpolation", type=int, default=1) parser.add_argument("--interpolation", type=int, default=1)
parser.add_argument("--save_path", type=str, default=None) parser.add_argument("--save_path", type=str, default=None)
parser.add_argument("--use_gpu", type=str2bool, default=True) parser.add_argument("--use_gpu", type=str2bool, default=True)
parser.add_argument(
"--load_static_weights",
type=str2bool,
default=False,
help='Whether to load the pretrained weights saved in static mode')
return parser.parse_args() return parser.parse_args()
...@@ -79,7 +74,7 @@ def main(): ...@@ -79,7 +74,7 @@ def main():
place = paddle.set_device(place) place = paddle.set_device(place)
net = ResNet50() net = ResNet50()
load_dygraph_pretrain(net, args.pretrained_model, args.load_static_weights) load_dygraph_pretrain(net, args.pretrained_model)
img = cv2.imread(args.image_file, cv2.IMREAD_COLOR) img = cv2.imread(args.image_file, cv2.IMREAD_COLOR)
data = preprocess(img, operators) data = preprocess(img, operators)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册