未验证 提交 1b5a1e26 编写于 作者: jm_12138's avatar jm_12138 提交者: GitHub

update modnet_resnet50vd_matting (#2100)

* add requirements.txt

* add init

* update format
上级 755425ce
# modnet_resnet50vd_matting # modnet_resnet50vd_matting
|模型名称|modnet_resnet50vd_matting| |模型名称|modnet_resnet50vd_matting|
| :--- | :---: | | :--- | :---: |
|类别|图像-抠图| |类别|图像-抠图|
|网络|modnet_resnet50vd| |网络|modnet_resnet50vd|
|数据集|百度自建数据集| |数据集|百度自建数据集|
...@@ -17,8 +17,8 @@ ...@@ -17,8 +17,8 @@
- 样例结果示例(左为原图,右为效果图): - 样例结果示例(左为原图,右为效果图):
<p align="center"> <p align="center">
<img src="https://user-images.githubusercontent.com/35907364/144574288-28671577-8d5d-4b20-adb9-fe737015c841.jpg" width = "337" height = "505" hspace='10'/> <img src="https://user-images.githubusercontent.com/35907364/144574288-28671577-8d5d-4b20-adb9-fe737015c841.jpg" width = "337" height = "505" hspace='10'/>
<img src="https://user-images.githubusercontent.com/35907364/144779164-47146d3a-58c9-4a38-b968-3530aa9a0137.png" width = "337" height = "505" hspace='10'/> <img src="https://user-images.githubusercontent.com/35907364/144779164-47146d3a-58c9-4a38-b968-3530aa9a0137.png" width = "337" height = "505" hspace='10'/>
</p> </p>
- ### 模型介绍 - ### 模型介绍
...@@ -26,9 +26,9 @@ ...@@ -26,9 +26,9 @@
- Matting(精细化分割/影像去背/抠图)是指借由计算前景的颜色和透明度,将前景从影像中撷取出来的技术,可用于替换背景、影像合成、视觉特效,在电影工业中被广泛地使用。影像中的每个像素会有代表其前景透明度的值,称作阿法值(Alpha),一张影像中所有阿法值的集合称作阿法遮罩(Alpha Matte),将影像被遮罩所涵盖的部分取出即可完成前景的分离。modnet_resnet50vd_matting可生成抠图结果。 - Matting(精细化分割/影像去背/抠图)是指借由计算前景的颜色和透明度,将前景从影像中撷取出来的技术,可用于替换背景、影像合成、视觉特效,在电影工业中被广泛地使用。影像中的每个像素会有代表其前景透明度的值,称作阿法值(Alpha),一张影像中所有阿法值的集合称作阿法遮罩(Alpha Matte),将影像被遮罩所涵盖的部分取出即可完成前景的分离。modnet_resnet50vd_matting可生成抠图结果。
- 更多详情请参考:[modnet_resnet50vd_matting](https://github.com/PaddlePaddle/PaddleSeg/tree/release/2.3/contrib/Matting) - 更多详情请参考:[modnet_resnet50vd_matting](https://github.com/PaddlePaddle/PaddleSeg/tree/release/2.3/contrib/Matting)
## 二、安装 ## 二、安装
...@@ -46,11 +46,11 @@ ...@@ -46,11 +46,11 @@
- ```shell - ```shell
$ hub install modnet_resnet50vd_matting $ hub install modnet_resnet50vd_matting
``` ```
- 如您安装时遇到问题,可参考:[零基础windows安装](../../../../docs/docs_ch/get_start/windows_quickstart.md) - 如您安装时遇到问题,可参考:[零基础windows安装](../../../../docs/docs_ch/get_start/windows_quickstart.md)
| [零基础Linux安装](../../../../docs/docs_ch/get_start/linux_quickstart.md) | [零基础MacOS安装](../../../../docs/docs_ch/get_start/mac_quickstart.md) | [零基础Linux安装](../../../../docs/docs_ch/get_start/linux_quickstart.md) | [零基础MacOS安装](../../../../docs/docs_ch/get_start/mac_quickstart.md)
## 三、模型API预测 ## 三、模型API预测
- ### 1、命令行预测 - ### 1、命令行预测
...@@ -58,9 +58,9 @@ ...@@ -58,9 +58,9 @@
- ```shell - ```shell
$ hub run modnet_resnet50vd_matting --input_path "/PATH/TO/IMAGE" $ hub run modnet_resnet50vd_matting --input_path "/PATH/TO/IMAGE"
``` ```
- 通过命令行方式实现hub模型的调用,更多请见 [PaddleHub命令行指令](../../../../docs/docs_ch/tutorial/cmd_usage.rst) - 通过命令行方式实现hub模型的调用,更多请见 [PaddleHub命令行指令](../../../../docs/docs_ch/tutorial/cmd_usage.rst)
- ### 2、预测代码示例 - ### 2、预测代码示例
...@@ -73,14 +73,14 @@ ...@@ -73,14 +73,14 @@
result = model.predict(["/PATH/TO/IMAGE"]) result = model.predict(["/PATH/TO/IMAGE"])
print(result) print(result)
``` ```
- ### 3、API - ### 3、API
- ```python - ```python
def predict(self, def predict(self,
image_list, image_list,
trimap_list, trimap_list,
visualization, visualization,
save_path): save_path):
``` ```
...@@ -97,7 +97,7 @@ ...@@ -97,7 +97,7 @@
- result (list(numpy.ndarray)):模型分割结果: - result (list(numpy.ndarray)):模型分割结果:
## 四、服务部署 ## 四、服务部署
- PaddleHub Serving可以部署人像matting在线服务。 - PaddleHub Serving可以部署人像matting在线服务。
......
# modnet_resnet50vd_matting # modnet_resnet50vd_matting
|Module Name|modnet_resnet50vd_matting| |Module Name|modnet_resnet50vd_matting|
| :--- | :---: | | :--- | :---: |
|Category|Image Matting| |Category|Image Matting|
|Network|modnet_resnet50vd| |Network|modnet_resnet50vd|
|Dataset|Baidu self-built dataset| |Dataset|Baidu self-built dataset|
...@@ -17,8 +17,8 @@ ...@@ -17,8 +17,8 @@
- Sample results: - Sample results:
<p align="center"> <p align="center">
<img src="https://user-images.githubusercontent.com/35907364/144574288-28671577-8d5d-4b20-adb9-fe737015c841.jpg" width = "337" height = "505" hspace='10'/> <img src="https://user-images.githubusercontent.com/35907364/144574288-28671577-8d5d-4b20-adb9-fe737015c841.jpg" width = "337" height = "505" hspace='10'/>
<img src="https://user-images.githubusercontent.com/35907364/144779164-47146d3a-58c9-4a38-b968-3530aa9a0137.png" width = "337" height = "505" hspace='10'/> <img src="https://user-images.githubusercontent.com/35907364/144779164-47146d3a-58c9-4a38-b968-3530aa9a0137.png" width = "337" height = "505" hspace='10'/>
</p> </p>
- ### Module Introduction - ### Module Introduction
...@@ -26,9 +26,9 @@ ...@@ -26,9 +26,9 @@
- Mating is the technique of extracting foreground from an image by calculating its color and transparency. It is widely used in the film industry to replace background, image composition, and visual effects. Each pixel in the image will have a value that represents its foreground transparency, called Alpha. The set of all Alpha values in an image is called Alpha Matte. The part of the image covered by the mask can be extracted to complete foreground separation. - Mating is the technique of extracting foreground from an image by calculating its color and transparency. It is widely used in the film industry to replace background, image composition, and visual effects. Each pixel in the image will have a value that represents its foreground transparency, called Alpha. The set of all Alpha values in an image is called Alpha Matte. The part of the image covered by the mask can be extracted to complete foreground separation.
- For more information, please refer to: [modnet_resnet50vd_matting](https://github.com/PaddlePaddle/PaddleSeg/tree/release/2.3/contrib/Matting) - For more information, please refer to: [modnet_resnet50vd_matting](https://github.com/PaddlePaddle/PaddleSeg/tree/release/2.3/contrib/Matting)
## II. Installation ## II. Installation
...@@ -46,11 +46,11 @@ ...@@ -46,11 +46,11 @@
- ```shell - ```shell
$ hub install modnet_resnet50vd_matting $ hub install modnet_resnet50vd_matting
``` ```
- In case of any problems during installation, please refer to:[Windows_Quickstart](../../../../docs/docs_en/get_start/windows_quickstart.md) - In case of any problems during installation, please refer to:[Windows_Quickstart](../../../../docs/docs_en/get_start/windows_quickstart.md)
| [Linux_Quickstart](../../../../docs/docs_en/get_start/linux_quickstart.md) | [Mac_Quickstart](../../../../docs/docs_en/get_start/mac_quickstart.md) | [Linux_Quickstart](../../../../docs/docs_en/get_start/linux_quickstart.md) | [Mac_Quickstart](../../../../docs/docs_en/get_start/mac_quickstart.md)
## III. Module API Prediction ## III. Module API Prediction
- ### 1、Command line Prediction - ### 1、Command line Prediction
...@@ -58,7 +58,7 @@ ...@@ -58,7 +58,7 @@
- ```shell - ```shell
$ hub run modnet_resnet50vd_matting --input_path "/PATH/TO/IMAGE" $ hub run modnet_resnet50vd_matting --input_path "/PATH/TO/IMAGE"
``` ```
- If you want to call the Hub module through the command line, please refer to: [PaddleHub Command Line Instruction](../../../../docs/docs_en/tutorial/cmd_usage.rst) - If you want to call the Hub module through the command line, please refer to: [PaddleHub Command Line Instruction](../../../../docs/docs_en/tutorial/cmd_usage.rst)
...@@ -76,10 +76,10 @@ ...@@ -76,10 +76,10 @@
- ### 3、API - ### 3、API
- ```python - ```python
def predict(self, def predict(self,
image_list, image_list,
trimap_list, trimap_list,
visualization, visualization,
save_path): save_path):
``` ```
...@@ -96,7 +96,7 @@ ...@@ -96,7 +96,7 @@
- result (list(numpy.ndarray)):The list of model results. - result (list(numpy.ndarray)):The list of model results.
## IV. Server Deployment ## IV. Server Deployment
- PaddleHub Serving can deploy an online service of matting. - PaddleHub Serving can deploy an online service of matting.
......
...@@ -11,17 +11,17 @@ ...@@ -11,17 +11,17 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import random
import base64 import base64
from typing import Callable, Union, List, Tuple from typing import Callable
from typing import List
from typing import Tuple
from typing import Union
import cv2 import cv2
import numpy as np import numpy as np
import paddle import paddle
import paddle.nn.functional as F import paddle.nn.functional as F
from paddleseg.transforms import functional from paddleseg.transforms import functional
from PIL import Image
class Compose: class Compose:
...@@ -61,6 +61,7 @@ class LoadImages: ...@@ -61,6 +61,7 @@ class LoadImages:
Args: Args:
to_rgb (bool, optional): If converting image to RGB color space. Default: True. to_rgb (bool, optional): If converting image to RGB color space. Default: True.
""" """
def __init__(self, to_rgb: bool = True): def __init__(self, to_rgb: bool = True):
self.to_rgb = to_rgb self.to_rgb = to_rgb
...@@ -95,7 +96,7 @@ class ResizeByShort: ...@@ -95,7 +96,7 @@ class ResizeByShort:
short_size (int): The target size of short side. short_size (int): The target size of short side.
""" """
def __init__(self, short_size: int =512): def __init__(self, short_size: int = 512):
self.short_size = short_size self.short_size = short_size
def __call__(self, data: dict) -> dict: def __call__(self, data: dict) -> dict:
...@@ -140,14 +141,13 @@ class Normalize: ...@@ -140,14 +141,13 @@ class Normalize:
ValueError: When mean/std is not list or any value in std is 0. ValueError: When mean/std is not list or any value in std is 0.
""" """
def __init__(self, mean: Union[List[float], Tuple[float]] = (0.5, 0.5, 0.5), std: Union[List[float], Tuple[float]] = (0.5, 0.5, 0.5)): def __init__(self,
mean: Union[List[float], Tuple[float]] = (0.5, 0.5, 0.5),
std: Union[List[float], Tuple[float]] = (0.5, 0.5, 0.5)):
self.mean = mean self.mean = mean
self.std = std self.std = std
if not (isinstance(self.mean, (list, tuple)) if not (isinstance(self.mean, (list, tuple)) and isinstance(self.std, (list, tuple))):
and isinstance(self.std, (list, tuple))): raise ValueError("{}: input type is invalid. It should be list or tuple".format(self))
raise ValueError(
"{}: input type is invalid. It should be list or tuple".format(
self))
from functools import reduce from functools import reduce
if reduce(lambda x, y: x * y, self.std) == 0: if reduce(lambda x, y: x * y, self.std) == 0:
raise ValueError('{}: std is invalid!'.format(self)) raise ValueError('{}: std is invalid!'.format(self))
...@@ -177,6 +177,7 @@ def reverse_transform(alpha: paddle.Tensor, trans_info: List[str]): ...@@ -177,6 +177,7 @@ def reverse_transform(alpha: paddle.Tensor, trans_info: List[str]):
raise Exception("Unexpected info '{}' in im_info".format(item[0])) raise Exception("Unexpected info '{}' in im_info".format(item[0]))
return alpha return alpha
def save_alpha_pred(alpha: np.ndarray, trimap: np.ndarray = None): def save_alpha_pred(alpha: np.ndarray, trimap: np.ndarray = None):
""" """
The value of alpha is range [0, 1], shape should be [h,w] The value of alpha is range [0, 1], shape should be [h,w]
...@@ -204,4 +205,4 @@ def base64_to_cv2(b64str: str): ...@@ -204,4 +205,4 @@ def base64_to_cv2(b64str: str):
data = base64.b64decode(b64str.encode('utf8')) data = base64.b64decode(b64str.encode('utf8'))
data = np.fromstring(data, np.uint8) data = np.fromstring(data, np.uint8)
data = cv2.imdecode(data, cv2.IMREAD_COLOR) data = cv2.imdecode(data, cv2.IMREAD_COLOR)
return data return data
\ No newline at end of file
...@@ -11,45 +11,40 @@ ...@@ -11,45 +11,40 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import paddle import paddle
import paddle.nn as nn import paddle.nn as nn
import paddle.nn.functional as F import paddle.nn.functional as F
from paddleseg.models import layers from paddleseg.models import layers
from paddleseg.utils import utils
__all__ = ["ResNet50_vd"] __all__ = ["ResNet50_vd"]
class ConvBNLayer(nn.Layer): class ConvBNLayer(nn.Layer):
"""Basic conv bn relu layer.""" """Basic conv bn relu layer."""
def __init__( def __init__(
self, self,
in_channels: int, in_channels: int,
out_channels: int, out_channels: int,
kernel_size: int, kernel_size: int,
stride: int = 1, stride: int = 1,
dilation: int = 1, dilation: int = 1,
groups: int = 1, groups: int = 1,
is_vd_mode: bool = False, is_vd_mode: bool = False,
act: str = None, act: str = None,
): ):
super(ConvBNLayer, self).__init__() super(ConvBNLayer, self).__init__()
self.is_vd_mode = is_vd_mode self.is_vd_mode = is_vd_mode
self._pool2d_avg = nn.AvgPool2D( self._pool2d_avg = nn.AvgPool2D(kernel_size=2, stride=2, padding=0, ceil_mode=True)
kernel_size=2, stride=2, padding=0, ceil_mode=True) self._conv = nn.Conv2D(in_channels=in_channels,
self._conv = nn.Conv2D( out_channels=out_channels,
in_channels=in_channels, kernel_size=kernel_size,
out_channels=out_channels, stride=stride,
kernel_size=kernel_size, padding=(kernel_size - 1) // 2 if dilation == 1 else 0,
stride=stride, dilation=dilation,
padding=(kernel_size - 1) // 2 if dilation == 1 else 0, groups=groups,
dilation=dilation, bias_attr=False)
groups=groups,
bias_attr=False)
self._batch_norm = layers.SyncBatchNorm(out_channels) self._batch_norm = layers.SyncBatchNorm(out_channels)
self._act_op = layers.Activation(act=act) self._act_op = layers.Activation(act=act)
...@@ -66,7 +61,7 @@ class ConvBNLayer(nn.Layer): ...@@ -66,7 +61,7 @@ class ConvBNLayer(nn.Layer):
class BottleneckBlock(nn.Layer): class BottleneckBlock(nn.Layer):
"""Residual bottleneck block""" """Residual bottleneck block"""
def __init__(self, def __init__(self,
in_channels: int, in_channels: int,
out_channels: int, out_channels: int,
...@@ -76,34 +71,24 @@ class BottleneckBlock(nn.Layer): ...@@ -76,34 +71,24 @@ class BottleneckBlock(nn.Layer):
dilation: int = 1): dilation: int = 1):
super(BottleneckBlock, self).__init__() super(BottleneckBlock, self).__init__()
self.conv0 = ConvBNLayer( self.conv0 = ConvBNLayer(in_channels=in_channels, out_channels=out_channels, kernel_size=1, act='relu')
in_channels=in_channels,
out_channels=out_channels,
kernel_size=1,
act='relu')
self.dilation = dilation self.dilation = dilation
self.conv1 = ConvBNLayer( self.conv1 = ConvBNLayer(in_channels=out_channels,
in_channels=out_channels, out_channels=out_channels,
out_channels=out_channels, kernel_size=3,
kernel_size=3, stride=stride,
stride=stride, act='relu',
act='relu', dilation=dilation)
dilation=dilation) self.conv2 = ConvBNLayer(in_channels=out_channels, out_channels=out_channels * 4, kernel_size=1, act=None)
self.conv2 = ConvBNLayer(
in_channels=out_channels,
out_channels=out_channels * 4,
kernel_size=1,
act=None)
if not shortcut: if not shortcut:
self.short = ConvBNLayer( self.short = ConvBNLayer(in_channels=in_channels,
in_channels=in_channels, out_channels=out_channels * 4,
out_channels=out_channels * 4, kernel_size=1,
kernel_size=1, stride=1,
stride=1, is_vd_mode=False if if_first or stride == 1 else True)
is_vd_mode=False if if_first or stride == 1 else True)
self.shortcut = shortcut self.shortcut = shortcut
...@@ -133,33 +118,23 @@ class BottleneckBlock(nn.Layer): ...@@ -133,33 +118,23 @@ class BottleneckBlock(nn.Layer):
class BasicBlock(nn.Layer): class BasicBlock(nn.Layer):
"""Basic residual block""" """Basic residual block"""
def __init__(self,
in_channels: int, def __init__(self, in_channels: int, out_channels: int, stride: int, shortcut: bool = True, if_first: bool = False):
out_channels: int,
stride: int,
shortcut: bool = True,
if_first: bool = False):
super(BasicBlock, self).__init__() super(BasicBlock, self).__init__()
self.stride = stride self.stride = stride
self.conv0 = ConvBNLayer( self.conv0 = ConvBNLayer(in_channels=in_channels,
in_channels=in_channels, out_channels=out_channels,
out_channels=out_channels, kernel_size=3,
kernel_size=3, stride=stride,
stride=stride, act='relu')
act='relu') self.conv1 = ConvBNLayer(in_channels=out_channels, out_channels=out_channels, kernel_size=3, act=None)
self.conv1 = ConvBNLayer(
in_channels=out_channels,
out_channels=out_channels,
kernel_size=3,
act=None)
if not shortcut: if not shortcut:
self.short = ConvBNLayer( self.short = ConvBNLayer(in_channels=in_channels,
in_channels=in_channels, out_channels=out_channels,
out_channels=out_channels, kernel_size=1,
kernel_size=1, stride=1,
stride=1, is_vd_mode=False if if_first else True)
is_vd_mode=False if if_first else True)
self.shortcut = shortcut self.shortcut = shortcut
...@@ -212,13 +187,11 @@ class ResNet_vd(nn.Layer): ...@@ -212,13 +187,11 @@ class ResNet_vd(nn.Layer):
depth = [3, 8, 36, 3] depth = [3, 8, 36, 3]
elif layers == 200: elif layers == 200:
depth = [3, 12, 48, 3] depth = [3, 12, 48, 3]
num_channels = [64, 256, 512, 1024 num_channels = [64, 256, 512, 1024] if layers >= 50 else [64, 64, 128, 256]
] if layers >= 50 else [64, 64, 128, 256]
num_filters = [64, 128, 256, 512] num_filters = [64, 128, 256, 512]
# for channels of four returned stages # for channels of four returned stages
self.feat_channels = [c * 4 for c in num_filters self.feat_channels = [c * 4 for c in num_filters] if layers >= 50 else num_filters
] if layers >= 50 else num_filters
self.feat_channels = [64] + self.feat_channels self.feat_channels = [64] + self.feat_channels
dilation_dict = None dilation_dict = None
...@@ -227,24 +200,9 @@ class ResNet_vd(nn.Layer): ...@@ -227,24 +200,9 @@ class ResNet_vd(nn.Layer):
elif output_stride == 16: elif output_stride == 16:
dilation_dict = {3: 2} dilation_dict = {3: 2}
self.conv1_1 = ConvBNLayer( self.conv1_1 = ConvBNLayer(in_channels=input_channels, out_channels=32, kernel_size=3, stride=2, act='relu')
in_channels=input_channels, self.conv1_2 = ConvBNLayer(in_channels=32, out_channels=32, kernel_size=3, stride=1, act='relu')
out_channels=32, self.conv1_3 = ConvBNLayer(in_channels=32, out_channels=64, kernel_size=3, stride=1, act='relu')
kernel_size=3,
stride=2,
act='relu')
self.conv1_2 = ConvBNLayer(
in_channels=32,
out_channels=32,
kernel_size=3,
stride=1,
act='relu')
self.conv1_3 = ConvBNLayer(
in_channels=32,
out_channels=64,
kernel_size=3,
stride=1,
act='relu')
self.pool2d_max = nn.MaxPool2D(kernel_size=3, stride=2, padding=1) self.pool2d_max = nn.MaxPool2D(kernel_size=3, stride=2, padding=1)
# self.block_list = [] # self.block_list = []
...@@ -264,8 +222,7 @@ class ResNet_vd(nn.Layer): ...@@ -264,8 +222,7 @@ class ResNet_vd(nn.Layer):
############################################################################### ###############################################################################
# Add dilation rate for some segmentation tasks, if dilation_dict is not None. # Add dilation rate for some segmentation tasks, if dilation_dict is not None.
dilation_rate = dilation_dict[ dilation_rate = dilation_dict[block] if dilation_dict and block in dilation_dict else 1
block] if dilation_dict and block in dilation_dict else 1
# Actually block here is 'stage', and i is 'block' in 'stage' # Actually block here is 'stage', and i is 'block' in 'stage'
# At the stage 4, expand the the dilation_rate if given multi_grid # At the stage 4, expand the the dilation_rate if given multi_grid
...@@ -275,15 +232,12 @@ class ResNet_vd(nn.Layer): ...@@ -275,15 +232,12 @@ class ResNet_vd(nn.Layer):
bottleneck_block = self.add_sublayer( bottleneck_block = self.add_sublayer(
'bb_%d_%d' % (block, i), 'bb_%d_%d' % (block, i),
BottleneckBlock( BottleneckBlock(in_channels=num_channels[block] if i == 0 else num_filters[block] * 4,
in_channels=num_channels[block] out_channels=num_filters[block],
if i == 0 else num_filters[block] * 4, stride=2 if i == 0 and block != 0 and dilation_rate == 1 else 1,
out_channels=num_filters[block], shortcut=shortcut,
stride=2 if i == 0 and block != 0 if_first=block == i == 0,
and dilation_rate == 1 else 1, dilation=dilation_rate))
shortcut=shortcut,
if_first=block == i == 0,
dilation=dilation_rate))
block_list.append(bottleneck_block) block_list.append(bottleneck_block)
shortcut = True shortcut = True
...@@ -296,13 +250,11 @@ class ResNet_vd(nn.Layer): ...@@ -296,13 +250,11 @@ class ResNet_vd(nn.Layer):
conv_name = "res" + str(block + 2) + chr(97 + i) conv_name = "res" + str(block + 2) + chr(97 + i)
basic_block = self.add_sublayer( basic_block = self.add_sublayer(
'bb_%d_%d' % (block, i), 'bb_%d_%d' % (block, i),
BasicBlock( BasicBlock(in_channels=num_channels[block] if i == 0 else num_filters[block],
in_channels=num_channels[block] out_channels=num_filters[block],
if i == 0 else num_filters[block], stride=2 if i == 0 and block != 0 else 1,
out_channels=num_filters[block], shortcut=shortcut,
stride=2 if i == 0 and block != 0 else 1, if_first=block == i == 0))
shortcut=shortcut,
if_first=block == i == 0))
block_list.append(basic_block) block_list.append(basic_block)
shortcut = True shortcut = True
self.stage_list.append(block_list) self.stage_list.append(block_list)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册