Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
20e5f719
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
20e5f719
编写于
7月 18, 2020
作者:
Z
zhaoting
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add gpu resnext50
上级
ca6da675
变更
13
隐藏空白更改
内联
并排
Showing
13 changed file
with
217 addition
and
69 deletion
+217
-69
model_zoo/official/cv/resnext50/README.md
model_zoo/official/cv/resnext50/README.md
+10
-4
model_zoo/official/cv/resnext50/eval.py
model_zoo/official/cv/resnext50/eval.py
+21
-10
model_zoo/official/cv/resnext50/scripts/run_distribute_train_for_gpu.sh
...cial/cv/resnext50/scripts/run_distribute_train_for_gpu.sh
+30
-0
model_zoo/official/cv/resnext50/scripts/run_eval.sh
model_zoo/official/cv/resnext50/scripts/run_eval.sh
+7
-2
model_zoo/official/cv/resnext50/scripts/run_standalone_train.sh
...zoo/official/cv/resnext50/scripts/run_standalone_train.sh
+1
-1
model_zoo/official/cv/resnext50/scripts/run_standalone_train_for_gpu.sh
...cial/cv/resnext50/scripts/run_standalone_train_for_gpu.sh
+30
-0
model_zoo/official/cv/resnext50/src/backbone/resnet.py
model_zoo/official/cv/resnext50/src/backbone/resnet.py
+20
-14
model_zoo/official/cv/resnext50/src/config.py
model_zoo/official/cv/resnext50/src/config.py
+2
-1
model_zoo/official/cv/resnext50/src/dataset.py
model_zoo/official/cv/resnext50/src/dataset.py
+4
-2
model_zoo/official/cv/resnext50/src/image_classification.py
model_zoo/official/cv/resnext50/src/image_classification.py
+4
-4
model_zoo/official/cv/resnext50/src/utils/auto_mixed_precision.py
...o/official/cv/resnext50/src/utils/auto_mixed_precision.py
+56
-0
model_zoo/official/cv/resnext50/src/utils/cunstom_op.py
model_zoo/official/cv/resnext50/src/utils/cunstom_op.py
+1
-5
model_zoo/official/cv/resnext50/train.py
model_zoo/official/cv/resnext50/train.py
+31
-26
未找到文件。
model_zoo/official/cv/resnext50/README.md
浏览文件 @
20e5f719
...
...
@@ -90,10 +90,15 @@ sh run_standalone_train.sh DEVICE_ID DATA_PATH
#### Launch
```
bash
# distributed training example(8p)
# distributed training example(8p)
for Ascend
sh scripts/run_distribute_train.sh MINDSPORE_HCCL_CONFIG_PATH /dataset/train
# standalone training example
# standalone training example
for Ascend
sh scripts/run_standalone_train.sh 0 /dataset/train
# distributed training example(8p) for GPU
sh scripts/run_distribute_train_for_gpu.sh /dataset/train
# standalone training example for GPU
sh scripts/run_standalone_train_for_gpu.sh 0 /dataset/train
```
#### Result
...
...
@@ -106,14 +111,15 @@ You can find checkpoint file together with result in log.
```
# Evaluation
sh run_eval.sh DEVICE_ID DATA_PATH PRETRAINED_CKPT_PATH
sh run_eval.sh DEVICE_ID DATA_PATH PRETRAINED_CKPT_PATH
PLATFORM
```
PLATFORM is Ascend or GPU, default is Ascend.
#### Launch
```
bash
# Evaluation with checkpoint
sh scripts/run_eval.sh 0 /opt/npu/datasets/classification/val /resnext50_100.ckpt
sh scripts/run_eval.sh 0 /opt/npu/datasets/classification/val /resnext50_100.ckpt
Ascend
```
> checkpoint can be produced in training process.
...
...
model_zoo/official/cv/resnext50/eval.py
浏览文件 @
20e5f719
...
...
@@ -29,15 +29,11 @@ from mindspore.ops import functional as F
from
mindspore.common
import
dtype
as
mstype
from
src.utils.logging
import
get_logger
from
src.utils.auto_mixed_precision
import
auto_mixed_precision
from
src.image_classification
import
get_network
from
src.dataset
import
classification_dataset
from
src.config
import
config
devid
=
int
(
os
.
getenv
(
'DEVICE_ID'
))
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
enable_auto_mixed_precision
=
True
,
device_target
=
"Ascend"
,
save_graphs
=
False
,
device_id
=
devid
)
class
ParameterReduce
(
nn
.
Cell
):
"""ParameterReduce"""
...
...
@@ -56,6 +52,7 @@ class ParameterReduce(nn.Cell):
def
parse_args
(
cloud_args
=
None
):
"""parse_args"""
parser
=
argparse
.
ArgumentParser
(
'mindspore classification test'
)
parser
.
add_argument
(
'--platform'
,
type
=
str
,
default
=
'Ascend'
,
choices
=
(
'Ascend'
,
'GPU'
),
help
=
'run platform'
)
# dataset related
parser
.
add_argument
(
'--data_dir'
,
type
=
str
,
default
=
'/opt/npu/datasets/classification/val'
,
help
=
'eval data dir'
)
...
...
@@ -108,12 +105,25 @@ def merge_args(args, cloud_args):
def
test
(
cloud_args
=
None
):
"""test"""
args
=
parse_args
(
cloud_args
)
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
enable_auto_mixed_precision
=
True
,
device_target
=
args
.
platform
,
save_graphs
=
False
)
if
os
.
getenv
(
'DEVICE_ID'
,
"not_set"
).
isdigit
():
context
.
set_context
(
device_id
=
int
(
os
.
getenv
(
'DEVICE_ID'
)))
# init distributed
if
args
.
is_distributed
:
init
()
if
args
.
platform
==
"Ascend"
:
init
()
elif
args
.
platform
==
"GPU"
:
init
(
"nccl"
)
args
.
rank
=
get_rank
()
args
.
group_size
=
get_group_size
()
parallel_mode
=
ParallelMode
.
DATA_PARALLEL
context
.
set_auto_parallel_context
(
parallel_mode
=
parallel_mode
,
device_num
=
args
.
group_size
,
parameter_broadcast
=
True
,
mirror_mean
=
True
)
else
:
args
.
rank
=
0
args
.
group_size
=
1
args
.
outputs_dir
=
os
.
path
.
join
(
args
.
log_path
,
datetime
.
datetime
.
now
().
strftime
(
'%Y-%m-%d_time_%H_%M_%S'
))
...
...
@@ -140,7 +150,7 @@ def test(cloud_args=None):
max_epoch
=
1
,
rank
=
args
.
rank
,
group_size
=
args
.
group_size
,
mode
=
'eval'
)
eval_dataloader
=
de_dataset
.
create_tuple_iterator
()
network
=
get_network
(
args
.
backbone
,
args
.
num_classes
)
network
=
get_network
(
args
.
backbone
,
args
.
num_classes
,
platform
=
args
.
platform
)
if
network
is
None
:
raise
NotImplementedError
(
'not implement {}'
.
format
(
args
.
backbone
))
...
...
@@ -157,12 +167,13 @@ def test(cloud_args=None):
load_param_into_net
(
network
,
param_dict_new
)
args
.
logger
.
info
(
'load model {} success'
.
format
(
model
))
# must add
network
.
add_flags_recursive
(
fp16
=
True
)
img_tot
=
0
top1_correct
=
0
top5_correct
=
0
if
args
.
platform
==
"Ascend"
:
network
.
to_float
(
mstype
.
float16
)
else
:
auto_mixed_precision
(
network
)
network
.
set_train
(
False
)
t_end
=
time
.
time
()
it
=
0
...
...
model_zoo/official/cv/resnext50/scripts/run_distribute_train_for_gpu.sh
0 → 100644
浏览文件 @
20e5f719
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
DATA_DIR
=
$1
export
RANK_SIZE
=
8
PATH_CHECKPOINT
=
""
if
[
$#
==
2
]
then
PATH_CHECKPOINT
=
$2
fi
mpirun
--allow-run-as-root
-n
$RANK_SIZE
\
python train.py
\
--is_distribute
=
1
\
--platform
=
"GPU"
\
--pretrained
=
$PATH_CHECKPOINT
\
--data_dir
=
$DATA_DIR
>
log.txt 2>&1 &
model_zoo/official/cv/resnext50/scripts/run_eval.sh
浏览文件 @
20e5f719
...
...
@@ -14,11 +14,16 @@
# limitations under the License.
# ============================================================================
DEVICE_ID
=
$1
export
DEVICE_ID
=
$1
DATA_DIR
=
$2
PATH_CHECKPOINT
=
$3
PLATFORM
=
Ascend
if
[
$#
==
4
]
then
PLATFORM
=
$4
fi
python eval.py
\
--device_id
=
$DEVICE_ID
\
--pretrained
=
$PATH_CHECKPOINT
\
--platform
=
$PLATFORM
\
--data_dir
=
$DATA_DIR
>
log.txt 2>&1 &
model_zoo/official/cv/resnext50/scripts/run_standalone_train.sh
浏览文件 @
20e5f719
...
...
@@ -14,7 +14,7 @@
# limitations under the License.
# ============================================================================
DEVICE_ID
=
$1
export
DEVICE_ID
=
$1
DATA_DIR
=
$2
PATH_CHECKPOINT
=
""
if
[
$#
==
3
]
...
...
model_zoo/official/cv/resnext50/scripts/run_standalone_train_for_gpu.sh
0 → 100644
浏览文件 @
20e5f719
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
export
DEVICE_ID
=
$1
DATA_DIR
=
$2
PATH_CHECKPOINT
=
""
if
[
$#
==
3
]
then
PATH_CHECKPOINT
=
$3
fi
python train.py
\
--is_distribute
=
0
\
--pretrained
=
$PATH_CHECKPOINT
\
--platform
=
"GPU"
\
--data_dir
=
$DATA_DIR
>
log.txt 2>&1 &
model_zoo/official/cv/resnext50/src/backbone/resnet.py
浏览文件 @
20e5f719
...
...
@@ -87,7 +87,8 @@ class BasicBlock(nn.Cell):
"""
expansion
=
1
def
__init__
(
self
,
in_channels
,
out_channels
,
stride
=
1
,
down_sample
=
None
,
use_se
=
False
,
**
kwargs
):
def
__init__
(
self
,
in_channels
,
out_channels
,
stride
=
1
,
down_sample
=
None
,
use_se
=
False
,
platform
=
"Ascend"
,
**
kwargs
):
super
(
BasicBlock
,
self
).
__init__
()
self
.
conv1
=
conv3x3
(
in_channels
,
out_channels
,
stride
=
stride
)
self
.
bn1
=
nn
.
BatchNorm2d
(
out_channels
)
...
...
@@ -142,7 +143,7 @@ class Bottleneck(nn.Cell):
expansion
=
4
def
__init__
(
self
,
in_channels
,
out_channels
,
stride
=
1
,
down_sample
=
None
,
base_width
=
64
,
groups
=
1
,
use_se
=
False
,
**
kwargs
):
base_width
=
64
,
groups
=
1
,
use_se
=
False
,
platform
=
"Ascend"
,
**
kwargs
):
super
(
Bottleneck
,
self
).
__init__
()
width
=
int
(
out_channels
*
(
base_width
/
64.0
))
*
groups
...
...
@@ -153,7 +154,11 @@ class Bottleneck(nn.Cell):
self
.
conv3x3s
=
nn
.
CellList
()
self
.
conv2
=
GroupConv
(
width
,
width
,
3
,
stride
,
pad
=
1
,
groups
=
groups
)
if
platform
==
"GPU"
:
self
.
conv2
=
nn
.
Conv2d
(
width
,
width
,
3
,
stride
,
pad_mode
=
'pad'
,
padding
=
1
,
group
=
groups
)
else
:
self
.
conv2
=
GroupConv
(
width
,
width
,
3
,
stride
,
pad
=
1
,
groups
=
groups
)
self
.
op_split
=
Split
(
axis
=
1
,
output_num
=
self
.
groups
)
self
.
op_concat
=
Concat
(
axis
=
1
)
...
...
@@ -211,7 +216,7 @@ class ResNet(nn.Cell):
Examples:
>>>ResNet()
"""
def
__init__
(
self
,
block
,
layers
,
width_per_group
=
64
,
groups
=
1
,
use_se
=
False
):
def
__init__
(
self
,
block
,
layers
,
width_per_group
=
64
,
groups
=
1
,
use_se
=
False
,
platform
=
"Ascend"
):
super
(
ResNet
,
self
).
__init__
()
self
.
in_channels
=
64
self
.
groups
=
groups
...
...
@@ -222,10 +227,10 @@ class ResNet(nn.Cell):
self
.
relu
=
P
.
ReLU
()
self
.
maxpool
=
nn
.
MaxPool2d
(
kernel_size
=
3
,
stride
=
2
,
pad_mode
=
'same'
)
self
.
layer1
=
self
.
_make_layer
(
block
,
64
,
layers
[
0
],
use_se
=
use_se
)
self
.
layer2
=
self
.
_make_layer
(
block
,
128
,
layers
[
1
],
stride
=
2
,
use_se
=
use_se
)
self
.
layer3
=
self
.
_make_layer
(
block
,
256
,
layers
[
2
],
stride
=
2
,
use_se
=
use_se
)
self
.
layer4
=
self
.
_make_layer
(
block
,
512
,
layers
[
3
],
stride
=
2
,
use_se
=
use_se
)
self
.
layer1
=
self
.
_make_layer
(
block
,
64
,
layers
[
0
],
use_se
=
use_se
,
platform
=
platform
)
self
.
layer2
=
self
.
_make_layer
(
block
,
128
,
layers
[
1
],
stride
=
2
,
use_se
=
use_se
,
platform
=
platform
)
self
.
layer3
=
self
.
_make_layer
(
block
,
256
,
layers
[
2
],
stride
=
2
,
use_se
=
use_se
,
platform
=
platform
)
self
.
layer4
=
self
.
_make_layer
(
block
,
512
,
layers
[
3
],
stride
=
2
,
use_se
=
use_se
,
platform
=
platform
)
self
.
out_channels
=
512
*
block
.
expansion
self
.
cast
=
P
.
Cast
()
...
...
@@ -242,7 +247,7 @@ class ResNet(nn.Cell):
return
x
def
_make_layer
(
self
,
block
,
out_channels
,
blocks_num
,
stride
=
1
,
use_se
=
False
):
def
_make_layer
(
self
,
block
,
out_channels
,
blocks_num
,
stride
=
1
,
use_se
=
False
,
platform
=
"Ascend"
):
"""_make_layer"""
down_sample
=
None
if
stride
!=
1
or
self
.
in_channels
!=
out_channels
*
block
.
expansion
:
...
...
@@ -257,11 +262,12 @@ class ResNet(nn.Cell):
down_sample
=
down_sample
,
base_width
=
self
.
base_width
,
groups
=
self
.
groups
,
use_se
=
use_se
))
use_se
=
use_se
,
platform
=
platform
))
self
.
in_channels
=
out_channels
*
block
.
expansion
for
_
in
range
(
1
,
blocks_num
):
layers
.
append
(
block
(
self
.
in_channels
,
out_channels
,
base_width
=
self
.
base_width
,
groups
=
self
.
groups
,
use_se
=
use_se
))
layers
.
append
(
block
(
self
.
in_channels
,
out_channels
,
base_width
=
self
.
base_width
,
groups
=
self
.
groups
,
use_se
=
use_se
,
platform
=
platform
))
return
nn
.
SequentialCell
(
layers
)
...
...
@@ -269,5 +275,5 @@ class ResNet(nn.Cell):
return
self
.
out_channels
def
resnext50
():
return
ResNet
(
Bottleneck
,
[
3
,
4
,
6
,
3
],
width_per_group
=
4
,
groups
=
32
)
def
resnext50
(
platform
=
"Ascend"
):
return
ResNet
(
Bottleneck
,
[
3
,
4
,
6
,
3
],
width_per_group
=
4
,
groups
=
32
,
platform
=
platform
)
model_zoo/official/cv/resnext50/src/config.py
浏览文件 @
20e5f719
...
...
@@ -36,7 +36,8 @@ config = ed({
"label_smooth"
:
1
,
"label_smooth_factor"
:
0.1
,
"ckpt_interval"
:
1250
,
"ckpt_interval"
:
5
,
"ckpt_save_max"
:
5
,
"ckpt_path"
:
'outputs/'
,
"is_save_on_master"
:
1
,
...
...
model_zoo/official/cv/resnext50/src/dataset.py
浏览文件 @
20e5f719
...
...
@@ -143,8 +143,10 @@ def classification_dataset(data_dir, image_size, per_batch_size, max_epoch, rank
de_dataset
=
de
.
GeneratorDataset
(
dataset
,
[
"image"
,
"label"
],
sampler
=
sampler
)
de_dataset
.
set_dataset_size
(
len
(
sampler
))
de_dataset
=
de_dataset
.
map
(
input_columns
=
"image"
,
num_parallel_workers
=
8
,
operations
=
transform_img
)
de_dataset
=
de_dataset
.
map
(
input_columns
=
"label"
,
num_parallel_workers
=
8
,
operations
=
transform_label
)
de_dataset
=
de_dataset
.
map
(
input_columns
=
"image"
,
num_parallel_workers
=
num_parallel_workers
,
operations
=
transform_img
)
de_dataset
=
de_dataset
.
map
(
input_columns
=
"label"
,
num_parallel_workers
=
num_parallel_workers
,
operations
=
transform_label
)
columns_to_project
=
[
"image"
,
"label"
]
de_dataset
=
de_dataset
.
project
(
columns
=
columns_to_project
)
...
...
model_zoo/official/cv/resnext50/src/image_classification.py
浏览文件 @
20e5f719
...
...
@@ -50,9 +50,9 @@ class Resnet(ImageClassificationNetwork):
Returns:
Resnet.
"""
def
__init__
(
self
,
backbone_name
,
num_classes
):
def
__init__
(
self
,
backbone_name
,
num_classes
,
platform
=
"Ascend"
):
self
.
backbone_name
=
backbone_name
backbone
=
backbones
.
__dict__
[
self
.
backbone_name
]()
backbone
=
backbones
.
__dict__
[
self
.
backbone_name
](
platform
=
platform
)
out_channels
=
backbone
.
get_out_channels
()
head
=
heads
.
CommonHead
(
num_classes
=
num_classes
,
out_channels
=
out_channels
)
super
(
Resnet
,
self
).
__init__
(
backbone
,
head
)
...
...
@@ -79,7 +79,7 @@ class Resnet(ImageClassificationNetwork):
def
get_network
(
backbone_name
,
num_classes
):
def
get_network
(
backbone_name
,
num_classes
,
platform
=
"Ascend"
):
if
backbone_name
in
[
'resnext50'
]:
return
Resnet
(
backbone_name
,
num_classes
)
return
Resnet
(
backbone_name
,
num_classes
,
platform
)
return
None
model_zoo/official/cv/resnext50/src/utils/auto_mixed_precision.py
0 → 100644
浏览文件 @
20e5f719
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Auto mixed precision."""
import
mindspore.nn
as
nn
from
mindspore.ops
import
functional
as
F
from
mindspore._checkparam
import
Validator
as
validator
from
mindspore.common
import
dtype
as
mstype
class
OutputTo
(
nn
.
Cell
):
"Cast cell output back to float16 or float32"
def
__init__
(
self
,
op
,
to_type
=
mstype
.
float16
):
super
(
OutputTo
,
self
).
__init__
(
auto_prefix
=
False
)
self
.
_op
=
op
validator
.
check_type_name
(
'to_type'
,
to_type
,
[
mstype
.
float16
,
mstype
.
float32
],
None
)
self
.
to_type
=
to_type
def
construct
(
self
,
x
):
return
F
.
cast
(
self
.
_op
(
x
),
self
.
to_type
)
def
auto_mixed_precision
(
network
):
"""Do keep batchnorm fp32."""
cells
=
network
.
name_cells
()
change
=
False
network
.
to_float
(
mstype
.
float16
)
for
name
in
cells
:
subcell
=
cells
[
name
]
if
subcell
==
network
:
continue
elif
name
==
'fc'
:
network
.
insert_child_to_cell
(
name
,
OutputTo
(
subcell
,
mstype
.
float32
))
change
=
True
elif
name
==
'conv2'
:
subcell
.
to_float
(
mstype
.
float32
)
change
=
True
elif
isinstance
(
subcell
,
(
nn
.
BatchNorm2d
,
nn
.
BatchNorm1d
)):
network
.
insert_child_to_cell
(
name
,
OutputTo
(
subcell
.
to_float
(
mstype
.
float32
),
mstype
.
float16
))
change
=
True
else
:
auto_mixed_precision
(
subcell
)
if
isinstance
(
network
,
nn
.
SequentialCell
)
and
change
:
network
.
cell_list
=
list
(
network
.
cells
())
model_zoo/official/cv/resnext50/src/utils/cunstom_op.py
浏览文件 @
20e5f719
...
...
@@ -29,14 +29,10 @@ class GlobalAvgPooling(nn.Cell):
"""
def
__init__
(
self
):
super
(
GlobalAvgPooling
,
self
).
__init__
()
self
.
mean
=
P
.
ReduceMean
(
True
)
self
.
shape
=
P
.
Shape
()
self
.
reshape
=
P
.
Reshape
()
self
.
mean
=
P
.
ReduceMean
(
False
)
def
construct
(
self
,
x
):
x
=
self
.
mean
(
x
,
(
2
,
3
))
b
,
c
,
_
,
_
=
self
.
shape
(
x
)
x
=
self
.
reshape
(
x
,
(
b
,
c
))
return
x
...
...
model_zoo/official/cv/resnext50/train.py
浏览文件 @
20e5f719
...
...
@@ -36,11 +36,9 @@ from src.warmup_cosine_annealing_lr import warmup_cosine_annealing_lr
from
src.utils.logging
import
get_logger
from
src.utils.optimizers__init__
import
get_param_groups
from
src.image_classification
import
get_network
from
src.utils.auto_mixed_precision
import
auto_mixed_precision
from
src.config
import
config
devid
=
int
(
os
.
getenv
(
'DEVICE_ID'
))
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
enable_auto_mixed_precision
=
True
,
device_target
=
"Ascend"
,
save_graphs
=
False
,
device_id
=
devid
)
class
BuildTrainNetwork
(
nn
.
Cell
):
"""build training network"""
...
...
@@ -109,6 +107,7 @@ class ProgressMonitor(Callback):
def
parse_args
(
cloud_args
=
None
):
"""parameters"""
parser
=
argparse
.
ArgumentParser
(
'mindspore classification training'
)
parser
.
add_argument
(
'--platform'
,
type
=
str
,
default
=
'Ascend'
,
choices
=
(
'Ascend'
,
'GPU'
),
help
=
'run platform'
)
# dataset related
parser
.
add_argument
(
'--data_dir'
,
type
=
str
,
default
=
''
,
help
=
'train data dir'
)
...
...
@@ -141,6 +140,7 @@ def parse_args(cloud_args=None):
args
.
label_smooth
=
config
.
label_smooth
args
.
label_smooth_factor
=
config
.
label_smooth_factor
args
.
ckpt_interval
=
config
.
ckpt_interval
args
.
ckpt_save_max
=
config
.
ckpt_save_max
args
.
ckpt_path
=
config
.
ckpt_path
args
.
is_save_on_master
=
config
.
is_save_on_master
args
.
rank
=
config
.
rank
...
...
@@ -166,12 +166,25 @@ def merge_args(args, cloud_args):
def
train
(
cloud_args
=
None
):
"""training process"""
args
=
parse_args
(
cloud_args
)
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
enable_auto_mixed_precision
=
True
,
device_target
=
args
.
platform
,
save_graphs
=
False
)
if
os
.
getenv
(
'DEVICE_ID'
,
"not_set"
).
isdigit
():
context
.
set_context
(
device_id
=
int
(
os
.
getenv
(
'DEVICE_ID'
)))
# init distributed
if
args
.
is_distributed
:
init
()
if
args
.
platform
==
"Ascend"
:
init
()
else
:
init
(
"nccl"
)
args
.
rank
=
get_rank
()
args
.
group_size
=
get_group_size
()
parallel_mode
=
ParallelMode
.
DATA_PARALLEL
context
.
set_auto_parallel_context
(
parallel_mode
=
parallel_mode
,
device_num
=
args
.
group_size
,
parameter_broadcast
=
True
,
mirror_mean
=
True
)
else
:
args
.
rank
=
0
args
.
group_size
=
1
if
args
.
is_dynamic_loss_scale
==
1
:
args
.
loss_scale
=
1
# for dynamic loss scale can not set loss scale in momentum opt
...
...
@@ -192,7 +205,7 @@ def train(cloud_args=None):
# dataloader
de_dataset
=
classification_dataset
(
args
.
data_dir
,
args
.
image_size
,
args
.
per_batch_size
,
1
,
args
.
rank
,
args
.
group_size
)
args
.
rank
,
args
.
group_size
,
num_parallel_workers
=
8
)
de_dataset
.
map_model
=
4
# !!!important
args
.
steps_per_epoch
=
de_dataset
.
get_dataset_size
()
...
...
@@ -201,15 +214,9 @@ def train(cloud_args=None):
# network
args
.
logger
.
important_info
(
'start create network'
)
# get network and init
network
=
get_network
(
args
.
backbone
,
args
.
num_classes
)
network
=
get_network
(
args
.
backbone
,
args
.
num_classes
,
platform
=
args
.
platform
)
if
network
is
None
:
raise
NotImplementedError
(
'not implement {}'
.
format
(
args
.
backbone
))
network
.
add_flags_recursive
(
fp16
=
True
)
# loss
if
not
args
.
label_smooth
:
args
.
label_smooth_factor
=
0.0
criterion
=
CrossEntropy
(
smooth_factor
=
args
.
label_smooth_factor
,
num_classes
=
args
.
num_classes
)
# load pretrain model
if
os
.
path
.
isfile
(
args
.
pretrained
):
...
...
@@ -252,31 +259,29 @@ def train(cloud_args=None):
loss_scale
=
args
.
loss_scale
)
criterion
.
add_flags_recursive
(
fp32
=
True
)
# loss
if
not
args
.
label_smooth
:
args
.
label_smooth_factor
=
0.0
loss
=
CrossEntropy
(
smooth_factor
=
args
.
label_smooth_factor
,
num_classes
=
args
.
num_classes
)
# package training process, adjust lr + forward + backward + optimizer
train_net
=
BuildTrainNetwork
(
network
,
criterion
)
if
args
.
is_distributed
:
parallel_mode
=
ParallelMode
.
DATA_PARALLEL
else
:
parallel_mode
=
ParallelMode
.
STAND_ALONE
if
args
.
is_dynamic_loss_scale
==
1
:
loss_scale_manager
=
DynamicLossScaleManager
(
init_loss_scale
=
65536
,
scale_factor
=
2
,
scale_window
=
2000
)
else
:
loss_scale_manager
=
FixedLossScaleManager
(
args
.
loss_scale
,
drop_overflow_update
=
False
)
# Model api changed since TR5_branch 2020/03/09
context
.
set_auto_parallel_context
(
parallel_mode
=
parallel_mode
,
device_num
=
args
.
group_size
,
parameter_broadcast
=
True
,
mirror_mean
=
True
)
model
=
Model
(
train_net
,
optimizer
=
opt
,
metrics
=
None
,
loss_scale_manager
=
loss_scale_manager
)
if
args
.
platform
==
"Ascend"
:
model
=
Model
(
network
,
loss_fn
=
loss
,
optimizer
=
opt
,
loss_scale_manager
=
loss_scale_manager
,
metrics
=
{
'acc'
},
amp_level
=
"O3"
)
else
:
auto_mixed_precision
(
network
)
model
=
Model
(
network
,
loss_fn
=
loss
,
optimizer
=
opt
,
loss_scale_manager
=
loss_scale_manager
,
metrics
=
{
'acc'
})
# checkpoint save
progress_cb
=
ProgressMonitor
(
args
)
callbacks
=
[
progress_cb
,]
if
args
.
rank_save_ckpt_flag
:
ckpt_max_num
=
args
.
max_epoch
*
args
.
steps_per_epoch
//
args
.
ckpt_interval
ckpt_config
=
CheckpointConfig
(
save_checkpoint_steps
=
args
.
ckpt_interval
,
keep_checkpoint_max
=
ckpt_max_num
)
ckpt_config
=
CheckpointConfig
(
save_checkpoint_steps
=
args
.
ckpt_interval
*
args
.
steps_per_epoch
,
keep_checkpoint_max
=
args
.
ckpt_save_max
)
ckpt_cb
=
ModelCheckpoint
(
config
=
ckpt_config
,
directory
=
args
.
outputs_dir
,
prefix
=
'{}'
.
format
(
args
.
rank
))
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录