diff --git a/new_tutorials/train/segmentation/fast_scnn.py b/new_tutorials/train/segmentation/fast_scnn.py new file mode 100644 index 0000000000000000000000000000000000000000..53f1a528a090d6d4f278e47b54b2660dccde2e0d --- /dev/null +++ b/new_tutorials/train/segmentation/fast_scnn.py @@ -0,0 +1,48 @@ +import os +# 选择使用0号卡 +os.environ['CUDA_VISIBLE_DEVICES'] = '0' + +import paddlex as pdx +from paddlex.seg import transforms + +# 下载和解压视盘分割数据集 +optic_dataset = 'https://bj.bcebos.com/paddlex/datasets/optic_disc_seg.tar.gz' +pdx.utils.download_and_decompress(optic_dataset, path='./') + +# 定义训练和验证时的transforms +# API说明: https://paddlex.readthedocs.io/zh_CN/latest/apis/transforms/seg_transforms.html#composedsegtransforms +train_transforms = transforms.ComposedSegTransforms( + mode='train', train_crop_size=[769, 769]) +eval_transforms = transforms.ComposedSegTransforms(mode='eval') + +# 定义训练和验证所用的数据集 +# API说明: https://paddlex.readthedocs.io/zh_CN/latest/apis/datasets/semantic_segmentation.html#segdataset +train_dataset = pdx.datasets.SegDataset( + data_dir='optic_disc_seg', + file_list='optic_disc_seg/train_list.txt', + label_list='optic_disc_seg/labels.txt', + transforms=train_transforms, + shuffle=True) +eval_dataset = pdx.datasets.SegDataset( + data_dir='optic_disc_seg', + file_list='optic_disc_seg/val_list.txt', + label_list='optic_disc_seg/labels.txt', + transforms=eval_transforms) + +# 初始化模型,并进行训练 +# 可使用VisualDL查看训练指标 +# VisualDL启动方式: visualdl --logdir output/unet/vdl_log --port 8001 +# 浏览器打开 https://0.0.0.0:8001即可 +# 其中0.0.0.0为本机访问,如为远程服务, 改成相应机器IP + +# https://paddlex.readthedocs.io/zh_CN/latest/apis/models/semantic_segmentation.html#hrnet +num_classes = len(train_dataset.labels) +model = pdx.seg.FastSCNN(num_classes=num_classes) +model.train( + num_epochs=20, + train_dataset=train_dataset, + train_batch_size=4, + eval_dataset=eval_dataset, + learning_rate=0.01, + save_dir='output/fastscnn', + use_vdl=True) diff --git a/paddlex/cv/models/base.py b/paddlex/cv/models/base.py index d15459c0bc318207b5bcf9593dfaaf676437fe27..e30a2529c5a7ff9cbcafb4a05d58f53ea5476e7e 100644 --- a/paddlex/cv/models/base.py +++ b/paddlex/cv/models/base.py @@ -1,11 +1,11 @@ # copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. -# +# # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at -# +# # http://www.apache.org/licenses/LICENSE-2.0 -# +# # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -194,9 +194,8 @@ class BaseAPI: if os.path.exists(pretrain_dir): os.remove(pretrain_dir) os.makedirs(pretrain_dir) - if pretrain_weights is not None and \ - not os.path.isdir(pretrain_weights) \ - and not os.path.isfile(pretrain_weights): + if pretrain_weights is not None and not os.path.exists( + pretrain_weights): if self.model_type == 'classifier': if pretrain_weights not in ['IMAGENET']: logging.warning( @@ -245,8 +244,8 @@ class BaseAPI: logging.info( "Load pretrain weights from {}.".format(pretrain_weights), use_color=True) - paddlex.utils.utils.load_pretrain_weights(self.exe, self.train_prog, - pretrain_weights, fuse_bn) + paddlex.utils.utils.load_pretrain_weights( + self.exe, self.train_prog, pretrain_weights, fuse_bn) # 进行裁剪 if sensitivities_file is not None: import paddleslim @@ -350,7 +349,9 @@ class BaseAPI: logging.info("Model saved in {}.".format(save_dir)) def export_inference_model(self, save_dir): - test_input_names = [var.name for var in list(self.test_inputs.values())] + test_input_names = [ + var.name for var in list(self.test_inputs.values()) + ] test_outputs = list(self.test_outputs.values()) if self.__class__.__name__ == 'MaskRCNN': from paddlex.utils.save import save_mask_inference_model @@ -387,7 +388,8 @@ class BaseAPI: # 模型保存成功的标志 open(osp.join(save_dir, '.success'), 'w').close() - logging.info("Model for inference deploy saved in {}.".format(save_dir)) + logging.info("Model for inference deploy saved in {}.".format( + save_dir)) def train_loop(self, num_epochs, @@ -511,11 +513,13 @@ class BaseAPI: eta = ((num_epochs - i) * total_num_steps - step - 1 ) * avg_step_time if time_eval_one_epoch is not None: - eval_eta = (total_eval_times - i // save_interval_epochs - ) * time_eval_one_epoch + eval_eta = ( + total_eval_times - i // save_interval_epochs + ) * time_eval_one_epoch else: - eval_eta = (total_eval_times - i // save_interval_epochs - ) * total_num_steps_eval * avg_step_time + eval_eta = ( + total_eval_times - i // save_interval_epochs + ) * total_num_steps_eval * avg_step_time eta_str = seconds_to_hms(eta + eval_eta) logging.info( diff --git a/paddlex/cv/models/fast_scnn.py b/paddlex/cv/models/fast_scnn.py new file mode 100644 index 0000000000000000000000000000000000000000..5f66e4df6ede1b48c0363b5b8a496b23021454ef --- /dev/null +++ b/paddlex/cv/models/fast_scnn.py @@ -0,0 +1,169 @@ +# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +import paddle.fluid as fluid +import paddlex +from collections import OrderedDict +from .deeplabv3p import DeepLabv3p + + +class FastSCNN(DeepLabv3p): + """实现Fast SCNN网络的构建并进行训练、评估、预测和模型导出。 + + Args: + num_classes (int): 类别数。 + use_bce_loss (bool): 是否使用bce loss作为网络的损失函数,只能用于两类分割。可与dice loss同时使用。默认False。 + use_dice_loss (bool): 是否使用dice loss作为网络的损失函数,只能用于两类分割,可与bce loss同时使用。 + 当use_bce_loss和use_dice_loss都为False时,使用交叉熵损失函数。默认False。 + class_weight (list/str): 交叉熵损失函数各类损失的权重。当class_weight为list的时候,长度应为 + num_classes。当class_weight为str时, weight.lower()应为'dynamic',这时会根据每一轮各类像素的比重 + 自行计算相应的权重,每一类的权重为:每类的比例 * num_classes。class_weight取默认值None是,各类的权重1, + 即平时使用的交叉熵损失函数。 + ignore_index (int): label上忽略的值,label为ignore_index的像素不参与损失函数的计算。默认255。 + multi_loss_weight (list): 多分支上的loss权重。默认计算一个分支上的loss,即默认值为[1.0]。 + 也支持计算两个分支或三个分支上的loss,权重按[fusion_branch_weight, higher_branch_weight, lower_branch_weight]排列, + fusion_branch_weight为空间细节分支和全局上下文分支融合后的分支上的loss权重,higher_branch_weight为空间细节分支上的loss权重, + lower_branch_weight为全局上下文分支上的loss权重,若higher_branch_weight和lower_branch_weight未设置则不会计算这两个分支上的loss。 + + Raises: + ValueError: use_bce_loss或use_dice_loss为真且num_calsses > 2。 + ValueError: class_weight为list, 但长度不等于num_class。 + class_weight为str, 但class_weight.low()不等于dynamic。 + TypeError: class_weight不为None时,其类型不是list或str。 + TypeError: multi_loss_weight不为list。 + ValueError: multi_loss_weight为list但长度小于0或者大于3。 + """ + + def __init__(self, + num_classes=2, + use_bce_loss=False, + use_dice_loss=False, + class_weight=None, + ignore_index=255, + multi_loss_weight=[1.0]): + self.init_params = locals() + super(DeepLabv3p, self).__init__('segmenter') + # dice_loss或bce_loss只适用两类分割中 + if num_classes > 2 and (use_bce_loss or use_dice_loss): + raise ValueError( + "dice loss and bce loss is only applicable to binary classfication" + ) + + if class_weight is not None: + if isinstance(class_weight, list): + if len(class_weight) != num_classes: + raise ValueError( + "Length of class_weight should be equal to number of classes" + ) + elif isinstance(class_weight, str): + if class_weight.lower() != 'dynamic': + raise ValueError( + "if class_weight is string, must be dynamic!") + else: + raise TypeError( + 'Expect class_weight is a list or string but receive {}'. + format(type(class_weight))) + + if not isinstance(multi_loss_weight, list): + raise TypeError( + 'Expect multi_loss_weight is a list but receive {}'.format( + type(multi_loss_weight))) + if len(multi_loss_weight) > 3 or len(multi_loss_weight) < 0: + raise ValueError( + "Length of multi_loss_weight should be lower than or equal to 3 but greater than 0." + ) + + self.num_classes = num_classes + self.use_bce_loss = use_bce_loss + self.use_dice_loss = use_dice_loss + self.class_weight = class_weight + self.multi_loss_weight = multi_loss_weight + self.ignore_index = ignore_index + self.labels = None + self.fixed_input_shape = None + + def build_net(self, mode='train'): + model = paddlex.cv.nets.segmentation.FastSCNN( + self.num_classes, + mode=mode, + use_bce_loss=self.use_bce_loss, + use_dice_loss=self.use_dice_loss, + class_weight=self.class_weight, + ignore_index=self.ignore_index, + multi_loss_weight=self.multi_loss_weight, + fixed_input_shape=self.fixed_input_shape) + inputs = model.generate_inputs() + model_out = model.build_net(inputs) + outputs = OrderedDict() + if mode == 'train': + self.optimizer.minimize(model_out) + outputs['loss'] = model_out + else: + outputs['pred'] = model_out[0] + outputs['logit'] = model_out[1] + return inputs, outputs + + def train(self, + num_epochs, + train_dataset, + train_batch_size=2, + eval_dataset=None, + save_interval_epochs=1, + log_interval_steps=2, + save_dir='output', + pretrain_weights='CITYSCAPES', + optimizer=None, + learning_rate=0.01, + lr_decay_power=0.9, + use_vdl=False, + sensitivities_file=None, + eval_metric_loss=0.05, + early_stop=False, + early_stop_patience=5, + resume_checkpoint=None): + """训练。 + + Args: + num_epochs (int): 训练迭代轮数。 + train_dataset (paddlex.datasets): 训练数据读取器。 + train_batch_size (int): 训练数据batch大小。同时作为验证数据batch大小。默认2。 + eval_dataset (paddlex.datasets): 评估数据读取器。 + save_interval_epochs (int): 模型保存间隔(单位:迭代轮数)。默认为1。 + log_interval_steps (int): 训练日志输出间隔(单位:迭代次数)。默认为2。 + save_dir (str): 模型保存路径。默认'output'。 + pretrain_weights (str): 若指定为路径时,则加载路径下预训练模型;若为字符串'CITYSCAPES' + 则自动下载在CITYSCAPES图片数据上预训练的模型权重;若为None,则不使用预训练模型。默认为'CITYSCAPES'。 + optimizer (paddle.fluid.optimizer): 优化器。当改参数为None时,使用默认的优化器:使用 + fluid.optimizer.Momentum优化方法,polynomial的学习率衰减策略。 + learning_rate (float): 默认优化器的初始学习率。默认0.01。 + lr_decay_power (float): 默认优化器学习率多项式衰减系数。默认0.9。 + use_vdl (bool): 是否使用VisualDL进行可视化。默认False。 + sensitivities_file (str): 若指定为路径时,则加载路径下敏感度信息进行裁剪;若为字符串'DEFAULT', + 则自动下载在Cityscapes图片数据上获得的敏感度信息进行裁剪;若为None,则不进行裁剪。默认为None。 + eval_metric_loss (float): 可容忍的精度损失。默认为0.05。 + early_stop (bool): 是否使用提前终止训练策略。默认值为False。 + early_stop_patience (int): 当使用提前终止训练策略时,如果验证集精度在`early_stop_patience`个epoch内 + 连续下降或持平,则终止训练。默认值为5。 + resume_checkpoint (str): 恢复训练时指定上次训练保存的模型路径。若为None,则不会恢复训练。默认值为None。 + + Raises: + ValueError: 模型从inference model进行加载。 + """ + return super(FastSCNN, self).train( + num_epochs, train_dataset, train_batch_size, eval_dataset, + save_interval_epochs, log_interval_steps, save_dir, + pretrain_weights, optimizer, learning_rate, lr_decay_power, + use_vdl, sensitivities_file, eval_metric_loss, early_stop, + early_stop_patience, resume_checkpoint) diff --git a/paddlex/cv/nets/segmentation/fast_scnn.py b/paddlex/cv/nets/segmentation/fast_scnn.py new file mode 100644 index 0000000000000000000000000000000000000000..71866e56df9adf31c45d841a7bcde3a062c3067a --- /dev/null +++ b/paddlex/cv/nets/segmentation/fast_scnn.py @@ -0,0 +1,395 @@ +# coding: utf8 +# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from collections import OrderedDict + +import paddle.fluid as fluid +from .model_utils.libs import scope +from .model_utils.libs import bn, bn_relu, relu, conv_bn_layer +from .model_utils.libs import conv, avg_pool +from .model_utils.libs import separate_conv +from .model_utils.libs import sigmoid_to_softmax +from .model_utils.loss import softmax_with_loss +from .model_utils.loss import dice_loss +from .model_utils.loss import bce_loss + + +class FastSCNN(object): + def __init__(self, + num_classes, + mode='train', + use_bce_loss=False, + use_dice_loss=False, + class_weight=None, + multi_loss_weight=[1.0], + ignore_index=255, + fixed_input_shape=None): + # dice_loss或bce_loss只适用两类分割中 + if num_classes > 2 and (use_bce_loss or use_dice_loss): + raise ValueError( + "dice loss and bce loss is only applicable to binary classfication" + ) + + if class_weight is not None: + if isinstance(class_weight, list): + if len(class_weight) != num_classes: + raise ValueError( + "Length of class_weight should be equal to number of classes" + ) + elif isinstance(class_weight, str): + if class_weight.lower() != 'dynamic': + raise ValueError( + "if class_weight is string, must be dynamic!") + else: + raise TypeError( + 'Expect class_weight is a list or string but receive {}'. + format(type(class_weight))) + + self.num_classes = num_classes + self.mode = mode + self.use_bce_loss = use_bce_loss + self.use_dice_loss = use_dice_loss + self.class_weight = class_weight + self.ignore_index = ignore_index + self.multi_loss_weight = multi_loss_weight + self.fixed_input_shape = fixed_input_shape + + def build_net(self, inputs): + if self.use_dice_loss or self.use_bce_loss: + self.num_classes = 1 + image = inputs['image'] + size = fluid.layers.shape(image)[2:] + with scope('learning_to_downsample'): + higher_res_features = self._learning_to_downsample(image, 32, 48, + 64) + with scope('global_feature_extractor'): + lower_res_feature = self._global_feature_extractor( + higher_res_features, 64, [64, 96, 128], 128, 6, [3, 3, 3]) + with scope('feature_fusion'): + x = self._feature_fusion(higher_res_features, lower_res_feature, + 64, 128, 128) + with scope('classifier'): + logit = self._classifier(x, 128) + logit = fluid.layers.resize_bilinear(logit, size, align_mode=0) + + if len(self.multi_loss_weight) == 3: + with scope('aux_layer_higher'): + higher_logit = self._aux_layer(higher_res_features, + self.num_classes) + higher_logit = fluid.layers.resize_bilinear( + higher_logit, size, align_mode=0) + with scope('aux_layer_lower'): + lower_logit = self._aux_layer(lower_res_feature, + self.num_classes) + lower_logit = fluid.layers.resize_bilinear( + lower_logit, size, align_mode=0) + logit = (logit, higher_logit, lower_logit) + elif len(self.multi_loss_weight) == 2: + with scope('aux_layer_higher'): + higher_logit = self._aux_layer(higher_res_features, + self.num_classes) + higher_logit = fluid.layers.resize_bilinear( + higher_logit, size, align_mode=0) + logit = (logit, higher_logit) + else: + logit = (logit, ) + + if self.num_classes == 1: + out = sigmoid_to_softmax(logit[0]) + out = fluid.layers.transpose(out, [0, 2, 3, 1]) + else: + out = fluid.layers.transpose(logit[0], [0, 2, 3, 1]) + + pred = fluid.layers.argmax(out, axis=3) + pred = fluid.layers.unsqueeze(pred, axes=[3]) + + if self.mode == 'train': + label = inputs['label'] + return self._get_loss(logit, label) + elif self.mode == 'eval': + label = inputs['label'] + loss = self._get_loss(logit, label) + return loss, pred, label, mask + else: + if self.num_classes == 1: + logit = sigmoid_to_softmax(logit[0]) + else: + logit = fluid.layers.softmax(logit[0], axis=1) + return pred, logit + + def generate_inputs(self): + inputs = OrderedDict() + if self.fixed_input_shape is not None: + input_shape = [ + None, 3, self.fixed_input_shape[1], self.fixed_input_shape[0] + ] + inputs['image'] = fluid.data( + dtype='float32', shape=input_shape, name='image') + else: + inputs['image'] = fluid.data( + dtype='float32', shape=[None, 3, None, None], name='image') + if self.mode == 'train': + inputs['label'] = fluid.data( + dtype='int32', shape=[None, 1, None, None], name='label') + elif self.mode == 'eval': + inputs['label'] = fluid.data( + dtype='int32', shape=[None, 1, None, None], name='label') + return inputs + + def _get_loss(self, logits, label): + avg_loss = 0 + if not (self.use_dice_loss or self.use_bce_loss): + for i, logit in enumerate(logits): + logit_mask = ( + label.astype('int32') != self.ignore_index).astype('int32') + loss = softmax_with_loss( + logit, + label, + logit_mask, + num_classes=self.num_classes, + weight=self.class_weight, + ignore_index=self.ignore_index) + avg_loss += self.multi_loss_weight[i] * loss + else: + if self.use_dice_loss: + for i, logit in enumerate(logits): + logit_mask = (label.astype('int32') != self.ignore_index + ).astype('int32') + loss = dice_loss(logit, label, logit_mask) + avg_loss += self.multi_loss_weight[i] * loss + if self.use_bce_loss: + for i, logit in enumerate(logits): + #logit_label = fluid.layers.resize_nearest(label, logit_shape[2:]) + logit_mask = (label.astype('int32') != self.ignore_index + ).astype('int32') + loss = bce_loss( + logit, + label, + logit_mask, + ignore_index=self.ignore_index) + avg_loss += self.multi_loss_weight[i] * loss + return avg_loss + + def _learning_to_downsample(self, + x, + dw_channels1=32, + dw_channels2=48, + out_channels=64): + x = relu(bn(conv(x, dw_channels1, 3, 2))) + with scope('dsconv1'): + x = separate_conv( + x, dw_channels2, stride=2, filter=3, act=fluid.layers.relu) + with scope('dsconv2'): + x = separate_conv( + x, out_channels, stride=2, filter=3, act=fluid.layers.relu) + return x + + def _shortcut(self, input, data_residual): + return fluid.layers.elementwise_add(input, data_residual) + + def _dropout2d(self, input, prob, is_train=False): + if not is_train: + return input + keep_prob = 1.0 - prob + shape = fluid.layers.shape(input) + channels = shape[1] + random_tensor = keep_prob + fluid.layers.uniform_random( + [shape[0], channels, 1, 1], min=0., max=1.) + binary_tensor = fluid.layers.floor(random_tensor) + output = input / keep_prob * binary_tensor + return output + + def _inverted_residual_unit(self, + input, + num_in_filter, + num_filters, + ifshortcut, + stride, + filter_size, + padding, + expansion_factor, + name=None): + num_expfilter = int(round(num_in_filter * expansion_factor)) + + channel_expand = conv_bn_layer( + input=input, + num_filters=num_expfilter, + filter_size=1, + stride=1, + padding=0, + num_groups=1, + if_act=True, + name=name + '_expand') + + bottleneck_conv = conv_bn_layer( + input=channel_expand, + num_filters=num_expfilter, + filter_size=filter_size, + stride=stride, + padding=padding, + num_groups=num_expfilter, + if_act=True, + name=name + '_dwise', + use_cudnn=False) + + depthwise_output = bottleneck_conv + + linear_out = conv_bn_layer( + input=bottleneck_conv, + num_filters=num_filters, + filter_size=1, + stride=1, + padding=0, + num_groups=1, + if_act=False, + name=name + '_linear') + + if ifshortcut: + out = self._shortcut(input=input, data_residual=linear_out) + return out, depthwise_output + else: + return linear_out, depthwise_output + + def _inverted_blocks(self, input, in_c, t, c, n, s, name=None): + first_block, depthwise_output = self._inverted_residual_unit( + input=input, + num_in_filter=in_c, + num_filters=c, + ifshortcut=False, + stride=s, + filter_size=3, + padding=1, + expansion_factor=t, + name=name + '_1') + + last_residual_block = first_block + last_c = c + + for i in range(1, n): + last_residual_block, depthwise_output = self._inverted_residual_unit( + input=last_residual_block, + num_in_filter=last_c, + num_filters=c, + ifshortcut=True, + stride=1, + filter_size=3, + padding=1, + expansion_factor=t, + name=name + '_' + str(i + 1)) + return last_residual_block, depthwise_output + + def _psp_module(self, input, out_features): + + cat_layers = [] + sizes = (1, 2, 3, 6) + for size in sizes: + psp_name = "psp" + str(size) + with scope(psp_name): + pool = fluid.layers.adaptive_pool2d( + input, + pool_size=[size, size], + pool_type='avg', + name=psp_name + '_adapool') + data = conv( + pool, + out_features, + filter_size=1, + bias_attr=False, + name=psp_name + '_conv') + data_bn = bn(data, act='relu') + interp = fluid.layers.resize_bilinear( + data_bn, + out_shape=fluid.layers.shape(input)[2:], + name=psp_name + '_interp', + align_mode=0) + cat_layers.append(interp) + cat_layers = [input] + cat_layers + out = fluid.layers.concat(cat_layers, axis=1, name='psp_cat') + + return out + + def _aux_layer(self, x, num_classes): + x = relu(bn(conv(x, 32, 3, padding=1))) + x = self._dropout2d(x, 0.1, is_train=(self.mode == 'train')) + with scope('logit'): + x = conv(x, num_classes, 1, bias_attr=True) + return x + + def _feature_fusion(self, + higher_res_feature, + lower_res_feature, + higher_in_channels, + lower_in_channels, + out_channels, + scale_factor=4): + shape = fluid.layers.shape(higher_res_feature) + w = shape[-1] + h = shape[-2] + lower_res_feature = fluid.layers.resize_bilinear( + lower_res_feature, [h, w], align_mode=0) + + with scope('dwconv'): + lower_res_feature = relu( + bn(conv(lower_res_feature, out_channels, + 1))) #(lower_res_feature) + with scope('conv_lower_res'): + lower_res_feature = bn( + conv( + lower_res_feature, out_channels, 1, bias_attr=True)) + with scope('conv_higher_res'): + higher_res_feature = bn( + conv( + higher_res_feature, out_channels, 1, bias_attr=True)) + out = higher_res_feature + lower_res_feature + + return relu(out) + + def _global_feature_extractor(self, + x, + in_channels=64, + block_channels=(64, 96, 128), + out_channels=128, + t=6, + num_blocks=(3, 3, 3)): + x, _ = self._inverted_blocks(x, in_channels, t, block_channels[0], + num_blocks[0], 2, 'inverted_block_1') + x, _ = self._inverted_blocks(x, block_channels[0], t, + block_channels[1], num_blocks[1], 2, + 'inverted_block_2') + x, _ = self._inverted_blocks(x, block_channels[1], t, + block_channels[2], num_blocks[2], 1, + 'inverted_block_3') + x = self._psp_module(x, block_channels[2] // 4) + + with scope('out'): + x = relu(bn(conv(x, out_channels, 1))) + + return x + + def _classifier(self, x, dw_channels, stride=1): + with scope('dsconv1'): + x = separate_conv( + x, dw_channels, stride=stride, filter=3, act=fluid.layers.relu) + with scope('dsconv2'): + x = separate_conv( + x, dw_channels, stride=stride, filter=3, act=fluid.layers.relu) + + x = self._dropout2d(x, 0.1, is_train=self.mode == 'train') + x = conv(x, self.num_classes, 1, bias_attr=True) + return x