Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
b079e34e
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
b079e34e
编写于
8月 11, 2020
作者:
Q
qujianwei
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add se block for resnet50
上级
0ae5eeb3
变更
9
隐藏空白更改
内联
并排
Showing
9 changed file
with
286 addition
and
72 deletion
+286
-72
model_zoo/official/cv/resnet/eval.py
model_zoo/official/cv/resnet/eval.py
+5
-2
model_zoo/official/cv/resnet/scripts/run_distribute_train.sh
model_zoo/official/cv/resnet/scripts/run_distribute_train.sh
+8
-3
model_zoo/official/cv/resnet/scripts/run_eval.sh
model_zoo/official/cv/resnet/scripts/run_eval.sh
+8
-3
model_zoo/official/cv/resnet/scripts/run_standalone_train.sh
model_zoo/official/cv/resnet/scripts/run_standalone_train.sh
+8
-3
model_zoo/official/cv/resnet/src/config.py
model_zoo/official/cv/resnet/src/config.py
+25
-3
model_zoo/official/cv/resnet/src/dataset.py
model_zoo/official/cv/resnet/src/dataset.py
+53
-1
model_zoo/official/cv/resnet/src/lr_generator.py
model_zoo/official/cv/resnet/src/lr_generator.py
+12
-0
model_zoo/official/cv/resnet/src/resnet.py
model_zoo/official/cv/resnet/src/resnet.py
+156
-46
model_zoo/official/cv/resnet/train.py
model_zoo/official/cv/resnet/train.py
+11
-11
未找到文件。
model_zoo/official/cv/resnet/eval.py
浏览文件 @
b079e34e
...
...
@@ -38,17 +38,20 @@ de.config.set_seed(1)
if
args_opt
.
net
==
"resnet50"
:
from
src.resnet
import
resnet50
as
resnet
if
args_opt
.
dataset
==
"cifar10"
:
from
src.config
import
config1
as
config
from
src.dataset
import
create_dataset1
as
create_dataset
else
:
from
src.config
import
config2
as
config
from
src.dataset
import
create_dataset2
as
create_dataset
el
se
:
el
if
args_opt
.
net
==
"resnet101"
:
from
src.resnet
import
resnet101
as
resnet
from
src.config
import
config3
as
config
from
src.dataset
import
create_dataset3
as
create_dataset
else
:
from
src.resnet
import
se_resnet50
as
resnet
from
src.config
import
config4
as
config
from
src.dataset
import
create_dataset4
as
create_dataset
if
__name__
==
'__main__'
:
target
=
args_opt
.
device_target
...
...
model_zoo/official/cv/resnet/scripts/run_distribute_train.sh
浏览文件 @
b079e34e
...
...
@@ -16,13 +16,13 @@
if
[
$#
!=
4
]
&&
[
$#
!=
5
]
then
echo
"Usage: sh run_distribute_train.sh [resnet50|resnet101] [cifar10|imagenet2012] [RANK_TABLE_FILE] [DATASET_PATH] [PRETRAINED_CKPT_PATH](optional)"
echo
"Usage: sh run_distribute_train.sh [resnet50|resnet101
|se-resnet50
] [cifar10|imagenet2012] [RANK_TABLE_FILE] [DATASET_PATH] [PRETRAINED_CKPT_PATH](optional)"
exit
1
fi
if
[
$1
!=
"resnet50"
]
&&
[
$1
!=
"resnet101"
]
if
[
$1
!=
"resnet50"
]
&&
[
$1
!=
"resnet101"
]
&&
[
$1
!=
"se-resnet50"
]
then
echo
"error: the selected net is neither resnet50 nor resnet101"
echo
"error: the selected net is neither resnet50 nor resnet101
and se-resnet50
"
exit
1
fi
...
...
@@ -38,6 +38,11 @@ then
exit
1
fi
if
[
$1
==
"se-resnet50"
]
&&
[
$2
==
"cifar10"
]
then
echo
"error: evaluating se-resnet50 with cifar10 dataset is unsupported now!"
exit
1
fi
get_real_path
(){
if
[
"
${
1
:0:1
}
"
==
"/"
]
;
then
...
...
model_zoo/official/cv/resnet/scripts/run_eval.sh
浏览文件 @
b079e34e
...
...
@@ -16,13 +16,13 @@
if
[
$#
!=
4
]
then
echo
"Usage: sh run_eval.sh [resnet50|resnet101] [cifar10|imagenet2012] [DATASET_PATH] [CHECKPOINT_PATH]"
echo
"Usage: sh run_eval.sh [resnet50|resnet101
|se-resnet50
] [cifar10|imagenet2012] [DATASET_PATH] [CHECKPOINT_PATH]"
exit
1
fi
if
[
$1
!=
"resnet50"
]
&&
[
$1
!=
"resnet101"
]
if
[
$1
!=
"resnet50"
]
&&
[
$1
!=
"resnet101"
]
&&
[
$1
!=
"se-resnet50"
]
then
echo
"error: the selected net is neither resnet50 nor resnet101"
echo
"error: the selected net is neither resnet50 nor resnet101
nor se-resnet50
"
exit
1
fi
...
...
@@ -38,6 +38,11 @@ then
exit
1
fi
if
[
$1
==
"se-resnet50"
]
&&
[
$2
==
"cifar10"
]
then
echo
"error: evaluating se-resnet50 with cifar10 dataset is unsupported now!"
exit
1
fi
get_real_path
(){
if
[
"
${
1
:0:1
}
"
==
"/"
]
;
then
...
...
model_zoo/official/cv/resnet/scripts/run_standalone_train.sh
浏览文件 @
b079e34e
...
...
@@ -16,13 +16,13 @@
if
[
$#
!=
3
]
&&
[
$#
!=
4
]
then
echo
"Usage: sh run_standalone_train.sh [resnet50|resnet101] [cifar10|imagenet2012] [DATASET_PATH] [PRETRAINED_CKPT_PATH](optional)"
echo
"Usage: sh run_standalone_train.sh [resnet50|resnet101
|se-resnet50
] [cifar10|imagenet2012] [DATASET_PATH] [PRETRAINED_CKPT_PATH](optional)"
exit
1
fi
if
[
$1
!=
"resnet50"
]
&&
[
$1
!=
"resnet101"
]
if
[
$1
!=
"resnet50"
]
&&
[
$1
!=
"resnet101"
]
&&
[
$1
!=
"se-resnet50"
]
then
echo
"error: the selected net is neither resnet50 nor resnet101"
echo
"error: the selected net is neither resnet50 nor resnet101
and se-resnet50
"
exit
1
fi
...
...
@@ -38,6 +38,11 @@ then
exit
1
fi
if
[
$1
==
"se-resnet50"
]
&&
[
$2
==
"cifar10"
]
then
echo
"error: evaluating se-resnet50 with cifar10 dataset is unsupported now!"
exit
1
fi
get_real_path
(){
if
[
"
${
1
:0:1
}
"
==
"/"
]
;
then
...
...
model_zoo/official/cv/resnet/src/config.py
浏览文件 @
b079e34e
...
...
@@ -50,12 +50,12 @@ config2 = ed({
"keep_checkpoint_max"
:
10
,
"save_checkpoint_path"
:
"./"
,
"warmup_epochs"
:
0
,
"lr_decay_mode"
:
"
cosine
"
,
"lr_decay_mode"
:
"
linear
"
,
"use_label_smooth"
:
True
,
"label_smooth_factor"
:
0.1
,
"lr_init"
:
0
,
"lr_max"
:
0.1
"lr_max"
:
0.1
,
"lr_end"
:
0.0
})
# config for resent101, imagenet2012
...
...
@@ -77,3 +77,25 @@ config3 = ed({
"label_smooth_factor"
:
0.1
,
"lr"
:
0.1
})
# config for se-resnet50, imagenet2012
config4
=
ed
({
"class_num"
:
1001
,
"batch_size"
:
32
,
"loss_scale"
:
1024
,
"momentum"
:
0.9
,
"weight_decay"
:
1e-4
,
"epoch_size"
:
28
,
"pretrain_epoch_size"
:
1
,
"save_checkpoint"
:
True
,
"save_checkpoint_epochs"
:
4
,
"keep_checkpoint_max"
:
10
,
"save_checkpoint_path"
:
"./"
,
"warmup_epochs"
:
3
,
"lr_decay_mode"
:
"cosine"
,
"use_label_smooth"
:
True
,
"label_smooth_factor"
:
0.1
,
"lr_init"
:
0.0
,
"lr_max"
:
0.3
,
"lr_end"
:
0.0001
})
model_zoo/official/cv/resnet/src/dataset.py
浏览文件 @
b079e34e
...
...
@@ -22,7 +22,6 @@ import mindspore.dataset.transforms.vision.c_transforms as C
import
mindspore.dataset.transforms.c_transforms
as
C2
from
mindspore.communication.management
import
init
,
get_rank
,
get_group_size
def
create_dataset1
(
dataset_path
,
do_train
,
repeat_num
=
1
,
batch_size
=
32
,
target
=
"Ascend"
):
"""
create a train or evaluate cifar10 dataset for resnet50
...
...
@@ -191,6 +190,59 @@ def create_dataset3(dataset_path, do_train, repeat_num=1, batch_size=32, target=
return
ds
def
create_dataset4
(
dataset_path
,
do_train
,
repeat_num
=
1
,
batch_size
=
32
,
target
=
"Ascend"
):
"""
create a train or eval imagenet2012 dataset for se-resnet50
Args:
dataset_path(string): the path of dataset.
do_train(bool): whether dataset is used for train or eval.
repeat_num(int): the repeat times of dataset. Default: 1
batch_size(int): the batch size of dataset. Default: 32
target(str): the device target. Default: Ascend
Returns:
dataset
"""
if
target
==
"Ascend"
:
device_num
,
rank_id
=
_get_rank_info
()
if
device_num
==
1
:
ds
=
de
.
ImageFolderDatasetV2
(
dataset_path
,
num_parallel_workers
=
12
,
shuffle
=
True
)
else
:
ds
=
de
.
ImageFolderDatasetV2
(
dataset_path
,
num_parallel_workers
=
12
,
shuffle
=
True
,
num_shards
=
device_num
,
shard_id
=
rank_id
)
image_size
=
224
mean
=
[
123.68
,
116.78
,
103.94
]
std
=
[
1.0
,
1.0
,
1.0
]
# define map operations
if
do_train
:
trans
=
[
C
.
RandomCropDecodeResize
(
image_size
,
scale
=
(
0.08
,
1.0
),
ratio
=
(
0.75
,
1.333
)),
C
.
RandomHorizontalFlip
(
prob
=
0.5
),
C
.
Normalize
(
mean
=
mean
,
std
=
std
),
C
.
HWC2CHW
()
]
else
:
trans
=
[
C
.
Decode
(),
C
.
Resize
(
292
),
C
.
CenterCrop
(
256
),
C
.
Normalize
(
mean
=
mean
,
std
=
std
),
C
.
HWC2CHW
()
]
type_cast_op
=
C2
.
TypeCast
(
mstype
.
int32
)
ds
=
ds
.
map
(
input_columns
=
"image"
,
num_parallel_workers
=
12
,
operations
=
trans
)
ds
=
ds
.
map
(
input_columns
=
"label"
,
num_parallel_workers
=
12
,
operations
=
type_cast_op
)
# apply batch operations
ds
=
ds
.
batch
(
batch_size
,
drop_remainder
=
True
)
# apply dataset repeat operation
ds
=
ds
.
repeat
(
repeat_num
)
return
ds
def
_get_rank_info
():
"""
...
...
model_zoo/official/cv/resnet/src/lr_generator.py
浏览文件 @
b079e34e
...
...
@@ -62,6 +62,18 @@ def get_lr(lr_init, lr_end, lr_max, warmup_epochs, total_epochs, steps_per_epoch
if
lr
<
0.0
:
lr
=
0.0
lr_each_step
.
append
(
lr
)
elif
lr_decay_mode
==
'cosine'
:
decay_steps
=
total_steps
-
warmup_steps
for
i
in
range
(
total_steps
):
if
i
<
warmup_steps
:
lr_inc
=
(
float
(
lr_max
)
-
float
(
lr_init
))
/
float
(
warmup_steps
)
lr
=
float
(
lr_init
)
+
lr_inc
*
(
i
+
1
)
else
:
linear_decay
=
(
total_steps
-
i
)
/
decay_steps
cosine_decay
=
0.5
*
(
1
+
math
.
cos
(
math
.
pi
*
2
*
0.47
*
i
/
decay_steps
))
decayed
=
linear_decay
*
cosine_decay
+
0.00001
lr
=
lr_max
*
decayed
lr_each_step
.
append
(
lr
)
else
:
for
i
in
range
(
total_steps
):
if
i
<
warmup_steps
:
...
...
model_zoo/official/cv/resnet/src/resnet.py
浏览文件 @
b079e34e
...
...
@@ -15,32 +15,53 @@
"""ResNet."""
import
numpy
as
np
import
mindspore.nn
as
nn
import
mindspore.common.dtype
as
mstype
from
mindspore.ops
import
operations
as
P
from
mindspore.ops
import
functional
as
F
from
mindspore.common.tensor
import
Tensor
from
scipy.stats
import
truncnorm
def
_conv_variance_scaling_initializer
(
in_channel
,
out_channel
,
kernel_size
):
fan_in
=
in_channel
*
kernel_size
*
kernel_size
scale
=
1.0
scale
/=
max
(
1.
,
fan_in
)
stddev
=
(
scale
**
0.5
)
/
.
87962566103423978
mu
,
sigma
=
0
,
stddev
weight
=
truncnorm
(
-
2
,
2
,
loc
=
mu
,
scale
=
sigma
).
rvs
(
out_channel
*
in_channel
*
kernel_size
*
kernel_size
)
weight
=
np
.
reshape
(
weight
,
(
out_channel
,
in_channel
,
kernel_size
,
kernel_size
))
return
Tensor
(
weight
,
dtype
=
mstype
.
float32
)
def
_weight_variable
(
shape
,
factor
=
0.01
):
init_value
=
np
.
random
.
randn
(
*
shape
).
astype
(
np
.
float32
)
*
factor
return
Tensor
(
init_value
)
def
_conv3x3
(
in_channel
,
out_channel
,
stride
=
1
):
weight_shape
=
(
out_channel
,
in_channel
,
3
,
3
)
weight
=
_weight_variable
(
weight_shape
)
def
_conv3x3
(
in_channel
,
out_channel
,
stride
=
1
,
use_se
=
False
):
if
use_se
:
weight
=
_conv_variance_scaling_initializer
(
in_channel
,
out_channel
,
kernel_size
=
3
)
else
:
weight_shape
=
(
out_channel
,
in_channel
,
3
,
3
)
weight
=
_weight_variable
(
weight_shape
)
return
nn
.
Conv2d
(
in_channel
,
out_channel
,
kernel_size
=
3
,
stride
=
stride
,
padding
=
0
,
pad_mode
=
'same'
,
weight_init
=
weight
)
def
_conv1x1
(
in_channel
,
out_channel
,
stride
=
1
):
weight_shape
=
(
out_channel
,
in_channel
,
1
,
1
)
weight
=
_weight_variable
(
weight_shape
)
def
_conv1x1
(
in_channel
,
out_channel
,
stride
=
1
,
use_se
=
False
):
if
use_se
:
weight
=
_conv_variance_scaling_initializer
(
in_channel
,
out_channel
,
kernel_size
=
1
)
else
:
weight_shape
=
(
out_channel
,
in_channel
,
1
,
1
)
weight
=
_weight_variable
(
weight_shape
)
return
nn
.
Conv2d
(
in_channel
,
out_channel
,
kernel_size
=
1
,
stride
=
stride
,
padding
=
0
,
pad_mode
=
'same'
,
weight_init
=
weight
)
def
_conv7x7
(
in_channel
,
out_channel
,
stride
=
1
):
weight_shape
=
(
out_channel
,
in_channel
,
7
,
7
)
weight
=
_weight_variable
(
weight_shape
)
def
_conv7x7
(
in_channel
,
out_channel
,
stride
=
1
,
use_se
=
False
):
if
use_se
:
weight
=
_conv_variance_scaling_initializer
(
in_channel
,
out_channel
,
kernel_size
=
7
)
else
:
weight_shape
=
(
out_channel
,
in_channel
,
7
,
7
)
weight
=
_weight_variable
(
weight_shape
)
return
nn
.
Conv2d
(
in_channel
,
out_channel
,
kernel_size
=
7
,
stride
=
stride
,
padding
=
0
,
pad_mode
=
'same'
,
weight_init
=
weight
)
...
...
@@ -55,9 +76,13 @@ def _bn_last(channel):
gamma_init
=
0
,
beta_init
=
0
,
moving_mean_init
=
0
,
moving_var_init
=
1
)
def
_fc
(
in_channel
,
out_channel
):
weight_shape
=
(
out_channel
,
in_channel
)
weight
=
_weight_variable
(
weight_shape
)
def
_fc
(
in_channel
,
out_channel
,
use_se
=
False
):
if
use_se
:
weight
=
np
.
random
.
normal
(
loc
=
0
,
scale
=
0.01
,
size
=
out_channel
*
in_channel
)
weight
=
Tensor
(
np
.
reshape
(
weight
,
(
out_channel
,
in_channel
)),
dtype
=
mstype
.
float32
)
else
:
weight_shape
=
(
out_channel
,
in_channel
)
weight
=
_weight_variable
(
weight_shape
)
return
nn
.
Dense
(
in_channel
,
out_channel
,
has_bias
=
True
,
weight_init
=
weight
,
bias_init
=
0
)
...
...
@@ -69,6 +94,8 @@ class ResidualBlock(nn.Cell):
in_channel (int): Input channel.
out_channel (int): Output channel.
stride (int): Stride size for the first convolutional layer. Default: 1.
use_se (bool): enable SE-ResNet50 net. Default: False.
se_block(bool): use se block in SE-ResNet50 net. Default: False.
Returns:
Tensor, output tensor.
...
...
@@ -81,19 +108,30 @@ class ResidualBlock(nn.Cell):
def
__init__
(
self
,
in_channel
,
out_channel
,
stride
=
1
):
stride
=
1
,
use_se
=
False
,
se_block
=
False
):
super
(
ResidualBlock
,
self
).
__init__
()
self
.
stride
=
stride
self
.
use_se
=
use_se
self
.
se_block
=
se_block
channel
=
out_channel
//
self
.
expansion
self
.
conv1
=
_conv1x1
(
in_channel
,
channel
,
stride
=
1
)
self
.
conv1
=
_conv1x1
(
in_channel
,
channel
,
stride
=
1
,
use_se
=
self
.
use_se
)
self
.
bn1
=
_bn
(
channel
)
self
.
conv2
=
_conv3x3
(
channel
,
channel
,
stride
=
stride
)
self
.
bn2
=
_bn
(
channel
)
self
.
conv3
=
_conv1x1
(
channel
,
out_channel
,
stride
=
1
)
if
self
.
use_se
and
self
.
stride
!=
1
:
self
.
e2
=
nn
.
SequentialCell
([
_conv3x3
(
channel
,
channel
,
stride
=
1
,
use_se
=
True
),
_bn
(
channel
),
nn
.
ReLU
(),
nn
.
MaxPool2d
(
kernel_size
=
2
,
stride
=
2
,
pad_mode
=
'same'
)])
else
:
self
.
conv2
=
_conv3x3
(
channel
,
channel
,
stride
=
stride
,
use_se
=
self
.
use_se
)
self
.
bn2
=
_bn
(
channel
)
self
.
conv3
=
_conv1x1
(
channel
,
out_channel
,
stride
=
1
,
use_se
=
self
.
use_se
)
self
.
bn3
=
_bn_last
(
out_channel
)
if
self
.
se_block
:
self
.
se_global_pool
=
P
.
ReduceMean
(
keep_dims
=
False
)
self
.
se_dense_0
=
_fc
(
out_channel
,
int
(
out_channel
/
4
),
use_se
=
self
.
use_se
)
self
.
se_dense_1
=
_fc
(
int
(
out_channel
/
4
),
out_channel
,
use_se
=
self
.
use_se
)
self
.
se_sigmoid
=
nn
.
Sigmoid
()
self
.
se_mul
=
P
.
Mul
()
self
.
relu
=
nn
.
ReLU
()
self
.
down_sample
=
False
...
...
@@ -103,8 +141,17 @@ class ResidualBlock(nn.Cell):
self
.
down_sample_layer
=
None
if
self
.
down_sample
:
self
.
down_sample_layer
=
nn
.
SequentialCell
([
_conv1x1
(
in_channel
,
out_channel
,
stride
),
_bn
(
out_channel
)])
if
self
.
use_se
:
if
stride
==
1
:
self
.
down_sample_layer
=
nn
.
SequentialCell
([
_conv1x1
(
in_channel
,
out_channel
,
stride
,
use_se
=
self
.
use_se
),
_bn
(
out_channel
)])
else
:
self
.
down_sample_layer
=
nn
.
SequentialCell
([
nn
.
MaxPool2d
(
kernel_size
=
2
,
stride
=
2
,
pad_mode
=
'same'
),
_conv1x1
(
in_channel
,
out_channel
,
1
,
use_se
=
self
.
use_se
),
_bn
(
out_channel
)])
else
:
self
.
down_sample_layer
=
nn
.
SequentialCell
([
_conv1x1
(
in_channel
,
out_channel
,
stride
,
use_se
=
self
.
use_se
),
_bn
(
out_channel
)])
self
.
add
=
P
.
TensorAdd
()
def
construct
(
self
,
x
):
...
...
@@ -113,13 +160,23 @@ class ResidualBlock(nn.Cell):
out
=
self
.
conv1
(
x
)
out
=
self
.
bn1
(
out
)
out
=
self
.
relu
(
out
)
out
=
self
.
conv2
(
out
)
out
=
self
.
bn2
(
out
)
out
=
self
.
relu
(
out
)
if
self
.
use_se
and
self
.
stride
!=
1
:
out
=
self
.
e2
(
out
)
else
:
out
=
self
.
conv2
(
out
)
out
=
self
.
bn2
(
out
)
out
=
self
.
relu
(
out
)
out
=
self
.
conv3
(
out
)
out
=
self
.
bn3
(
out
)
if
self
.
se_block
:
out_se
=
out
out
=
self
.
se_global_pool
(
out
,
(
2
,
3
))
out
=
self
.
se_dense_0
(
out
)
out
=
self
.
relu
(
out
)
out
=
self
.
se_dense_1
(
out
)
out
=
self
.
se_sigmoid
(
out
)
out
=
F
.
reshape
(
out
,
F
.
shape
(
out
)
+
(
1
,
1
))
out
=
self
.
se_mul
(
out
,
out_se
)
if
self
.
down_sample
:
identity
=
self
.
down_sample_layer
(
identity
)
...
...
@@ -141,6 +198,8 @@ class ResNet(nn.Cell):
out_channels (list): Output channel in each layer.
strides (list): Stride size in each layer.
num_classes (int): The number of classes that the training images are belonging to.
use_se (bool): enable SE-ResNet50 net. Default: False.
se_block(bool): use se block in SE-ResNet50 net in layer 3 and layer 4. Default: False.
Returns:
Tensor, output tensor.
...
...
@@ -159,43 +218,60 @@ class ResNet(nn.Cell):
in_channels
,
out_channels
,
strides
,
num_classes
):
num_classes
,
use_se
=
False
):
super
(
ResNet
,
self
).
__init__
()
if
not
len
(
layer_nums
)
==
len
(
in_channels
)
==
len
(
out_channels
)
==
4
:
raise
ValueError
(
"the length of layer_num, in_channels, out_channels list must be 4!"
)
self
.
conv1
=
_conv7x7
(
3
,
64
,
stride
=
2
)
self
.
use_se
=
use_se
self
.
se_block
=
False
if
self
.
use_se
:
self
.
se_block
=
True
if
self
.
use_se
:
self
.
conv1_0
=
_conv3x3
(
3
,
32
,
stride
=
2
,
use_se
=
self
.
use_se
)
self
.
bn1_0
=
_bn
(
32
)
self
.
conv1_1
=
_conv3x3
(
32
,
32
,
stride
=
1
,
use_se
=
self
.
use_se
)
self
.
bn1_1
=
_bn
(
32
)
self
.
conv1_2
=
_conv3x3
(
32
,
64
,
stride
=
1
,
use_se
=
self
.
use_se
)
else
:
self
.
conv1
=
_conv7x7
(
3
,
64
,
stride
=
2
)
self
.
bn1
=
_bn
(
64
)
self
.
relu
=
P
.
ReLU
()
self
.
maxpool
=
nn
.
MaxPool2d
(
kernel_size
=
3
,
stride
=
2
,
pad_mode
=
"same"
)
self
.
layer1
=
self
.
_make_layer
(
block
,
layer_nums
[
0
],
in_channel
=
in_channels
[
0
],
out_channel
=
out_channels
[
0
],
stride
=
strides
[
0
])
stride
=
strides
[
0
],
use_se
=
self
.
use_se
)
self
.
layer2
=
self
.
_make_layer
(
block
,
layer_nums
[
1
],
in_channel
=
in_channels
[
1
],
out_channel
=
out_channels
[
1
],
stride
=
strides
[
1
])
stride
=
strides
[
1
],
use_se
=
self
.
use_se
)
self
.
layer3
=
self
.
_make_layer
(
block
,
layer_nums
[
2
],
in_channel
=
in_channels
[
2
],
out_channel
=
out_channels
[
2
],
stride
=
strides
[
2
])
stride
=
strides
[
2
],
use_se
=
self
.
use_se
,
se_block
=
self
.
se_block
)
self
.
layer4
=
self
.
_make_layer
(
block
,
layer_nums
[
3
],
in_channel
=
in_channels
[
3
],
out_channel
=
out_channels
[
3
],
stride
=
strides
[
3
])
stride
=
strides
[
3
],
use_se
=
self
.
use_se
,
se_block
=
self
.
se_block
)
self
.
mean
=
P
.
ReduceMean
(
keep_dims
=
True
)
self
.
flatten
=
nn
.
Flatten
()
self
.
end_point
=
_fc
(
out_channels
[
3
],
num_classes
)
self
.
end_point
=
_fc
(
out_channels
[
3
],
num_classes
,
use_se
=
self
.
use_se
)
def
_make_layer
(
self
,
block
,
layer_num
,
in_channel
,
out_channel
,
stride
):
def
_make_layer
(
self
,
block
,
layer_num
,
in_channel
,
out_channel
,
stride
,
use_se
=
False
,
se_block
=
False
):
"""
Make stage network of ResNet.
...
...
@@ -205,7 +281,7 @@ class ResNet(nn.Cell):
in_channel (int): Input channel.
out_channel (int): Output channel.
stride (int): Stride size for the first convolutional layer.
se_block(bool): use se block in SE-ResNet50 net. Default: False.
Returns:
SequentialCell, the output layer.
...
...
@@ -214,17 +290,31 @@ class ResNet(nn.Cell):
"""
layers
=
[]
resnet_block
=
block
(
in_channel
,
out_channel
,
stride
=
stride
)
resnet_block
=
block
(
in_channel
,
out_channel
,
stride
=
stride
,
use_se
=
use_se
)
layers
.
append
(
resnet_block
)
for
_
in
range
(
1
,
layer_num
):
resnet_block
=
block
(
out_channel
,
out_channel
,
stride
=
1
)
if
se_block
:
for
_
in
range
(
1
,
layer_num
-
1
):
resnet_block
=
block
(
out_channel
,
out_channel
,
stride
=
1
,
use_se
=
use_se
)
layers
.
append
(
resnet_block
)
resnet_block
=
block
(
out_channel
,
out_channel
,
stride
=
1
,
use_se
=
use_se
,
se_block
=
se_block
)
layers
.
append
(
resnet_block
)
else
:
for
_
in
range
(
1
,
layer_num
):
resnet_block
=
block
(
out_channel
,
out_channel
,
stride
=
1
,
use_se
=
use_se
)
layers
.
append
(
resnet_block
)
return
nn
.
SequentialCell
(
layers
)
def
construct
(
self
,
x
):
x
=
self
.
conv1
(
x
)
if
self
.
use_se
:
x
=
self
.
conv1_0
(
x
)
x
=
self
.
bn1_0
(
x
)
x
=
self
.
relu
(
x
)
x
=
self
.
conv1_1
(
x
)
x
=
self
.
bn1_1
(
x
)
x
=
self
.
relu
(
x
)
x
=
self
.
conv1_2
(
x
)
else
:
x
=
self
.
conv1
(
x
)
x
=
self
.
bn1
(
x
)
x
=
self
.
relu
(
x
)
c1
=
self
.
maxpool
(
x
)
...
...
@@ -261,6 +351,26 @@ def resnet50(class_num=10):
[
1
,
2
,
2
,
2
],
class_num
)
def
se_resnet50
(
class_num
=
1001
):
"""
Get SE-ResNet50 neural network.
Args:
class_num (int): Class number.
Returns:
Cell, cell instance of SE-ResNet50 neural network.
Examples:
>>> net = se-resnet50(1001)
"""
return
ResNet
(
ResidualBlock
,
[
3
,
4
,
6
,
3
],
[
64
,
256
,
512
,
1024
],
[
256
,
512
,
1024
,
2048
],
[
1
,
2
,
2
,
2
],
class_num
,
use_se
=
True
)
def
resnet101
(
class_num
=
1001
):
"""
...
...
model_zoo/official/cv/resnet/train.py
浏览文件 @
b079e34e
...
...
@@ -50,17 +50,21 @@ de.config.set_seed(1)
if
args_opt
.
net
==
"resnet50"
:
from
src.resnet
import
resnet50
as
resnet
if
args_opt
.
dataset
==
"cifar10"
:
from
src.config
import
config1
as
config
from
src.dataset
import
create_dataset1
as
create_dataset
else
:
from
src.config
import
config2
as
config
from
src.dataset
import
create_dataset2
as
create_dataset
el
se
:
el
if
args_opt
.
net
==
"resnet101"
:
from
src.resnet
import
resnet101
as
resnet
from
src.config
import
config3
as
config
from
src.dataset
import
create_dataset3
as
create_dataset
else
:
from
src.resnet
import
se_resnet50
as
resnet
from
src.config
import
config4
as
config
from
src.dataset
import
create_dataset4
as
create_dataset
if
__name__
==
'__main__'
:
target
=
args_opt
.
device_target
...
...
@@ -74,7 +78,7 @@ if __name__ == '__main__':
context
.
set_context
(
device_id
=
device_id
,
enable_auto_mixed_precision
=
True
)
context
.
set_auto_parallel_context
(
device_num
=
args_opt
.
device_num
,
parallel_mode
=
ParallelMode
.
DATA_PARALLEL
,
mirror_mean
=
True
)
if
args_opt
.
net
==
"resnet50"
:
if
args_opt
.
net
==
"resnet50"
or
args_opt
.
net
==
"se-resnet50"
:
auto_parallel_context
().
set_all_reduce_fusion_split_indices
([
85
,
160
])
else
:
auto_parallel_context
().
set_all_reduce_fusion_split_indices
([
180
,
313
])
...
...
@@ -112,14 +116,10 @@ if __name__ == '__main__':
cell
.
weight
.
dtype
)
# init lr
if
args_opt
.
net
==
"resnet50"
:
if
args_opt
.
dataset
==
"cifar10"
:
lr
=
get_lr
(
lr_init
=
config
.
lr_init
,
lr_end
=
config
.
lr_end
,
lr_max
=
config
.
lr_max
,
warmup_epochs
=
config
.
warmup_epochs
,
total_epochs
=
config
.
epoch_size
,
steps_per_epoch
=
step_size
,
lr_decay_mode
=
'poly'
)
else
:
lr
=
get_lr
(
lr_init
=
config
.
lr_init
,
lr_end
=
0.0
,
lr_max
=
config
.
lr_max
,
warmup_epochs
=
config
.
warmup_epochs
,
total_epochs
=
config
.
epoch_size
,
steps_per_epoch
=
step_size
,
lr_decay_mode
=
'cosine'
)
if
args_opt
.
net
==
"resnet50"
or
args_opt
.
net
==
"se-resnet50"
:
lr
=
get_lr
(
lr_init
=
config
.
lr_init
,
lr_end
=
config
.
lr_end
,
lr_max
=
config
.
lr_max
,
warmup_epochs
=
config
.
warmup_epochs
,
total_epochs
=
config
.
epoch_size
,
steps_per_epoch
=
step_size
,
lr_decay_mode
=
config
.
lr_decay_mode
)
else
:
lr
=
warmup_cosine_annealing_lr
(
config
.
lr
,
step_size
,
config
.
warmup_epochs
,
config
.
epoch_size
,
config
.
pretrain_epoch_size
*
step_size
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录