提交 8ac51f74 编写于 作者: G gaotingquan

fix: adapt to release 2.3

上级 677f6aea
...@@ -6,43 +6,45 @@ ...@@ -6,43 +6,45 @@
## 二、准备工作 ## 二、准备工作
首先需要选定研究的模型,本文设定ResNet50作为研究模型,将resnet.py从[模型库](../../../ppcls/arch/architecture/)拷贝到当前目录下,并下载预训练模型[预训练模型](../../zh_CN/models/models_intro), 复制resnet50的模型链接,使用下列命令下载并解压预训练模型 首先需要选定研究的模型,本文设定ResNet50作为研究模型,将模型组网代码[resnet.py](../../../ppcls/arch/backbone/legendary_models/resnet.py)拷贝到[目录](../../../ppcls/utils/feature_maps_visualization/)下,并下载[ResNet50预训练模型](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ResNet50_pretrained.pdparams),或使用以下命令下载
```bash ```bash
wget The Link for Pretrained Model wget https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ResNet50_pretrained.pdparams
tar -xf Downloaded Pretrained Model
``` ```
以resnet50为例: 其他模型网络结构代码及预训练模型请自行下载:[模型库](../../../ppcls/arch/backbone/)[预训练模型](../models/models_intro.md)
```bash
wget https://paddle-imagenet-models-name.bj.bcebos.com/ResNet50_pretrained.tar
tar -xf ResNet50_pretrained.tar
```
## 三、修改模型 ## 三、修改模型
找到我们所需要的特征图位置,设置self.fm将其fetch出来,本文以resnet50中的stem层之后的特征图为例。 找到我们所需要的特征图位置,设置self.fm将其fetch出来,本文以resnet50中的stem层之后的特征图为例。
fm_vis.py中修改模型的名字。 ResNet50的forward函数中指定要可视化的特征图
在ResNet50的__init__函数中定义self.fm
```python ```python
self.fm = None def forward(self, x):
with paddle.static.amp.fp16_guard():
if self.data_format == "NHWC":
x = paddle.transpose(x, [0, 2, 3, 1])
x.stop_gradient = True
x = self.stem(x)
fm = x
x = self.max_pool(x)
x = self.blocks(x)
x = self.avg_pool(x)
x = self.flatten(x)
x = self.fc(x)
return x, fm
``` ```
在ResNet50的forward函数中指定特征图
然后修改代码[fm_vis.py](../../../ppcls/utils/feature_maps_visualization/fm_vis.py),引入 `ResNet50`,实例化 `net` 对象:
```python ```python
def forward(self, inputs): from resnet import ResNet50
y = self.conv(inputs) net = ResNet50()
self.fm = y
y = self.pool2d_max(y)
for bottleneck_block in self.bottleneck_block_list:
y = bottleneck_block(y)
y = self.avg_pool(y)
y = fluid.layers.reshape(y, shape=[-1, self.pool2d_avg_output])
y = self.out(y)
return y, self.fm
``` ```
执行函数
最后执行函数
```bash ```bash
python tools/feature_maps_visualization/fm_vis.py -i the image you want to test \ python tools/feature_maps_visualization/fm_vis.py -i the image you want to test \
-c channel_num -p pretrained model \ -c channel_num -p pretrained model \
...@@ -51,9 +53,10 @@ python tools/feature_maps_visualization/fm_vis.py -i the image you want to test ...@@ -51,9 +53,10 @@ python tools/feature_maps_visualization/fm_vis.py -i the image you want to test
--save_path where to save \ --save_path where to save \
--use_gpu whether to use gpu --use_gpu whether to use gpu
``` ```
参数说明: 参数说明:
+ `-i`:待预测的图片文件路径,如 `./test.jpeg` + `-i`:待预测的图片文件路径,如 `./test.jpeg`
+ `-c`:特征图维度,如 `./resnet50_vd/model` + `-c`:特征图维度,如 `5`
+ `-p`:权重文件路径,如 `./ResNet50_pretrained/` + `-p`:权重文件路径,如 `./ResNet50_pretrained/`
+ `--interpolation`: 图像插值方式, 默认值 1 + `--interpolation`: 图像插值方式, 默认值 1
+ `--save_path`:保存路径,如:`./tools/` + `--save_path`:保存路径,如:`./tools/`
...@@ -63,7 +66,7 @@ python tools/feature_maps_visualization/fm_vis.py -i the image you want to test ...@@ -63,7 +66,7 @@ python tools/feature_maps_visualization/fm_vis.py -i the image you want to test
* 输入图片: * 输入图片:
![](../../../docs/images/feature_maps/feature_visualization_input.jpg) ![](../../images/feature_maps/feature_visualization_input.jpg)
* 运行下面的特征图可视化脚本 * 运行下面的特征图可视化脚本
...@@ -75,10 +78,9 @@ python tools/feature_maps_visualization/fm_vis.py \ ...@@ -75,10 +78,9 @@ python tools/feature_maps_visualization/fm_vis.py \
--show=True \ --show=True \
--interpolation=1 \ --interpolation=1 \
--save_path="./output.png" \ --save_path="./output.png" \
--use_gpu=False \ --use_gpu=False
--load_static_weights=True
``` ```
* 输出特征图保存为`output.png`,如下所示。 * 输出特征图保存为`output.png`,如下所示。
![](../../../docs/images/feature_maps/feature_visualization_output.jpg) ![](../../images/feature_maps/feature_visualization_output.jpg)
wget https://paddle-imagenet-models-name.bj.bcebos.com/ResNet50_pretrained.tar
tar -xf ResNet50_pretrained.tar
\ No newline at end of file
...@@ -19,7 +19,7 @@ import os ...@@ -19,7 +19,7 @@ import os
import sys import sys
__dir__ = os.path.dirname(os.path.abspath(__file__)) __dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(__dir__) sys.path.append(__dir__)
sys.path.append(os.path.abspath(os.path.join(__dir__, '../..'))) sys.path.append(os.path.abspath(os.path.join(__dir__, '../../..')))
import paddle import paddle
from paddle.distributed import ParallelEnv from paddle.distributed import ParallelEnv
...@@ -33,18 +33,13 @@ def parse_args(): ...@@ -33,18 +33,13 @@ def parse_args():
return v.lower() in ("true", "t", "1") return v.lower() in ("true", "t", "1")
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("-i", "--image_file", type=str) parser.add_argument("-i", "--image_file", required=True, type=str)
parser.add_argument("-c", "--channel_num", type=int) parser.add_argument("-c", "--channel_num", type=int)
parser.add_argument("-p", "--pretrained_model", type=str) parser.add_argument("-p", "--pretrained_model", type=str)
parser.add_argument("--show", type=str2bool, default=False) parser.add_argument("--show", type=str2bool, default=False)
parser.add_argument("--interpolation", type=int, default=1) parser.add_argument("--interpolation", type=int, default=1)
parser.add_argument("--save_path", type=str, default=None) parser.add_argument("--save_path", type=str, default=None)
parser.add_argument("--use_gpu", type=str2bool, default=True) parser.add_argument("--use_gpu", type=str2bool, default=True)
parser.add_argument(
"--load_static_weights",
type=str2bool,
default=False,
help='Whether to load the pretrained weights saved in static mode')
return parser.parse_args() return parser.parse_args()
...@@ -79,7 +74,7 @@ def main(): ...@@ -79,7 +74,7 @@ def main():
place = paddle.set_device(place) place = paddle.set_device(place)
net = ResNet50() net = ResNet50()
load_dygraph_pretrain(net, args.pretrained_model, args.load_static_weights) load_dygraph_pretrain(net, args.pretrained_model)
img = cv2.imread(args.image_file, cv2.IMREAD_COLOR) img = cv2.imread(args.image_file, cv2.IMREAD_COLOR)
data = preprocess(img, operators) data = preprocess(img, operators)
......
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. # copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -12,126 +12,204 @@ ...@@ -12,126 +12,204 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from __future__ import absolute_import from __future__ import absolute_import, division, print_function
from __future__ import division
from __future__ import print_function
import numpy as np import numpy as np
import paddle import paddle
from paddle import ParamAttr from paddle import ParamAttr
import paddle.nn as nn import paddle.nn as nn
import paddle.nn.functional as F from paddle.nn import Conv2D, BatchNorm, Linear
from paddle.nn import Conv2D, BatchNorm, Linear, Dropout
from paddle.nn import AdaptiveAvgPool2D, MaxPool2D, AvgPool2D from paddle.nn import AdaptiveAvgPool2D, MaxPool2D, AvgPool2D
from paddle.nn.initializer import Uniform from paddle.nn.initializer import Uniform
import math import math
__all__ = ["ResNet18", "ResNet34", "ResNet50", "ResNet101", "ResNet152"] from ppcls.arch.backbone.base.theseus_layer import TheseusLayer
from ppcls.utils.save_load import load_dygraph_pretrain, load_dygraph_pretrain_from_url
class ConvBNLayer(nn.Layer): MODEL_URLS = {
"ResNet18":
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/ResNet18_pretrained.pdparams",
"ResNet18_vd":
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/ResNet18_vd_pretrained.pdparams",
"ResNet34":
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/ResNet34_pretrained.pdparams",
"ResNet34_vd":
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/ResNet34_vd_pretrained.pdparams",
"ResNet50":
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/ResNet50_pretrained.pdparams",
"ResNet50_vd":
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/ResNet50_vd_pretrained.pdparams",
"ResNet101":
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/ResNet101_pretrained.pdparams",
"ResNet101_vd":
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/ResNet101_vd_pretrained.pdparams",
"ResNet152":
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/ResNet152_pretrained.pdparams",
"ResNet152_vd":
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/ResNet152_vd_pretrained.pdparams",
"ResNet200_vd":
"https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/legendary_models/ResNet200_vd_pretrained.pdparams",
}
__all__ = MODEL_URLS.keys()
'''
ResNet config: dict.
key: depth of ResNet.
values: config's dict of specific model.
keys:
block_type: Two different blocks in ResNet, BasicBlock and BottleneckBlock are optional.
block_depth: The number of blocks in different stages in ResNet.
num_channels: The number of channels to enter the next stage.
'''
NET_CONFIG = {
"18": {
"block_type": "BasicBlock",
"block_depth": [2, 2, 2, 2],
"num_channels": [64, 64, 128, 256]
},
"34": {
"block_type": "BasicBlock",
"block_depth": [3, 4, 6, 3],
"num_channels": [64, 64, 128, 256]
},
"50": {
"block_type": "BottleneckBlock",
"block_depth": [3, 4, 6, 3],
"num_channels": [64, 256, 512, 1024]
},
"101": {
"block_type": "BottleneckBlock",
"block_depth": [3, 4, 23, 3],
"num_channels": [64, 256, 512, 1024]
},
"152": {
"block_type": "BottleneckBlock",
"block_depth": [3, 8, 36, 3],
"num_channels": [64, 256, 512, 1024]
},
"200": {
"block_type": "BottleneckBlock",
"block_depth": [3, 12, 48, 3],
"num_channels": [64, 256, 512, 1024]
},
}
class ConvBNLayer(TheseusLayer):
def __init__(self, def __init__(self,
num_channels, num_channels,
num_filters, num_filters,
filter_size, filter_size,
stride=1, stride=1,
groups=1, groups=1,
is_vd_mode=False,
act=None, act=None,
name=None): lr_mult=1.0,
super(ConvBNLayer, self).__init__() data_format="NCHW"):
super().__init__()
self._conv = Conv2D( self.is_vd_mode = is_vd_mode
self.act = act
self.avg_pool = AvgPool2D(
kernel_size=2, stride=2, padding=0, ceil_mode=True)
self.conv = Conv2D(
in_channels=num_channels, in_channels=num_channels,
out_channels=num_filters, out_channels=num_filters,
kernel_size=filter_size, kernel_size=filter_size,
stride=stride, stride=stride,
padding=(filter_size - 1) // 2, padding=(filter_size - 1) // 2,
groups=groups, groups=groups,
weight_attr=ParamAttr(name=name + "_weights"), weight_attr=ParamAttr(learning_rate=lr_mult),
bias_attr=False) bias_attr=False,
if name == "conv1": data_format=data_format)
bn_name = "bn_" + name self.bn = BatchNorm(
else:
bn_name = "bn" + name[3:]
self._batch_norm = BatchNorm(
num_filters, num_filters,
act=act, param_attr=ParamAttr(learning_rate=lr_mult),
param_attr=ParamAttr(name=bn_name + "_scale"), bias_attr=ParamAttr(learning_rate=lr_mult),
bias_attr=ParamAttr(bn_name + "_offset"), data_layout=data_format)
moving_mean_name=bn_name + "_mean", self.relu = nn.ReLU()
moving_variance_name=bn_name + "_variance")
def forward(self, x):
def forward(self, inputs): if self.is_vd_mode:
y = self._conv(inputs) x = self.avg_pool(x)
y = self._batch_norm(y) x = self.conv(x)
return y x = self.bn(x)
if self.act:
x = self.relu(x)
class BottleneckBlock(nn.Layer): return x
class BottleneckBlock(TheseusLayer):
def __init__(self, def __init__(self,
num_channels, num_channels,
num_filters, num_filters,
stride, stride,
shortcut=True, shortcut=True,
name=None): if_first=False,
super(BottleneckBlock, self).__init__() lr_mult=1.0,
data_format="NCHW"):
super().__init__()
self.conv0 = ConvBNLayer( self.conv0 = ConvBNLayer(
num_channels=num_channels, num_channels=num_channels,
num_filters=num_filters, num_filters=num_filters,
filter_size=1, filter_size=1,
act="relu", act="relu",
name=name + "_branch2a") lr_mult=lr_mult,
data_format=data_format)
self.conv1 = ConvBNLayer( self.conv1 = ConvBNLayer(
num_channels=num_filters, num_channels=num_filters,
num_filters=num_filters, num_filters=num_filters,
filter_size=3, filter_size=3,
stride=stride, stride=stride,
act="relu", act="relu",
name=name + "_branch2b") lr_mult=lr_mult,
data_format=data_format)
self.conv2 = ConvBNLayer( self.conv2 = ConvBNLayer(
num_channels=num_filters, num_channels=num_filters,
num_filters=num_filters * 4, num_filters=num_filters * 4,
filter_size=1, filter_size=1,
act=None, act=None,
name=name + "_branch2c") lr_mult=lr_mult,
data_format=data_format)
if not shortcut: if not shortcut:
self.short = ConvBNLayer( self.short = ConvBNLayer(
num_channels=num_channels, num_channels=num_channels,
num_filters=num_filters * 4, num_filters=num_filters * 4,
filter_size=1, filter_size=1,
stride=stride, stride=stride if if_first else 1,
name=name + "_branch1") is_vd_mode=False if if_first else True,
lr_mult=lr_mult,
data_format=data_format)
self.relu = nn.ReLU()
self.shortcut = shortcut self.shortcut = shortcut
self._num_channels_out = num_filters * 4 def forward(self, x):
identity = x
def forward(self, inputs): x = self.conv0(x)
y = self.conv0(inputs) x = self.conv1(x)
conv1 = self.conv1(y) x = self.conv2(x)
conv2 = self.conv2(conv1)
if self.shortcut: if self.shortcut:
short = inputs short = identity
else: else:
short = self.short(inputs) short = self.short(identity)
x = paddle.add(x=x, y=short)
y = paddle.add(x=short, y=conv2) x = self.relu(x)
y = F.relu(y) return x
return y
class BasicBlock(nn.Layer): class BasicBlock(TheseusLayer):
def __init__(self, def __init__(self,
num_channels, num_channels,
num_filters, num_filters,
stride, stride,
shortcut=True, shortcut=True,
name=None): if_first=False,
super(BasicBlock, self).__init__() lr_mult=1.0,
data_format="NCHW"):
super().__init__()
self.stride = stride self.stride = stride
self.conv0 = ConvBNLayer( self.conv0 = ConvBNLayer(
num_channels=num_channels, num_channels=num_channels,
...@@ -139,155 +217,319 @@ class BasicBlock(nn.Layer): ...@@ -139,155 +217,319 @@ class BasicBlock(nn.Layer):
filter_size=3, filter_size=3,
stride=stride, stride=stride,
act="relu", act="relu",
name=name + "_branch2a") lr_mult=lr_mult,
data_format=data_format)
self.conv1 = ConvBNLayer( self.conv1 = ConvBNLayer(
num_channels=num_filters, num_channels=num_filters,
num_filters=num_filters, num_filters=num_filters,
filter_size=3, filter_size=3,
act=None, act=None,
name=name + "_branch2b") lr_mult=lr_mult,
data_format=data_format)
if not shortcut: if not shortcut:
self.short = ConvBNLayer( self.short = ConvBNLayer(
num_channels=num_channels, num_channels=num_channels,
num_filters=num_filters, num_filters=num_filters,
filter_size=1, filter_size=1,
stride=stride, stride=stride if if_first else 1,
name=name + "_branch1") is_vd_mode=False if if_first else True,
lr_mult=lr_mult,
data_format=data_format)
self.shortcut = shortcut self.shortcut = shortcut
self.relu = nn.ReLU()
def forward(self, inputs): def forward(self, x):
y = self.conv0(inputs) identity = x
conv1 = self.conv1(y) x = self.conv0(x)
x = self.conv1(x)
if self.shortcut: if self.shortcut:
short = inputs short = identity
else: else:
short = self.short(inputs) short = self.short(identity)
y = paddle.add(x=short, y=conv1) x = paddle.add(x=x, y=short)
y = F.relu(y) x = self.relu(x)
return y return x
class ResNet(nn.Layer): class ResNet(TheseusLayer):
def __init__(self, layers=50, class_dim=1000): """
super(ResNet, self).__init__() ResNet
Args:
self.layers = layers config: dict. config of ResNet.
supported_layers = [18, 34, 50, 101, 152] version: str="vb". Different version of ResNet, version vd can perform better.
assert layers in supported_layers, \ class_num: int=1000. The number of classes.
"supported layers are {} but input layer is {}".format( lr_mult_list: list. Control the learning rate of different stages.
supported_layers, layers) Returns:
model: nn.Layer. Specific ResNet model depends on args.
if layers == 18: """
depth = [2, 2, 2, 2]
elif layers == 34 or layers == 50:
depth = [3, 4, 6, 3]
elif layers == 101:
depth = [3, 4, 23, 3]
elif layers == 152:
depth = [3, 8, 36, 3]
num_channels = [64, 256, 512,
1024] if layers >= 50 else [64, 64, 128, 256]
num_filters = [64, 128, 256, 512]
self.feature_map = None
self.conv = ConvBNLayer(
num_channels=3,
num_filters=64,
filter_size=7,
stride=2,
act="relu",
name="conv1")
self.pool2d_max = MaxPool2D(kernel_size=3, stride=2, padding=1)
self.block_list = [] def __init__(self,
if layers >= 50: config,
for block in range(len(depth)): version="vb",
class_num=1000,
lr_mult_list=[1.0, 1.0, 1.0, 1.0, 1.0],
data_format="NCHW",
input_image_channel=3,
return_patterns=None):
super().__init__()
self.cfg = config
self.lr_mult_list = lr_mult_list
self.is_vd_mode = version == "vd"
self.class_num = class_num
self.num_filters = [64, 128, 256, 512]
self.block_depth = self.cfg["block_depth"]
self.block_type = self.cfg["block_type"]
self.num_channels = self.cfg["num_channels"]
self.channels_mult = 1 if self.num_channels[-1] == 256 else 4
assert isinstance(self.lr_mult_list, (
list, tuple
)), "lr_mult_list should be in (list, tuple) but got {}".format(
type(self.lr_mult_list))
assert len(self.lr_mult_list
) == 5, "lr_mult_list length should be 5 but got {}".format(
len(self.lr_mult_list))
self.stem_cfg = {
#num_channels, num_filters, filter_size, stride
"vb": [[input_image_channel, 64, 7, 2]],
"vd":
[[input_image_channel, 32, 3, 2], [32, 32, 3, 1], [32, 64, 3, 1]]
}
self.stem = nn.Sequential(* [
ConvBNLayer(
num_channels=in_c,
num_filters=out_c,
filter_size=k,
stride=s,
act="relu",
lr_mult=self.lr_mult_list[0],
data_format=data_format)
for in_c, out_c, k, s in self.stem_cfg[version]
])
self.max_pool = MaxPool2D(
kernel_size=3, stride=2, padding=1, data_format=data_format)
block_list = []
for block_idx in range(len(self.block_depth)):
shortcut = False shortcut = False
for i in range(depth[block]): for i in range(self.block_depth[block_idx]):
if layers in [101, 152] and block == 2: block_list.append(globals()[self.block_type](
if i == 0: num_channels=self.num_channels[block_idx] if i == 0 else
conv_name = "res" + str(block + 2) + "a" self.num_filters[block_idx] * self.channels_mult,
else: num_filters=self.num_filters[block_idx],
conv_name = "res" + str(block + 2) + "b" + str(i) stride=2 if i == 0 and block_idx != 0 else 1,
else:
conv_name = "res" + str(block + 2) + chr(97 + i)
bottleneck_block = self.add_sublayer(
conv_name,
BottleneckBlock(
num_channels=num_channels[block]
if i == 0 else num_filters[block] * 4,
num_filters=num_filters[block],
stride=2 if i == 0 and block != 0 else 1,
shortcut=shortcut, shortcut=shortcut,
name=conv_name)) if_first=block_idx == i == 0 if version == "vd" else True,
self.block_list.append(bottleneck_block) lr_mult=self.lr_mult_list[block_idx + 1],
data_format=data_format))
shortcut = True shortcut = True
self.blocks = nn.Sequential(*block_list)
self.avg_pool = AdaptiveAvgPool2D(1, data_format=data_format)
self.flatten = nn.Flatten()
self.avg_pool_channels = self.num_channels[-1] * 2
stdv = 1.0 / math.sqrt(self.avg_pool_channels * 1.0)
self.fc = Linear(
self.avg_pool_channels,
self.class_num,
weight_attr=ParamAttr(initializer=Uniform(-stdv, stdv)))
self.data_format = data_format
if return_patterns is not None:
self.update_res(return_patterns)
self.register_forward_post_hook(self._return_dict_hook)
def forward(self, x):
with paddle.static.amp.fp16_guard():
if self.data_format == "NHWC":
x = paddle.transpose(x, [0, 2, 3, 1])
x.stop_gradient = True
x = self.stem(x)
fm = x
x = self.max_pool(x)
x = self.blocks(x)
x = self.avg_pool(x)
x = self.flatten(x)
x = self.fc(x)
return x, fm
def _load_pretrained(pretrained, model, model_url, use_ssld):
if pretrained is False:
pass
elif pretrained is True:
load_dygraph_pretrain_from_url(model, model_url, use_ssld=use_ssld)
elif isinstance(pretrained, str):
load_dygraph_pretrain(model, pretrained)
else: else:
for block in range(len(depth)): raise RuntimeError(
shortcut = False "pretrained type is not available. Please use `string` or `boolean` type."
for i in range(depth[block]): )
conv_name = "res" + str(block + 2) + chr(97 + i)
basic_block = self.add_sublayer(
conv_name, def ResNet18(pretrained=False, use_ssld=False, **kwargs):
BasicBlock( """
num_channels=num_channels[block] ResNet18
if i == 0 else num_filters[block], Args:
num_filters=num_filters[block], pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise.
stride=2 if i == 0 and block != 0 else 1, If str, means the path of the pretrained model.
shortcut=shortcut, use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
name=conv_name)) Returns:
self.block_list.append(basic_block) model: nn.Layer. Specific `ResNet18` model depends on args.
shortcut = True """
model = ResNet(config=NET_CONFIG["18"], version="vb", **kwargs)
_load_pretrained(pretrained, model, MODEL_URLS["ResNet18"], use_ssld)
return model
def ResNet18_vd(pretrained=False, use_ssld=False, **kwargs):
"""
ResNet18_vd
Args:
pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise.
If str, means the path of the pretrained model.
use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
Returns:
model: nn.Layer. Specific `ResNet18_vd` model depends on args.
"""
model = ResNet(config=NET_CONFIG["18"], version="vd", **kwargs)
_load_pretrained(pretrained, model, MODEL_URLS["ResNet18_vd"], use_ssld)
return model
self.pool2d_avg = AdaptiveAvgPool2D(1)
self.pool2d_avg_channels = num_channels[-1] * 2 def ResNet34(pretrained=False, use_ssld=False, **kwargs):
"""
ResNet34
Args:
pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise.
If str, means the path of the pretrained model.
use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
Returns:
model: nn.Layer. Specific `ResNet34` model depends on args.
"""
model = ResNet(config=NET_CONFIG["34"], version="vb", **kwargs)
_load_pretrained(pretrained, model, MODEL_URLS["ResNet34"], use_ssld)
return model
stdv = 1.0 / math.sqrt(self.pool2d_avg_channels * 1.0)
self.out = Linear( def ResNet34_vd(pretrained=False, use_ssld=False, **kwargs):
self.pool2d_avg_channels, """
class_dim, ResNet34_vd
weight_attr=ParamAttr( Args:
initializer=Uniform(-stdv, stdv), name="fc_0.w_0"), pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise.
bias_attr=ParamAttr(name="fc_0.b_0")) If str, means the path of the pretrained model.
use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
Returns:
model: nn.Layer. Specific `ResNet34_vd` model depends on args.
"""
model = ResNet(config=NET_CONFIG["34"], version="vd", **kwargs)
_load_pretrained(pretrained, model, MODEL_URLS["ResNet34_vd"], use_ssld)
return model
def forward(self, inputs): def ResNet50(pretrained=False, use_ssld=False, **kwargs):
y = self.conv(inputs) """
y = self.pool2d_max(y) ResNet50
self.feature_map = y Args:
for block in self.block_list: pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise.
y = block(y) If str, means the path of the pretrained model.
y = self.pool2d_avg(y) use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
y = paddle.reshape(y, shape=[-1, self.pool2d_avg_channels]) Returns:
y = self.out(y) model: nn.Layer. Specific `ResNet50` model depends on args.
return y, self.feature_map """
model = ResNet(config=NET_CONFIG["50"], version="vb", **kwargs)
_load_pretrained(pretrained, model, MODEL_URLS["ResNet50"], use_ssld)
return model
def ResNet50_vd(pretrained=False, use_ssld=False, **kwargs):
"""
ResNet50_vd
Args:
pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise.
If str, means the path of the pretrained model.
use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
Returns:
model: nn.Layer. Specific `ResNet50_vd` model depends on args.
"""
model = ResNet(config=NET_CONFIG["50"], version="vd", **kwargs)
_load_pretrained(pretrained, model, MODEL_URLS["ResNet50_vd"], use_ssld)
return model
def ResNet18(**args): def ResNet101(pretrained=False, use_ssld=False, **kwargs):
model = ResNet(layers=18, **args) """
ResNet101
Args:
pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise.
If str, means the path of the pretrained model.
use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
Returns:
model: nn.Layer. Specific `ResNet101` model depends on args.
"""
model = ResNet(config=NET_CONFIG["101"], version="vb", **kwargs)
_load_pretrained(pretrained, model, MODEL_URLS["ResNet101"], use_ssld)
return model return model
def ResNet34(**args): def ResNet101_vd(pretrained=False, use_ssld=False, **kwargs):
model = ResNet(layers=34, **args) """
ResNet101_vd
Args:
pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise.
If str, means the path of the pretrained model.
use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
Returns:
model: nn.Layer. Specific `ResNet101_vd` model depends on args.
"""
model = ResNet(config=NET_CONFIG["101"], version="vd", **kwargs)
_load_pretrained(pretrained, model, MODEL_URLS["ResNet101_vd"], use_ssld)
return model return model
def ResNet50(**args): def ResNet152(pretrained=False, use_ssld=False, **kwargs):
model = ResNet(layers=50, **args) """
ResNet152
Args:
pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise.
If str, means the path of the pretrained model.
use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
Returns:
model: nn.Layer. Specific `ResNet152` model depends on args.
"""
model = ResNet(config=NET_CONFIG["152"], version="vb", **kwargs)
_load_pretrained(pretrained, model, MODEL_URLS["ResNet152"], use_ssld)
return model return model
def ResNet101(**args): def ResNet152_vd(pretrained=False, use_ssld=False, **kwargs):
model = ResNet(layers=101, **args) """
ResNet152_vd
Args:
pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise.
If str, means the path of the pretrained model.
use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
Returns:
model: nn.Layer. Specific `ResNet152_vd` model depends on args.
"""
model = ResNet(config=NET_CONFIG["152"], version="vd", **kwargs)
_load_pretrained(pretrained, model, MODEL_URLS["ResNet152_vd"], use_ssld)
return model return model
def ResNet152(**args): def ResNet200_vd(pretrained=False, use_ssld=False, **kwargs):
model = ResNet(layers=152, **args) """
ResNet200_vd
Args:
pretrained: bool=False or str. If `True` load pretrained parameters, `False` otherwise.
If str, means the path of the pretrained model.
use_ssld: bool=False. Whether using distillation pretrained model when pretrained=True.
Returns:
model: nn.Layer. Specific `ResNet200_vd` model depends on args.
"""
model = ResNet(config=NET_CONFIG["200"], version="vd", **kwargs)
_load_pretrained(pretrained, model, MODEL_URLS["ResNet200_vd"], use_ssld)
return model return model
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册