diff --git a/docs/apis/models/semantic_segmentation.md b/docs/apis/models/semantic_segmentation.md index a03aa0c2da3c7befe98ddfc5f356c7bd90ce7026..82b758d98f243e6f653c5e8d39d181b45e150587 100755 --- a/docs/apis/models/semantic_segmentation.md +++ b/docs/apis/models/semantic_segmentation.md @@ -3,7 +3,7 @@ ## paddlex.seg.DeepLabv3p ```python -paddlex.seg.DeepLabv3p(num_classes=2, backbone='MobileNetV2_x1.0', output_stride=16, aspp_with_sep_conv=True, decoder_use_sep_conv=True, encoder_with_aspp=True, enable_decoder=True, use_bce_loss=False, use_dice_loss=False, class_weight=None, ignore_index=255) +paddlex.seg.DeepLabv3p(num_classes=2, backbone='MobileNetV2_x1.0', output_stride=16, aspp_with_sep_conv=True, decoder_use_sep_conv=True, encoder_with_aspp=True, enable_decoder=True, use_bce_loss=False, use_dice_loss=False, class_weight=None, ignore_index=255, pooling_crop_size=None) ``` @@ -12,7 +12,7 @@ paddlex.seg.DeepLabv3p(num_classes=2, backbone='MobileNetV2_x1.0', output_stride > **参数** > > - **num_classes** (int): 类别数。 -> > - **backbone** (str): DeepLabv3+的backbone网络,实现特征图的计算,取值范围为['Xception65', 'Xception41', 'MobileNetV2_x0.25', 'MobileNetV2_x0.5', 'MobileNetV2_x1.0', 'MobileNetV2_x1.5', 'MobileNetV2_x2.0'],默认值为'MobileNetV2_x1.0'。 +> > - **backbone** (str): DeepLabv3+的backbone网络,实现特征图的计算,取值范围为['Xception65', 'Xception41', 'MobileNetV2_x0.25', 'MobileNetV2_x0.5', 'MobileNetV2_x1.0', 'MobileNetV2_x1.5', 'MobileNetV2_x2.0', 'MobileNetV3_large_x1_0_ssld'],默认值为'MobileNetV2_x1.0'。 > > - **output_stride** (int): backbone 输出特征图相对于输入的下采样倍数,一般取值为8或16。默认16。 > > - **aspp_with_sep_conv** (bool): decoder模块是否采用separable convolutions。默认True。 > > - **decoder_use_sep_conv** (bool): decoder模块是否采用separable convolutions。默认True。 @@ -22,6 +22,7 @@ paddlex.seg.DeepLabv3p(num_classes=2, backbone='MobileNetV2_x1.0', output_stride > > - **use_dice_loss** (bool): 是否使用dice loss作为网络的损失函数,只能用于两类分割,可与bce loss同时使用,当`use_bce_loss`和`use_dice_loss`都为False时,使用交叉熵损失函数。默认False。 > > - **class_weight** (list/str): 交叉熵损失函数各类损失的权重。当`class_weight`为list的时候,长度应为`num_classes`。当`class_weight`为str时, weight.lower()应为'dynamic',这时会根据每一轮各类像素的比重自行计算相应的权重,每一类的权重为:每类的比例 * num_classes。class_weight取默认值None是,各类的权重1,即平时使用的交叉熵损失函数。 > > - **ignore_index** (int): label上忽略的值,label为`ignore_index`的像素不参与损失函数的计算。默认255。 +> > - **pooling_crop_size** (int):当backbone为`MobileNetV3_large_x1_0_ssld`时,需设置为训练过程中模型输入大小,格式为[W, H]。例如模型输入大小为[512, 512], 则`pooling_crop_size`应该设置为[512, 512]。在encoder模块中获取图像平均值时被用到,若为None,则直接求平均值;若为模型输入大小,则使用`avg_pool`算子得到平均值。默认值None。 ### train diff --git a/docs/appendix/model_zoo.md b/docs/appendix/model_zoo.md index c2314bfe64519cd14fe34eea0adbe74dbd4758ee..4c2f3911f2d8afa95bfe0009cf37212d24d43065 100644 --- a/docs/appendix/model_zoo.md +++ b/docs/appendix/model_zoo.md @@ -81,6 +81,7 @@ | 模型 | 模型大小 | 预测时间(毫秒) | mIoU(%) | |:-------|:-----------|:-------------|:----------| +| [DeepLabv3_MobileNetV3_large_x1_0_ssld](https://paddleseg.bj.bcebos.com/models/deeplabv3p_mobilenetv3_large_cityscapes.tar.gz) | 9.3MB | - | 73.28 | | [DeepLabv3_MobileNetv2_x1.0](https://paddleseg.bj.bcebos.com/models/mobilenet_cityscapes.tgz) | 14.7MB | - | 69.8 | | [DeepLabv3_Xception65](https://paddleseg.bj.bcebos.com/models/xception65_bn_cityscapes.tgz) | 329.3MB | - | 79.3 | | [HRNet_W18](https://paddleseg.bj.bcebos.com/models/hrnet_w18_bn_cityscapes.tgz) | 77.3MB | | 79.36 | diff --git a/docs/examples/solutions.md b/docs/examples/solutions.md index 6d11d243dc4e19923025241dd48fb7d7fa60386b..ed1304c5e2067414790ced1bd01103110f87f619 100644 --- a/docs/examples/solutions.md +++ b/docs/examples/solutions.md @@ -80,6 +80,7 @@ PaddleX目前提供了实例分割MaskRCNN模型,支持5种不同的backbone | 模型 | 模型特点 | 存储体积 | GPU预测速度 | CPU(x86)预测速度(毫秒) | 骁龙855(ARM)预测速度 (毫秒)| mIoU | | :---- | :------- | :---------- | :---------- | :----- | :----- |:--- | | DeepLabv3p-MobileNetV2_x1.0 | 轻量级模型,适用于移动端场景| - | - | - | 69.8% | +| DeepLabv3-MobileNetV3_large_x1_0_ssld | 轻量级模型,适用于移动端场景| - | - | - | 73.28% | | HRNet_W18_Small_v1 | 轻量高速,适用于移动端场景 | - | - | - | - | | FastSCNN | 轻量高速,适用于追求高速预测的移动端或服务器端场景 | - | - | - | 69.64 | | HRNet_W18 | 高精度模型,适用于对图像分辨率较为敏感、对目标细节预测要求更高的服务器端场景| - | - | - | 79.36 | diff --git a/docs/train/semantic_segmentation.md b/docs/train/semantic_segmentation.md index eed540a8051ef52df0b0e695176c217270270a26..2224db4e7d8779e37821574672f91e92b93ab87e 100644 --- a/docs/train/semantic_segmentation.md +++ b/docs/train/semantic_segmentation.md @@ -12,6 +12,7 @@ PaddleX目前提供了DeepLabv3p、UNet、HRNet和FastSCNN四种语义分割结 | :---------------- | :------- | :------- | :--------- | :--------- | :----- | | [DeepLabv3p-MobileNetV2-x0.25](https://github.com/PaddlePaddle/PaddleX/blob/develop/tutorials/train/semantic_segmentation/deeplabv3p_mobilenetv2_x0.25.py) | - | 2.9MB | - | - | 模型小,预测速度快,适用于低性能或移动端设备 | | [DeepLabv3p-MobileNetV2-x1.0](https://github.com/PaddlePaddle/PaddleX/blob/develop/tutorials/train/semantic_segmentation/deeplabv3p_mobilenetv2.py) | 69.8% | 11MB | - | - | 模型小,预测速度快,适用于低性能或移动端设备 | +| [DeepLabv3_MobileNetV3_large_x1_0_ssld](https://github.com/PaddlePaddle/PaddleX/blob/develop/tutorials/train/semantic_segmentation/deeplabv3p_mobilenetv3_large_ssld.py) | 73.28% | 9.3MB | - | - | 模型小,预测速度快,精度较高,适用于低性能或移动端设备 | | [DeepLabv3p-Xception65](https://github.com/PaddlePaddle/PaddleX/blob/develop/tutorials/train/semantic_segmentation/deeplabv3p_xception65.py) | 79.3% | 158MB | - | - | 模型大,精度高,适用于服务端 | | [UNet](https://github.com/PaddlePaddle/PaddleX/blob/develop/tutorials/train/semantic_segmentation/unet.py) | - | 52MB | - | - | 模型较大,精度高,适用于服务端 | | [HRNet](https://github.com/PaddlePaddle/PaddleX/blob/develop/tutorials/train/semantic_segmentation/hrnet.py) | 79.4% | 37MB | - | - | 模型较小,模型精度高,适用于服务端部署 | diff --git a/paddlex/cv/models/deeplabv3p.py b/paddlex/cv/models/deeplabv3p.py index a6395f6617c16fb3a51b2fc7d73ad50c326e4858..2c9170c5db5b7213d5149100e53db9afb0368237 100644 --- a/paddlex/cv/models/deeplabv3p.py +++ b/paddlex/cv/models/deeplabv3p.py @@ -37,7 +37,7 @@ class DeepLabv3p(BaseAPI): num_classes (int): 类别数。 backbone (str): DeepLabv3+的backbone网络,实现特征图的计算,取值范围为['Xception65', 'Xception41', 'MobileNetV2_x0.25', 'MobileNetV2_x0.5', 'MobileNetV2_x1.0', 'MobileNetV2_x1.5', - 'MobileNetV2_x2.0']。默认'MobileNetV2_x1.0'。 + 'MobileNetV2_x2.0', 'MobileNetV3_large_x1_0_ssld']。默认'MobileNetV2_x1.0'。 output_stride (int): backbone 输出特征图相对于输入的下采样倍数,一般取值为8或16。默认16。 aspp_with_sep_conv (bool): 在asspp模块是否采用separable convolutions。默认True。 decoder_use_sep_conv (bool): decoder模块是否采用separable convolutions。默认True。 @@ -51,10 +51,13 @@ class DeepLabv3p(BaseAPI): 自行计算相应的权重,每一类的权重为:每类的比例 * num_classes。class_weight取默认值None时,各类的权重1, 即平时使用的交叉熵损失函数。 ignore_index (int): label上忽略的值,label为ignore_index的像素不参与损失函数的计算。默认255。 + pooling_crop_size (list): 当backbone为MobileNetV3_large_x1_0_ssld时,需设置为训练过程中模型输入大小, 格式为[W, H]。 + 在encoder模块中获取图像平均值时被用到,若为None,则直接求平均值;若为模型输入大小,则使用'pool'算子得到平均值。 + 默认值为None。 Raises: ValueError: use_bce_loss或use_dice_loss为真且num_calsses > 2。 ValueError: backbone取值不在['Xception65', 'Xception41', 'MobileNetV2_x0.25', - 'MobileNetV2_x0.5', 'MobileNetV2_x1.0', 'MobileNetV2_x1.5', 'MobileNetV2_x2.0']之内。 + 'MobileNetV2_x0.5', 'MobileNetV2_x1.0', 'MobileNetV2_x1.5', 'MobileNetV2_x2.0', 'MobileNetV3_large_x1_0_ssld']之内。 ValueError: class_weight为list, 但长度不等于num_class。 class_weight为str, 但class_weight.low()不等于dynamic。 TypeError: class_weight不为None时,其类型不是list或str。 @@ -71,7 +74,8 @@ class DeepLabv3p(BaseAPI): use_bce_loss=False, use_dice_loss=False, class_weight=None, - ignore_index=255): + ignore_index=255, + pooling_crop_size=None): self.init_params = locals() super(DeepLabv3p, self).__init__('segmenter') # dice_loss或bce_loss只适用两类分割中 @@ -85,12 +89,12 @@ class DeepLabv3p(BaseAPI): if backbone not in [ 'Xception65', 'Xception41', 'MobileNetV2_x0.25', 'MobileNetV2_x0.5', 'MobileNetV2_x1.0', 'MobileNetV2_x1.5', - 'MobileNetV2_x2.0' + 'MobileNetV2_x2.0', 'MobileNetV3_large_x1_0_ssld' ]: raise ValueError( "backbone: {} is set wrong. it should be one of " "('Xception65', 'Xception41', 'MobileNetV2_x0.25', 'MobileNetV2_x0.5'," - " 'MobileNetV2_x1.0', 'MobileNetV2_x1.5', 'MobileNetV2_x2.0')". + " 'MobileNetV2_x1.0', 'MobileNetV2_x1.5', 'MobileNetV2_x2.0', 'MobileNetV3_large_x1_0_ssld')". format(backbone)) if class_weight is not None: @@ -121,6 +125,30 @@ class DeepLabv3p(BaseAPI): self.labels = None self.sync_bn = True self.fixed_input_shape = None + self.pooling_stride = [1, 1] + self.pooling_crop_size = pooling_crop_size + self.aspp_with_se = False + self.se_use_qsigmoid = False + self.aspp_convs_filters = 256 + self.aspp_with_concat_projection = True + self.add_image_level_feature = True + self.use_sum_merge = False + self.conv_filters = 256 + self.output_is_logits = False + self.backbone_lr_mult_list = None + if 'MobileNetV3' in backbone: + self.output_stride = 32 + self.pooling_stride = (4, 5) + self.aspp_with_se = True + self.se_use_qsigmoid = True + self.aspp_convs_filters = 128 + self.aspp_with_concat_projection = False + self.add_image_level_feature = False + self.use_sum_merge = True + self.output_is_logits = True + if self.output_is_logits: + self.conv_filters = self.num_classes + self.backbone_lr_mult_list = [0.15, 0.35, 0.65, 0.85, 1] def _get_backbone(self, backbone): def mobilenetv2(backbone): @@ -167,10 +195,22 @@ class DeepLabv3p(BaseAPI): end_points=end_points, decode_points=decode_points) + def mobilenetv3(backbone): + scale = 1.0 + lr_mult_list = self.backbone_lr_mult_list + return paddlex.cv.nets.MobileNetV3( + scale=scale, + model_name='large', + output_stride=self.output_stride, + lr_mult_list=lr_mult_list, + for_seg=True) + if 'Xception' in backbone: return xception(backbone) elif 'MobileNetV2' in backbone: return mobilenetv2(backbone) + elif 'MobileNetV3' in backbone: + return mobilenetv3(backbone) def build_net(self, mode='train'): model = paddlex.cv.nets.segmentation.DeepLabv3p( @@ -186,7 +226,17 @@ class DeepLabv3p(BaseAPI): use_dice_loss=self.use_dice_loss, class_weight=self.class_weight, ignore_index=self.ignore_index, - fixed_input_shape=self.fixed_input_shape) + fixed_input_shape=self.fixed_input_shape, + pooling_stride=self.pooling_stride, + pooling_crop_size=self.pooling_crop_size, + aspp_with_se=self.aspp_with_se, + se_use_qsigmoid=self.se_use_qsigmoid, + aspp_convs_filters=self.aspp_convs_filters, + aspp_with_concat_projection=self.aspp_with_concat_projection, + add_image_level_feature=self.add_image_level_feature, + use_sum_merge=self.use_sum_merge, + conv_filters=self.conv_filters, + output_is_logits=self.output_is_logits) inputs = model.generate_inputs() model_out = model.build_net(inputs) outputs = OrderedDict() diff --git a/paddlex/cv/models/utils/pretrain_weights.py b/paddlex/cv/models/utils/pretrain_weights.py index adaabc223e39fd5c835d0c6fb75dae263a8801e2..b865cc7fdba1931b51559b1d5d42e9f56b0759a0 100644 --- a/paddlex/cv/models/utils/pretrain_weights.py +++ b/paddlex/cv/models/utils/pretrain_weights.py @@ -122,6 +122,8 @@ coco_pretrain = { } cityscapes_pretrain = { + 'DeepLabv3p_MobileNetV3_large_x1_0_ssld_CITYSCAPES': + 'https://paddleseg.bj.bcebos.com/models/deeplabv3p_mobilenetv3_large_cityscapes.tar.gz', 'DeepLabv3p_MobileNetV2_x1.0_CITYSCAPES': 'https://paddleseg.bj.bcebos.com/models/mobilenet_cityscapes.tgz', 'DeepLabv3p_Xception65_CITYSCAPES': @@ -167,7 +169,8 @@ def get_pretrain_weights(flag, class_name, backbone, save_dir): flag = 'IMAGENET' if class_name == 'DeepLabv3p' and backbone in [ 'Xception41', 'MobileNetV2_x0.25', 'MobileNetV2_x0.5', - 'MobileNetV2_x1.5', 'MobileNetV2_x2.0' + 'MobileNetV2_x1.5', 'MobileNetV2_x2.0', + 'MobileNetV3_large_x1_0_ssld' ]: model_name = '{}_{}'.format(class_name, backbone) logging.warning(warning_info.format(model_name, flag, 'IMAGENET')) diff --git a/paddlex/cv/nets/mobilenet_v3.py b/paddlex/cv/nets/mobilenet_v3.py index 6adcee03d7bb9c5ffab0ceb7198083e3534e7ab9..750692b6a5fb92537a44a485d56533e001eb2ca8 100644 --- a/paddlex/cv/nets/mobilenet_v3.py +++ b/paddlex/cv/nets/mobilenet_v3.py @@ -42,7 +42,9 @@ class MobileNetV3(): extra_block_filters=[[256, 512], [128, 256], [128, 256], [64, 128]], num_classes=None, - lr_mult_list=[1.0, 1.0, 1.0, 1.0, 1.0]): + lr_mult_list=[1.0, 1.0, 1.0, 1.0, 1.0], + for_seg=False, + output_stride=None): assert len(lr_mult_list) == 5, \ "lr_mult_list length in MobileNetV3 must be 5 but got {}!!".format( len(lr_mult_list)) @@ -57,48 +59,111 @@ class MobileNetV3(): self.num_classes = num_classes self.lr_mult_list = lr_mult_list self.curr_stage = 0 - if model_name == "large": - self.cfg = [ - # kernel_size, expand, channel, se_block, act_mode, stride - [3, 16, 16, False, 'relu', 1], - [3, 64, 24, False, 'relu', 2], - [3, 72, 24, False, 'relu', 1], - [5, 72, 40, True, 'relu', 2], - [5, 120, 40, True, 'relu', 1], - [5, 120, 40, True, 'relu', 1], - [3, 240, 80, False, 'hard_swish', 2], - [3, 200, 80, False, 'hard_swish', 1], - [3, 184, 80, False, 'hard_swish', 1], - [3, 184, 80, False, 'hard_swish', 1], - [3, 480, 112, True, 'hard_swish', 1], - [3, 672, 112, True, 'hard_swish', 1], - [5, 672, 160, True, 'hard_swish', 2], - [5, 960, 160, True, 'hard_swish', 1], - [5, 960, 160, True, 'hard_swish', 1], - ] - self.cls_ch_squeeze = 960 - self.cls_ch_expand = 1280 - self.lr_interval = 3 - elif model_name == "small": - self.cfg = [ - # kernel_size, expand, channel, se_block, act_mode, stride - [3, 16, 16, True, 'relu', 2], - [3, 72, 24, False, 'relu', 2], - [3, 88, 24, False, 'relu', 1], - [5, 96, 40, True, 'hard_swish', 2], - [5, 240, 40, True, 'hard_swish', 1], - [5, 240, 40, True, 'hard_swish', 1], - [5, 120, 48, True, 'hard_swish', 1], - [5, 144, 48, True, 'hard_swish', 1], - [5, 288, 96, True, 'hard_swish', 2], - [5, 576, 96, True, 'hard_swish', 1], - [5, 576, 96, True, 'hard_swish', 1], - ] - self.cls_ch_squeeze = 576 - self.cls_ch_expand = 1280 - self.lr_interval = 2 + self.for_seg = for_seg + self.decode_point = None + + if self.for_seg: + if model_name == "large": + self.cfg = [ + # k, exp, c, se, nl, s, + [3, 16, 16, False, 'relu', 1], + [3, 64, 24, False, 'relu', 2], + [3, 72, 24, False, 'relu', 1], + [5, 72, 40, True, 'relu', 2], + [5, 120, 40, True, 'relu', 1], + [5, 120, 40, True, 'relu', 1], + [3, 240, 80, False, 'hard_swish', 2], + [3, 200, 80, False, 'hard_swish', 1], + [3, 184, 80, False, 'hard_swish', 1], + [3, 184, 80, False, 'hard_swish', 1], + [3, 480, 112, True, 'hard_swish', 1], + [3, 672, 112, True, 'hard_swish', 1], + # The number of channels in the last 4 stages is reduced by a + # factor of 2 compared to the standard implementation. + [5, 336, 80, True, 'hard_swish', 2], + [5, 480, 80, True, 'hard_swish', 1], + [5, 480, 80, True, 'hard_swish', 1], + ] + self.cls_ch_squeeze = 480 + self.cls_ch_expand = 1280 + self.lr_interval = 3 + elif model_name == "small": + self.cfg = [ + # k, exp, c, se, nl, s, + [3, 16, 16, True, 'relu', 2], + [3, 72, 24, False, 'relu', 2], + [3, 88, 24, False, 'relu', 1], + [5, 96, 40, True, 'hard_swish', 2], + [5, 240, 40, True, 'hard_swish', 1], + [5, 240, 40, True, 'hard_swish', 1], + [5, 120, 48, True, 'hard_swish', 1], + [5, 144, 48, True, 'hard_swish', 1], + # The number of channels in the last 4 stages is reduced by a + # factor of 2 compared to the standard implementation. + [5, 144, 48, True, 'hard_swish', 2], + [5, 288, 48, True, 'hard_swish', 1], + [5, 288, 48, True, 'hard_swish', 1], + ] + else: + raise NotImplementedError else: - raise NotImplementedError + if model_name == "large": + self.cfg = [ + # kernel_size, expand, channel, se_block, act_mode, stride + [3, 16, 16, False, 'relu', 1], + [3, 64, 24, False, 'relu', 2], + [3, 72, 24, False, 'relu', 1], + [5, 72, 40, True, 'relu', 2], + [5, 120, 40, True, 'relu', 1], + [5, 120, 40, True, 'relu', 1], + [3, 240, 80, False, 'hard_swish', 2], + [3, 200, 80, False, 'hard_swish', 1], + [3, 184, 80, False, 'hard_swish', 1], + [3, 184, 80, False, 'hard_swish', 1], + [3, 480, 112, True, 'hard_swish', 1], + [3, 672, 112, True, 'hard_swish', 1], + [5, 672, 160, True, 'hard_swish', 2], + [5, 960, 160, True, 'hard_swish', 1], + [5, 960, 160, True, 'hard_swish', 1], + ] + self.cls_ch_squeeze = 960 + self.cls_ch_expand = 1280 + self.lr_interval = 3 + elif model_name == "small": + self.cfg = [ + # kernel_size, expand, channel, se_block, act_mode, stride + [3, 16, 16, True, 'relu', 2], + [3, 72, 24, False, 'relu', 2], + [3, 88, 24, False, 'relu', 1], + [5, 96, 40, True, 'hard_swish', 2], + [5, 240, 40, True, 'hard_swish', 1], + [5, 240, 40, True, 'hard_swish', 1], + [5, 120, 48, True, 'hard_swish', 1], + [5, 144, 48, True, 'hard_swish', 1], + [5, 288, 96, True, 'hard_swish', 2], + [5, 576, 96, True, 'hard_swish', 1], + [5, 576, 96, True, 'hard_swish', 1], + ] + self.cls_ch_squeeze = 576 + self.cls_ch_expand = 1280 + self.lr_interval = 2 + else: + raise NotImplementedError + + self.modify_bottle_params(output_stride) + + def modify_bottle_params(self, output_stride=None): + if output_stride is not None and output_stride % 2 != 0: + raise Exception("output stride must to be even number") + if output_stride is None: + return + else: + stride = 2 + for i, _cfg in enumerate(self.cfg): + stride = stride * _cfg[-1] + if stride > output_stride: + s = 1 + self.cfg[i][-1] = s def _conv_bn_layer(self, input, @@ -153,6 +218,14 @@ class MobileNetV3(): bn = fluid.layers.relu6(bn) return bn + def make_divisible(self, v, divisor=8, min_value=None): + if min_value is None: + min_value = divisor + new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) + if new_v < 0.9 * v: + new_v += divisor + return new_v + def _hard_swish(self, x): return x * fluid.layers.relu6(x + 3) / 6. @@ -220,6 +293,9 @@ class MobileNetV3(): use_cudnn=False, name=name + '_depthwise') + if self.curr_stage == 5: + self.decode_point = conv1 + if use_se: conv1 = self._se_block( input=conv1, num_out_filter=num_mid_filter, name=name + '_se') @@ -282,7 +358,7 @@ class MobileNetV3(): conv = self._conv_bn_layer( input, filter_size=3, - num_filters=inplanes if scale <= 1.0 else int(inplanes * scale), + num_filters=self.make_divisible(inplanes * scale), stride=2, padding=1, num_groups=1, @@ -290,6 +366,7 @@ class MobileNetV3(): act='hard_swish', name='conv1') i = 0 + inplanes = self.make_divisible(inplanes * scale) for layer_cfg in cfg: self.block_stride *= layer_cfg[5] if layer_cfg[5] == 2: @@ -297,19 +374,32 @@ class MobileNetV3(): conv = self._residual_unit( input=conv, num_in_filter=inplanes, - num_mid_filter=int(scale * layer_cfg[1]), - num_out_filter=int(scale * layer_cfg[2]), + num_mid_filter=self.make_divisible(scale * layer_cfg[1]), + num_out_filter=self.make_divisible(scale * layer_cfg[2]), act=layer_cfg[4], stride=layer_cfg[5], filter_size=layer_cfg[0], use_se=layer_cfg[3], name='conv' + str(i + 2)) - - inplanes = int(scale * layer_cfg[2]) + inplanes = self.make_divisible(scale * layer_cfg[2]) i += 1 self.curr_stage = i blocks.append(conv) + if self.for_seg: + conv = self._conv_bn_layer( + input=conv, + filter_size=1, + num_filters=self.make_divisible(scale * self.cls_ch_squeeze), + stride=1, + padding=0, + num_groups=1, + if_act=True, + act='hard_swish', + name='conv_last') + + return conv, self.decode_point + if self.num_classes: conv = self._conv_bn_layer( input=conv, diff --git a/paddlex/cv/nets/segmentation/deeplabv3p.py b/paddlex/cv/nets/segmentation/deeplabv3p.py index c568a8cd9c44985f1d9defbfddd7db39f298ec68..7d597a606a88a78513452c37357b806c4dfa156f 100644 --- a/paddlex/cv/nets/segmentation/deeplabv3p.py +++ b/paddlex/cv/nets/segmentation/deeplabv3p.py @@ -21,7 +21,7 @@ from collections import OrderedDict import paddle.fluid as fluid from .model_utils.libs import scope, name_scope -from .model_utils.libs import bn, bn_relu, relu +from .model_utils.libs import bn, bn_relu, relu, qsigmoid from .model_utils.libs import conv, max_pool, deconv from .model_utils.libs import separate_conv from .model_utils.libs import sigmoid_to_softmax @@ -82,7 +82,17 @@ class DeepLabv3p(object): use_dice_loss=False, class_weight=None, ignore_index=255, - fixed_input_shape=None): + fixed_input_shape=None, + pooling_stride=[1, 1], + pooling_crop_size=None, + aspp_with_se=False, + se_use_qsigmoid=False, + aspp_convs_filters=256, + aspp_with_concat_projection=True, + add_image_level_feature=True, + use_sum_merge=False, + conv_filters=256, + output_is_logits=False): # dice_loss或bce_loss只适用两类分割中 if num_classes > 2 and (use_bce_loss or use_dice_loss): raise ValueError( @@ -117,6 +127,17 @@ class DeepLabv3p(object): self.encoder_with_aspp = encoder_with_aspp self.enable_decoder = enable_decoder self.fixed_input_shape = fixed_input_shape + self.output_is_logits = output_is_logits + self.aspp_convs_filters = aspp_convs_filters + self.output_stride = output_stride + self.pooling_crop_size = pooling_crop_size + self.pooling_stride = pooling_stride + self.se_use_qsigmoid = se_use_qsigmoid + self.aspp_with_concat_projection = aspp_with_concat_projection + self.add_image_level_feature = add_image_level_feature + self.aspp_with_se = aspp_with_se + self.use_sum_merge = use_sum_merge + self.conv_filters = conv_filters def _encoder(self, input): # 编码器配置,采用ASPP架构,pooling + 1x1_conv + 三个不同尺度的空洞卷积并行, concat后1x1conv @@ -129,19 +150,36 @@ class DeepLabv3p(object): elif self.output_stride == 8: aspp_ratios = [12, 24, 36] else: - raise Exception("DeepLabv3p only support stride 8 or 16") + aspp_ratios = [] param_attr = fluid.ParamAttr( name=name_scope + 'weights', regularizer=None, initializer=fluid.initializer.TruncatedNormal( loc=0.0, scale=0.06)) + + concat_logits = [] with scope('encoder'): - channel = 256 + channel = self.aspp_convs_filters with scope("image_pool"): - image_avg = fluid.layers.reduce_mean( - input, [2, 3], keep_dim=True) - image_avg = bn_relu( + if self.pooling_crop_size is None: + image_avg = fluid.layers.reduce_mean( + input, [2, 3], keep_dim=True) + else: + pool_w = int((self.pooling_crop_size[0] - 1.0) / + self.output_stride + 1.0) + pool_h = int((self.pooling_crop_size[1] - 1.0) / + self.output_stride + 1.0) + image_avg = fluid.layers.pool2d( + input, + pool_size=(pool_h, pool_w), + pool_stride=self.pooling_stride, + pool_type='avg', + pool_padding='VALID') + + act = qsigmoid if self.se_use_qsigmoid else bn_relu + + image_avg = act( conv( image_avg, channel, @@ -153,6 +191,8 @@ class DeepLabv3p(object): input_shape = fluid.layers.shape(input) image_avg = fluid.layers.resize_bilinear(image_avg, input_shape[2:]) + if self.add_image_level_feature: + concat_logits.append(image_avg) with scope("aspp0"): aspp0 = bn_relu( @@ -164,77 +204,160 @@ class DeepLabv3p(object): groups=1, padding=0, param_attr=param_attr)) - with scope("aspp1"): - if self.aspp_with_sep_conv: - aspp1 = separate_conv( - input, - channel, - 1, - 3, - dilation=aspp_ratios[0], - act=relu) - else: - aspp1 = bn_relu( - conv( + concat_logits.append(aspp0) + + if aspp_ratios: + with scope("aspp1"): + if self.aspp_with_sep_conv: + aspp1 = separate_conv( input, channel, - stride=1, - filter_size=3, + 1, + 3, dilation=aspp_ratios[0], - padding=aspp_ratios[0], - param_attr=param_attr)) - with scope("aspp2"): - if self.aspp_with_sep_conv: - aspp2 = separate_conv( - input, - channel, - 1, - 3, - dilation=aspp_ratios[1], - act=relu) - else: - aspp2 = bn_relu( - conv( + act=relu) + else: + aspp1 = bn_relu( + conv( + input, + channel, + stride=1, + filter_size=3, + dilation=aspp_ratios[0], + padding=aspp_ratios[0], + param_attr=param_attr)) + concat_logits.append(aspp1) + with scope("aspp2"): + if self.aspp_with_sep_conv: + aspp2 = separate_conv( input, channel, - stride=1, - filter_size=3, + 1, + 3, dilation=aspp_ratios[1], - padding=aspp_ratios[1], - param_attr=param_attr)) - with scope("aspp3"): - if self.aspp_with_sep_conv: - aspp3 = separate_conv( - input, - channel, - 1, - 3, - dilation=aspp_ratios[2], - act=relu) - else: - aspp3 = bn_relu( - conv( + act=relu) + else: + aspp2 = bn_relu( + conv( + input, + channel, + stride=1, + filter_size=3, + dilation=aspp_ratios[1], + padding=aspp_ratios[1], + param_attr=param_attr)) + concat_logits.append(aspp2) + with scope("aspp3"): + if self.aspp_with_sep_conv: + aspp3 = separate_conv( input, channel, - stride=1, - filter_size=3, + 1, + 3, dilation=aspp_ratios[2], - padding=aspp_ratios[2], - param_attr=param_attr)) + act=relu) + else: + aspp3 = bn_relu( + conv( + input, + channel, + stride=1, + filter_size=3, + dilation=aspp_ratios[2], + padding=aspp_ratios[2], + param_attr=param_attr)) + concat_logits.append(aspp3) + with scope("concat"): - data = fluid.layers.concat( - [image_avg, aspp0, aspp1, aspp2, aspp3], axis=1) - data = bn_relu( + data = fluid.layers.concat(concat_logits, axis=1) + if self.aspp_with_concat_projection: + data = bn_relu( + conv( + data, + channel, + 1, + 1, + groups=1, + padding=0, + param_attr=param_attr)) + data = fluid.layers.dropout(data, 0.9) + if self.aspp_with_se: + data = data * image_avg + return data + + def _decoder_with_sum_merge(self, encode_data, decode_shortcut, + param_attr): + decode_shortcut_shape = fluid.layers.shape(decode_shortcut) + encode_data = fluid.layers.resize_bilinear(encode_data, + decode_shortcut_shape[2:]) + + encode_data = conv( + encode_data, + self.conv_filters, + 1, + 1, + groups=1, + padding=0, + param_attr=param_attr) + + with scope('merge'): + decode_shortcut = conv( + decode_shortcut, + self.conv_filters, + 1, + 1, + groups=1, + padding=0, + param_attr=param_attr) + + return encode_data + decode_shortcut + + def _decoder_with_concat(self, encode_data, decode_shortcut, param_attr): + with scope('concat'): + decode_shortcut = bn_relu( + conv( + decode_shortcut, + 48, + 1, + 1, + groups=1, + padding=0, + param_attr=param_attr)) + + decode_shortcut_shape = fluid.layers.shape(decode_shortcut) + encode_data = fluid.layers.resize_bilinear( + encode_data, decode_shortcut_shape[2:]) + encode_data = fluid.layers.concat( + [encode_data, decode_shortcut], axis=1) + if self.decoder_use_sep_conv: + with scope("separable_conv1"): + encode_data = separate_conv( + encode_data, self.conv_filters, 1, 3, dilation=1, act=relu) + with scope("separable_conv2"): + encode_data = separate_conv( + encode_data, self.conv_filters, 1, 3, dilation=1, act=relu) + else: + with scope("decoder_conv1"): + encode_data = bn_relu( conv( - data, - channel, - 1, - 1, - groups=1, - padding=0, + encode_data, + self.conv_filters, + stride=1, + filter_size=3, + dilation=1, + padding=1, param_attr=param_attr)) - data = fluid.layers.dropout(data, 0.9) - return data + with scope("decoder_conv2"): + encode_data = bn_relu( + conv( + encode_data, + self.conv_filters, + stride=1, + filter_size=3, + dilation=1, + padding=1, + param_attr=param_attr)) + return encode_data def _decoder(self, encode_data, decode_shortcut): # 解码器配置 @@ -246,52 +369,14 @@ class DeepLabv3p(object): regularizer=None, initializer=fluid.initializer.TruncatedNormal( loc=0.0, scale=0.06)) + with scope('decoder'): - with scope('concat'): - decode_shortcut = bn_relu( - conv( - decode_shortcut, - 48, - 1, - 1, - groups=1, - padding=0, - param_attr=param_attr)) + if self.use_sum_merge: + return self._decoder_with_sum_merge( + encode_data, decode_shortcut, param_attr) - decode_shortcut_shape = fluid.layers.shape(decode_shortcut) - encode_data = fluid.layers.resize_bilinear( - encode_data, decode_shortcut_shape[2:]) - encode_data = fluid.layers.concat( - [encode_data, decode_shortcut], axis=1) - if self.decoder_use_sep_conv: - with scope("separable_conv1"): - encode_data = separate_conv( - encode_data, 256, 1, 3, dilation=1, act=relu) - with scope("separable_conv2"): - encode_data = separate_conv( - encode_data, 256, 1, 3, dilation=1, act=relu) - else: - with scope("decoder_conv1"): - encode_data = bn_relu( - conv( - encode_data, - 256, - stride=1, - filter_size=3, - dilation=1, - padding=1, - param_attr=param_attr)) - with scope("decoder_conv2"): - encode_data = bn_relu( - conv( - encode_data, - 256, - stride=1, - filter_size=3, - dilation=1, - padding=1, - param_attr=param_attr)) - return encode_data + return self._decoder_with_concat(encode_data, decode_shortcut, + param_attr) def _get_loss(self, logit, label, mask): avg_loss = 0 @@ -335,8 +420,11 @@ class DeepLabv3p(object): self.num_classes = 1 image = inputs['image'] - data, decode_shortcuts = self.backbone(image) - decode_shortcut = decode_shortcuts[self.backbone.decode_points] + if 'MobileNetV3' in self.backbone.__class__.__name__: + data, decode_shortcut = self.backbone(image) + else: + data, decode_shortcuts = self.backbone(image) + decode_shortcut = decode_shortcuts[self.backbone.decode_points] # 编码器解码器设置 if self.encoder_with_aspp: @@ -351,18 +439,22 @@ class DeepLabv3p(object): regularization_coeff=0.0), initializer=fluid.initializer.TruncatedNormal( loc=0.0, scale=0.01)) - with scope('logit'): - with fluid.name_scope('last_conv'): - logit = conv( - data, - self.num_classes, - 1, - stride=1, - padding=0, - bias_attr=True, - param_attr=param_attr) - image_shape = fluid.layers.shape(image) - logit = fluid.layers.resize_bilinear(logit, image_shape[2:]) + if not self.output_is_logits: + with scope('logit'): + with fluid.name_scope('last_conv'): + logit = conv( + data, + self.num_classes, + 1, + stride=1, + padding=0, + bias_attr=True, + param_attr=param_attr) + else: + logit = data + + image_shape = fluid.layers.shape(image) + logit = fluid.layers.resize_bilinear(logit, image_shape[2:]) if self.num_classes == 1: out = sigmoid_to_softmax(logit) diff --git a/paddlex/cv/nets/segmentation/model_utils/libs.py b/paddlex/cv/nets/segmentation/model_utils/libs.py index a0eb9c639d5cb79f7962c6b7376d51be3bd57f8b..68ddd35beff56697135d4a8b3ffb1862426ca07d 100644 --- a/paddlex/cv/nets/segmentation/model_utils/libs.py +++ b/paddlex/cv/nets/segmentation/model_utils/libs.py @@ -112,6 +112,10 @@ def bn_relu(data, norm_type='bn', eps=1e-5): return fluid.layers.relu(bn(data, norm_type=norm_type, eps=eps)) +def qsigmoid(data): + return fluid.layers.relu6(data + 3) * 0.16667 + + def relu(data): return fluid.layers.relu(data) diff --git a/tutorials/train/README.md b/tutorials/train/README.md index 637be22374a591e6abe90c5627fe55fc509574f2..e480378a19c75cb622d149ca1da89be7d85baf84 100644 --- a/tutorials/train/README.md +++ b/tutorials/train/README.md @@ -20,6 +20,7 @@ |instance_segmentation/mask_rcnn_r18_fpn.py | 实例分割MaskRCNN | 小度熊分拣 | |instance_segmentation/mask_rcnn_f50_fpn.py | 实例分割MaskRCNN | 小度熊分拣 | |semantic_segmentation/deeplabv3p_mobilenetv2.py | 语义分割DeepLabV3 | 视盘分割 | +|semantic_segmentation/deeplabv3p_mobilenetv2.py | 语义分割DeepLabV3 | 视盘分割 | |semantic_segmentation/deeplabv3p_mobilenetv2_x0.25.py | 语义分割DeepLabV3 | 视盘分割 | |semantic_segmentation/deeplabv3p_xception65.py | 语义分割DeepLabV3 | 视盘分割 | |semantic_segmentation/fast_scnn.py | 语义分割FastSCNN | 视盘分割 | diff --git a/tutorials/train/semantic_segmentation/deeplabv3p_mobilenetv3_large_ssld.py b/tutorials/train/semantic_segmentation/deeplabv3p_mobilenetv3_large_ssld.py new file mode 100644 index 0000000000000000000000000000000000000000..9be782cde1394115feea973eea483b4bc2b24ea0 --- /dev/null +++ b/tutorials/train/semantic_segmentation/deeplabv3p_mobilenetv3_large_ssld.py @@ -0,0 +1,58 @@ +# 环境变量配置,用于控制是否使用GPU +# 说明文档:https://paddlex.readthedocs.io/zh_CN/develop/appendix/parameters.html#gpu +import os +os.environ['CUDA_VISIBLE_DEVICES'] = '0' + +import paddlex as pdx +from paddlex.seg import transforms + +# 下载和解压视盘分割数据集 +optic_dataset = 'https://bj.bcebos.com/paddlex/datasets/optic_disc_seg.tar.gz' +pdx.utils.download_and_decompress(optic_dataset, path='./') + +# 定义训练和验证时的transforms +# API说明 https://paddlex.readthedocs.io/zh_CN/develop/apis/transforms/seg_transforms.html +train_transforms = transforms.Compose([ + transforms.RandomHorizontalFlip(), transforms.ResizeRangeScaling(), + transforms.RandomPaddingCrop(crop_size=512), transforms.Normalize() +]) + +eval_transforms = transforms.Compose([ + transforms.ResizeByLong(long_size=512), + transforms.Padding(target_size=512), transforms.Normalize() +]) + +# 定义训练和验证所用的数据集 +# API说明:https://paddlex.readthedocs.io/zh_CN/develop/apis/datasets.html#paddlex-datasets-segdataset +train_dataset = pdx.datasets.SegDataset( + data_dir='optic_disc_seg', + file_list='optic_disc_seg/train_list.txt', + label_list='optic_disc_seg/labels.txt', + transforms=train_transforms, + shuffle=True) +eval_dataset = pdx.datasets.SegDataset( + data_dir='optic_disc_seg', + file_list='optic_disc_seg/val_list.txt', + label_list='optic_disc_seg/labels.txt', + transforms=eval_transforms) + +# 初始化模型,并进行训练 +# 可使用VisualDL查看训练指标,参考https://paddlex.readthedocs.io/zh_CN/develop/train/visualdl.html +num_classes = len(train_dataset.labels) + +# API说明:https://paddlex.readthedocs.io/zh_CN/develop/apis/models/semantic_segmentation.html#paddlex-seg-deeplabv3p +model = pdx.seg.DeepLabv3p( + num_classes=num_classes, + backbone='MobileNetV3_large_x1_0_ssld', + pooling_crop_size=(512, 512)) + +# API说明:https://paddlex.readthedocs.io/zh_CN/develop/apis/models/semantic_segmentation.html#train +# 各参数介绍与调整说明:https://paddlex.readthedocs.io/zh_CN/develop/appendix/parameters.html +model.train( + num_epochs=40, + train_dataset=train_dataset, + train_batch_size=4, + eval_dataset=eval_dataset, + learning_rate=0.01, + save_dir='output/deeplabv3p_mobilenetv3_large_ssld', + use_vdl=True)