未验证 提交 fbd18e99 编写于 作者: jm_12138's avatar jm_12138 提交者: GitHub

update format (#2101)

上级 350407c7
# ginet_resnet50vd_voc
|模型名称|ginet_resnet50vd_voc|
| :--- | :---: |
| :--- | :---: |
|类别|图像-图像分割|
|网络|ginet_resnet50vd|
|数据集|PascalVOC2012|
......@@ -11,7 +11,7 @@
|最新更新日期|2021-12-14|
## 一、模型基本信息
- 样例结果示例:
<p align="center">
<img src="https://user-images.githubusercontent.com/35907364/145925887-bf9e62d3-8c6d-43c2-8062-6cb6ba59ec0e.jpg" width = "420" height = "505" hspace='10'/> <img src="https://user-images.githubusercontent.com/35907364/145925692-badb21d1-10e7-4a5d-82f5-1177d10a7681.png" width = "420" height = "505" hspace='10'/>
......@@ -37,7 +37,7 @@
```
- 如您安装时遇到问题,可参考:[零基础windows安装](../../../../docs/docs_ch/get_start/windows_quickstart.md)
| [零基础Linux安装](../../../../docs/docs_ch/get_start/linux_quickstart.md) | [零基础MacOS安装](../../../../docs/docs_ch/get_start/mac_quickstart.md)
| [零基础Linux安装](../../../../docs/docs_ch/get_start/linux_quickstart.md) | [零基础MacOS安装](../../../../docs/docs_ch/get_start/mac_quickstart.md)
## 三、模型API预测
......@@ -104,7 +104,7 @@
trainer = Trainer(model, optimizer, checkpoint_dir='test_ckpt_img_seg', use_gpu=True)
trainer.train(train_reader, epochs=10, batch_size=4, log_interval=10, save_interval=4)
```
- 模型预测
......
# ginet_resnet50vd_voc
|Module Name|ginet_resnet50vd_voc|
| :--- | :---: |
| :--- | :---: |
|Category|Image Segmentation|
|Network|ginet_resnet50vd|
|Dataset|PascalVOC2012|
......@@ -10,8 +10,8 @@
|Data indicators|-|
|Latest update date|2021-12-14|
## I. Basic Information
## I. Basic Information
- ### Application Effect Display
- Sample results:
<p align="center">
......@@ -108,7 +108,7 @@
trainer = Trainer(model, optimizer, checkpoint_dir='test_ckpt_img_seg', use_gpu=True)
trainer.train(train_reader, epochs=10, batch_size=4, log_interval=10, save_interval=4)
```
- Model prediction
......
......@@ -11,12 +11,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from paddle.nn import AvgPool2D
from paddle.nn import Conv2D
from paddle.nn.layer import activation
from paddle.nn import Conv2D, AvgPool2D
def SyncBatchNorm(*args, **kwargs):
......@@ -30,31 +30,28 @@ def SyncBatchNorm(*args, **kwargs):
class ConvBNLayer(nn.Layer):
"""Basic conv bn relu layer."""
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: int,
stride: int = 1,
dilation: int = 1,
groups: int = 1,
is_vd_mode: bool = False,
act: str = None,
name: str = None):
def __init__(self,
in_channels: int,
out_channels: int,
kernel_size: int,
stride: int = 1,
dilation: int = 1,
groups: int = 1,
is_vd_mode: bool = False,
act: str = None,
name: str = None):
super(ConvBNLayer, self).__init__()
self.is_vd_mode = is_vd_mode
self._pool2d_avg = AvgPool2D(
kernel_size=2, stride=2, padding=0, ceil_mode=True)
self._conv = Conv2D(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=(kernel_size - 1) // 2 if dilation == 1 else 0,
dilation=dilation,
groups=groups,
bias_attr=False)
self._pool2d_avg = AvgPool2D(kernel_size=2, stride=2, padding=0, ceil_mode=True)
self._conv = Conv2D(in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=(kernel_size - 1) // 2 if dilation == 1 else 0,
dilation=dilation,
groups=groups,
bias_attr=False)
self._batch_norm = SyncBatchNorm(out_channels)
self._act_op = Activation(act=act)
......@@ -82,38 +79,34 @@ class BottleneckBlock(nn.Layer):
name: str = None):
super(BottleneckBlock, self).__init__()
self.conv0 = ConvBNLayer(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=1,
act='relu',
name=name + "_branch2a")
self.conv0 = ConvBNLayer(in_channels=in_channels,
out_channels=out_channels,
kernel_size=1,
act='relu',
name=name + "_branch2a")
self.dilation = dilation
self.conv1 = ConvBNLayer(
in_channels=out_channels,
out_channels=out_channels,
kernel_size=3,
stride=stride,
act='relu',
dilation=dilation,
name=name + "_branch2b")
self.conv2 = ConvBNLayer(
in_channels=out_channels,
out_channels=out_channels * 4,
kernel_size=1,
act=None,
name=name + "_branch2c")
self.conv1 = ConvBNLayer(in_channels=out_channels,
out_channels=out_channels,
kernel_size=3,
stride=stride,
act='relu',
dilation=dilation,
name=name + "_branch2b")
self.conv2 = ConvBNLayer(in_channels=out_channels,
out_channels=out_channels * 4,
kernel_size=1,
act=None,
name=name + "_branch2c")
if not shortcut:
self.short = ConvBNLayer(
in_channels=in_channels,
out_channels=out_channels * 4,
kernel_size=1,
stride=1,
is_vd_mode=False if if_first or stride == 1 else True,
name=name + "_branch1")
self.short = ConvBNLayer(in_channels=in_channels,
out_channels=out_channels * 4,
kernel_size=1,
stride=1,
is_vd_mode=False if if_first or stride == 1 else True,
name=name + "_branch1")
self.shortcut = shortcut
......@@ -139,22 +132,15 @@ class BottleneckBlock(nn.Layer):
class SeparableConvBNReLU(nn.Layer):
"""Depthwise Separable Convolution."""
def __init__(self,
in_channels: int,
out_channels: int,
kernel_size: int,
padding: str = 'same',
**kwargs: dict):
def __init__(self, in_channels: int, out_channels: int, kernel_size: int, padding: str = 'same', **kwargs: dict):
super(SeparableConvBNReLU, self).__init__()
self.depthwise_conv = ConvBN(
in_channels,
out_channels=in_channels,
kernel_size=kernel_size,
padding=padding,
groups=in_channels,
**kwargs)
self.piontwise_conv = ConvBNReLU(
in_channels, out_channels, kernel_size=1, groups=1)
self.depthwise_conv = ConvBN(in_channels,
out_channels=in_channels,
kernel_size=kernel_size,
padding=padding,
groups=in_channels,
**kwargs)
self.piontwise_conv = ConvBNReLU(in_channels, out_channels, kernel_size=1, groups=1)
def forward(self, x: paddle.Tensor) -> paddle.Tensor:
x = self.depthwise_conv(x)
......@@ -165,15 +151,9 @@ class SeparableConvBNReLU(nn.Layer):
class ConvBN(nn.Layer):
"""Basic conv bn layer"""
def __init__(self,
in_channels: int,
out_channels: int,
kernel_size: int,
padding: str = 'same',
**kwargs: dict):
def __init__(self, in_channels: int, out_channels: int, kernel_size: int, padding: str = 'same', **kwargs: dict):
super(ConvBN, self).__init__()
self._conv = Conv2D(
in_channels, out_channels, kernel_size, padding=padding, **kwargs)
self._conv = Conv2D(in_channels, out_channels, kernel_size, padding=padding, **kwargs)
self._batch_norm = SyncBatchNorm(out_channels)
def forward(self, x: paddle.Tensor) -> paddle.Tensor:
......@@ -185,16 +165,10 @@ class ConvBN(nn.Layer):
class ConvBNReLU(nn.Layer):
"""Basic conv bn relu layer."""
def __init__(self,
in_channels: int,
out_channels: int,
kernel_size: int,
padding: str = 'same',
**kwargs: dict):
def __init__(self, in_channels: int, out_channels: int, kernel_size: int, padding: str = 'same', **kwargs: dict):
super(ConvBNReLU, self).__init__()
self._conv = Conv2D(
in_channels, out_channels, kernel_size, padding=padding, **kwargs)
self._conv = Conv2D(in_channels, out_channels, kernel_size, padding=padding, **kwargs)
self._batch_norm = SyncBatchNorm(out_channels)
def forward(self, x: paddle.Tensor) -> paddle.Tensor:
......@@ -251,8 +225,7 @@ class Activation(nn.Layer):
act_name = act_dict[act]
self.act_func = eval("activation.{}()".format(act_name))
else:
raise KeyError("{} does not exist in the current {}".format(
act, act_dict.keys()))
raise KeyError("{} does not exist in the current {}".format(act, act_dict.keys()))
def forward(self, x: paddle.Tensor) -> paddle.Tensor:
......@@ -281,7 +254,7 @@ class ASPPModule(nn.Layer):
in_channels: int,
out_channels: int,
align_corners: bool,
use_sep_conv: bool= False,
use_sep_conv: bool = False,
image_pooling: bool = False):
super().__init__()
......@@ -294,27 +267,22 @@ class ASPPModule(nn.Layer):
else:
conv_func = ConvBNReLU
block = conv_func(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=1 if ratio == 1 else 3,
dilation=ratio,
padding=0 if ratio == 1 else ratio)
block = conv_func(in_channels=in_channels,
out_channels=out_channels,
kernel_size=1 if ratio == 1 else 3,
dilation=ratio,
padding=0 if ratio == 1 else ratio)
self.aspp_blocks.append(block)
out_size = len(self.aspp_blocks)
if image_pooling:
self.global_avg_pool = nn.Sequential(
nn.AdaptiveAvgPool2D(output_size=(1, 1)),
ConvBNReLU(in_channels, out_channels, kernel_size=1, bias_attr=False))
self.global_avg_pool = nn.Sequential(nn.AdaptiveAvgPool2D(output_size=(1, 1)),
ConvBNReLU(in_channels, out_channels, kernel_size=1, bias_attr=False))
out_size += 1
self.image_pooling = image_pooling
self.conv_bn_relu = ConvBNReLU(
in_channels=out_channels * out_size,
out_channels=out_channels,
kernel_size=1)
self.conv_bn_relu = ConvBNReLU(in_channels=out_channels * out_size, out_channels=out_channels, kernel_size=1)
self.dropout = nn.Dropout(p=0.1) # drop rate
......@@ -322,20 +290,12 @@ class ASPPModule(nn.Layer):
outputs = []
for block in self.aspp_blocks:
y = block(x)
y = F.interpolate(
y,
x.shape[2:],
mode='bilinear',
align_corners=self.align_corners)
y = F.interpolate(y, x.shape[2:], mode='bilinear', align_corners=self.align_corners)
outputs.append(y)
if self.image_pooling:
img_avg = self.global_avg_pool(x)
img_avg = F.interpolate(
img_avg,
x.shape[2:],
mode='bilinear',
align_corners=self.align_corners)
img_avg = F.interpolate(img_avg, x.shape[2:], mode='bilinear', align_corners=self.align_corners)
outputs.append(img_avg)
x = paddle.concat(outputs, axis=1)
......
......@@ -11,31 +11,30 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from typing import Union, List, Tuple
from typing import List
from typing import Tuple
from typing import Union
import numpy as np
import paddle
from paddle import nn
import paddle.nn.functional as F
import numpy as np
from paddlehub.module.module import moduleinfo
import paddlehub.vision.segmentation_transforms as T
from paddlehub.module.cv_module import ImageSegmentationModule
from paddleseg.utils import utils
from ginet_resnet50vd_voc.resnet import ResNet50_vd
from paddle import nn
from paddleseg.models import layers
from ginet_resnet50vd_voc.resnet import ResNet50_vd
import paddlehub.vision.segmentation_transforms as T
from paddlehub.module.cv_module import ImageSegmentationModule
from paddlehub.module.module import moduleinfo
@moduleinfo(
name="ginet_resnet50vd_voc",
type="CV/semantic_segmentation",
author="paddlepaddle",
author_email="",
summary="GINetResnet50 is a segmentation model.",
version="1.0.0",
meta=ImageSegmentationModule)
@moduleinfo(name="ginet_resnet50vd_voc",
type="CV/semantic_segmentation",
author="paddlepaddle",
author_email="",
summary="GINetResnet50 is a segmentation model.",
version="1.0.0",
meta=ImageSegmentationModule)
class GINetResNet50(nn.Layer):
"""
The GINetResNet50 implementation based on PaddlePaddle.
......@@ -55,8 +54,8 @@ class GINetResNet50(nn.Layer):
def __init__(self,
num_classes: int = 21,
backbone_indices: Tuple[int]=(0, 1, 2, 3),
enable_auxiliary_loss:bool = True,
backbone_indices: Tuple[int] = (0, 1, 2, 3),
enable_auxiliary_loss: bool = True,
align_corners: bool = True,
jpu: bool = True,
pretrained: str = None):
......@@ -74,8 +73,7 @@ class GINetResNet50(nn.Layer):
self.head = GIHead(in_channels=2048, nclass=num_classes)
if self.aux:
self.auxlayer = layers.AuxLayer(
1024, 1024 // 4, num_classes, bias_attr=False)
self.auxlayer = layers.AuxLayer(1024, 1024 // 4, num_classes, bias_attr=False)
if pretrained is not None:
model_dict = paddle.load(pretrained)
......@@ -113,12 +111,7 @@ class GINetResNet50(nn.Layer):
logit_list.append(auxout)
return [
F.interpolate(
logit, (h, w),
mode='bilinear',
align_corners=self.align_corners) for logit in logit_list
]
return [F.interpolate(logit, (h, w), mode='bilinear', align_corners=self.align_corners) for logit in logit_list]
class GIHead(nn.Layer):
......@@ -129,30 +122,16 @@ class GIHead(nn.Layer):
self.nclass = nclass
inter_channels = in_channels // 4
self.inp = paddle.zeros(shape=(nclass, 300), dtype='float32')
self.inp = paddle.create_parameter(
shape=self.inp.shape,
dtype=str(self.inp.numpy().dtype),
default_initializer=paddle.nn.initializer.Assign(self.inp))
self.fc1 = nn.Sequential(
nn.Linear(300, 128), nn.BatchNorm1D(128), nn.ReLU())
self.fc2 = nn.Sequential(
nn.Linear(128, 256), nn.BatchNorm1D(256), nn.ReLU())
self.conv5 = layers.ConvBNReLU(
in_channels,
inter_channels,
3,
padding=1,
bias_attr=False,
stride=1)
self.gloru = GlobalReasonUnit(
in_channels=inter_channels,
num_state=256,
num_node=84,
nclass=nclass)
self.conv6 = nn.Sequential(
nn.Dropout(0.1), nn.Conv2D(inter_channels, nclass, 1))
self.inp = paddle.create_parameter(shape=self.inp.shape,
dtype=str(self.inp.numpy().dtype),
default_initializer=paddle.nn.initializer.Assign(self.inp))
self.fc1 = nn.Sequential(nn.Linear(300, 128), nn.BatchNorm1D(128), nn.ReLU())
self.fc2 = nn.Sequential(nn.Linear(128, 256), nn.BatchNorm1D(256), nn.ReLU())
self.conv5 = layers.ConvBNReLU(in_channels, inter_channels, 3, padding=1, bias_attr=False, stride=1)
self.gloru = GlobalReasonUnit(in_channels=inter_channels, num_state=256, num_node=84, nclass=nclass)
self.conv6 = nn.Sequential(nn.Dropout(0.1), nn.Conv2D(inter_channels, nclass, 1))
def forward(self, x: paddle.Tensor) -> List[paddle.Tensor]:
B, C, H, W = x.shape
......@@ -178,13 +157,10 @@ class GlobalReasonUnit(nn.Layer):
def __init__(self, in_channels: int, num_state: int = 256, num_node: int = 84, nclass: int = 59):
super().__init__()
self.num_state = num_state
self.conv_theta = nn.Conv2D(
in_channels, num_node, kernel_size=1, stride=1, padding=0)
self.conv_phi = nn.Conv2D(
in_channels, num_state, kernel_size=1, stride=1, padding=0)
self.conv_theta = nn.Conv2D(in_channels, num_node, kernel_size=1, stride=1, padding=0)
self.conv_phi = nn.Conv2D(in_channels, num_state, kernel_size=1, stride=1, padding=0)
self.graph = GraphLayer(num_state, num_node, nclass)
self.extend_dim = nn.Conv2D(
num_state, in_channels, kernel_size=1, bias_attr=False)
self.extend_dim = nn.Conv2D(num_state, in_channels, kernel_size=1, bias_attr=False)
self.bn = layers.SyncBatchNorm(in_channels)
......@@ -199,8 +175,7 @@ class GlobalReasonUnit(nn.Layer):
.transpose((0, 2, 1))
V = paddle.bmm(B, x_reduce).transpose((0, 2, 1))
V = paddle.divide(
V, paddle.to_tensor([sizex[2] * sizex[3]], dtype='float32'))
V = paddle.divide(V, paddle.to_tensor([sizex[2] * sizex[3]], dtype='float32'))
class_node, new_V = self.graph(inp, V)
D = B.reshape((sizeB[0], -1, sizeB[2] * sizeB[3])).transpose((0, 2, 1))
......@@ -215,6 +190,7 @@ class GlobalReasonUnit(nn.Layer):
class GraphLayer(nn.Layer):
def __init__(self, num_state: int, num_node: int, num_class: int):
super().__init__()
self.vis_gcn = GCN(num_state, num_node)
......@@ -222,14 +198,12 @@ class GraphLayer(nn.Layer):
self.transfer = GraphTransfer(num_state)
self.gamma_vis = paddle.zeros([num_node])
self.gamma_word = paddle.zeros([num_class])
self.gamma_vis = paddle.create_parameter(
shape=self.gamma_vis.shape,
dtype=str(self.gamma_vis.numpy().dtype),
default_initializer=paddle.nn.initializer.Assign(self.gamma_vis))
self.gamma_word = paddle.create_parameter(
shape=self.gamma_word.shape,
dtype=str(self.gamma_word.numpy().dtype),
default_initializer=paddle.nn.initializer.Assign(self.gamma_word))
self.gamma_vis = paddle.create_parameter(shape=self.gamma_vis.shape,
dtype=str(self.gamma_vis.numpy().dtype),
default_initializer=paddle.nn.initializer.Assign(self.gamma_vis))
self.gamma_word = paddle.create_parameter(shape=self.gamma_word.shape,
dtype=str(self.gamma_word.numpy().dtype),
default_initializer=paddle.nn.initializer.Assign(self.gamma_word))
def forward(self, inp: paddle.Tensor, vis_node: paddle.Tensor) -> List[paddle.Tensor]:
inp = self.word_gcn(inp)
......@@ -242,6 +216,7 @@ class GraphLayer(nn.Layer):
class GCN(nn.Layer):
def __init__(self, num_state: int = 128, num_node: int = 64, bias: bool = False):
super().__init__()
self.conv1 = nn.Conv1D(
......@@ -253,14 +228,7 @@ class GCN(nn.Layer):
groups=1,
)
self.relu = nn.ReLU()
self.conv2 = nn.Conv1D(
num_state,
num_state,
kernel_size=1,
padding=0,
stride=1,
groups=1,
bias_attr=bias)
self.conv2 = nn.Conv1D(num_state, num_state, kernel_size=1, padding=0, stride=1, groups=1, bias_attr=bias)
def forward(self, x: paddle.Tensor) -> paddle.Tensor:
h = self.conv1(x.transpose((0, 2, 1))).transpose((0, 2, 1))
......@@ -276,14 +244,10 @@ class GraphTransfer(nn.Layer):
def __init__(self, in_dim: int):
super().__init__()
self.channle_in = in_dim
self.query_conv = nn.Conv1D(
in_channels=in_dim, out_channels=in_dim // 2, kernel_size=1)
self.key_conv = nn.Conv1D(
in_channels=in_dim, out_channels=in_dim // 2, kernel_size=1)
self.value_conv_vis = nn.Conv1D(
in_channels=in_dim, out_channels=in_dim, kernel_size=1)
self.value_conv_word = nn.Conv1D(
in_channels=in_dim, out_channels=in_dim, kernel_size=1)
self.query_conv = nn.Conv1D(in_channels=in_dim, out_channels=in_dim // 2, kernel_size=1)
self.key_conv = nn.Conv1D(in_channels=in_dim, out_channels=in_dim // 2, kernel_size=1)
self.value_conv_vis = nn.Conv1D(in_channels=in_dim, out_channels=in_dim, kernel_size=1)
self.value_conv_word = nn.Conv1D(in_channels=in_dim, out_channels=in_dim, kernel_size=1)
self.softmax_vis = nn.Softmax(axis=-1)
self.softmax_word = nn.Softmax(axis=-2)
......@@ -299,11 +263,9 @@ class GraphTransfer(nn.Layer):
attention_vis = self.softmax_vis(energy).transpose((0, 2, 1))
attention_word = self.softmax_word(energy)
proj_value_vis = self.value_conv_vis(vis_node).reshape((m_batchsize, -1,
Nn))
proj_value_word = self.value_conv_word(word).reshape((m_batchsize, -1,
Nc))
proj_value_vis = self.value_conv_vis(vis_node).reshape((m_batchsize, -1, Nn))
proj_value_word = self.value_conv_word(word).reshape((m_batchsize, -1, Nc))
class_out = paddle.bmm(proj_value_vis, attention_vis)
node_out = paddle.bmm(proj_value_word, attention_word)
return class_out, node_out
\ No newline at end of file
return class_out, node_out
......@@ -11,14 +11,13 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import ginet_resnet50vd_voc.layers as L
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
import ginet_resnet50vd_voc.layers as L
class BasicBlock(nn.Layer):
def __init__(self,
in_channels: int,
out_channels: int,
......@@ -28,28 +27,25 @@ class BasicBlock(nn.Layer):
name: str = None):
super(BasicBlock, self).__init__()
self.stride = stride
self.conv0 = L.ConvBNLayer(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=3,
stride=stride,
act='relu',
name=name + "_branch2a")
self.conv1 = L.ConvBNLayer(
in_channels=out_channels,
out_channels=out_channels,
kernel_size=3,
act=None,
name=name + "_branch2b")
self.conv0 = L.ConvBNLayer(in_channels=in_channels,
out_channels=out_channels,
kernel_size=3,
stride=stride,
act='relu',
name=name + "_branch2a")
self.conv1 = L.ConvBNLayer(in_channels=out_channels,
out_channels=out_channels,
kernel_size=3,
act=None,
name=name + "_branch2b")
if not shortcut:
self.short = L.ConvBNLayer(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=1,
stride=1,
is_vd_mode=False if if_first else True,
name=name + "_branch1")
self.short = L.ConvBNLayer(in_channels=in_channels,
out_channels=out_channels,
kernel_size=1,
stride=1,
is_vd_mode=False if if_first else True,
name=name + "_branch1")
self.shortcut = shortcut
......@@ -67,35 +63,32 @@ class BasicBlock(nn.Layer):
class ResNet50_vd(nn.Layer):
def __init__(self,
multi_grid: tuple = (1, 2, 4)):
def __init__(self, multi_grid: tuple = (1, 2, 4)):
super(ResNet50_vd, self).__init__()
depth = [3, 4, 6, 3]
num_channels = [64, 256, 512, 1024]
num_channels = [64, 256, 512, 1024]
num_filters = [64, 128, 256, 512]
self.feat_channels = [c * 4 for c in num_filters]
dilation_dict = {2: 2, 3: 4}
self.conv1_1 = L.ConvBNLayer(
in_channels=3,
out_channels=32,
kernel_size=3,
stride=2,
act='relu',
name="conv1_1")
self.conv1_2 = L.ConvBNLayer(
in_channels=32,
out_channels=32,
kernel_size=3,
stride=1,
act='relu',
name="conv1_2")
self.conv1_3 = L.ConvBNLayer(
in_channels=32,
out_channels=64,
kernel_size=3,
stride=1,
act='relu',
name="conv1_3")
self.conv1_1 = L.ConvBNLayer(in_channels=3,
out_channels=32,
kernel_size=3,
stride=2,
act='relu',
name="conv1_1")
self.conv1_2 = L.ConvBNLayer(in_channels=32,
out_channels=32,
kernel_size=3,
stride=1,
act='relu',
name="conv1_2")
self.conv1_3 = L.ConvBNLayer(in_channels=32,
out_channels=64,
kernel_size=3,
stride=1,
act='relu',
name="conv1_3")
self.pool2d_max = nn.MaxPool2D(kernel_size=3, stride=2, padding=1)
self.stage_list = []
......@@ -104,22 +97,18 @@ class ResNet50_vd(nn.Layer):
block_list = []
for i in range(depth[block]):
conv_name = "res" + str(block + 2) + chr(97 + i)
dilation_rate = dilation_dict[
block] if dilation_dict and block in dilation_dict else 1
dilation_rate = dilation_dict[block] if dilation_dict and block in dilation_dict else 1
if block == 3:
dilation_rate = dilation_rate * multi_grid[i]
bottleneck_block = self.add_sublayer(
'bb_%d_%d' % (block, i),
L.BottleneckBlock(
in_channels=num_channels[block]
if i == 0 else num_filters[block] * 4,
out_channels=num_filters[block],
stride=2 if i == 0 and block != 0
and dilation_rate == 1 else 1,
shortcut=shortcut,
if_first=block == i == 0,
name=conv_name,
dilation=dilation_rate))
L.BottleneckBlock(in_channels=num_channels[block] if i == 0 else num_filters[block] * 4,
out_channels=num_filters[block],
stride=2 if i == 0 and block != 0 and dilation_rate == 1 else 1,
shortcut=shortcut,
if_first=block == i == 0,
name=conv_name,
dilation=dilation_rate))
block_list.append(bottleneck_block)
shortcut = True
self.stage_list.append(block_list)
......@@ -134,4 +123,4 @@ class ResNet50_vd(nn.Layer):
for block in stage:
y = block(y)
feat_list.append(y)
return feat_list
\ No newline at end of file
return feat_list
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册