diff --git a/configs/ResNeSt/ResNeSt50.yaml b/configs/ResNeSt/ResNeSt50.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..01e056da05796b91374caed9bac0f652bc9752e6
--- /dev/null
+++ b/configs/ResNeSt/ResNeSt50.yaml
@@ -0,0 +1,78 @@
+mode: 'train'
+ARCHITECTURE:
+ name: 'ResNeSt50'
+
+pretrained_model: ""
+model_save_dir: "./output/"
+classes_num: 1000
+total_images: 1281167
+save_interval: 1
+validate: True
+valid_interval: 1
+epochs: 300
+topk: 5
+image_shape: [3, 224, 224]
+
+use_mix: True
+ls_epsilon: 0.1
+
+LEARNING_RATE:
+ function: 'CosineWarmup'
+ params:
+ lr: 0.1
+
+OPTIMIZER:
+ function: 'Momentum'
+ params:
+ momentum: 0.9
+ regularizer:
+ function: 'L2'
+ factor: 0.000070
+
+TRAIN:
+ batch_size: 256
+ num_workers: 4
+ file_list: "./dataset/ILSVRC2012/train_list.txt"
+ data_dir: "./dataset/ILSVRC2012/"
+ shuffle_seed: 0
+ transforms:
+ - DecodeImage:
+ to_rgb: True
+ to_np: False
+ channel_first: False
+ - RandCropImage:
+ size: 224
+ - RandFlipImage:
+ flip_code: 1
+ - AutoAugment:
+ - NormalizeImage:
+ scale: 1./255.
+ mean: [0.485, 0.456, 0.406]
+ std: [0.229, 0.224, 0.225]
+ order: ''
+ - ToCHWImage:
+ mix:
+ - CutmixOperator:
+ alpha: 0.2
+
+VALID:
+ batch_size: 64
+ num_workers: 4
+ file_list: "./dataset/ILSVRC2012/val_list.txt"
+ data_dir: "./dataset/ILSVRC2012/"
+ shuffle_seed: 0
+ transforms:
+ - DecodeImage:
+ to_rgb: True
+ to_np: False
+ channel_first: False
+ - ResizeImage:
+ resize_short: 256
+ - CropImage:
+ size: 224
+ - NormalizeImage:
+ scale: 1.0/255.0
+ mean: [0.485, 0.456, 0.406]
+ std: [0.229, 0.224, 0.225]
+ order: ''
+ - ToCHWImage:
diff --git a/docs/en/models/ResNeSt_RegNet_en.md b/docs/en/models/ResNeSt_RegNet_en.md
new file mode 100644
index 0000000000000000000000000000000000000000..3952b1155775bdaf6df245e0fb2ab626b736294a
--- /dev/null
+++ b/docs/en/models/ResNeSt_RegNet_en.md
@@ -0,0 +1,9 @@
+## Overview
+
+The ResNeSt series was proposed in 2020. The original resnet network structure has been improved by introducing K groups and adding an attention module similar to SEBlock in different groups, the accuracy is greater than that of the basic model ResNet, but the parameter amount and flops are almost the same as the basic ResNet.
+
+## Accuracy, FLOPs and Parameters
+
+| Models | Top1 | Top5 | Reference
top1 | Reference
top5 | FLOPS
(G) | Parameters
(M) |
+|:--:|:--:|:--:|:--:|:--:|:--:|:--:|
+| ResNeSt50 | 0.8102 | 0.9542| 0.8113 | -|5.39 | 27.5 |
diff --git a/docs/zh_CN/models/ResNeSt_RegNet.md b/docs/zh_CN/models/ResNeSt_RegNet.md
new file mode 100644
index 0000000000000000000000000000000000000000..2b12d73945a18e5242be234586134c9a2831080d
--- /dev/null
+++ b/docs/zh_CN/models/ResNeSt_RegNet.md
@@ -0,0 +1,14 @@
+# ResNeSt以及RegNet网络
+
+## 概述
+
+ResNeSt系列模型是在2020年提出的,在原有的resnet网络结构上做了改进,通过引入K个Group和在不同Group中加入类似于SEBlock的attention模块,使得精度相比于基础模型ResNet有了大幅度的提高,且参数量和flops与基础的ResNet基本保持一致。
+
+
+## 精度、FLOPS和参数量
+
+| Models | Top1 | Top5 | Reference
top1 | Reference
top5 | FLOPS
(G) | Parameters
(M) |
+|:--:|:--:|:--:|:--:|:--:|:--:|:--:|
+| ResNeSt50 | 0.8102 | 0.9542| 0.8113 | -|5.39 | 27.5 |
+
+
diff --git a/ppcls/modeling/architectures/__init__.py b/ppcls/modeling/architectures/__init__.py
index 8b288ed099acd9a849366ad138294b64876a5700..71f25c87238035f458ae845af1a4c863bd75a78f 100644
--- a/ppcls/modeling/architectures/__init__.py
+++ b/ppcls/modeling/architectures/__init__.py
@@ -47,6 +47,7 @@ from .hrnet import HRNet_W18_C, HRNet_W30_C, HRNet_W32_C, HRNet_W40_C, HRNet_W44
from .darts_gs import DARTS_GS_6M, DARTS_GS_4M
from .resnet_acnet import ResNet18_ACNet, ResNet34_ACNet, ResNet50_ACNet, ResNet101_ACNet, ResNet152_ACNet
from .ghostnet import GhostNet_x0_5, GhostNet_x1_0, GhostNet_x1_3
+from .resnest import ResNeSt50, ResNeSt101, ResNeSt200, ResNeSt269, ResNeSt50_fast_1s1x64d, ResNeSt50_fast_2s1x64d, ResNeSt50_fast_4s1x64d, ResNeSt50_fast_1s2x40d, ResNeSt50_fast_2s2x40d, ResNeSt50_fast_2s2x40d, ResNeSt50_fast_4s2x40d, ResNeSt50_fast_1s4x24d
# distillation model
from .distillation_models import ResNet50_vd_distill_MobileNetV3_large_x1_0, ResNeXt101_32x16d_wsl_distill_ResNet50_vd
diff --git a/ppcls/modeling/architectures/resnest.py b/ppcls/modeling/architectures/resnest.py
new file mode 100644
index 0000000000000000000000000000000000000000..d5f1933b1c35410d2ab87c8f6598df3c20453ec6
--- /dev/null
+++ b/ppcls/modeling/architectures/resnest.py
@@ -0,0 +1,648 @@
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import paddle.fluid as fluid
+from paddle.fluid.initializer import MSRA, ConstantInitializer
+from paddle.fluid.param_attr import ParamAttr
+from paddle.fluid.regularizer import L2DecayRegularizer
+import math
+
+__all__ = [
+ 'ResNeSt50', 'ResNeSt101', 'ResNeSt200', 'ResNeSt269',
+ 'ResNeSt50_fast_1s1x64d', 'ResNeSt50_fast_2s1x64d',
+ 'ResNeSt50_fast_4s1x64d', 'ResNeSt50_fast_1s2x40d',
+ 'ResNeSt50_fast_2s2x40d', 'ResNeSt50_fast_2s2x40d',
+ 'ResNeSt50_fast_4s2x40d', 'ResNeSt50_fast_1s4x24d'
+]
+
+
+class ResNeSt():
+ def __init__(self,
+ layers,
+ radix=1,
+ groups=1,
+ bottleneck_width=64,
+ dilated=False,
+ dilation=1,
+ deep_stem=False,
+ stem_width=64,
+ avg_down=False,
+ rectify_avg=False,
+ avd=False,
+ avd_first=False,
+ final_drop=0.0,
+ last_gamma=False,
+ bn_decay=0.0):
+ self.cardinality = groups
+ self.bottleneck_width = bottleneck_width
+ # ResNet-D params
+ self.inplanes = stem_width * 2 if deep_stem else 64
+ self.avg_down = avg_down
+ self.last_gamma = last_gamma
+ # ResNeSt params
+ self.radix = radix
+ self.avd = avd
+ self.avd_first = avd_first
+
+ self.deep_stem = deep_stem
+ self.stem_width = stem_width
+ self.layers = layers
+ self.final_drop = final_drop
+ self.dilated = dilated
+ self.dilation = dilation
+ self.bn_decay = bn_decay
+
+ self.rectify_avg = rectify_avg
+
+ def net(self, input, class_dim=1000):
+ if self.deep_stem:
+ x = self.conv_bn_layer(
+ x=input,
+ num_filters=self.stem_width,
+ filters_size=3,
+ stride=2,
+ groups=1,
+ act="relu",
+ name="conv1")
+ x = self.conv_bn_layer(
+ x=x,
+ num_filters=self.stem_width,
+ filters_size=3,
+ stride=1,
+ groups=1,
+ act="relu",
+ name="conv2")
+ x = self.conv_bn_layer(
+ x=x,
+ num_filters=self.stem_width * 2,
+ filters_size=3,
+ stride=1,
+ groups=1,
+ act="relu",
+ name="conv3")
+ else:
+ x = self.conv_bn_layer(
+ x=input,
+ num_filters=64,
+ filters_size=7,
+ stride=2,
+ act="relu",
+ name="conv1")
+
+ x = fluid.layers.pool2d(
+ input=x,
+ pool_size=3,
+ pool_type="max",
+ pool_stride=2,
+ pool_padding=1)
+
+ x = self.resnest_layer(
+ x=x,
+ planes=64,
+ blocks=self.layers[0],
+ is_first=False,
+ name="layer1")
+ x = self.resnest_layer(
+ x=x,
+ planes=128,
+ blocks=self.layers[1],
+ stride=2,
+ name="layer2")
+ if self.dilated or self.dilation == 4:
+ x = self.resnest_layer(
+ x=x,
+ planes=256,
+ blocks=self.layers[2],
+ stride=1,
+ dilation=2,
+ name="layer3")
+ x = self.resnest_layer(
+ x=x,
+ planes=512,
+ blocks=self.layers[3],
+ stride=1,
+ dilation=4,
+ name="layer4")
+ elif self.dilation == 2:
+ x = self.resnest_layer(
+ x=x,
+ planes=256,
+ blocks=self.layers[2],
+ stride=2,
+ dilation=1,
+ name="layer3")
+ x = self.resnest_layer(
+ x=x,
+ planes=512,
+ blocks=self.layers[3],
+ stride=1,
+ dilation=2,
+ name="layer4")
+ else:
+ x = self.resnest_layer(
+ x=x,
+ planes=256,
+ blocks=self.layers[2],
+ stride=2,
+ name="layer3")
+ x = self.resnest_layer(
+ x=x,
+ planes=512,
+ blocks=self.layers[3],
+ stride=2,
+ name="layer4")
+ x = fluid.layers.pool2d(
+ input=x, pool_type="avg", global_pooling=True)
+ x = fluid.layers.dropout(
+ x=x, dropout_prob=self.final_drop)
+ stdv = 1.0 / math.sqrt(x.shape[1] * 1.0)
+ x = fluid.layers.fc(
+ input=x,
+ size=class_dim,
+ param_attr=ParamAttr(
+ name="fc_weights",
+ initializer=fluid.initializer.Uniform(-stdv, stdv)),
+ bias_attr=ParamAttr(name="fc_offset"))
+ return x
+
+ def conv_bn_layer(self,
+ x,
+ num_filters,
+ filters_size,
+ stride=1,
+ groups=1,
+ act=None,
+ name=None):
+ x = fluid.layers.conv2d(
+ input=x,
+ num_filters=num_filters,
+ filter_size=filters_size,
+ stride=stride,
+ padding=(filters_size - 1) // 2,
+ groups=groups,
+ act=None,
+ param_attr=ParamAttr(
+ initializer=MSRA(), name=name + "_weight"),
+ bias_attr=False)
+ x = fluid.layers.batch_norm(
+ input=x,
+ act=act,
+ param_attr=ParamAttr(
+ name=name + "_scale",
+ regularizer=L2DecayRegularizer(
+ regularization_coeff=self.bn_decay)),
+ bias_attr=ParamAttr(
+ name=name + "_offset",
+ regularizer=L2DecayRegularizer(
+ regularization_coeff=self.bn_decay)),
+ moving_mean_name=name + "_mean",
+ moving_variance_name=name + "_variance")
+ return x
+
+ def rsoftmax(self, x, radix, cardinality):
+ batch, r, h, w = x.shape
+ if radix > 1:
+ x = fluid.layers.reshape(
+ x=x,
+ shape=[
+ 0, cardinality, radix, int(r * h * w / cardinality / radix)
+ ])
+ x = fluid.layers.transpose(x=x, perm=[0, 2, 1, 3])
+ x = fluid.layers.softmax(input=x, axis=1)
+ x = fluid.layers.reshape(x=x, shape=[0, r * h * w])
+ else:
+ x = fluid.layers.sigmoid(x=x)
+ return x
+
+ def splat_conv(self,
+ x,
+ in_channels,
+ channels,
+ kernel_size,
+ stride=1,
+ padding=0,
+ dilation=1,
+ groups=1,
+ bias=True,
+ radix=2,
+ reduction_factor=4,
+ rectify_avg=False,
+ name=None):
+ x = self.conv_bn_layer(
+ x=x,
+ num_filters=channels * radix,
+ filters_size=kernel_size,
+ stride=stride,
+ groups=groups * radix,
+ act="relu",
+ name=name + "_splat1")
+
+ batch, rchannel = x.shape[:2]
+ if radix > 1:
+ splited = fluid.layers.split(input=x, num_or_sections=radix, dim=1)
+ gap = fluid.layers.sum(x=splited)
+ else:
+ gap = x
+ gap = fluid.layers.pool2d(
+ input=gap, pool_type="avg", global_pooling=True)
+ inter_channels = int(max(in_channels * radix // reduction_factor, 32))
+ gap = self.conv_bn_layer(
+ x=gap,
+ num_filters=inter_channels,
+ filters_size=1,
+ groups=groups,
+ act="relu",
+ name=name + "_splat2")
+
+ atten = fluid.layers.conv2d(
+ input=gap,
+ num_filters=channels * radix,
+ filter_size=1,
+ stride=1,
+ padding=0,
+ groups=groups,
+ act=None,
+ param_attr=ParamAttr(
+ name=name + "_splat_weights", initializer=MSRA()),
+ bias_attr=False)
+ atten = self.rsoftmax(
+ x=atten, radix=radix, cardinality=groups)
+ atten = fluid.layers.reshape(x=atten, shape=[-1, atten.shape[1], 1, 1])
+
+ if radix > 1:
+ attens = fluid.layers.split(
+ input=atten, num_or_sections=radix, dim=1)
+ out = fluid.layers.sum([
+ fluid.layers.elementwise_mul(
+ x=att, y=split) for (att, split) in zip(attens, splited)
+ ])
+ else:
+ out = fluid.layers.elementwise_mul(atten, x)
+ return out
+
+ def bottleneck(self,
+ x,
+ inplanes,
+ planes,
+ stride=1,
+ radix=1,
+ cardinality=1,
+ bottleneck_width=64,
+ avd=False,
+ avd_first=False,
+ dilation=1,
+ is_first=False,
+ rectify_avg=False,
+ last_gamma=False,
+ name=None):
+
+ short = x
+
+ group_width = int(planes * (bottleneck_width / 64.)) * cardinality
+ x = self.conv_bn_layer(
+ x=x,
+ num_filters=group_width,
+ filters_size=1,
+ stride=1,
+ groups=1,
+ act="relu",
+ name=name + "_conv1")
+ if avd and avd_first and (stride > 1 or is_first):
+ x = fluid.layers.pool2d(
+ input=x,
+ pool_size=3,
+ pool_type="avg",
+ pool_stride=stride,
+ pool_padding=1)
+ if radix >= 1:
+ x = self.splat_conv(
+ x=x,
+ in_channels=group_width,
+ channels=group_width,
+ kernel_size=3,
+ stride=1,
+ padding=dilation,
+ dilation=dilation,
+ groups=cardinality,
+ bias=False,
+ radix=radix,
+ rectify_avg=rectify_avg,
+ name=name + "_splatconv")
+ else:
+ x = self.conv_bn_layer(
+ x=x,
+ num_filters=group_width,
+ filters_size=3,
+ stride=1,
+ padding=dilation,
+ dilation=dialtion,
+ groups=cardinality,
+ act="relu",
+ name=name + "_conv2")
+
+ if avd and avd_first == False and (stride > 1 or is_first):
+ x = fluid.layers.pool2d(
+ input=x,
+ pool_size=3,
+ pool_type="avg",
+ pool_stride=stride,
+ pool_padding=1)
+ x = self.conv_bn_layer(
+ x=x,
+ num_filters=planes * 4,
+ filters_size=1,
+ stride=1,
+ groups=1,
+ act=None,
+ name=name + "_conv3")
+
+ if stride != 1 or self.inplanes != planes * 4:
+ if self.avg_down:
+ if dilation == 1:
+ short = fluid.layers.pool2d(
+ input=short,
+ pool_size=stride,
+ pool_type="avg",
+ pool_stride=stride,
+ ceil_mode=True)
+ else:
+ short = fluid.layers.pool2d(
+ input=short,
+ pool_size=1,
+ pool_type="avg",
+ pool_stride=1,
+ ceil_mode=True)
+ short = fluid.layers.conv2d(
+ input=short,
+ num_filters=planes * 4,
+ filter_size=1,
+ stride=1,
+ padding=0,
+ groups=1,
+ act=None,
+ param_attr=ParamAttr(
+ name=name + "_weights", initializer=MSRA()),
+ bias_attr=False)
+ else:
+ short = fluid.layers.conv2d(
+ input=short,
+ num_filters=planes * 4,
+ filter_size=1,
+ stride=stride,
+ param_attr=ParamAttr(
+ name=name + "_shortcut_weights", initializer=MSRA()),
+ bias_attr=False)
+
+ short = fluid.layers.batch_norm(
+ input=short,
+ act=None,
+ param_attr=ParamAttr(
+ name=name + "_shortcut_scale",
+ regularizer=L2DecayRegularizer(
+ regularization_coeff=self.bn_decay)),
+ bias_attr=ParamAttr(
+ name=name + "_shortcut_offset",
+ regularizer=L2DecayRegularizer(
+ regularization_coeff=self.bn_decay)),
+ moving_mean_name=name + "_shortcut_mean",
+ moving_variance_name=name + "_shortcut_variance")
+
+ return fluid.layers.elementwise_add(x=short, y=x, act="relu")
+
+ def resnest_layer(self,
+ x,
+ planes,
+ blocks,
+ stride=1,
+ dilation=1,
+ is_first=True,
+ name=None):
+ if dilation == 1 or dilation == 2:
+ x = self.bottleneck(
+ x=x,
+ inplanes=self.inplanes,
+ planes=planes,
+ stride=stride,
+ radix=self.radix,
+ cardinality=self.cardinality,
+ bottleneck_width=self.bottleneck_width,
+ avd=self.avd,
+ avd_first=self.avd_first,
+ dilation=1,
+ is_first=is_first,
+ rectify_avg=self.rectify_avg,
+ last_gamma=self.last_gamma,
+ name=name + "_bottleneck_0")
+ elif dilation == 4:
+ x = self.bottleneck(
+ x=x,
+ inplanes=self.inplanes,
+ planes=planes,
+ stride=stride,
+ radix=self.radix,
+ cardinality=self.cardinality,
+ bottleneck_width=self.bottleneck_width,
+ avd=self.avd,
+ avd_first=self.avd_first,
+ dilation=2,
+ is_first=is_first,
+ rectify_avg=self.rectify_avg,
+ last_gamma=self.last_gamma,
+ name=name + "_bottleneck_0")
+ else:
+ raise RuntimeError("=>unknown dilation size")
+
+ self.inplanes = planes * 4
+ for i in range(1, blocks):
+ name = name + "_bottleneck_" + str(i)
+ x = self.bottleneck(
+ x=x,
+ inplanes=self.inplanes,
+ planes=planes,
+ radix=self.radix,
+ cardinality=self.cardinality,
+ bottleneck_width=self.bottleneck_width,
+ avd=self.avd,
+ avd_first=self.avd_first,
+ dilation=dilation,
+ rectify_avg=self.rectify_avg,
+ last_gamma=self.last_gamma,
+ name=name)
+ return x
+
+
+def ResNeSt50(**args):
+ model = ResNeSt(
+ layers=[3, 4, 6, 3],
+ radix=2,
+ groups=1,
+ bottleneck_width=64,
+ deep_stem=True,
+ stem_width=32,
+ avg_down=True,
+ avd=True,
+ avd_first=False,
+ final_drop=0.0,
+ **args)
+ return model
+
+
+def ResNeSt101(**args):
+ model = ResNeSt(
+ layers=[3, 4, 23, 3],
+ radix=2,
+ groups=1,
+ bottleneck_width=64,
+ deep_stem=True,
+ stem_width=64,
+ avg_down=True,
+ avd=True,
+ avd_first=False,
+ final_drop=0.0,
+ **args)
+ return model
+
+
+def ResNeSt200(**args):
+ model = ResNeSt(
+ layers=[3, 24, 36, 3],
+ radix=2,
+ groups=1,
+ bottleneck_width=64,
+ deep_stem=True,
+ stem_width=64,
+ avg_down=True,
+ avd=True,
+ avd_first=False,
+ final_drop=0.2,
+ **args)
+ return model
+
+
+def ResNeSt269(**args):
+ model = ResNeSt(
+ layers=[3, 30, 48, 8],
+ radix=2,
+ groups=1,
+ bottleneck_width=64,
+ deep_stem=True,
+ stem_width=64,
+ avg_down=True,
+ avd=True,
+ avd_first=False,
+ final_drop=0.2,
+ **args)
+ return model
+
+
+def ResNeSt50_fast_1s1x64d(**args):
+ model = ResNeSt(
+ layers=[3, 4, 6, 3],
+ radix=1,
+ groups=1,
+ bottleneck_width=64,
+ deep_stem=True,
+ stem_width=32,
+ avg_down=True,
+ avd=True,
+ avd_first=True,
+ final_drop=0.0,
+ **args)
+ return model
+
+
+def ResNeSt50_fast_2s1x64d(**args):
+ model = ResNeSt(
+ layers=[3, 4, 6, 3],
+ radix=2,
+ groups=1,
+ bottleneck_width=64,
+ deep_stem=True,
+ stem_width=32,
+ avg_down=True,
+ avd=True,
+ avd_first=True,
+ final_drop=0.0,
+ **args)
+ return model
+
+
+def ResNeSt50_fast_4s1x64d(**args):
+ model = ResNeSt(
+ layers=[3, 4, 6, 3],
+ radix=2,
+ groups=1,
+ bottleneck_width=64,
+ deep_stem=True,
+ stem_width=32,
+ avg_down=True,
+ avd=True,
+ avd_first=True,
+ final_drop=0.0,
+ **args)
+ return model
+
+
+def ResNeSt50_fast_1s2x40d(**args):
+ model = ResNeSt(
+ layers=[3, 4, 6, 3],
+ radix=1,
+ groups=2,
+ bottleneck_width=40,
+ deep_stem=True,
+ stem_width=32,
+ avg_down=True,
+ avd=True,
+ avd_first=True,
+ final_drop=0.0,
+ **args)
+ return model
+
+
+def ResNeSt50_fast_2s2x40d(**args):
+ model = ResNeSt(
+ layers=[3, 4, 6, 3],
+ radix=2,
+ groups=2,
+ bottleneck_width=40,
+ deep_stem=True,
+ stem_width=32,
+ avg_down=True,
+ avd=True,
+ avd_first=True,
+ final_drop=0.0,
+ **args)
+ return model
+
+
+def ResNeSt50_fast_4s2x40d(**args):
+ model = ResNeSt(
+ layers=[3, 4, 6, 3],
+ radix=4,
+ groups=2,
+ bottleneck_width=40,
+ deep_stem=True,
+ stem_width=32,
+ avg_down=True,
+ avd=True,
+ avd_first=True,
+ final_drop=0.0,
+ **args)
+ return model
+
+
+def ResNeSt50_fast_1s4x24d(**args):
+ model = ResNeSt(
+ layers=[3, 4, 6, 3],
+ radix=1,
+ groups=4,
+ bottleneck_width=24,
+ deep_stem=True,
+ stem_width=32,
+ avg_down=True,
+ avd=True,
+ avd_first=True,
+ final_drop=0.0,
+ **args)
+ return model