提交 ee416505 编写于 作者: F FlyingQianMM

add deeplabv3p_mobilenetv3_large_ssld

上级 f6682135
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
## paddlex.seg.DeepLabv3p ## paddlex.seg.DeepLabv3p
```python ```python
paddlex.seg.DeepLabv3p(num_classes=2, backbone='MobileNetV2_x1.0', output_stride=16, aspp_with_sep_conv=True, decoder_use_sep_conv=True, encoder_with_aspp=True, enable_decoder=True, use_bce_loss=False, use_dice_loss=False, class_weight=None, ignore_index=255) paddlex.seg.DeepLabv3p(num_classes=2, backbone='MobileNetV2_x1.0', output_stride=16, aspp_with_sep_conv=True, decoder_use_sep_conv=True, encoder_with_aspp=True, enable_decoder=True, use_bce_loss=False, use_dice_loss=False, class_weight=None, ignore_index=255, pooling_crop_size=None)
``` ```
...@@ -12,7 +12,7 @@ paddlex.seg.DeepLabv3p(num_classes=2, backbone='MobileNetV2_x1.0', output_stride ...@@ -12,7 +12,7 @@ paddlex.seg.DeepLabv3p(num_classes=2, backbone='MobileNetV2_x1.0', output_stride
> **参数** > **参数**
> > - **num_classes** (int): 类别数。 > > - **num_classes** (int): 类别数。
> > - **backbone** (str): DeepLabv3+的backbone网络,实现特征图的计算,取值范围为['Xception65', 'Xception41', 'MobileNetV2_x0.25', 'MobileNetV2_x0.5', 'MobileNetV2_x1.0', 'MobileNetV2_x1.5', 'MobileNetV2_x2.0'],默认值为'MobileNetV2_x1.0'。 > > - **backbone** (str): DeepLabv3+的backbone网络,实现特征图的计算,取值范围为['Xception65', 'Xception41', 'MobileNetV2_x0.25', 'MobileNetV2_x0.5', 'MobileNetV2_x1.0', 'MobileNetV2_x1.5', 'MobileNetV2_x2.0', 'MobileNetV3_large_x1_0_ssld'],默认值为'MobileNetV2_x1.0'。
> > - **output_stride** (int): backbone 输出特征图相对于输入的下采样倍数,一般取值为8或16。默认16。 > > - **output_stride** (int): backbone 输出特征图相对于输入的下采样倍数,一般取值为8或16。默认16。
> > - **aspp_with_sep_conv** (bool): decoder模块是否采用separable convolutions。默认True。 > > - **aspp_with_sep_conv** (bool): decoder模块是否采用separable convolutions。默认True。
> > - **decoder_use_sep_conv** (bool): decoder模块是否采用separable convolutions。默认True。 > > - **decoder_use_sep_conv** (bool): decoder模块是否采用separable convolutions。默认True。
...@@ -22,6 +22,7 @@ paddlex.seg.DeepLabv3p(num_classes=2, backbone='MobileNetV2_x1.0', output_stride ...@@ -22,6 +22,7 @@ paddlex.seg.DeepLabv3p(num_classes=2, backbone='MobileNetV2_x1.0', output_stride
> > - **use_dice_loss** (bool): 是否使用dice loss作为网络的损失函数,只能用于两类分割,可与bce loss同时使用,当`use_bce_loss`和`use_dice_loss`都为False时,使用交叉熵损失函数。默认False。 > > - **use_dice_loss** (bool): 是否使用dice loss作为网络的损失函数,只能用于两类分割,可与bce loss同时使用,当`use_bce_loss`和`use_dice_loss`都为False时,使用交叉熵损失函数。默认False。
> > - **class_weight** (list/str): 交叉熵损失函数各类损失的权重。当`class_weight`为list的时候,长度应为`num_classes`。当`class_weight`为str时, weight.lower()应为'dynamic',这时会根据每一轮各类像素的比重自行计算相应的权重,每一类的权重为:每类的比例 * num_classes。class_weight取默认值None是,各类的权重1,即平时使用的交叉熵损失函数。 > > - **class_weight** (list/str): 交叉熵损失函数各类损失的权重。当`class_weight`为list的时候,长度应为`num_classes`。当`class_weight`为str时, weight.lower()应为'dynamic',这时会根据每一轮各类像素的比重自行计算相应的权重,每一类的权重为:每类的比例 * num_classes。class_weight取默认值None是,各类的权重1,即平时使用的交叉熵损失函数。
> > - **ignore_index** (int): label上忽略的值,label为`ignore_index`的像素不参与损失函数的计算。默认255。 > > - **ignore_index** (int): label上忽略的值,label为`ignore_index`的像素不参与损失函数的计算。默认255。
> > - **pooling_crop_size** (int):当backbone为`MobileNetV3_large_x1_0_ssld`时,需设置为训练过程中模型输入大小,格式为[W, H]。例如模型输入大小为[512, 512], 则`pooling_crop_size`应该设置为[512, 512]。在encoder模块中获取图像平均值时被用到,若为None,则直接求平均值;若为模型输入大小,则使用`avg_pool`算子得到平均值。默认值None。
### train ### train
......
...@@ -81,6 +81,7 @@ ...@@ -81,6 +81,7 @@
| 模型 | 模型大小 | 预测时间(毫秒) | mIoU(%) | | 模型 | 模型大小 | 预测时间(毫秒) | mIoU(%) |
|:-------|:-----------|:-------------|:----------| |:-------|:-----------|:-------------|:----------|
| [DeepLabv3_MobileNetV3_large_x1_0_ssld](https://paddleseg.bj.bcebos.com/models/deeplabv3p_mobilenetv3_large_cityscapes.tar.gz) | 9.3MB | - | 73.28 |
| [DeepLabv3_MobileNetv2_x1.0](https://paddleseg.bj.bcebos.com/models/mobilenet_cityscapes.tgz) | 14.7MB | - | 69.8 | | [DeepLabv3_MobileNetv2_x1.0](https://paddleseg.bj.bcebos.com/models/mobilenet_cityscapes.tgz) | 14.7MB | - | 69.8 |
| [DeepLabv3_Xception65](https://paddleseg.bj.bcebos.com/models/xception65_bn_cityscapes.tgz) | 329.3MB | - | 79.3 | | [DeepLabv3_Xception65](https://paddleseg.bj.bcebos.com/models/xception65_bn_cityscapes.tgz) | 329.3MB | - | 79.3 |
| [HRNet_W18](https://paddleseg.bj.bcebos.com/models/hrnet_w18_bn_cityscapes.tgz) | 77.3MB | | 79.36 | | [HRNet_W18](https://paddleseg.bj.bcebos.com/models/hrnet_w18_bn_cityscapes.tgz) | 77.3MB | | 79.36 |
......
...@@ -80,6 +80,7 @@ PaddleX目前提供了实例分割MaskRCNN模型,支持5种不同的backbone ...@@ -80,6 +80,7 @@ PaddleX目前提供了实例分割MaskRCNN模型,支持5种不同的backbone
| 模型 | 模型特点 | 存储体积 | GPU预测速度 | CPU(x86)预测速度(毫秒) | 骁龙855(ARM)预测速度 (毫秒)| mIoU | | 模型 | 模型特点 | 存储体积 | GPU预测速度 | CPU(x86)预测速度(毫秒) | 骁龙855(ARM)预测速度 (毫秒)| mIoU |
| :---- | :------- | :---------- | :---------- | :----- | :----- |:--- | | :---- | :------- | :---------- | :---------- | :----- | :----- |:--- |
| DeepLabv3p-MobileNetV2_x1.0 | 轻量级模型,适用于移动端场景| - | - | - | 69.8% | | DeepLabv3p-MobileNetV2_x1.0 | 轻量级模型,适用于移动端场景| - | - | - | 69.8% |
| DeepLabv3-MobileNetV3_large_x1_0_ssld | 轻量级模型,适用于移动端场景| - | - | - | 73.28% |
| HRNet_W18_Small_v1 | 轻量高速,适用于移动端场景 | - | - | - | - | | HRNet_W18_Small_v1 | 轻量高速,适用于移动端场景 | - | - | - | - |
| FastSCNN | 轻量高速,适用于追求高速预测的移动端或服务器端场景 | - | - | - | 69.64 | | FastSCNN | 轻量高速,适用于追求高速预测的移动端或服务器端场景 | - | - | - | 69.64 |
| HRNet_W18 | 高精度模型,适用于对图像分辨率较为敏感、对目标细节预测要求更高的服务器端场景| - | - | - | 79.36 | | HRNet_W18 | 高精度模型,适用于对图像分辨率较为敏感、对目标细节预测要求更高的服务器端场景| - | - | - | 79.36 |
......
...@@ -12,6 +12,7 @@ PaddleX目前提供了DeepLabv3p、UNet、HRNet和FastSCNN四种语义分割结 ...@@ -12,6 +12,7 @@ PaddleX目前提供了DeepLabv3p、UNet、HRNet和FastSCNN四种语义分割结
| :---------------- | :------- | :------- | :--------- | :--------- | :----- | | :---------------- | :------- | :------- | :--------- | :--------- | :----- |
| [DeepLabv3p-MobileNetV2-x0.25](https://github.com/PaddlePaddle/PaddleX/blob/develop/tutorials/train/semantic_segmentation/deeplabv3p_mobilenetv2_x0.25.py) | - | 2.9MB | - | - | 模型小,预测速度快,适用于低性能或移动端设备 | | [DeepLabv3p-MobileNetV2-x0.25](https://github.com/PaddlePaddle/PaddleX/blob/develop/tutorials/train/semantic_segmentation/deeplabv3p_mobilenetv2_x0.25.py) | - | 2.9MB | - | - | 模型小,预测速度快,适用于低性能或移动端设备 |
| [DeepLabv3p-MobileNetV2-x1.0](https://github.com/PaddlePaddle/PaddleX/blob/develop/tutorials/train/semantic_segmentation/deeplabv3p_mobilenetv2.py) | 69.8% | 11MB | - | - | 模型小,预测速度快,适用于低性能或移动端设备 | | [DeepLabv3p-MobileNetV2-x1.0](https://github.com/PaddlePaddle/PaddleX/blob/develop/tutorials/train/semantic_segmentation/deeplabv3p_mobilenetv2.py) | 69.8% | 11MB | - | - | 模型小,预测速度快,适用于低性能或移动端设备 |
| [DeepLabv3_MobileNetV3_large_x1_0_ssld](https://github.com/PaddlePaddle/PaddleX/blob/develop/tutorials/train/semantic_segmentation/deeplabv3p_mobilenetv3_large_ssld.py) | 73.28% | 9.3MB | - | - | 模型小,预测速度快,精度较高,适用于低性能或移动端设备 |
| [DeepLabv3p-Xception65](https://github.com/PaddlePaddle/PaddleX/blob/develop/tutorials/train/semantic_segmentation/deeplabv3p_xception65.py) | 79.3% | 158MB | - | - | 模型大,精度高,适用于服务端 | | [DeepLabv3p-Xception65](https://github.com/PaddlePaddle/PaddleX/blob/develop/tutorials/train/semantic_segmentation/deeplabv3p_xception65.py) | 79.3% | 158MB | - | - | 模型大,精度高,适用于服务端 |
| [UNet](https://github.com/PaddlePaddle/PaddleX/blob/develop/tutorials/train/semantic_segmentation/unet.py) | - | 52MB | - | - | 模型较大,精度高,适用于服务端 | | [UNet](https://github.com/PaddlePaddle/PaddleX/blob/develop/tutorials/train/semantic_segmentation/unet.py) | - | 52MB | - | - | 模型较大,精度高,适用于服务端 |
| [HRNet](https://github.com/PaddlePaddle/PaddleX/blob/develop/tutorials/train/semantic_segmentation/hrnet.py) | 79.4% | 37MB | - | - | 模型较小,模型精度高,适用于服务端部署 | | [HRNet](https://github.com/PaddlePaddle/PaddleX/blob/develop/tutorials/train/semantic_segmentation/hrnet.py) | 79.4% | 37MB | - | - | 模型较小,模型精度高,适用于服务端部署 |
......
...@@ -37,7 +37,7 @@ class DeepLabv3p(BaseAPI): ...@@ -37,7 +37,7 @@ class DeepLabv3p(BaseAPI):
num_classes (int): 类别数。 num_classes (int): 类别数。
backbone (str): DeepLabv3+的backbone网络,实现特征图的计算,取值范围为['Xception65', 'Xception41', backbone (str): DeepLabv3+的backbone网络,实现特征图的计算,取值范围为['Xception65', 'Xception41',
'MobileNetV2_x0.25', 'MobileNetV2_x0.5', 'MobileNetV2_x1.0', 'MobileNetV2_x1.5', 'MobileNetV2_x0.25', 'MobileNetV2_x0.5', 'MobileNetV2_x1.0', 'MobileNetV2_x1.5',
'MobileNetV2_x2.0']。默认'MobileNetV2_x1.0'。 'MobileNetV2_x2.0', 'MobileNetV3_large_x1_0_ssld']。默认'MobileNetV2_x1.0'。
output_stride (int): backbone 输出特征图相对于输入的下采样倍数,一般取值为8或16。默认16。 output_stride (int): backbone 输出特征图相对于输入的下采样倍数,一般取值为8或16。默认16。
aspp_with_sep_conv (bool): 在asspp模块是否采用separable convolutions。默认True。 aspp_with_sep_conv (bool): 在asspp模块是否采用separable convolutions。默认True。
decoder_use_sep_conv (bool): decoder模块是否采用separable convolutions。默认True。 decoder_use_sep_conv (bool): decoder模块是否采用separable convolutions。默认True。
...@@ -51,10 +51,13 @@ class DeepLabv3p(BaseAPI): ...@@ -51,10 +51,13 @@ class DeepLabv3p(BaseAPI):
自行计算相应的权重,每一类的权重为:每类的比例 * num_classes。class_weight取默认值None时,各类的权重1, 自行计算相应的权重,每一类的权重为:每类的比例 * num_classes。class_weight取默认值None时,各类的权重1,
即平时使用的交叉熵损失函数。 即平时使用的交叉熵损失函数。
ignore_index (int): label上忽略的值,label为ignore_index的像素不参与损失函数的计算。默认255。 ignore_index (int): label上忽略的值,label为ignore_index的像素不参与损失函数的计算。默认255。
pooling_crop_size (list): 当backbone为MobileNetV3_large_x1_0_ssld时,需设置为训练过程中模型输入大小, 格式为[W, H]。
在encoder模块中获取图像平均值时被用到,若为None,则直接求平均值;若为模型输入大小,则使用'pool'算子得到平均值。
默认值为None。
Raises: Raises:
ValueError: use_bce_loss或use_dice_loss为真且num_calsses > 2。 ValueError: use_bce_loss或use_dice_loss为真且num_calsses > 2。
ValueError: backbone取值不在['Xception65', 'Xception41', 'MobileNetV2_x0.25', ValueError: backbone取值不在['Xception65', 'Xception41', 'MobileNetV2_x0.25',
'MobileNetV2_x0.5', 'MobileNetV2_x1.0', 'MobileNetV2_x1.5', 'MobileNetV2_x2.0']之内。 'MobileNetV2_x0.5', 'MobileNetV2_x1.0', 'MobileNetV2_x1.5', 'MobileNetV2_x2.0', 'MobileNetV3_large_x1_0_ssld']之内。
ValueError: class_weight为list, 但长度不等于num_class。 ValueError: class_weight为list, 但长度不等于num_class。
class_weight为str, 但class_weight.low()不等于dynamic。 class_weight为str, 但class_weight.low()不等于dynamic。
TypeError: class_weight不为None时,其类型不是list或str。 TypeError: class_weight不为None时,其类型不是list或str。
...@@ -71,7 +74,8 @@ class DeepLabv3p(BaseAPI): ...@@ -71,7 +74,8 @@ class DeepLabv3p(BaseAPI):
use_bce_loss=False, use_bce_loss=False,
use_dice_loss=False, use_dice_loss=False,
class_weight=None, class_weight=None,
ignore_index=255): ignore_index=255,
pooling_crop_size=None):
self.init_params = locals() self.init_params = locals()
super(DeepLabv3p, self).__init__('segmenter') super(DeepLabv3p, self).__init__('segmenter')
# dice_loss或bce_loss只适用两类分割中 # dice_loss或bce_loss只适用两类分割中
...@@ -85,12 +89,12 @@ class DeepLabv3p(BaseAPI): ...@@ -85,12 +89,12 @@ class DeepLabv3p(BaseAPI):
if backbone not in [ if backbone not in [
'Xception65', 'Xception41', 'MobileNetV2_x0.25', 'Xception65', 'Xception41', 'MobileNetV2_x0.25',
'MobileNetV2_x0.5', 'MobileNetV2_x1.0', 'MobileNetV2_x1.5', 'MobileNetV2_x0.5', 'MobileNetV2_x1.0', 'MobileNetV2_x1.5',
'MobileNetV2_x2.0' 'MobileNetV2_x2.0', 'MobileNetV3_large_x1_0_ssld'
]: ]:
raise ValueError( raise ValueError(
"backbone: {} is set wrong. it should be one of " "backbone: {} is set wrong. it should be one of "
"('Xception65', 'Xception41', 'MobileNetV2_x0.25', 'MobileNetV2_x0.5'," "('Xception65', 'Xception41', 'MobileNetV2_x0.25', 'MobileNetV2_x0.5',"
" 'MobileNetV2_x1.0', 'MobileNetV2_x1.5', 'MobileNetV2_x2.0')". " 'MobileNetV2_x1.0', 'MobileNetV2_x1.5', 'MobileNetV2_x2.0', 'MobileNetV3_large_x1_0_ssld')".
format(backbone)) format(backbone))
if class_weight is not None: if class_weight is not None:
...@@ -121,6 +125,30 @@ class DeepLabv3p(BaseAPI): ...@@ -121,6 +125,30 @@ class DeepLabv3p(BaseAPI):
self.labels = None self.labels = None
self.sync_bn = True self.sync_bn = True
self.fixed_input_shape = None self.fixed_input_shape = None
self.pooling_stride = [1, 1]
self.pooling_crop_size = pooling_crop_size
self.aspp_with_se = False
self.se_use_qsigmoid = False
self.aspp_convs_filters = 256
self.aspp_with_concat_projection = True
self.add_image_level_feature = True
self.use_sum_merge = False
self.conv_filters = 256
self.output_is_logits = False
self.backbone_lr_mult_list = None
if 'MobileNetV3' in backbone:
self.output_stride = 32
self.pooling_stride = (4, 5)
self.aspp_with_se = True
self.se_use_qsigmoid = True
self.aspp_convs_filters = 128
self.aspp_with_concat_projection = False
self.add_image_level_feature = False
self.use_sum_merge = True
self.output_is_logits = True
if self.output_is_logits:
self.conv_filters = self.num_classes
self.backbone_lr_mult_list = [0.15, 0.35, 0.65, 0.85, 1]
def _get_backbone(self, backbone): def _get_backbone(self, backbone):
def mobilenetv2(backbone): def mobilenetv2(backbone):
...@@ -167,10 +195,22 @@ class DeepLabv3p(BaseAPI): ...@@ -167,10 +195,22 @@ class DeepLabv3p(BaseAPI):
end_points=end_points, end_points=end_points,
decode_points=decode_points) decode_points=decode_points)
def mobilenetv3(backbone):
scale = 1.0
lr_mult_list = self.backbone_lr_mult_list
return paddlex.cv.nets.MobileNetV3(
scale=scale,
model_name='large',
output_stride=self.output_stride,
lr_mult_list=lr_mult_list,
for_seg=True)
if 'Xception' in backbone: if 'Xception' in backbone:
return xception(backbone) return xception(backbone)
elif 'MobileNetV2' in backbone: elif 'MobileNetV2' in backbone:
return mobilenetv2(backbone) return mobilenetv2(backbone)
elif 'MobileNetV3' in backbone:
return mobilenetv3(backbone)
def build_net(self, mode='train'): def build_net(self, mode='train'):
model = paddlex.cv.nets.segmentation.DeepLabv3p( model = paddlex.cv.nets.segmentation.DeepLabv3p(
...@@ -186,7 +226,17 @@ class DeepLabv3p(BaseAPI): ...@@ -186,7 +226,17 @@ class DeepLabv3p(BaseAPI):
use_dice_loss=self.use_dice_loss, use_dice_loss=self.use_dice_loss,
class_weight=self.class_weight, class_weight=self.class_weight,
ignore_index=self.ignore_index, ignore_index=self.ignore_index,
fixed_input_shape=self.fixed_input_shape) fixed_input_shape=self.fixed_input_shape,
pooling_stride=self.pooling_stride,
pooling_crop_size=self.pooling_crop_size,
aspp_with_se=self.aspp_with_se,
se_use_qsigmoid=self.se_use_qsigmoid,
aspp_convs_filters=self.aspp_convs_filters,
aspp_with_concat_projection=self.aspp_with_concat_projection,
add_image_level_feature=self.add_image_level_feature,
use_sum_merge=self.use_sum_merge,
conv_filters=self.conv_filters,
output_is_logits=self.output_is_logits)
inputs = model.generate_inputs() inputs = model.generate_inputs()
model_out = model.build_net(inputs) model_out = model.build_net(inputs)
outputs = OrderedDict() outputs = OrderedDict()
......
...@@ -122,6 +122,8 @@ coco_pretrain = { ...@@ -122,6 +122,8 @@ coco_pretrain = {
} }
cityscapes_pretrain = { cityscapes_pretrain = {
'DeepLabv3p_MobileNetV3_large_x1_0_ssld_CITYSCAPES':
'https://paddleseg.bj.bcebos.com/models/deeplabv3p_mobilenetv3_large_cityscapes.tar.gz',
'DeepLabv3p_MobileNetV2_x1.0_CITYSCAPES': 'DeepLabv3p_MobileNetV2_x1.0_CITYSCAPES':
'https://paddleseg.bj.bcebos.com/models/mobilenet_cityscapes.tgz', 'https://paddleseg.bj.bcebos.com/models/mobilenet_cityscapes.tgz',
'DeepLabv3p_Xception65_CITYSCAPES': 'DeepLabv3p_Xception65_CITYSCAPES':
...@@ -167,7 +169,8 @@ def get_pretrain_weights(flag, class_name, backbone, save_dir): ...@@ -167,7 +169,8 @@ def get_pretrain_weights(flag, class_name, backbone, save_dir):
flag = 'IMAGENET' flag = 'IMAGENET'
if class_name == 'DeepLabv3p' and backbone in [ if class_name == 'DeepLabv3p' and backbone in [
'Xception41', 'MobileNetV2_x0.25', 'MobileNetV2_x0.5', 'Xception41', 'MobileNetV2_x0.25', 'MobileNetV2_x0.5',
'MobileNetV2_x1.5', 'MobileNetV2_x2.0' 'MobileNetV2_x1.5', 'MobileNetV2_x2.0',
'MobileNetV3_large_x1_0_ssld'
]: ]:
model_name = '{}_{}'.format(class_name, backbone) model_name = '{}_{}'.format(class_name, backbone)
logging.warning(warning_info.format(model_name, flag, 'IMAGENET')) logging.warning(warning_info.format(model_name, flag, 'IMAGENET'))
......
...@@ -42,7 +42,9 @@ class MobileNetV3(): ...@@ -42,7 +42,9 @@ class MobileNetV3():
extra_block_filters=[[256, 512], [128, 256], [128, 256], extra_block_filters=[[256, 512], [128, 256], [128, 256],
[64, 128]], [64, 128]],
num_classes=None, num_classes=None,
lr_mult_list=[1.0, 1.0, 1.0, 1.0, 1.0]): lr_mult_list=[1.0, 1.0, 1.0, 1.0, 1.0],
for_seg=False,
output_stride=None):
assert len(lr_mult_list) == 5, \ assert len(lr_mult_list) == 5, \
"lr_mult_list length in MobileNetV3 must be 5 but got {}!!".format( "lr_mult_list length in MobileNetV3 must be 5 but got {}!!".format(
len(lr_mult_list)) len(lr_mult_list))
...@@ -57,6 +59,54 @@ class MobileNetV3(): ...@@ -57,6 +59,54 @@ class MobileNetV3():
self.num_classes = num_classes self.num_classes = num_classes
self.lr_mult_list = lr_mult_list self.lr_mult_list = lr_mult_list
self.curr_stage = 0 self.curr_stage = 0
self.for_seg = for_seg
self.decode_point = None
if self.for_seg:
if model_name == "large":
self.cfg = [
# k, exp, c, se, nl, s,
[3, 16, 16, False, 'relu', 1],
[3, 64, 24, False, 'relu', 2],
[3, 72, 24, False, 'relu', 1],
[5, 72, 40, True, 'relu', 2],
[5, 120, 40, True, 'relu', 1],
[5, 120, 40, True, 'relu', 1],
[3, 240, 80, False, 'hard_swish', 2],
[3, 200, 80, False, 'hard_swish', 1],
[3, 184, 80, False, 'hard_swish', 1],
[3, 184, 80, False, 'hard_swish', 1],
[3, 480, 112, True, 'hard_swish', 1],
[3, 672, 112, True, 'hard_swish', 1],
# The number of channels in the last 4 stages is reduced by a
# factor of 2 compared to the standard implementation.
[5, 336, 80, True, 'hard_swish', 2],
[5, 480, 80, True, 'hard_swish', 1],
[5, 480, 80, True, 'hard_swish', 1],
]
self.cls_ch_squeeze = 480
self.cls_ch_expand = 1280
self.lr_interval = 3
elif model_name == "small":
self.cfg = [
# k, exp, c, se, nl, s,
[3, 16, 16, True, 'relu', 2],
[3, 72, 24, False, 'relu', 2],
[3, 88, 24, False, 'relu', 1],
[5, 96, 40, True, 'hard_swish', 2],
[5, 240, 40, True, 'hard_swish', 1],
[5, 240, 40, True, 'hard_swish', 1],
[5, 120, 48, True, 'hard_swish', 1],
[5, 144, 48, True, 'hard_swish', 1],
# The number of channels in the last 4 stages is reduced by a
# factor of 2 compared to the standard implementation.
[5, 144, 48, True, 'hard_swish', 2],
[5, 288, 48, True, 'hard_swish', 1],
[5, 288, 48, True, 'hard_swish', 1],
]
else:
raise NotImplementedError
else:
if model_name == "large": if model_name == "large":
self.cfg = [ self.cfg = [
# kernel_size, expand, channel, se_block, act_mode, stride # kernel_size, expand, channel, se_block, act_mode, stride
...@@ -100,6 +150,21 @@ class MobileNetV3(): ...@@ -100,6 +150,21 @@ class MobileNetV3():
else: else:
raise NotImplementedError raise NotImplementedError
self.modify_bottle_params(output_stride)
def modify_bottle_params(self, output_stride=None):
if output_stride is not None and output_stride % 2 != 0:
raise Exception("output stride must to be even number")
if output_stride is None:
return
else:
stride = 2
for i, _cfg in enumerate(self.cfg):
stride = stride * _cfg[-1]
if stride > output_stride:
s = 1
self.cfg[i][-1] = s
def _conv_bn_layer(self, def _conv_bn_layer(self,
input, input,
filter_size, filter_size,
...@@ -153,6 +218,14 @@ class MobileNetV3(): ...@@ -153,6 +218,14 @@ class MobileNetV3():
bn = fluid.layers.relu6(bn) bn = fluid.layers.relu6(bn)
return bn return bn
def make_divisible(self, v, divisor=8, min_value=None):
if min_value is None:
min_value = divisor
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
if new_v < 0.9 * v:
new_v += divisor
return new_v
def _hard_swish(self, x): def _hard_swish(self, x):
return x * fluid.layers.relu6(x + 3) / 6. return x * fluid.layers.relu6(x + 3) / 6.
...@@ -220,6 +293,9 @@ class MobileNetV3(): ...@@ -220,6 +293,9 @@ class MobileNetV3():
use_cudnn=False, use_cudnn=False,
name=name + '_depthwise') name=name + '_depthwise')
if self.curr_stage == 5:
self.decode_point = conv1
if use_se: if use_se:
conv1 = self._se_block( conv1 = self._se_block(
input=conv1, num_out_filter=num_mid_filter, name=name + '_se') input=conv1, num_out_filter=num_mid_filter, name=name + '_se')
...@@ -282,7 +358,7 @@ class MobileNetV3(): ...@@ -282,7 +358,7 @@ class MobileNetV3():
conv = self._conv_bn_layer( conv = self._conv_bn_layer(
input, input,
filter_size=3, filter_size=3,
num_filters=inplanes if scale <= 1.0 else int(inplanes * scale), num_filters=self.make_divisible(inplanes * scale),
stride=2, stride=2,
padding=1, padding=1,
num_groups=1, num_groups=1,
...@@ -290,6 +366,7 @@ class MobileNetV3(): ...@@ -290,6 +366,7 @@ class MobileNetV3():
act='hard_swish', act='hard_swish',
name='conv1') name='conv1')
i = 0 i = 0
inplanes = self.make_divisible(inplanes * scale)
for layer_cfg in cfg: for layer_cfg in cfg:
self.block_stride *= layer_cfg[5] self.block_stride *= layer_cfg[5]
if layer_cfg[5] == 2: if layer_cfg[5] == 2:
...@@ -297,19 +374,32 @@ class MobileNetV3(): ...@@ -297,19 +374,32 @@ class MobileNetV3():
conv = self._residual_unit( conv = self._residual_unit(
input=conv, input=conv,
num_in_filter=inplanes, num_in_filter=inplanes,
num_mid_filter=int(scale * layer_cfg[1]), num_mid_filter=self.make_divisible(scale * layer_cfg[1]),
num_out_filter=int(scale * layer_cfg[2]), num_out_filter=self.make_divisible(scale * layer_cfg[2]),
act=layer_cfg[4], act=layer_cfg[4],
stride=layer_cfg[5], stride=layer_cfg[5],
filter_size=layer_cfg[0], filter_size=layer_cfg[0],
use_se=layer_cfg[3], use_se=layer_cfg[3],
name='conv' + str(i + 2)) name='conv' + str(i + 2))
inplanes = self.make_divisible(scale * layer_cfg[2])
inplanes = int(scale * layer_cfg[2])
i += 1 i += 1
self.curr_stage = i self.curr_stage = i
blocks.append(conv) blocks.append(conv)
if self.for_seg:
conv = self._conv_bn_layer(
input=conv,
filter_size=1,
num_filters=self.make_divisible(scale * self.cls_ch_squeeze),
stride=1,
padding=0,
num_groups=1,
if_act=True,
act='hard_swish',
name='conv_last')
return conv, self.decode_point
if self.num_classes: if self.num_classes:
conv = self._conv_bn_layer( conv = self._conv_bn_layer(
input=conv, input=conv,
......
...@@ -21,7 +21,7 @@ from collections import OrderedDict ...@@ -21,7 +21,7 @@ from collections import OrderedDict
import paddle.fluid as fluid import paddle.fluid as fluid
from .model_utils.libs import scope, name_scope from .model_utils.libs import scope, name_scope
from .model_utils.libs import bn, bn_relu, relu from .model_utils.libs import bn, bn_relu, relu, qsigmoid
from .model_utils.libs import conv, max_pool, deconv from .model_utils.libs import conv, max_pool, deconv
from .model_utils.libs import separate_conv from .model_utils.libs import separate_conv
from .model_utils.libs import sigmoid_to_softmax from .model_utils.libs import sigmoid_to_softmax
...@@ -82,7 +82,17 @@ class DeepLabv3p(object): ...@@ -82,7 +82,17 @@ class DeepLabv3p(object):
use_dice_loss=False, use_dice_loss=False,
class_weight=None, class_weight=None,
ignore_index=255, ignore_index=255,
fixed_input_shape=None): fixed_input_shape=None,
pooling_stride=[1, 1],
pooling_crop_size=None,
aspp_with_se=False,
se_use_qsigmoid=False,
aspp_convs_filters=256,
aspp_with_concat_projection=True,
add_image_level_feature=True,
use_sum_merge=False,
conv_filters=256,
output_is_logits=False):
# dice_loss或bce_loss只适用两类分割中 # dice_loss或bce_loss只适用两类分割中
if num_classes > 2 and (use_bce_loss or use_dice_loss): if num_classes > 2 and (use_bce_loss or use_dice_loss):
raise ValueError( raise ValueError(
...@@ -117,6 +127,17 @@ class DeepLabv3p(object): ...@@ -117,6 +127,17 @@ class DeepLabv3p(object):
self.encoder_with_aspp = encoder_with_aspp self.encoder_with_aspp = encoder_with_aspp
self.enable_decoder = enable_decoder self.enable_decoder = enable_decoder
self.fixed_input_shape = fixed_input_shape self.fixed_input_shape = fixed_input_shape
self.output_is_logits = output_is_logits
self.aspp_convs_filters = aspp_convs_filters
self.output_stride = output_stride
self.pooling_crop_size = pooling_crop_size
self.pooling_stride = pooling_stride
self.se_use_qsigmoid = se_use_qsigmoid
self.aspp_with_concat_projection = aspp_with_concat_projection
self.add_image_level_feature = add_image_level_feature
self.aspp_with_se = aspp_with_se
self.use_sum_merge = use_sum_merge
self.conv_filters = conv_filters
def _encoder(self, input): def _encoder(self, input):
# 编码器配置,采用ASPP架构,pooling + 1x1_conv + 三个不同尺度的空洞卷积并行, concat后1x1conv # 编码器配置,采用ASPP架构,pooling + 1x1_conv + 三个不同尺度的空洞卷积并行, concat后1x1conv
...@@ -129,19 +150,36 @@ class DeepLabv3p(object): ...@@ -129,19 +150,36 @@ class DeepLabv3p(object):
elif self.output_stride == 8: elif self.output_stride == 8:
aspp_ratios = [12, 24, 36] aspp_ratios = [12, 24, 36]
else: else:
raise Exception("DeepLabv3p only support stride 8 or 16") aspp_ratios = []
param_attr = fluid.ParamAttr( param_attr = fluid.ParamAttr(
name=name_scope + 'weights', name=name_scope + 'weights',
regularizer=None, regularizer=None,
initializer=fluid.initializer.TruncatedNormal( initializer=fluid.initializer.TruncatedNormal(
loc=0.0, scale=0.06)) loc=0.0, scale=0.06))
concat_logits = []
with scope('encoder'): with scope('encoder'):
channel = 256 channel = self.aspp_convs_filters
with scope("image_pool"): with scope("image_pool"):
if self.pooling_crop_size is None:
image_avg = fluid.layers.reduce_mean( image_avg = fluid.layers.reduce_mean(
input, [2, 3], keep_dim=True) input, [2, 3], keep_dim=True)
image_avg = bn_relu( else:
pool_w = int((self.pooling_crop_size[0] - 1.0) /
self.output_stride + 1.0)
pool_h = int((self.pooling_crop_size[1] - 1.0) /
self.output_stride + 1.0)
image_avg = fluid.layers.pool2d(
input,
pool_size=(pool_h, pool_w),
pool_stride=self.pooling_stride,
pool_type='avg',
pool_padding='VALID')
act = qsigmoid if self.se_use_qsigmoid else bn_relu
image_avg = act(
conv( conv(
image_avg, image_avg,
channel, channel,
...@@ -153,6 +191,8 @@ class DeepLabv3p(object): ...@@ -153,6 +191,8 @@ class DeepLabv3p(object):
input_shape = fluid.layers.shape(input) input_shape = fluid.layers.shape(input)
image_avg = fluid.layers.resize_bilinear(image_avg, image_avg = fluid.layers.resize_bilinear(image_avg,
input_shape[2:]) input_shape[2:])
if self.add_image_level_feature:
concat_logits.append(image_avg)
with scope("aspp0"): with scope("aspp0"):
aspp0 = bn_relu( aspp0 = bn_relu(
...@@ -164,6 +204,9 @@ class DeepLabv3p(object): ...@@ -164,6 +204,9 @@ class DeepLabv3p(object):
groups=1, groups=1,
padding=0, padding=0,
param_attr=param_attr)) param_attr=param_attr))
concat_logits.append(aspp0)
if aspp_ratios:
with scope("aspp1"): with scope("aspp1"):
if self.aspp_with_sep_conv: if self.aspp_with_sep_conv:
aspp1 = separate_conv( aspp1 = separate_conv(
...@@ -183,6 +226,7 @@ class DeepLabv3p(object): ...@@ -183,6 +226,7 @@ class DeepLabv3p(object):
dilation=aspp_ratios[0], dilation=aspp_ratios[0],
padding=aspp_ratios[0], padding=aspp_ratios[0],
param_attr=param_attr)) param_attr=param_attr))
concat_logits.append(aspp1)
with scope("aspp2"): with scope("aspp2"):
if self.aspp_with_sep_conv: if self.aspp_with_sep_conv:
aspp2 = separate_conv( aspp2 = separate_conv(
...@@ -202,6 +246,7 @@ class DeepLabv3p(object): ...@@ -202,6 +246,7 @@ class DeepLabv3p(object):
dilation=aspp_ratios[1], dilation=aspp_ratios[1],
padding=aspp_ratios[1], padding=aspp_ratios[1],
param_attr=param_attr)) param_attr=param_attr))
concat_logits.append(aspp2)
with scope("aspp3"): with scope("aspp3"):
if self.aspp_with_sep_conv: if self.aspp_with_sep_conv:
aspp3 = separate_conv( aspp3 = separate_conv(
...@@ -221,9 +266,11 @@ class DeepLabv3p(object): ...@@ -221,9 +266,11 @@ class DeepLabv3p(object):
dilation=aspp_ratios[2], dilation=aspp_ratios[2],
padding=aspp_ratios[2], padding=aspp_ratios[2],
param_attr=param_attr)) param_attr=param_attr))
concat_logits.append(aspp3)
with scope("concat"): with scope("concat"):
data = fluid.layers.concat( data = fluid.layers.concat(concat_logits, axis=1)
[image_avg, aspp0, aspp1, aspp2, aspp3], axis=1) if self.aspp_with_concat_projection:
data = bn_relu( data = bn_relu(
conv( conv(
data, data,
...@@ -234,19 +281,38 @@ class DeepLabv3p(object): ...@@ -234,19 +281,38 @@ class DeepLabv3p(object):
padding=0, padding=0,
param_attr=param_attr)) param_attr=param_attr))
data = fluid.layers.dropout(data, 0.9) data = fluid.layers.dropout(data, 0.9)
if self.aspp_with_se:
data = data * image_avg
return data return data
def _decoder(self, encode_data, decode_shortcut): def _decoder_with_sum_merge(self, encode_data, decode_shortcut,
# 解码器配置 param_attr):
# encode_data:编码器输出 decode_shortcut_shape = fluid.layers.shape(decode_shortcut)
# decode_shortcut: 从backbone引出的分支, resize后与encode_data concat encode_data = fluid.layers.resize_bilinear(encode_data,
# decoder_use_sep_conv: 默认为真,则concat后连接两个可分离卷积,否则为普通卷积 decode_shortcut_shape[2:])
param_attr = fluid.ParamAttr(
name=name_scope + 'weights', encode_data = conv(
regularizer=None, encode_data,
initializer=fluid.initializer.TruncatedNormal( self.conv_filters,
loc=0.0, scale=0.06)) 1,
with scope('decoder'): 1,
groups=1,
padding=0,
param_attr=param_attr)
with scope('merge'):
decode_shortcut = conv(
decode_shortcut,
self.conv_filters,
1,
1,
groups=1,
padding=0,
param_attr=param_attr)
return encode_data + decode_shortcut
def _decoder_with_concat(self, encode_data, decode_shortcut, param_attr):
with scope('concat'): with scope('concat'):
decode_shortcut = bn_relu( decode_shortcut = bn_relu(
conv( conv(
...@@ -266,16 +332,16 @@ class DeepLabv3p(object): ...@@ -266,16 +332,16 @@ class DeepLabv3p(object):
if self.decoder_use_sep_conv: if self.decoder_use_sep_conv:
with scope("separable_conv1"): with scope("separable_conv1"):
encode_data = separate_conv( encode_data = separate_conv(
encode_data, 256, 1, 3, dilation=1, act=relu) encode_data, self.conv_filters, 1, 3, dilation=1, act=relu)
with scope("separable_conv2"): with scope("separable_conv2"):
encode_data = separate_conv( encode_data = separate_conv(
encode_data, 256, 1, 3, dilation=1, act=relu) encode_data, self.conv_filters, 1, 3, dilation=1, act=relu)
else: else:
with scope("decoder_conv1"): with scope("decoder_conv1"):
encode_data = bn_relu( encode_data = bn_relu(
conv( conv(
encode_data, encode_data,
256, self.conv_filters,
stride=1, stride=1,
filter_size=3, filter_size=3,
dilation=1, dilation=1,
...@@ -285,7 +351,7 @@ class DeepLabv3p(object): ...@@ -285,7 +351,7 @@ class DeepLabv3p(object):
encode_data = bn_relu( encode_data = bn_relu(
conv( conv(
encode_data, encode_data,
256, self.conv_filters,
stride=1, stride=1,
filter_size=3, filter_size=3,
dilation=1, dilation=1,
...@@ -293,6 +359,25 @@ class DeepLabv3p(object): ...@@ -293,6 +359,25 @@ class DeepLabv3p(object):
param_attr=param_attr)) param_attr=param_attr))
return encode_data return encode_data
def _decoder(self, encode_data, decode_shortcut):
# 解码器配置
# encode_data:编码器输出
# decode_shortcut: 从backbone引出的分支, resize后与encode_data concat
# decoder_use_sep_conv: 默认为真,则concat后连接两个可分离卷积,否则为普通卷积
param_attr = fluid.ParamAttr(
name=name_scope + 'weights',
regularizer=None,
initializer=fluid.initializer.TruncatedNormal(
loc=0.0, scale=0.06))
with scope('decoder'):
if self.use_sum_merge:
return self._decoder_with_sum_merge(
encode_data, decode_shortcut, param_attr)
return self._decoder_with_concat(encode_data, decode_shortcut,
param_attr)
def _get_loss(self, logit, label, mask): def _get_loss(self, logit, label, mask):
avg_loss = 0 avg_loss = 0
if not (self.use_dice_loss or self.use_bce_loss): if not (self.use_dice_loss or self.use_bce_loss):
...@@ -335,6 +420,9 @@ class DeepLabv3p(object): ...@@ -335,6 +420,9 @@ class DeepLabv3p(object):
self.num_classes = 1 self.num_classes = 1
image = inputs['image'] image = inputs['image']
if 'MobileNetV3' in self.backbone.__class__.__name__:
data, decode_shortcut = self.backbone(image)
else:
data, decode_shortcuts = self.backbone(image) data, decode_shortcuts = self.backbone(image)
decode_shortcut = decode_shortcuts[self.backbone.decode_points] decode_shortcut = decode_shortcuts[self.backbone.decode_points]
...@@ -351,6 +439,7 @@ class DeepLabv3p(object): ...@@ -351,6 +439,7 @@ class DeepLabv3p(object):
regularization_coeff=0.0), regularization_coeff=0.0),
initializer=fluid.initializer.TruncatedNormal( initializer=fluid.initializer.TruncatedNormal(
loc=0.0, scale=0.01)) loc=0.0, scale=0.01))
if not self.output_is_logits:
with scope('logit'): with scope('logit'):
with fluid.name_scope('last_conv'): with fluid.name_scope('last_conv'):
logit = conv( logit = conv(
...@@ -361,6 +450,9 @@ class DeepLabv3p(object): ...@@ -361,6 +450,9 @@ class DeepLabv3p(object):
padding=0, padding=0,
bias_attr=True, bias_attr=True,
param_attr=param_attr) param_attr=param_attr)
else:
logit = data
image_shape = fluid.layers.shape(image) image_shape = fluid.layers.shape(image)
logit = fluid.layers.resize_bilinear(logit, image_shape[2:]) logit = fluid.layers.resize_bilinear(logit, image_shape[2:])
......
...@@ -112,6 +112,10 @@ def bn_relu(data, norm_type='bn', eps=1e-5): ...@@ -112,6 +112,10 @@ def bn_relu(data, norm_type='bn', eps=1e-5):
return fluid.layers.relu(bn(data, norm_type=norm_type, eps=eps)) return fluid.layers.relu(bn(data, norm_type=norm_type, eps=eps))
def qsigmoid(data):
return fluid.layers.relu6(data + 3) * 0.16667
def relu(data): def relu(data):
return fluid.layers.relu(data) return fluid.layers.relu(data)
......
...@@ -20,6 +20,7 @@ ...@@ -20,6 +20,7 @@
|instance_segmentation/mask_rcnn_r18_fpn.py | 实例分割MaskRCNN | 小度熊分拣 | |instance_segmentation/mask_rcnn_r18_fpn.py | 实例分割MaskRCNN | 小度熊分拣 |
|instance_segmentation/mask_rcnn_f50_fpn.py | 实例分割MaskRCNN | 小度熊分拣 | |instance_segmentation/mask_rcnn_f50_fpn.py | 实例分割MaskRCNN | 小度熊分拣 |
|semantic_segmentation/deeplabv3p_mobilenetv2.py | 语义分割DeepLabV3 | 视盘分割 | |semantic_segmentation/deeplabv3p_mobilenetv2.py | 语义分割DeepLabV3 | 视盘分割 |
|semantic_segmentation/deeplabv3p_mobilenetv2.py | 语义分割DeepLabV3 | 视盘分割 |
|semantic_segmentation/deeplabv3p_mobilenetv2_x0.25.py | 语义分割DeepLabV3 | 视盘分割 | |semantic_segmentation/deeplabv3p_mobilenetv2_x0.25.py | 语义分割DeepLabV3 | 视盘分割 |
|semantic_segmentation/deeplabv3p_xception65.py | 语义分割DeepLabV3 | 视盘分割 | |semantic_segmentation/deeplabv3p_xception65.py | 语义分割DeepLabV3 | 视盘分割 |
|semantic_segmentation/fast_scnn.py | 语义分割FastSCNN | 视盘分割 | |semantic_segmentation/fast_scnn.py | 语义分割FastSCNN | 视盘分割 |
......
# 环境变量配置,用于控制是否使用GPU
# 说明文档:https://paddlex.readthedocs.io/zh_CN/develop/appendix/parameters.html#gpu
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
import paddlex as pdx
from paddlex.seg import transforms
# 下载和解压视盘分割数据集
optic_dataset = 'https://bj.bcebos.com/paddlex/datasets/optic_disc_seg.tar.gz'
pdx.utils.download_and_decompress(optic_dataset, path='./')
# 定义训练和验证时的transforms
# API说明 https://paddlex.readthedocs.io/zh_CN/develop/apis/transforms/seg_transforms.html
train_transforms = transforms.Compose([
transforms.RandomHorizontalFlip(), transforms.ResizeRangeScaling(),
transforms.RandomPaddingCrop(crop_size=512), transforms.Normalize()
])
eval_transforms = transforms.Compose([
transforms.ResizeByLong(long_size=512),
transforms.Padding(target_size=512), transforms.Normalize()
])
# 定义训练和验证所用的数据集
# API说明:https://paddlex.readthedocs.io/zh_CN/develop/apis/datasets.html#paddlex-datasets-segdataset
train_dataset = pdx.datasets.SegDataset(
data_dir='optic_disc_seg',
file_list='optic_disc_seg/train_list.txt',
label_list='optic_disc_seg/labels.txt',
transforms=train_transforms,
shuffle=True)
eval_dataset = pdx.datasets.SegDataset(
data_dir='optic_disc_seg',
file_list='optic_disc_seg/val_list.txt',
label_list='optic_disc_seg/labels.txt',
transforms=eval_transforms)
# 初始化模型,并进行训练
# 可使用VisualDL查看训练指标,参考https://paddlex.readthedocs.io/zh_CN/develop/train/visualdl.html
num_classes = len(train_dataset.labels)
# API说明:https://paddlex.readthedocs.io/zh_CN/develop/apis/models/semantic_segmentation.html#paddlex-seg-deeplabv3p
model = pdx.seg.DeepLabv3p(
num_classes=num_classes,
backbone='MobileNetV3_large_x1_0_ssld',
pooling_crop_size=(512, 512))
# API说明:https://paddlex.readthedocs.io/zh_CN/develop/apis/models/semantic_segmentation.html#train
# 各参数介绍与调整说明:https://paddlex.readthedocs.io/zh_CN/develop/appendix/parameters.html
model.train(
num_epochs=40,
train_dataset=train_dataset,
train_batch_size=4,
eval_dataset=eval_dataset,
learning_rate=0.01,
save_dir='output/deeplabv3p_mobilenetv3_large_ssld',
use_vdl=True)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册