diff --git a/configs/GhostNet/GhostNet_x0_5.yaml b/configs/GhostNet/GhostNet_x0_5.yaml new file mode 100644 index 0000000000000000000000000000000000000000..e787c6bd381fe874029d53600ac2bd9f4c68f028 --- /dev/null +++ b/configs/GhostNet/GhostNet_x0_5.yaml @@ -0,0 +1,74 @@ +mode: 'train' +ARCHITECTURE: + name: 'GhostNet_x0_5' + +pretrained_model: "" +model_save_dir: "./output/" +classes_num: 1000 +total_images: 1281167 +save_interval: 1 +validate: True +valid_interval: 1 +epochs: 360 +topk: 5 +image_shape: [3, 224, 224] + +use_mix: False +ls_epsilon: 0.1 + +LEARNING_RATE: + function: 'CosineWarmup' + params: + lr: 0.8 + +OPTIMIZER: + function: 'Momentum' + params: + momentum: 0.9 + regularizer: + function: 'L2' + factor: 0.0000400 + +TRAIN: + batch_size: 2048 + num_workers: 4 + file_list: "./dataset/ILSVRC2012/train_list.txt" + data_dir: "./dataset/ILSVRC2012/" + shuffle_seed: 0 + transforms: + - DecodeImage: + to_rgb: True + to_np: False + channel_first: False + - RandCropImage: + size: 224 + - RandFlipImage: + flip_code: 1 + - NormalizeImage: + scale: 1./255. + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: '' + - ToCHWImage: + +VALID: + batch_size: 64 + num_workers: 4 + file_list: "./dataset/ILSVRC2012/val_list.txt" + data_dir: "./dataset/ILSVRC2012/" + shuffle_seed: 0 + transforms: + - DecodeImage: + to_rgb: True + to_np: False + channel_first: False + - ResizeImage: + resize_short: 256 + - CropImage: + size: 224 + - NormalizeImage: + scale: 1.0/255.0 + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: '' + - ToCHWImage: diff --git a/configs/GhostNet/GhostNet_x1_0.yaml b/configs/GhostNet/GhostNet_x1_0.yaml new file mode 100644 index 0000000000000000000000000000000000000000..c1a61dca4c6e820649fc2c764688d88d0d36cf30 --- /dev/null +++ b/configs/GhostNet/GhostNet_x1_0.yaml @@ -0,0 +1,74 @@ +mode: 'train' +ARCHITECTURE: + name: 'GhostNet_x1_0' + +pretrained_model: "" +model_save_dir: "./output/" +classes_num: 1000 +total_images: 1281167 +save_interval: 1 +validate: True +valid_interval: 1 +epochs: 360 +topk: 5 +image_shape: [3, 224, 224] + +use_mix: False +ls_epsilon: 0.1 + +LEARNING_RATE: + function: 'CosineWarmup' + params: + lr: 0.4 + +OPTIMIZER: + function: 'Momentum' + params: + momentum: 0.9 + regularizer: + function: 'L2' + factor: 0.0000400 + +TRAIN: + batch_size: 1024 + num_workers: 4 + file_list: "./dataset/ILSVRC2012/train_list.txt" + data_dir: "./dataset/ILSVRC2012/" + shuffle_seed: 0 + transforms: + - DecodeImage: + to_rgb: True + to_np: False + channel_first: False + - RandCropImage: + size: 224 + - RandFlipImage: + flip_code: 1 + - NormalizeImage: + scale: 1./255. + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: '' + - ToCHWImage: + +VALID: + batch_size: 64 + num_workers: 4 + file_list: "./dataset/ILSVRC2012/val_list.txt" + data_dir: "./dataset/ILSVRC2012/" + shuffle_seed: 0 + transforms: + - DecodeImage: + to_rgb: True + to_np: False + channel_first: False + - ResizeImage: + resize_short: 256 + - CropImage: + size: 224 + - NormalizeImage: + scale: 1.0/255.0 + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: '' + - ToCHWImage: diff --git a/configs/GhostNet/GhostNet_x1_3.yaml b/configs/GhostNet/GhostNet_x1_3.yaml new file mode 100644 index 0000000000000000000000000000000000000000..452c1f49804d39264b4efa1c90cce280a0de3ef2 --- /dev/null +++ b/configs/GhostNet/GhostNet_x1_3.yaml @@ -0,0 +1,75 @@ +mode: 'train' +ARCHITECTURE: + name: 'GhostNet_x1_3' + +pretrained_model: "" +model_save_dir: "./output/" +classes_num: 1000 +total_images: 1281167 +save_interval: 1 +validate: True +valid_interval: 1 +epochs: 360 +topk: 5 +image_shape: [3, 224, 224] + +use_mix: False +ls_epsilon: 0.1 + +LEARNING_RATE: + function: 'CosineWarmup' + params: + lr: 0.4 + +OPTIMIZER: + function: 'Momentum' + params: + momentum: 0.9 + regularizer: + function: 'L2' + factor: 0.0000400 + +TRAIN: + batch_size: 1024 + num_workers: 4 + file_list: "./dataset/ILSVRC2012/train_list.txt" + data_dir: "./dataset/ILSVRC2012/" + shuffle_seed: 0 + transforms: + - DecodeImage: + to_rgb: True + to_np: False + channel_first: False + - RandCropImage: + size: 224 + - RandFlipImage: + flip_code: 1 + - AutoAugment: + - NormalizeImage: + scale: 1./255. + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: '' + - ToCHWImage: + +VALID: + batch_size: 64 + num_workers: 4 + file_list: "./dataset/ILSVRC2012/val_list.txt" + data_dir: "./dataset/ILSVRC2012/" + shuffle_seed: 0 + transforms: + - DecodeImage: + to_rgb: True + to_np: False + channel_first: False + - ResizeImage: + resize_short: 256 + - CropImage: + size: 224 + - NormalizeImage: + scale: 1.0/255.0 + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: '' + - ToCHWImage: diff --git a/ppcls/modeling/architectures/__init__.py b/ppcls/modeling/architectures/__init__.py index ac57a786aa16dcdda9c418a130fd4423c721e421..bc7d7593f2eca994fc1d0133697a9f428b27a29a 100644 --- a/ppcls/modeling/architectures/__init__.py +++ b/ppcls/modeling/architectures/__init__.py @@ -42,8 +42,9 @@ from .res2net_vd import Res2Net50_vd_48w_2s, Res2Net50_vd_26w_4s, Res2Net50_vd_1 from .hrnet import HRNet_W18_C, HRNet_W30_C, HRNet_W32_C, HRNet_W40_C, HRNet_W44_C, HRNet_W48_C, HRNet_W60_C, HRNet_W64_C, SE_HRNet_W18_C, SE_HRNet_W30_C, SE_HRNet_W32_C, SE_HRNet_W40_C, SE_HRNet_W44_C, SE_HRNet_W48_C, SE_HRNet_W60_C, SE_HRNet_W64_C from .darts_gs import DARTS_GS_6M, DARTS_GS_4M from .resnet_acnet import ResNet18_ACNet, ResNet34_ACNet, ResNet50_ACNet, ResNet101_ACNet, ResNet152_ACNet +from .ghostnet import GhostNet_x0_5, GhostNet_x1_0, GhostNet_x1_3 # distillation model from .distillation_models import ResNet50_vd_distill_MobileNetV3_large_x1_0, ResNeXt101_32x16d_wsl_distill_ResNet50_vd -from .csp_resnet import CSPResNet50_leaky \ No newline at end of file +from .csp_resnet import CSPResNet50_leaky diff --git a/ppcls/modeling/architectures/ghostnet.py b/ppcls/modeling/architectures/ghostnet.py new file mode 100644 index 0000000000000000000000000000000000000000..038e2f39fd5f77dc93d47790c9edb0447191a6ad --- /dev/null +++ b/ppcls/modeling/architectures/ghostnet.py @@ -0,0 +1,265 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import math + +import paddle.fluid as fluid +from paddle.fluid.param_attr import ParamAttr + +__all__ = ["GhostNet", "GhostNet_x0_5", "GhostNet_x1_0", "GhostNet_x1_3"] + + +class GhostNet(): + def __init__(self, scale): + cfgs = [ + # k, t, c, SE, s + [3, 16, 16, 0, 1], + [3, 48, 24, 0, 2], + [3, 72, 24, 0, 1], + [5, 72, 40, 1, 2], + [5, 120, 40, 1, 1], + [3, 240, 80, 0, 2], + [3, 200, 80, 0, 1], + [3, 184, 80, 0, 1], + [3, 184, 80, 0, 1], + [3, 480, 112, 1, 1], + [3, 672, 112, 1, 1], + [5, 672, 160, 1, 2], + [5, 960, 160, 0, 1], + [5, 960, 160, 1, 1], + [5, 960, 160, 0, 1], + [5, 960, 160, 1, 1] + ] + self.cfgs = cfgs + self.scale = scale + + def net(self, input, class_dim=1000): + # build first layer: + output_channel = int(self._make_divisible(16 * self.scale, 4)) + x = self.conv_bn_layer(input=input, + num_filters=output_channel, + filter_size=3, + stride=2, + groups=1, + act="relu", + name="conv1") + # build inverted residual blocks + idx = 0 + for k, exp_size, c, use_se, s in self.cfgs: + output_channel = int(self._make_divisible(c * self.scale, 4)) + hidden_channel = int(self._make_divisible(exp_size * self.scale, 4)) + x = self.ghost_bottleneck(input=x, + hidden_dim=hidden_channel, + output=output_channel, + kernel_size=k, + stride=s, + use_se=use_se, + name="_ghostbottleneck_" + str(idx)) + idx += 1 + # build last several layers + output_channel = int(self._make_divisible(exp_size * self.scale, 4)) + x = self.conv_bn_layer(input=x, + num_filters=output_channel, + filter_size=1, + stride=1, + groups=1, + act="relu", + name="conv_last") + x = fluid.layers.pool2d(input=x, pool_type='avg', global_pooling=True) + output_channel = 1280 + + stdv = 1.0 / math.sqrt(x.shape[1] * 1.0) + out = self.conv_bn_layer(input=x, + num_filters=output_channel, + filter_size=1, + stride=1, + act="relu", + name="fc_0") + out = fluid.layers.dropout(x=out, dropout_prob=0.2) + stdv = 1.0 / math.sqrt(out.shape[1] * 1.0) + out = fluid.layers.fc(input=out, + size=class_dim, + param_attr=ParamAttr(name="fc_1_weights", + initializer=fluid.initializer.Uniform(-stdv, stdv)), + bias_attr=ParamAttr(name="fc_1_offset")) + + return out + + def _make_divisible(self, v, divisor, min_value=None): + """ + This function is taken from the original tf repo. + It ensures that all layers have a channel number that is divisible by 8 + It can be seen here: + https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py + """ + if min_value is None: + min_value = divisor + new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) + # Make sure that round down does not go down by more than 10%. + if new_v < 0.9 * v: + new_v += divisor + return new_v + + def conv_bn_layer(self, + input, + num_filters, + filter_size, + stride=1, + groups=1, + act=None, + name=None): + x = fluid.layers.conv2d(input=input, + num_filters=num_filters, + filter_size=filter_size, + stride=stride, + padding=(filter_size - 1) // 2, + groups=groups, + act=None, + param_attr=ParamAttr( + initializer=fluid.initializer.MSRA(), name=name + "_weights"), + bias_attr=False) + bn_name = name + "_bn" + x = fluid.layers.batch_norm(input=x, + act=act, + param_attr=ParamAttr( + name=bn_name + "_scale", + regularizer=fluid.regularizer.L2DecayRegularizer( + regularization_coeff=0.0)), + bias_attr=ParamAttr( + name=bn_name + "_offset", + regularizer=fluid.regularizer.L2DecayRegularizer( + regularization_coeff=0.0)), + moving_mean_name=bn_name + "_mean", + moving_variance_name=name + "_variance") + return x + + def se_block(self, input, num_channels, reduction_ratio=4, name=None): + pool = fluid.layers.pool2d(input=input, pool_type='avg', global_pooling=True, use_cudnn=False) + stdv = 1.0 / math.sqrt(pool.shape[1] * 1.0) + squeeze = fluid.layers.fc(input=pool, + size=num_channels // reduction_ratio, + act='relu', + param_attr=fluid.param_attr.ParamAttr( + initializer=fluid.initializer.Uniform(-stdv, stdv), + name=name + '_1_weights'), + bias_attr=ParamAttr(name=name + '_1_offset')) + stdv = 1.0 / math.sqrt(squeeze.shape[1] * 1.0) + excitation = fluid.layers.fc(input=squeeze, + size=num_channels, + act=None, + param_attr=fluid.param_attr.ParamAttr( + initializer=fluid.initializer.Uniform(-stdv, stdv), + name=name + '_2_weights'), + bias_attr=ParamAttr(name=name + '_2_offset')) + #excitation = fluid.layers.clip(x=excitation, min=0, max=1) + se_scale = fluid.layers.elementwise_mul(x=input, y=excitation, axis=0) + return se_scale + + def depthwise_conv(self, + input, + output, + kernel_size, + stride=1, + relu=False, + name=None): + return self.conv_bn_layer(input=input, + num_filters=output, + filter_size=kernel_size, + stride=stride, + groups=input.shape[1], + act="relu" if relu else None, + name=name + "_depthwise") + + def ghost_module(self, + input, + output, + kernel_size=1, + ratio=2, + dw_size=3, + stride=1, + relu=True, + name=None): + self.output = output + init_channels = int(math.ceil(output / ratio)) + new_channels = int(init_channels * (ratio - 1)) + primary_conv = self.conv_bn_layer(input=input, + num_filters=init_channels, + filter_size=kernel_size, + stride=stride, + groups=1, + act="relu" if relu else None, + name=name + "_primary_conv") + cheap_operation = self.conv_bn_layer(input=primary_conv, + num_filters=new_channels, + filter_size=dw_size, + stride=1, + groups=init_channels, + act="relu" if relu else None, + name=name + "_cheap_operation") + out = fluid.layers.concat([primary_conv, cheap_operation], axis=1) + return out + + def ghost_bottleneck(self, + input, + hidden_dim, + output, + kernel_size, + stride, + use_se, + name=None): + inp_channels = input.shape[1] + x = self.ghost_module(input=input, + output=hidden_dim, + kernel_size=1, + stride=1, + relu=True, + name=name + "_ghost_module_1") + if stride == 2: + x = self.depthwise_conv(input=x, + output=hidden_dim, + kernel_size=kernel_size, + stride=stride, + relu=False, + name=name + "_depthwise") + if use_se: + x = self.se_block(input=x, num_channels=hidden_dim, name=name + "_se") + x = self.ghost_module(input=x, + output=output, + kernel_size=1, + relu=False, + name=name + "_ghost_module_2") + if stride == 1 and inp_channels == output: + shortcut = input + else: + shortcut = self.depthwise_conv(input=input, + output=inp_channels, + kernel_size=kernel_size, + stride=stride, + relu=False, + name=name + "_shortcut_depthwise") + shortcut = self.conv_bn_layer(input=shortcut, + num_filters=output, + filter_size=1, + stride=1, + groups=1, + act=None, + name=name + "_shortcut_conv") + return fluid.layers.elementwise_add(x=x, + y=shortcut, + axis=-1) + + +def GhostNet_x0_5(): + model = GhostNet(scale=0.5) + return model + + +def GhostNet_x1_0(): + model = GhostNet(scale=1.0) + return model + + +def GhostNet_x1_3(): + model = GhostNet(scale=1.3) + return model