Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleClas
提交
e1943f9a
P
PaddleClas
项目概览
PaddlePaddle
/
PaddleClas
大约 1 年 前同步成功
通知
115
Star
4999
Fork
1114
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
19
列表
看板
标记
里程碑
合并请求
6
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleClas
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
19
Issue
19
列表
看板
标记
里程碑
合并请求
6
合并请求
6
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
e1943f9a
编写于
4月 28, 2022
作者:
F
flytocc
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add ConvNeXt code
上级
fea9522a
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
459 addition
and
0 deletion
+459
-0
ppcls/arch/backbone/__init__.py
ppcls/arch/backbone/__init__.py
+1
-0
ppcls/arch/backbone/model_zoo/convnext.py
ppcls/arch/backbone/model_zoo/convnext.py
+240
-0
ppcls/configs/ImageNet/ConvNeXt/convnext_tiny.yaml
ppcls/configs/ImageNet/ConvNeXt/convnext_tiny.yaml
+164
-0
test_tipc/config/ConvNeXt/ConvNeXt_tiny_train_infer_python.txt
...tipc/config/ConvNeXt/ConvNeXt_tiny_train_infer_python.txt
+54
-0
未找到文件。
ppcls/arch/backbone/__init__.py
浏览文件 @
e1943f9a
...
...
@@ -65,6 +65,7 @@ from ppcls.arch.backbone.model_zoo.pvt_v2 import PVT_V2_B0, PVT_V2_B1, PVT_V2_B2
from
ppcls.arch.backbone.model_zoo.mobilevit
import
MobileViT_XXS
,
MobileViT_XS
,
MobileViT_S
from
ppcls.arch.backbone.model_zoo.repvgg
import
RepVGG_A0
,
RepVGG_A1
,
RepVGG_A2
,
RepVGG_B0
,
RepVGG_B1
,
RepVGG_B2
,
RepVGG_B1g2
,
RepVGG_B1g4
,
RepVGG_B2g4
,
RepVGG_B3g4
from
ppcls.arch.backbone.model_zoo.van
import
VAN_tiny
from
ppcls.arch.backbone.model_zoo.convnext
import
ConvNext_tiny
from
ppcls.arch.backbone.variant_models.resnet_variant
import
ResNet50_last_stage_stride1
from
ppcls.arch.backbone.variant_models.vgg_variant
import
VGG19Sigmoid
from
ppcls.arch.backbone.variant_models.pp_lcnet_variant
import
PPLCNet_x2_5_Tanh
...
...
ppcls/arch/backbone/model_zoo/convnext.py
0 → 100644
浏览文件 @
e1943f9a
# MIT License
#
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
#
# Code was heavily based on https://github.com/facebookresearch/ConvNeXt
import
paddle
import
paddle.nn
as
nn
from
paddle.nn.initializer
import
TruncatedNormal
,
Constant
from
ppcls.utils.save_load
import
load_dygraph_pretrain
,
load_dygraph_pretrain_from_url
MODEL_URLS
=
{
"ConvNext_tiny"
:
""
,
# TODO
}
__all__
=
list
(
MODEL_URLS
.
keys
())
trunc_normal_
=
TruncatedNormal
(
std
=
.
02
)
zeros_
=
Constant
(
value
=
0.
)
ones_
=
Constant
(
value
=
1.
)
def
drop_path
(
x
,
drop_prob
=
0.
,
training
=
False
):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ...
"""
if
drop_prob
==
0.
or
not
training
:
return
x
keep_prob
=
paddle
.
to_tensor
(
1
-
drop_prob
)
shape
=
(
paddle
.
shape
(
x
)[
0
],
)
+
(
1
,
)
*
(
x
.
ndim
-
1
)
random_tensor
=
keep_prob
+
paddle
.
rand
(
shape
,
dtype
=
x
.
dtype
)
random_tensor
=
paddle
.
floor
(
random_tensor
)
# binarize
output
=
x
.
divide
(
keep_prob
)
*
random_tensor
return
output
class
DropPath
(
nn
.
Layer
):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
"""
def
__init__
(
self
,
drop_prob
=
None
):
super
(
DropPath
,
self
).
__init__
()
self
.
drop_prob
=
drop_prob
def
forward
(
self
,
x
):
return
drop_path
(
x
,
self
.
drop_prob
,
self
.
training
)
class
ChannelsFirstLayerNorm
(
nn
.
Layer
):
r
""" LayerNorm that supports two data formats: channels_last (default) or channels_first.
The ordering of the dimensions in the inputs. channels_last corresponds to inputs with
shape (batch_size, height, width, channels) while channels_first corresponds to inputs
with shape (batch_size, channels, height, width).
"""
def
__init__
(
self
,
normalized_shape
,
epsilon
=
1e-5
):
super
().
__init__
()
self
.
weight
=
self
.
create_parameter
(
shape
=
[
normalized_shape
],
default_initializer
=
ones_
)
self
.
bias
=
self
.
create_parameter
(
shape
=
[
normalized_shape
],
default_initializer
=
zeros_
)
self
.
epsilon
=
epsilon
self
.
normalized_shape
=
[
normalized_shape
]
def
forward
(
self
,
x
):
u
=
x
.
mean
(
1
,
keepdim
=
True
)
s
=
(
x
-
u
).
pow
(
2
).
mean
(
1
,
keepdim
=
True
)
x
=
(
x
-
u
)
/
paddle
.
sqrt
(
s
+
self
.
epsilon
)
x
=
self
.
weight
[:,
None
,
None
]
*
x
+
self
.
bias
[:,
None
,
None
]
return
x
class
Block
(
nn
.
Layer
):
r
""" ConvNeXt Block. There are two equivalent implementations:
(1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W)
(2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back
We use (2) as we find it slightly faster in PyTorch
Args:
dim (int): Number of input channels.
drop_path (float): Stochastic depth rate. Default: 0.0
layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
"""
def
__init__
(
self
,
dim
,
drop_path
=
0.
,
layer_scale_init_value
=
1e-6
):
super
().
__init__
()
self
.
dwconv
=
nn
.
Conv2D
(
dim
,
dim
,
7
,
padding
=
3
,
groups
=
dim
)
# depthwise conv
self
.
norm
=
nn
.
LayerNorm
(
dim
,
epsilon
=
1e-6
)
# pointwise/1x1 convs, implemented with linear layers
self
.
pwconv1
=
nn
.
Linear
(
dim
,
4
*
dim
)
self
.
act
=
nn
.
GELU
()
self
.
pwconv2
=
nn
.
Linear
(
4
*
dim
,
dim
)
if
layer_scale_init_value
>
0
:
self
.
gamma
=
self
.
create_parameter
(
shape
=
[
dim
],
default_initializer
=
Constant
(
value
=
layer_scale_init_value
))
else
:
self
.
gamma
=
None
self
.
drop_path
=
DropPath
(
drop_path
)
if
drop_path
>
0.
else
nn
.
Identity
()
def
forward
(
self
,
x
):
input
=
x
x
=
self
.
dwconv
(
x
)
x
=
x
.
transpose
([
0
,
2
,
3
,
1
])
# (N, C, H, W) -> (N, H, W, C)
x
=
self
.
norm
(
x
)
x
=
self
.
pwconv1
(
x
)
x
=
self
.
act
(
x
)
x
=
self
.
pwconv2
(
x
)
if
self
.
gamma
is
not
None
:
x
=
self
.
gamma
*
x
x
=
x
.
transpose
([
0
,
3
,
1
,
2
])
# (N, H, W, C) -> (N, C, H, W)
x
=
input
+
self
.
drop_path
(
x
)
return
x
class
ConvNeXt
(
nn
.
Layer
):
r
""" ConvNeXt
A PyTorch impl of : `A ConvNet for the 2020s` -
https://arxiv.org/pdf/2201.03545.pdf
Args:
in_chans (int): Number of input image channels. Default: 3
class_num (int): Number of classes for classification head. Default: 1000
depths (tuple(int)): Number of blocks at each stage. Default: [3, 3, 9, 3]
dims (int): Feature dimension at each stage. Default: [96, 192, 384, 768]
drop_path_rate (float): Stochastic depth rate. Default: 0.
layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
head_init_scale (float): Init scaling value for classifier weights and biases. Default: 1.
"""
def
__init__
(
self
,
in_chans
=
3
,
class_num
=
1000
,
depths
=
[
3
,
3
,
9
,
3
],
dims
=
[
96
,
192
,
384
,
768
],
drop_path_rate
=
0.
,
layer_scale_init_value
=
1e-6
,
head_init_scale
=
1.
):
super
().
__init__
()
# stem and 3 intermediate downsampling conv layers
self
.
downsample_layers
=
nn
.
LayerList
()
stem
=
nn
.
Sequential
(
nn
.
Conv2D
(
in_chans
,
dims
[
0
],
4
,
stride
=
4
),
ChannelsFirstLayerNorm
(
dims
[
0
],
epsilon
=
1e-6
))
self
.
downsample_layers
.
append
(
stem
)
for
i
in
range
(
3
):
downsample_layer
=
nn
.
Sequential
(
ChannelsFirstLayerNorm
(
dims
[
i
],
epsilon
=
1e-6
),
nn
.
Conv2D
(
dims
[
i
],
dims
[
i
+
1
],
2
,
stride
=
2
),
)
self
.
downsample_layers
.
append
(
downsample_layer
)
# 4 feature resolution stages, each consisting of multiple residual blocks
self
.
stages
=
nn
.
LayerList
()
dp_rates
=
[
x
.
item
()
for
x
in
paddle
.
linspace
(
0
,
drop_path_rate
,
sum
(
depths
))
]
cur
=
0
for
i
in
range
(
4
):
stage
=
nn
.
Sequential
(
*
[
Block
(
dim
=
dims
[
i
],
drop_path
=
dp_rates
[
cur
+
j
],
layer_scale_init_value
=
layer_scale_init_value
)
for
j
in
range
(
depths
[
i
])
])
self
.
stages
.
append
(
stage
)
cur
+=
depths
[
i
]
self
.
norm
=
nn
.
LayerNorm
(
dims
[
-
1
],
epsilon
=
1e-6
)
# final norm layer
self
.
head
=
nn
.
Linear
(
dims
[
-
1
],
class_num
)
self
.
apply
(
self
.
_init_weights
)
self
.
head
.
weight
.
set_value
(
self
.
head
.
weight
*
head_init_scale
)
self
.
head
.
bias
.
set_value
(
self
.
head
.
bias
*
head_init_scale
)
def
_init_weights
(
self
,
m
):
if
isinstance
(
m
,
(
nn
.
Conv2D
,
nn
.
Linear
)):
trunc_normal_
(
m
.
weight
)
if
m
.
bias
is
not
None
:
zeros_
(
m
.
bias
)
def
forward_features
(
self
,
x
):
for
i
in
range
(
4
):
x
=
self
.
downsample_layers
[
i
](
x
)
x
=
self
.
stages
[
i
](
x
)
# global average pooling, (N, C, H, W) -> (N, C)
return
self
.
norm
(
x
.
mean
([
-
2
,
-
1
]))
def
forward
(
self
,
x
):
x
=
self
.
forward_features
(
x
)
x
=
self
.
head
(
x
)
return
x
def
_load_pretrained
(
pretrained
,
model
,
model_url
,
use_ssld
=
False
):
if
pretrained
is
False
:
pass
elif
pretrained
is
True
:
load_dygraph_pretrain_from_url
(
model
,
model_url
,
use_ssld
=
use_ssld
)
elif
isinstance
(
pretrained
,
str
):
load_dygraph_pretrain
(
model
,
pretrained
)
else
:
raise
RuntimeError
(
"pretrained type is not available. Please use `string` or `boolean` type."
)
def
ConvNext_tiny
(
pretrained
=
False
,
use_ssld
=
False
,
**
kwargs
):
model
=
ConvNeXt
(
depths
=
[
3
,
3
,
9
,
3
],
dims
=
[
96
,
192
,
384
,
768
],
**
kwargs
)
_load_pretrained
(
pretrained
,
model
,
MODEL_URLS
[
"ConvNext_tiny"
],
use_ssld
=
use_ssld
)
return
model
ppcls/configs/ImageNet/ConvNeXt/convnext_tiny.yaml
0 → 100644
浏览文件 @
e1943f9a
# global configs
Global
:
checkpoints
:
null
pretrained_model
:
null
output_dir
:
./output/
device
:
gpu
save_interval
:
1
eval_during_train
:
True
eval_interval
:
1
epochs
:
300
print_batch_step
:
10
use_visualdl
:
False
# used for static mode and model export
image_shape
:
[
3
,
224
,
224
]
save_inference_dir
:
./inference
# training model under @to_static
to_static
:
False
# model architecture
Arch
:
name
:
ConvNext_tiny
class_num
:
1000
drop_path_rate
:
0.1
layer_scale_init_value
:
1e-6
head_init_scale
:
1.0
# loss function config for traing/eval process
Loss
:
Train
:
-
CELoss
:
weight
:
1.0
epsilon
:
0.1
Eval
:
-
CELoss
:
weight
:
1.0
Optimizer
:
name
:
AdamW
beta1
:
0.9
beta2
:
0.999
epsilon
:
1e-8
weight_decay
:
0.05
one_dim_param_no_weight_decay
:
True
lr
:
name
:
Cosine
learning_rate
:
5e-4
eta_min
:
1e-6
warmup_epoch
:
20
warmup_start_lr
:
0
# data loader for train and eval
DataLoader
:
Train
:
dataset
:
name
:
ImageNetDataset
image_root
:
./dataset/ILSVRC2012/
cls_label_path
:
./dataset/ILSVRC2012/train_list.txt
transform_ops
:
-
DecodeImage
:
to_rgb
:
True
channel_first
:
False
-
RandCropImage
:
size
:
224
interpolation
:
bicubic
backend
:
pil
-
RandFlipImage
:
flip_code
:
1
-
TimmAutoAugment
:
config_str
:
rand-m9-mstd0.5-inc1
interpolation
:
bicubic
img_size
:
224
-
NormalizeImage
:
scale
:
1.0/255.0
mean
:
[
0.485
,
0.456
,
0.406
]
std
:
[
0.229
,
0.224
,
0.225
]
order
:
'
'
-
RandomErasing
:
EPSILON
:
0.25
sl
:
0.02
sh
:
1.0/3.0
r1
:
0.3
attempt
:
10
use_log_aspect
:
True
mode
:
pixel
batch_transform_ops
:
-
OpSampler
:
MixupOperator
:
alpha
:
0.8
prob
:
0.5
CutmixOperator
:
alpha
:
1.0
prob
:
0.5
sampler
:
name
:
DistributedBatchSampler
batch_size
:
128
drop_last
:
True
shuffle
:
True
loader
:
num_workers
:
4
use_shared_memory
:
True
Eval
:
dataset
:
name
:
ImageNetDataset
image_root
:
./dataset/ILSVRC2012/
cls_label_path
:
./dataset/ILSVRC2012/val_list.txt
transform_ops
:
-
DecodeImage
:
to_rgb
:
True
channel_first
:
False
-
ResizeImage
:
resize_short
:
256
interpolation
:
bicubic
backend
:
pil
-
CropImage
:
size
:
224
-
NormalizeImage
:
scale
:
1.0/255.0
mean
:
[
0.485
,
0.456
,
0.406
]
std
:
[
0.229
,
0.224
,
0.225
]
order
:
'
'
sampler
:
name
:
DistributedBatchSampler
batch_size
:
128
drop_last
:
False
shuffle
:
False
loader
:
num_workers
:
4
use_shared_memory
:
True
Infer
:
infer_imgs
:
docs/images/inference_deployment/whl_demo.jpg
batch_size
:
10
transforms
:
-
DecodeImage
:
to_rgb
:
True
channel_first
:
False
-
ResizeImage
:
resize_short
:
256
interpolation
:
bicubic
backend
:
pil
-
CropImage
:
size
:
224
-
NormalizeImage
:
scale
:
1.0/255.0
mean
:
[
0.485
,
0.456
,
0.406
]
std
:
[
0.229
,
0.224
,
0.225
]
order
:
'
'
-
ToCHWImage
:
PostProcess
:
name
:
Topk
topk
:
5
class_id_map_file
:
ppcls/utils/imagenet1k_label_list.txt
Metric
:
Eval
:
-
TopkAcc
:
topk
:
[
1
,
5
]
test_tipc/config/ConvNeXt/ConvNeXt_tiny_train_infer_python.txt
0 → 100644
浏览文件 @
e1943f9a
===========================train_params===========================
model_name:ConvNeXt_tiny
python:python3.7
gpu_list:0|0,1
-o Global.device:gpu
-o Global.auto_cast:null
-o Global.epochs:lite_train_lite_infer=2|whole_train_whole_infer=120
-o Global.output_dir:./output/
-o DataLoader.Train.sampler.batch_size:8
-o Global.pretrained_model:null
train_model_name:latest
train_infer_img_dir:./dataset/ILSVRC2012/val
null:null
##
trainer:norm_train
norm_train:tools/train.py -c ppcls/configs/ImageNet/ConvNeXt/ConvNeXt_tiny.yaml -o Global.seed=1234 -o DataLoader.Train.sampler.shuffle=False -o DataLoader.Train.loader.num_workers=0 -o DataLoader.Train.loader.use_shared_memory=False
pact_train:null
fpgm_train:null
distill_train:null
null:null
null:null
##
===========================eval_params===========================
eval:tools/eval.py -c ppcls/configs/ImageNet/ConvNeXt/ConvNeXt_tiny.yaml
null:null
##
===========================infer_params==========================
-o Global.save_inference_dir:./inference
-o Global.pretrained_model:
norm_export:tools/export_model.py -c ppcls/configs/ImageNet/ConvNeXt/ConvNeXt_tiny.yaml
quant_export:null
fpgm_export:null
distill_export:null
kl_quant:null
export2:null
inference_dir:null
infer_model:../inference/
infer_export:True
infer_quant:Fasle
inference:python/predict_cls.py -c configs/inference_cls.yaml -o PreProcess.transform_ops.0.ResizeImage.resize_short=256 -o PreProcess.transform_ops.1.CropImage.size=224
-o Global.use_gpu:True|False
-o Global.enable_mkldnn:True|False
-o Global.cpu_num_threads:1|6
-o Global.batch_size:1|16
-o Global.use_tensorrt:True|False
-o Global.use_fp16:True|False
-o Global.inference_model_dir:../inference
-o Global.infer_imgs:../dataset/ILSVRC2012/val
-o Global.save_log_path:null
-o Global.benchmark:True
null:null
null:null
===========================infer_benchmark_params==========================
random_infer_input:[{float32,[3,224,224]}]
\ No newline at end of file
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录