diff --git a/dygraph/__init__.py b/dygraph/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9c52eaa7ed1463ca40036bd959610b0d1fd80fea --- /dev/null +++ b/dygraph/__init__.py @@ -0,0 +1,15 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import dygraph.models \ No newline at end of file diff --git a/dygraph/models/__init__.py b/dygraph/models/__init__.py index 62af2a91fa22edd69ed9aeae9de33958bd810959..750e77ac3209f2ea6fefbcaeae0ae0cf6426cd94 100644 --- a/dygraph/models/__init__.py +++ b/dygraph/models/__init__.py @@ -12,36 +12,39 @@ # See the License for the specific language governing permissions and # limitations under the License. +from .architectures import * from .unet import UNet from .hrnet import * from .deeplab import * -MODELS = { - "UNet": UNet, - "HRNet_W18_Small_V1": HRNet_W18_Small_V1, - "HRNet_W18_Small_V2": HRNet_W18_Small_V2, - "HRNet_W18": HRNet_W18, - "HRNet_W30": HRNet_W30, - "HRNet_W32": HRNet_W32, - "HRNet_W40": HRNet_W40, - "HRNet_W44": HRNet_W44, - "HRNet_W48": HRNet_W48, - "HRNet_W60": HRNet_W48, - "HRNet_W64": HRNet_W64, - "SE_HRNet_W18_Small_V1": SE_HRNet_W18_Small_V1, - "SE_HRNet_W18_Small_V2": SE_HRNet_W18_Small_V2, - "SE_HRNet_W18": SE_HRNet_W18, - "SE_HRNet_W30": SE_HRNet_W30, - "SE_HRNet_W32": SE_HRNet_W30, - "SE_HRNet_W40": SE_HRNet_W40, - "SE_HRNet_W44": SE_HRNet_W44, - "SE_HRNet_W48": SE_HRNet_W48, - "SE_HRNet_W60": SE_HRNet_W60, - "SE_HRNet_W64": SE_HRNet_W64, - "DeepLabV3P": DeepLabV3P, - "deeplabv3p_resnet101_vd": deeplabv3p_resnet101_vd, - "deeplabv3p_resnet101_vd_os8": deeplabv3p_resnet101_vd_os8, - "deeplabv3p_resnet50_vd": deeplabv3p_resnet50_vd, - "deeplabv3p_resnet50_vd_os8": deeplabv3p_resnet50_vd_os8, - "deeplabv3p_xception65_deeplab": deeplabv3p_xception65_deeplab -} +# MODELS = { +# "UNet": UNet, +# "HRNet_W18_Small_V1": HRNet_W18_Small_V1, +# "HRNet_W18_Small_V2": HRNet_W18_Small_V2, +# "HRNet_W18": HRNet_W18, +# "HRNet_W30": HRNet_W30, +# "HRNet_W32": HRNet_W32, +# "HRNet_W40": HRNet_W40, +# "HRNet_W44": HRNet_W44, +# "HRNet_W48": HRNet_W48, +# "HRNet_W60": HRNet_W48, +# "HRNet_W64": HRNet_W64, +# "SE_HRNet_W18_Small_V1": SE_HRNet_W18_Small_V1, +# "SE_HRNet_W18_Small_V2": SE_HRNet_W18_Small_V2, +# "SE_HRNet_W18": SE_HRNet_W18, +# "SE_HRNet_W30": SE_HRNet_W30, +# "SE_HRNet_W32": SE_HRNet_W30, +# "SE_HRNet_W40": SE_HRNet_W40, +# "SE_HRNet_W44": SE_HRNet_W44, +# "SE_HRNet_W48": SE_HRNet_W48, +# "SE_HRNet_W60": SE_HRNet_W60, +# "SE_HRNet_W64": SE_HRNet_W64, +# "DeepLabV3P": DeepLabV3P, +# "deeplabv3p_resnet101_vd": deeplabv3p_resnet101_vd, +# "deeplabv3p_resnet101_vd_os8": deeplabv3p_resnet101_vd_os8, +# "deeplabv3p_resnet50_vd": deeplabv3p_resnet50_vd, +# "deeplabv3p_resnet50_vd_os8": deeplabv3p_resnet50_vd_os8, +# "deeplabv3p_xception65_deeplab": deeplabv3p_xception65_deeplab, +# "deeplabv3p_mobilenetv3_large": deeplabv3p_mobilenetv3_large, +# "deeplabv3p_mobilenetv3_small": deeplabv3p_mobilenetv3_small +# } diff --git a/dygraph/models/architectures/mobilenetv3.py b/dygraph/models/architectures/mobilenetv3.py new file mode 100644 index 0000000000000000000000000000000000000000..91aa0563ebbca62284f399ffa37100bcca08042c --- /dev/null +++ b/dygraph/models/architectures/mobilenetv3.py @@ -0,0 +1,421 @@ +# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +import paddle +import paddle.fluid as fluid +from paddle.fluid.param_attr import ParamAttr +from paddle.fluid.layer_helper import LayerHelper +from paddle.fluid.dygraph.nn import Conv2D, Pool2D, BatchNorm, Linear, Dropout + +import math + +from dygraph.cvlibs import manager + +__all__ = [ + "MobileNetV3_small_x0_35", "MobileNetV3_small_x0_5", + "MobileNetV3_small_x0_75", "MobileNetV3_small_x1_0", + "MobileNetV3_small_x1_25", "MobileNetV3_large_x0_35", + "MobileNetV3_large_x0_5", "MobileNetV3_large_x0_75", + "MobileNetV3_large_x1_0", "MobileNetV3_large_x1_25" +] + + +def make_divisible(v, divisor=8, min_value=None): + if min_value is None: + min_value = divisor + new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) + if new_v < 0.9 * v: + new_v += divisor + return new_v + +def get_padding_same(kernel_size, dilation_rate): + """ + SAME padding implementation given kernel_size and dilation_rate. + The calculation formula as following: + (F-(k+(k -1)*(r-1))+2*p)/s + 1 = F_new + where F: a feature map + k: kernel size, r: dilation rate, p: padding value, s: stride + F_new: new feature map + Args: + kernel_size (int) + dilation_rate (int) + + Returns: + padding_same (int): padding value + """ + k = kernel_size + r = dilation_rate + padding_same = (k + (k - 1) * (r - 1) - 1)//2 + + return padding_same + +class MobileNetV3(fluid.dygraph.Layer): + def __init__(self, scale=1.0, model_name="small", class_dim=1000, output_stride=None, **kwargs): + super(MobileNetV3, self).__init__() + + inplanes = 16 + if model_name == "large": + self.cfg = [ + # k, exp, c, se, nl, s, + [3, 16, 16, False, "relu", 1], + [3, 64, 24, False, "relu", 2], + [3, 72, 24, False, "relu", 1], # output 1 -> out_index=2 + [5, 72, 40, True, "relu", 2], + [5, 120, 40, True, "relu", 1], + [5, 120, 40, True, "relu", 1], # output 2 -> out_index=5 + [3, 240, 80, False, "hard_swish", 2], + [3, 200, 80, False, "hard_swish", 1], + [3, 184, 80, False, "hard_swish", 1], + [3, 184, 80, False, "hard_swish", 1], + [3, 480, 112, True, "hard_swish", 1], + [3, 672, 112, True, "hard_swish", 1], # output 3 -> out_index=11 + [5, 672, 160, True, "hard_swish", 2], + [5, 960, 160, True, "hard_swish", 1], + [5, 960, 160, True, "hard_swish", 1], # output 3 -> out_index=14 + ] + self.out_indices = [2, 5, 11, 14] + + self.cls_ch_squeeze = 960 + self.cls_ch_expand = 1280 + elif model_name == "small": + self.cfg = [ + # k, exp, c, se, nl, s, + [3, 16, 16, True, "relu", 2], # output 1 -> out_index=0 + [3, 72, 24, False, "relu", 2], + [3, 88, 24, False, "relu", 1], # output 2 -> out_index=3 + [5, 96, 40, True, "hard_swish", 2], + [5, 240, 40, True, "hard_swish", 1], + [5, 240, 40, True, "hard_swish", 1], + [5, 120, 48, True, "hard_swish", 1], + [5, 144, 48, True, "hard_swish", 1], # output 3 -> out_index=7 + [5, 288, 96, True, "hard_swish", 2], + [5, 576, 96, True, "hard_swish", 1], + [5, 576, 96, True, "hard_swish", 1], # output 4 -> out_index=10 + ] + self.out_indices = [0, 3, 7, 10] + + self.cls_ch_squeeze = 576 + self.cls_ch_expand = 1280 + else: + raise NotImplementedError( + "mode[{}_model] is not implemented!".format(model_name)) + + ################################################### + # modify stride and dilation based on output_stride + self.dilation_cfg = [1] * len(self.cfg) + self.modify_bottle_params(output_stride=output_stride) + ################################################### + + self.conv1 = ConvBNLayer( + in_c=3, + out_c=make_divisible(inplanes * scale), + filter_size=3, + stride=2, + padding=1, + num_groups=1, + if_act=True, + act="hard_swish", + name="conv1") + + self.block_list = [] + + inplanes = make_divisible(inplanes * scale) + for i, (k, exp, c, se, nl, s) in enumerate(self.cfg): + ###################################### + # add dilation rate + dilation_rate = self.dilation_cfg[i] + ###################################### + self.block_list.append( + ResidualUnit( + in_c=inplanes, + mid_c=make_divisible(scale * exp), + out_c=make_divisible(scale * c), + filter_size=k, + stride=s, + dilation=dilation_rate, + use_se=se, + act=nl, + name="conv" + str(i + 2))) + self.add_sublayer( + sublayer=self.block_list[-1], name="conv" + str(i + 2)) + inplanes = make_divisible(scale * c) + + + self.last_second_conv = ConvBNLayer( + in_c=inplanes, + out_c=make_divisible(scale * self.cls_ch_squeeze), + filter_size=1, + stride=1, + padding=0, + num_groups=1, + if_act=True, + act="hard_swish", + name="conv_last") + + self.pool = Pool2D( + pool_type="avg", global_pooling=True, use_cudnn=False) + + self.last_conv = Conv2D( + num_channels=make_divisible(scale * self.cls_ch_squeeze), + num_filters=self.cls_ch_expand, + filter_size=1, + stride=1, + padding=0, + act=None, + param_attr=ParamAttr(name="last_1x1_conv_weights"), + bias_attr=False) + + self.out = Linear( + input_dim=self.cls_ch_expand, + output_dim=class_dim, + param_attr=ParamAttr("fc_weights"), + bias_attr=ParamAttr(name="fc_offset")) + + def modify_bottle_params(self, output_stride=None): + + if output_stride is not None and output_stride % 2 != 0: + raise Exception("output stride must to be even number") + if output_stride is not None: + stride = 2 + rate = 1 + for i, _cfg in enumerate(self.cfg): + stride = stride * _cfg[-1] + if stride > output_stride: + rate = rate * _cfg[-1] + self.cfg[i][-1] = 1 + + self.dilation_cfg[i] = rate + + def forward(self, inputs, label=None, dropout_prob=0.2): + x = self.conv1(inputs) + # A feature list saves each downsampling feature. + feat_list = [] + for i, block in enumerate(self.block_list): + x = block(x) + if i in self.out_indices: + feat_list.append(x) + #print("block {}:".format(i),x.shape, self.dilation_cfg[i]) + x = self.last_second_conv(x) + x = self.pool(x) + x = self.last_conv(x) + x = fluid.layers.hard_swish(x) + x = fluid.layers.dropout(x=x, dropout_prob=dropout_prob) + x = fluid.layers.reshape(x, shape=[x.shape[0], x.shape[1]]) + x = self.out(x) + + return x, feat_list + + +class ConvBNLayer(fluid.dygraph.Layer): + def __init__(self, + in_c, + out_c, + filter_size, + stride, + padding, + dilation=1, + num_groups=1, + if_act=True, + act=None, + use_cudnn=True, + name=""): + super(ConvBNLayer, self).__init__() + self.if_act = if_act + self.act = act + + self.conv = fluid.dygraph.Conv2D( + num_channels=in_c, + num_filters=out_c, + filter_size=filter_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=num_groups, + param_attr=ParamAttr(name=name + "_weights"), + bias_attr=False, + use_cudnn=use_cudnn, + act=None) + self.bn = fluid.dygraph.BatchNorm( + num_channels=out_c, + act=None, + param_attr=ParamAttr( + name=name + "_bn_scale", + regularizer=fluid.regularizer.L2DecayRegularizer( + regularization_coeff=0.0)), + bias_attr=ParamAttr( + name=name + "_bn_offset", + regularizer=fluid.regularizer.L2DecayRegularizer( + regularization_coeff=0.0)), + moving_mean_name=name + "_bn_mean", + moving_variance_name=name + "_bn_variance") + + def forward(self, x): + x = self.conv(x) + x = self.bn(x) + if self.if_act: + if self.act == "relu": + x = fluid.layers.relu(x) + elif self.act == "hard_swish": + x = fluid.layers.hard_swish(x) + else: + print("The activation function is selected incorrectly.") + exit() + return x + + +class ResidualUnit(fluid.dygraph.Layer): + def __init__(self, + in_c, + mid_c, + out_c, + filter_size, + stride, + use_se, + dilation=1, + act=None, + name=''): + super(ResidualUnit, self).__init__() + self.if_shortcut = stride == 1 and in_c == out_c + self.if_se = use_se + + self.expand_conv = ConvBNLayer( + in_c=in_c, + out_c=mid_c, + filter_size=1, + stride=1, + padding=0, + if_act=True, + act=act, + name=name + "_expand") + + + self.bottleneck_conv = ConvBNLayer( + in_c=mid_c, + out_c=mid_c, + filter_size=filter_size, + stride=stride, + padding= get_padding_same(filter_size, dilation), #int((filter_size - 1) // 2) + (dilation - 1), + dilation=dilation, + num_groups=mid_c, + if_act=True, + act=act, + name=name + "_depthwise") + if self.if_se: + self.mid_se = SEModule(mid_c, name=name + "_se") + self.linear_conv = ConvBNLayer( + in_c=mid_c, + out_c=out_c, + filter_size=1, + stride=1, + padding=0, + if_act=False, + act=None, + name=name + "_linear") + self.dilation = dilation + def forward(self, inputs): + x = self.expand_conv(inputs) + x = self.bottleneck_conv(x) + if self.if_se: + x = self.mid_se(x) + x = self.linear_conv(x) + if self.if_shortcut: + x = fluid.layers.elementwise_add(inputs, x) + return x + + +class SEModule(fluid.dygraph.Layer): + def __init__(self, channel, reduction=4, name=""): + super(SEModule, self).__init__() + self.avg_pool = fluid.dygraph.Pool2D( + pool_type="avg", global_pooling=True, use_cudnn=False) + self.conv1 = fluid.dygraph.Conv2D( + num_channels=channel, + num_filters=channel // reduction, + filter_size=1, + stride=1, + padding=0, + act="relu", + param_attr=ParamAttr(name=name + "_1_weights"), + bias_attr=ParamAttr(name=name + "_1_offset")) + self.conv2 = fluid.dygraph.Conv2D( + num_channels=channel // reduction, + num_filters=channel, + filter_size=1, + stride=1, + padding=0, + act=None, + param_attr=ParamAttr(name + "_2_weights"), + bias_attr=ParamAttr(name=name + "_2_offset")) + + def forward(self, inputs): + outputs = self.avg_pool(inputs) + outputs = self.conv1(outputs) + outputs = self.conv2(outputs) + outputs = fluid.layers.hard_sigmoid(outputs) + return fluid.layers.elementwise_mul(x=inputs, y=outputs, axis=0) + + +def MobileNetV3_small_x0_35(**kwargs): + model = MobileNetV3(model_name="small", scale=0.35, **kwargs) + return model + + +def MobileNetV3_small_x0_5(**kwargs): + model = MobileNetV3(model_name="small", scale=0.5, **kwargs) + return model + + +def MobileNetV3_small_x0_75(**kwargs): + model = MobileNetV3(model_name="small", scale=0.75, **kwargs) + return model + +@manager.BACKBONES.add_component +def MobileNetV3_small_x1_0(**kwargs): + model = MobileNetV3(model_name="small", scale=1.0, **kwargs) + return model + + +def MobileNetV3_small_x1_25(**kwargs): + model = MobileNetV3(model_name="small", scale=1.25, **kwargs) + return model + + +def MobileNetV3_large_x0_35(**kwargs): + model = MobileNetV3(model_name="large", scale=0.35, **kwargs) + return model + + +def MobileNetV3_large_x0_5(**kwargs): + model = MobileNetV3(model_name="large", scale=0.5, **kwargs) + return model + + +def MobileNetV3_large_x0_75(**kwargs): + model = MobileNetV3(model_name="large", scale=0.75, **kwargs) + return model + +@manager.BACKBONES.add_component +def MobileNetV3_large_x1_0(**kwargs): + model = MobileNetV3(model_name="large", scale=1.0, **kwargs) + return model + + +def MobileNetV3_large_x1_25(**kwargs): + model = MobileNetV3(model_name="large", scale=1.25, **kwargs) + return model diff --git a/dygraph/models/architectures/resnet_vd.py b/dygraph/models/architectures/resnet_vd.py index 6fb1371356c64624d9eb72c40f1fdde0457a0804..b08dcd90c97605ac22a376342d801c8ddc4f378f 100644 --- a/dygraph/models/architectures/resnet_vd.py +++ b/dygraph/models/architectures/resnet_vd.py @@ -28,6 +28,8 @@ from paddle.fluid.dygraph.nn import Conv2D, Pool2D, BatchNorm, Linear, Dropout from dygraph.utils import utils +from dygraph.cvlibs import manager + __all__ = [ "ResNet18_vd", "ResNet34_vd", "ResNet50_vd", "ResNet101_vd", "ResNet152_vd" ] @@ -199,9 +201,9 @@ class BasicBlock(fluid.dygraph.Layer): class ResNet_vd(fluid.dygraph.Layer): - def __init__(self, layers=50, class_dim=1000, dilation_dict=None, multi_grid=(1, 2, 4), **kwargs): + def __init__(self, layers=50, class_dim=1000, output_stride=None, multi_grid=(1, 2, 4), **kwargs): super(ResNet_vd, self).__init__() - + self.layers = layers supported_layers = [18, 34, 50, 101, 152, 200] assert layers in supported_layers, \ @@ -222,6 +224,12 @@ class ResNet_vd(fluid.dygraph.Layer): 1024] if layers >= 50 else [64, 64, 128, 256] num_filters = [64, 128, 256, 512] + dilation_dict=None + if output_stride == 8: + dilation_dict = {2: 2, 3: 4} + elif output_stride == 16: + dilation_dict = {3: 2} + self.conv1_1 = ConvBNLayer( num_channels=3, num_filters=32, @@ -359,12 +367,12 @@ def ResNet34_vd(**args): model = ResNet_vd(layers=34, **args) return model - +@manager.BACKBONES.add_component def ResNet50_vd(**args): model = ResNet_vd(layers=50, **args) return model - +@manager.BACKBONES.add_component def ResNet101_vd(**args): model = ResNet_vd(layers=101, **args) return model diff --git a/dygraph/models/architectures/xception_deeplab.py b/dygraph/models/architectures/xception_deeplab.py index 57285a4b1f346aa24cc0c2e3f80f1f1a1e049f2f..1cb0f2a9e9cdb2f2a2406bc36dec8a0ee06ed395 100644 --- a/dygraph/models/architectures/xception_deeplab.py +++ b/dygraph/models/architectures/xception_deeplab.py @@ -4,6 +4,8 @@ from paddle.fluid.param_attr import ParamAttr from paddle.fluid.layer_helper import LayerHelper from paddle.fluid.dygraph.nn import Conv2D, Pool2D, BatchNorm, Linear, Dropout +from dygraph.cvlibs import manager + __all__ = ["Xception41_deeplab", "Xception65_deeplab", "Xception71_deeplab"] @@ -400,7 +402,7 @@ def Xception41_deeplab(**args): model = XceptionDeeplab('xception_41', **args) return model - +@manager.BACKBONES.add_component def Xception65_deeplab(**args): model = XceptionDeeplab("xception_65", **args) return model diff --git a/dygraph/models/deeplab.py b/dygraph/models/deeplab.py index 67e35915945393a7a508bfb4bdf4f480e86af9f4..e9a0167f044e658a3c494d2c465f3ed729ef9367 100644 --- a/dygraph/models/deeplab.py +++ b/dygraph/models/deeplab.py @@ -13,22 +13,21 @@ # limitations under the License. - import os -import numpy as np - -import paddle +from dygraph.cvlibs import manager +from dygraph.models.architectures import layer_utils from paddle import fluid from paddle.fluid import dygraph from paddle.fluid.dygraph import Conv2D -from .architectures import layer_utils, xception_deeplab, resnet_vd from dygraph.utils import utils -__all__ = ['DeepLabV3P', "deeplabv3p_resnet101_vd", "deeplabv3p_resnet101_vd_os8", - "deeplabv3p_resnet50_vd", "deeplabv3p_resnet50_vd_os8", - "deeplabv3p_xception65_deeplab"] +__all__ = ['DeepLabV3P', "deeplabv3p_resnet101_vd", "deeplabv3p_resnet101_vd_os8", + "deeplabv3p_resnet50_vd", "deeplabv3p_resnet50_vd_os8", + "deeplabv3p_xception65_deeplab", + "deeplabv3p_mobilenetv3_large", "deeplabv3p_mobilenetv3_small"] + class ImageAverage(dygraph.Layer): """ @@ -42,8 +41,8 @@ class ImageAverage(dygraph.Layer): def __init__(self, num_channels): super(ImageAverage, self).__init__() self.conv_bn_relu = layer_utils.ConvBnRelu(num_channels, - num_filters=256, - filter_size=1) + num_filters=256, + filter_size=1) def forward(self, input): x = fluid.layers.reduce_mean(input, dim=[2, 3], keep_dim=True) @@ -78,8 +77,8 @@ class ASPP(dygraph.Layer): self.aspp1 = layer_utils.ConvBnRelu(num_channels=in_channels, num_filters=256, filter_size=1, - using_sep_conv=False) - + using_sep_conv=False) + # The second aspp using 3*3 (separable) conv at dilated rate aspp_ratios[0] self.aspp2 = layer_utils.ConvBnRelu(num_channels=in_channels, num_filters=256, @@ -87,7 +86,7 @@ class ASPP(dygraph.Layer): using_sep_conv=using_sep_conv, dilation=aspp_ratios[0], padding=aspp_ratios[0]) - + # The Third aspp using 3*3 (separable) conv at dilated rate aspp_ratios[1] self.aspp3 = layer_utils.ConvBnRelu(num_channels=in_channels, num_filters=256, @@ -103,22 +102,21 @@ class ASPP(dygraph.Layer): using_sep_conv=using_sep_conv, dilation=aspp_ratios[2], padding=aspp_ratios[2]) - - + # After concat op, using 1*1 conv self.conv_bn_relu = layer_utils.ConvBnRelu(num_channels=1280, - num_filters=256, - filter_size=1) + num_filters=256, + filter_size=1) def forward(self, x): - + x1 = self.image_average(x) x2 = self.aspp1(x) x3 = self.aspp2(x) x4 = self.aspp3(x) x5 = self.aspp4(x) x = fluid.layers.concat([x1, x2, x3, x4, x5], axis=1) - + x = self.conv_bn_relu(x) x = fluid.layers.dropout(x, dropout_prob=0.1) return x @@ -137,11 +135,11 @@ class Decoder(dygraph.Layer): def __init__(self, num_classes, in_channels, using_sep_conv=True): super(Decoder, self).__init__() - + self.conv_bn_relu1 = layer_utils.ConvBnRelu(num_channels=in_channels, num_filters=48, filter_size=1) - + self.conv_bn_relu2 = layer_utils.ConvBnRelu(num_channels=304, num_filters=256, filter_size=3, @@ -152,8 +150,8 @@ class Decoder(dygraph.Layer): filter_size=3, using_sep_conv=using_sep_conv, padding=1) - self.conv = Conv2D(num_channels=256, - num_filters=num_classes, + self.conv = Conv2D(num_channels=256, + num_filters=num_classes, filter_size=1) def forward(self, x, low_level_feat): @@ -169,7 +167,7 @@ class Decoder(dygraph.Layer): class DeepLabV3P(dygraph.Layer): """ The DeepLabV3P consists of three main components, Backbone, ASPP and Decoder - The orginal artile refers to + The orginal artile refers to "Encoder-Decoder with Atrous Separable Convolution for Semantic Image Segmentation" Liang-Chieh Chen, Yukun Zhu, George Papandreou, Florian Schroff, Hartwig Adam. (https://arxiv.org/abs/1802.02611) @@ -183,7 +181,7 @@ class DeepLabV3P(dygraph.Layer): backbone_indices (tuple): two values in the tuple indicte the indices of output of backbone. the first index will be taken as a low-level feature in Deconder component; - the second one will be taken as input of ASPP component. + the second one will be taken as input of ASPP component. Usually backbone consists of four downsampling stage, and return an output of each stage, so we set default (0, 3), which means taking feature map of the first stage in backbone as low-level feature used in Decoder, and feature map of the fourth @@ -193,15 +191,16 @@ class DeepLabV3P(dygraph.Layer): ignore_index (int): the value of ground-truth mask would be ignored while doing evaluation. Default 255. - using_sep_conv (bool): a bool value indicates whether using separable convolutions + using_sep_conv (bool): a bool value indicates whether using separable convolutions in ASPP and Decoder components. Default True. pretrained_model (str): the pretrained_model path of backbone. """ - def __init__(self, - backbone, - num_classes=2, + + def __init__(self, + backbone, + num_classes=2, output_stride=16, - backbone_indices=(0,3), + backbone_indices=(0, 3), backbone_channels=(256, 2048), ignore_index=255, using_sep_conv=True, @@ -209,7 +208,7 @@ class DeepLabV3P(dygraph.Layer): super(DeepLabV3P, self).__init__() - self.backbone = build_backbone(backbone, output_stride) + self.backbone = manager.BACKBONES[backbone](output_stride=output_stride) self.aspp = ASPP(output_stride, backbone_channels[1], using_sep_conv) self.decoder = Decoder(num_classes, backbone_channels[0], using_sep_conv) self.ignore_index = ignore_index @@ -217,14 +216,15 @@ class DeepLabV3P(dygraph.Layer): self.backbone_indices = backbone_indices self.init_weight(pretrained_model) - def forward(self, input, label=None, mode='train'): + def forward(self, input, label=None): + _, feat_list = self.backbone(input) low_level_feat = feat_list[self.backbone_indices[0]] x = feat_list[self.backbone_indices[1]] x = self.aspp(x) logit = self.decoder(x, low_level_feat) logit = fluid.layers.resize_bilinear(logit, input.shape[2:]) - + if self.training: return self._get_loss(logit, label) else: @@ -233,7 +233,7 @@ class DeepLabV3P(dygraph.Layer): pred = fluid.layers.argmax(score_map, axis=3) pred = fluid.layers.unsqueeze(pred, axes=[3]) return pred, score_map - + def init_weight(self, pretrained_model=None): """ Initialize the parameters of model parts. @@ -271,58 +271,71 @@ class DeepLabV3P(dygraph.Layer): loss = loss * mask avg_loss = fluid.layers.mean(loss) / ( - fluid.layers.mean(mask) + self.EPS) + fluid.layers.mean(mask) + self.EPS) label.stop_gradient = True mask.stop_gradient = True - - return avg_loss - - -def build_backbone(backbone, output_stride): - - if output_stride == 8: - dilation_dict = {2: 2, 3: 4} - elif output_stride == 16: - dilation_dict = {3: 2} - else: - raise Exception("deeplab only support stride 8 or 16") - - model_dict = {"ResNet50_vd":resnet_vd.ResNet50_vd, - "ResNet101_vd":resnet_vd.ResNet101_vd, - "Xception65_deeplab": xception_deeplab.Xception65_deeplab} - model = model_dict[backbone] + return avg_loss - return model(dilation_dict=dilation_dict) - def build_aspp(output_stride, using_sep_conv): return ASPP(output_stride=output_stride, using_sep_conv=using_sep_conv) + def build_decoder(num_classes, using_sep_conv): return Decoder(num_classes, using_sep_conv=using_sep_conv) + +@manager.MODELS.add_component def deeplabv3p_resnet101_vd(*args, **kwargs): pretrained_model = None return DeepLabV3P(backbone='ResNet101_vd', pretrained_model=pretrained_model, **kwargs) + +@manager.MODELS.add_component def deeplabv3p_resnet101_vd_os8(*args, **kwargs): pretrained_model = None return DeepLabV3P(backbone='ResNet101_vd', output_stride=8, pretrained_model=pretrained_model, **kwargs) + +@manager.MODELS.add_component def deeplabv3p_resnet50_vd(*args, **kwargs): pretrained_model = None return DeepLabV3P(backbone='ResNet50_vd', pretrained_model=pretrained_model, **kwargs) + +@manager.MODELS.add_component def deeplabv3p_resnet50_vd_os8(*args, **kwargs): pretrained_model = None return DeepLabV3P(backbone='ResNet50_vd', output_stride=8, pretrained_model=pretrained_model, **kwargs) + +@manager.MODELS.add_component def deeplabv3p_xception65_deeplab(*args, **kwargs): pretrained_model = None - return DeepLabV3P(backbone='Xception65_deeplab', + return DeepLabV3P(backbone='Xception65_deeplab', pretrained_model=pretrained_model, - backbone_indices=(0,1), + backbone_indices=(0, 1), backbone_channels=(128, 2048), - **kwargs) \ No newline at end of file + **kwargs) + + +@manager.MODELS.add_component +def deeplabv3p_mobilenetv3_large(*args, **kwargs): + pretrained_model = None + return DeepLabV3P(backbone='MobileNetV3_large_x1_0', + pretrained_model=pretrained_model, + backbone_indices=(0, 3), + backbone_channels=(24, 160), + **kwargs) + + +@manager.MODELS.add_component +def deeplabv3p_mobilenetv3_small(*args, **kwargs): + pretrained_model = None + return DeepLabV3P(backbone='MobileNetV3_small_x1_0', + pretrained_model=pretrained_model, + backbone_indices=(0, 3), + backbone_channels=(16, 96), + **kwargs) diff --git a/dygraph/train.py b/dygraph/train.py index 5a66f6073ce842944d777054e45435ca058c135a..073e90a3baeb4e61262ea2be767fd173a49d4872 100644 --- a/dygraph/train.py +++ b/dygraph/train.py @@ -19,7 +19,8 @@ from paddle.fluid.dygraph.parallel import ParallelEnv from dygraph.datasets import DATASETS import dygraph.transforms as T -from dygraph.models import MODELS +#from dygraph.models import MODELS +from dygraph.cvlibs import manager from dygraph.utils import get_environ_info from dygraph.core import train @@ -32,7 +33,7 @@ def parse_args(): '--model_name', dest='model_name', help='Model type for training, which is one of {}'.format( - str(list(MODELS.keys()))), + str(list(manager.MODELS.components_dict.keys()))), type=str, default='UNet') @@ -160,11 +161,8 @@ def main(args): transforms=eval_transforms, mode='val') - if args.model_name not in MODELS: - raise Exception( - '`--model_name` is invalid. it should be one of {}'.format( - str(list(MODELS.keys())))) - model = MODELS[args.model_name](num_classes=train_dataset.num_classes) + + model = manager.MODELS[args.model_name](num_classes=train_dataset.num_classes) # Creat optimizer # todo, may less one than len(loader)