Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
正统之独孤求败
mindspore
提交
eb9e2ed7
M
mindspore
项目概览
正统之独孤求败
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
eb9e2ed7
编写于
8月 14, 2020
作者:
D
dessyang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add an example of training NASNet in MindSpore
fix pylint
上级
6ce06b77
变更
12
展开全部
显示空白变更内容
内联
并排
Showing
12 changed file
with
1519 addition
and
0 deletion
+1519
-0
model_zoo/official/cv/nasnet/README.md
model_zoo/official/cv/nasnet/README.md
+111
-0
model_zoo/official/cv/nasnet/eval.py
model_zoo/official/cv/nasnet/eval.py
+53
-0
model_zoo/official/cv/nasnet/export.py
model_zoo/official/cv/nasnet/export.py
+39
-0
model_zoo/official/cv/nasnet/scripts/run_distribute_train_for_gpu.sh
...fficial/cv/nasnet/scripts/run_distribute_train_for_gpu.sh
+17
-0
model_zoo/official/cv/nasnet/scripts/run_eval_for_gpu.sh
model_zoo/official/cv/nasnet/scripts/run_eval_for_gpu.sh
+19
-0
model_zoo/official/cv/nasnet/scripts/run_standalone_train_for_gpu.sh
...fficial/cv/nasnet/scripts/run_standalone_train_for_gpu.sh
+19
-0
model_zoo/official/cv/nasnet/src/config.py
model_zoo/official/cv/nasnet/src/config.py
+56
-0
model_zoo/official/cv/nasnet/src/dataset.py
model_zoo/official/cv/nasnet/src/dataset.py
+70
-0
model_zoo/official/cv/nasnet/src/loss.py
model_zoo/official/cv/nasnet/src/loss.py
+38
-0
model_zoo/official/cv/nasnet/src/lr_generator.py
model_zoo/official/cv/nasnet/src/lr_generator.py
+43
-0
model_zoo/official/cv/nasnet/src/nasnet_a_mobile.py
model_zoo/official/cv/nasnet/src/nasnet_a_mobile.py
+937
-0
model_zoo/official/cv/nasnet/train.py
model_zoo/official/cv/nasnet/train.py
+117
-0
未找到文件。
model_zoo/official/cv/nasnet/README.md
0 → 100755
浏览文件 @
eb9e2ed7
# NASNet Example
## Description
This is an example of training NASNet-A-Mobile in MindSpore.
## Requirements
-
Install
[
Mindspore
](
http://www.mindspore.cn/install/en
)
.
-
Download the dataset.
## Structure
```
shell
.
└─nasnet
├─README.md
├─scripts
├─run_standalone_train_for_gpu.sh
# launch standalone training with gpu platform(1p)
├─run_distribute_train_for_gpu.sh
# launch distributed training with gpu platform(8p)
└─run_eval_for_gpu.sh
# launch evaluating with gpu platform
├─src
├─config.py
# parameter configuration
├─dataset.py
# data preprocessing
├─loss.py
# Customized CrossEntropy loss function
├─lr_generator.py
# learning rate generator
├─nasnet_a_mobile.py
# network definition
├─eval.py
# eval net
├─export.py
# convert checkpoint
└─train.py
# train net
```
## Parameter Configuration
Parameters for both training and evaluating can be set in config.py
```
'random_seed': 1, # fix random seed
'rank': 0, # local rank of distributed
'group_size': 1, # world size of distributed
'work_nums': 8, # number of workers to read the data
'epoch_size': 250, # total epoch numbers
'keep_checkpoint_max': 100, # max numbers to keep checkpoints
'ckpt_path': './checkpoint/', # save checkpoint path
'is_save_on_master': 1 # save checkpoint on rank0, distributed parameters
'batch_size': 32, # input batchsize
'num_classes': 1000, # dataset class numbers
'label_smooth_factor': 0.1, # label smoothing factor
'aux_factor': 0.4, # loss factor of aux logit
'lr_init': 0.04, # initiate learning rate
'lr_decay_rate': 0.97, # decay rate of learning rate
'num_epoch_per_decay': 2.4, # decay epoch number
'weight_decay': 0.00004, # weight decay
'momentum': 0.9, # momentum
'opt_eps': 1.0, # epsilon
'rmsprop_decay': 0.9, # rmsprop decay
'loss_scale': 1, # loss scale
```
## Running the example
### Train
#### Usage
```
# distribute training example(8p)
sh run_distribute_train_for_gpu.sh DATA_DIR
# standalone training
sh run_standalone_train_for_gpu.sh DEVICE_ID DATA_DIR
```
#### Launch
```
bash
# distributed training example(8p) for GPU
sh scripts/run_distribute_train_for_gpu.sh /dataset/train
# standalone training example for GPU
sh scripts/run_standalone_train_for_gpu.sh 0 /dataset/train
```
#### Result
You can find checkpoint file together with result in log.
### Evaluation
#### Usage
```
# Evaluation
sh run_eval_for_gpu.sh DEVICE_ID DATA_DIR PATH_CHECKPOINT
```
#### Launch
```
bash
# Evaluation with checkpoint
sh scripts/run_eval_for_gpu.sh 0 /dataset/val ./checkpoint/nasnet-a-mobile-rank0-248_10009.ckpt
```
> checkpoint can be produced in training process.
#### Result
Evaluation result will be stored in the scripts path. Under this, you can find result like the followings in log.
model_zoo/official/cv/nasnet/eval.py
0 → 100755
浏览文件 @
eb9e2ed7
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""evaluate imagenet"""
import
argparse
import
os
import
mindspore.nn
as
nn
from
mindspore
import
context
from
mindspore.train.model
import
Model
from
mindspore.train.serialization
import
load_checkpoint
,
load_param_into_net
from
src.config
import
nasnet_a_mobile_config_gpu
as
cfg
from
src.dataset
import
create_dataset
from
src.nasnet_a_mobile
import
NASNetAMobile
from
src.loss
import
CrossEntropy_Val
if
__name__
==
'__main__'
:
parser
=
argparse
.
ArgumentParser
(
description
=
'image classification evaluation'
)
parser
.
add_argument
(
'--checkpoint'
,
type
=
str
,
default
=
''
,
help
=
'checkpoint of nasnet_a_mobile (Default: None)'
)
parser
.
add_argument
(
'--dataset_path'
,
type
=
str
,
default
=
''
,
help
=
'Dataset path'
)
parser
.
add_argument
(
'--platform'
,
type
=
str
,
default
=
'GPU'
,
choices
=
(
'Ascend'
,
'GPU'
),
help
=
'run platform'
)
args_opt
=
parser
.
parse_args
()
if
args_opt
.
platform
==
'Ascend'
:
device_id
=
int
(
os
.
getenv
(
'DEVICE_ID'
))
context
.
set_context
(
device_id
=
device_id
)
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
device_target
=
args_opt
.
platform
)
net
=
NASNetAMobile
(
num_classes
=
cfg
.
num_classes
,
is_training
=
False
)
ckpt
=
load_checkpoint
(
args_opt
.
checkpoint
)
load_param_into_net
(
net
,
ckpt
)
net
.
set_train
(
False
)
dataset
=
create_dataset
(
args_opt
.
dataset_path
,
cfg
,
False
)
loss
=
CrossEntropy_Val
(
smooth_factor
=
0.1
,
num_classes
=
cfg
.
num_classes
)
eval_metrics
=
{
'Loss'
:
nn
.
Loss
(),
'Top1-Acc'
:
nn
.
Top1CategoricalAccuracy
(),
'Top5-Acc'
:
nn
.
Top5CategoricalAccuracy
()}
model
=
Model
(
net
,
loss
,
optimizer
=
None
,
metrics
=
eval_metrics
)
metrics
=
model
.
eval
(
dataset
)
print
(
"metric: "
,
metrics
)
model_zoo/official/cv/nasnet/export.py
0 → 100755
浏览文件 @
eb9e2ed7
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
##############export checkpoint file into geir and onnx models#################
"""
import
argparse
import
numpy
as
np
import
mindspore
as
ms
from
mindspore
import
Tensor
from
mindspore.train.serialization
import
load_checkpoint
,
load_param_into_net
,
export
from
src.config
import
nasnet_a_mobile_config_gpu
as
cfg
from
src.nasnet_a_mobile
import
NASNetAMobile
if
__name__
==
'__main__'
:
parser
=
argparse
.
ArgumentParser
(
description
=
'checkpoint export'
)
parser
.
add_argument
(
'--checkpoint'
,
type
=
str
,
default
=
''
,
help
=
'checkpoint of nasnet_a_mobile (Default: None)'
)
args_opt
=
parser
.
parse_args
()
net
=
NASNetAMobile
(
num_classes
=
cfg
.
num_classes
,
is_training
=
False
)
param_dict
=
load_checkpoint
(
args_opt
.
checkpoint
)
load_param_into_net
(
net
,
param_dict
)
input_arr
=
Tensor
(
np
.
random
.
uniform
(
0.0
,
1.0
,
size
=
[
1
,
3
,
cfg
.
image_size
,
cfg
.
image_size
]),
ms
.
float32
)
export
(
net
,
input_arr
,
file_name
=
cfg
.
onnx_filename
,
file_format
=
"ONNX"
)
export
(
net
,
input_arr
,
file_name
=
cfg
.
geir_filename
,
file_format
=
"GEIR"
)
model_zoo/official/cv/nasnet/scripts/run_distribute_train_for_gpu.sh
0 → 100755
浏览文件 @
eb9e2ed7
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
DATA_DIR
=
$1
mpirun
--allow-run-as-root
-n
8 python ./train.py
--is_distributed
--platform
'GPU'
--dataset_path
$DATA_DIR
>
train.log 2>&1 &
model_zoo/official/cv/nasnet/scripts/run_eval_for_gpu.sh
0 → 100755
浏览文件 @
eb9e2ed7
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
DEVICE_ID
=
$1
DATA_DIR
=
$2
PATH_CHECKPOINT
=
$3
CUDA_VISIBLE_DEVICES
=
$DEVICE_ID
python ./eval.py
--platform
'GPU'
--dataset_path
$DATA_DIR
--checkpoint
$PATH_CHECKPOINT
>
eval.log 2>&1 &
model_zoo/official/cv/nasnet/scripts/run_standalone_train_for_gpu.sh
0 → 100755
浏览文件 @
eb9e2ed7
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
DEVICE_ID
=
$1
DATA_DIR
=
$2
CUDA_VISIBLE_DEVICES
=
$DEVICE_ID
python ./train.py
--platform
'GPU'
--dataset_path
$DATA_DIR
>
train.log 2>&1 &
model_zoo/official/cv/nasnet/src/config.py
0 → 100755
浏览文件 @
eb9e2ed7
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
network config setting, will be used in main.py
"""
from
easydict
import
EasyDict
as
edict
nasnet_a_mobile_config_gpu
=
edict
({
'random_seed'
:
1
,
'rank'
:
0
,
'group_size'
:
1
,
'work_nums'
:
8
,
'epoch_size'
:
312
,
'keep_checkpoint_max'
:
100
,
'ckpt_path'
:
'./nasnet_a_mobile_checkpoint/'
,
'is_save_on_master'
:
0
,
### Dataset Config
'batch_size'
:
32
,
'image_size'
:
224
,
'num_classes'
:
1000
,
### Loss Config
'label_smooth_factor'
:
0.1
,
'aux_factor'
:
0.4
,
### Learning Rate Config
# 'lr_decay_method': 'exponential',
'lr_init'
:
0.04
,
'lr_decay_rate'
:
0.97
,
'num_epoch_per_decay'
:
2.4
,
### Optimization Config
'weight_decay'
:
0.00004
,
'momentum'
:
0.9
,
'opt_eps'
:
1.0
,
'rmsprop_decay'
:
0.9
,
"loss_scale"
:
1
,
### onnx&air Config
'onnx_filename'
:
'nasnet_a_mobile.onnx'
,
'air_filename'
:
'nasnet_a_mobile.air'
})
model_zoo/official/cv/nasnet/src/dataset.py
0 → 100755
浏览文件 @
eb9e2ed7
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
Data operations, will be used in train.py and eval.py
"""
import
mindspore.common.dtype
as
mstype
import
mindspore.dataset.engine
as
de
import
mindspore.dataset.transforms.c_transforms
as
C2
import
mindspore.dataset.transforms.vision.c_transforms
as
C
def
create_dataset
(
dataset_path
,
config
,
do_train
,
repeat_num
=
1
):
"""
create a train or eval dataset
Args:
dataset_path(string): the path of dataset.
config(dict): config of dataset.
do_train(bool): whether dataset is used for train or eval.
repeat_num(int): the repeat times of dataset. Default: 1.
Returns:
dataset
"""
rank
=
config
.
rank
group_size
=
config
.
group_size
if
group_size
==
1
:
ds
=
de
.
ImageFolderDatasetV2
(
dataset_path
,
num_parallel_workers
=
config
.
work_nums
,
shuffle
=
True
)
else
:
ds
=
de
.
ImageFolderDatasetV2
(
dataset_path
,
num_parallel_workers
=
config
.
work_nums
,
shuffle
=
True
,
num_shards
=
group_size
,
shard_id
=
rank
)
# define map operations
if
do_train
:
trans
=
[
C
.
RandomCropDecodeResize
(
config
.
image_size
),
C
.
RandomHorizontalFlip
(
prob
=
0.5
),
C
.
RandomColorAdjust
(
brightness
=
0.4
,
saturation
=
0.5
)
# fast mode
#C.RandomColorAdjust(brightness=0.4, contrast=0.5, saturation=0.5, hue=0.2)
]
else
:
trans
=
[
C
.
Decode
(),
C
.
Resize
(
int
(
config
.
image_size
/
0.875
)),
C
.
CenterCrop
(
config
.
image_size
)
]
trans
+=
[
C
.
Rescale
(
1.0
/
255.0
,
0.0
),
C
.
Normalize
(
mean
=
[
0.5
,
0.5
,
0.5
],
std
=
[
0.5
,
0.5
,
0.5
]),
C
.
HWC2CHW
()
]
type_cast_op
=
C2
.
TypeCast
(
mstype
.
int32
)
ds
=
ds
.
map
(
input_columns
=
"image"
,
operations
=
trans
,
num_parallel_workers
=
config
.
work_nums
)
ds
=
ds
.
map
(
input_columns
=
"label"
,
operations
=
type_cast_op
,
num_parallel_workers
=
config
.
work_nums
)
# apply batch operations
ds
=
ds
.
batch
(
config
.
batch_size
,
drop_remainder
=
True
)
# apply dataset repeat operation
ds
=
ds
.
repeat
(
repeat_num
)
return
ds
model_zoo/official/cv/nasnet/src/loss.py
0 → 100755
浏览文件 @
eb9e2ed7
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""define evaluation loss function for network."""
from
mindspore.nn.loss.loss
import
_Loss
from
mindspore.ops
import
operations
as
P
from
mindspore.ops
import
functional
as
F
from
mindspore
import
Tensor
from
mindspore.common
import
dtype
as
mstype
import
mindspore.nn
as
nn
class
CrossEntropy_Val
(
_Loss
):
"""the redefined loss function with SoftmaxCrossEntropyWithLogits, will be used in inference process"""
def
__init__
(
self
,
smooth_factor
=
0
,
num_classes
=
1000
):
super
(
CrossEntropy_Val
,
self
).
__init__
()
self
.
onehot
=
P
.
OneHot
()
self
.
on_value
=
Tensor
(
1.0
-
smooth_factor
,
mstype
.
float32
)
self
.
off_value
=
Tensor
(
1.0
*
smooth_factor
/
(
num_classes
-
1
),
mstype
.
float32
)
self
.
ce
=
nn
.
SoftmaxCrossEntropyWithLogits
()
self
.
mean
=
P
.
ReduceMean
(
False
)
def
construct
(
self
,
logits
,
label
):
one_hot_label
=
self
.
onehot
(
label
,
F
.
shape
(
logits
)[
1
],
self
.
on_value
,
self
.
off_value
)
loss_logit
=
self
.
ce
(
logits
,
one_hot_label
)
loss_logit
=
self
.
mean
(
loss_logit
,
0
)
return
loss_logit
model_zoo/official/cv/nasnet/src/lr_generator.py
0 → 100755
浏览文件 @
eb9e2ed7
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""learning rate exponential decay generator"""
import
math
import
numpy
as
np
def
get_lr
(
lr_init
,
lr_decay_rate
,
num_epoch_per_decay
,
total_epochs
,
steps_per_epoch
,
is_stair
=
False
):
"""
generate learning rate array
Args:
lr_init(float): init learning rate
lr_decay_rate (float):
total_epochs(int): total epoch of training
steps_per_epoch(int): steps of one epoch
is_stair(bool): If `True` decay the learning rate at discrete intervals
Returns:
np.array, learning rate array
"""
lr_each_step
=
[]
total_steps
=
steps_per_epoch
*
total_epochs
decay_steps
=
steps_per_epoch
*
num_epoch_per_decay
for
i
in
range
(
total_steps
):
p
=
i
/
decay_steps
if
is_stair
:
p
=
math
.
floor
(
p
)
lr_each_step
.
append
(
lr_init
*
math
.
pow
(
lr_decay_rate
,
p
))
learning_rate
=
np
.
array
(
lr_each_step
).
astype
(
np
.
float32
)
return
learning_rate
model_zoo/official/cv/nasnet/src/nasnet_a_mobile.py
0 → 100755
浏览文件 @
eb9e2ed7
此差异已折叠。
点击以展开。
model_zoo/official/cv/nasnet/train.py
0 → 100755
浏览文件 @
eb9e2ed7
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""train imagenet."""
import
argparse
import
os
import
random
import
numpy
as
np
from
mindspore
import
Tensor
from
mindspore
import
context
from
mindspore
import
ParallelMode
from
mindspore.communication.management
import
init
,
get_rank
,
get_group_size
from
mindspore.nn.optim.rmsprop
import
RMSProp
from
mindspore.train.callback
import
ModelCheckpoint
,
CheckpointConfig
,
LossMonitor
,
TimeMonitor
from
mindspore.train.model
import
Model
from
mindspore.train.serialization
import
load_checkpoint
,
load_param_into_net
from
mindspore
import
dataset
as
de
from
src.config
import
nasnet_a_mobile_config_gpu
as
cfg
from
src.dataset
import
create_dataset
from
src.nasnet_a_mobile
import
NASNetAMobileWithLoss
,
NASNetAMobileTrainOneStepWithClipGradient
from
src.lr_generator
import
get_lr
random
.
seed
(
cfg
.
random_seed
)
np
.
random
.
seed
(
cfg
.
random_seed
)
de
.
config
.
set_seed
(
cfg
.
random_seed
)
if
__name__
==
'__main__'
:
parser
=
argparse
.
ArgumentParser
(
description
=
'image classification training'
)
parser
.
add_argument
(
'--dataset_path'
,
type
=
str
,
default
=
''
,
help
=
'Dataset path'
)
parser
.
add_argument
(
'--resume'
,
type
=
str
,
default
=
''
,
help
=
'resume training with existed checkpoint'
)
parser
.
add_argument
(
'--is_distributed'
,
action
=
'store_true'
,
default
=
False
,
help
=
'distributed training'
)
parser
.
add_argument
(
'--platform'
,
type
=
str
,
default
=
'GPU'
,
choices
=
(
'Ascend'
,
'GPU'
),
help
=
'run platform'
)
args_opt
=
parser
.
parse_args
()
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
device_target
=
args_opt
.
platform
,
save_graphs
=
False
)
if
os
.
getenv
(
'DEVICE_ID'
,
"not_set"
).
isdigit
():
context
.
set_context
(
device_id
=
int
(
os
.
getenv
(
'DEVICE_ID'
)))
# init distributed
if
args_opt
.
is_distributed
:
if
args_opt
.
platform
==
"Ascend"
:
init
()
else
:
init
(
"nccl"
)
cfg
.
rank
=
get_rank
()
cfg
.
group_size
=
get_group_size
()
parallel_mode
=
ParallelMode
.
DATA_PARALLEL
context
.
set_auto_parallel_context
(
parallel_mode
=
parallel_mode
,
device_num
=
cfg
.
group_size
,
parameter_broadcast
=
True
,
mirror_mean
=
True
)
else
:
cfg
.
rank
=
0
cfg
.
group_size
=
1
# dataloader
dataset
=
create_dataset
(
args_opt
.
dataset_path
,
cfg
,
True
)
batches_per_epoch
=
dataset
.
get_dataset_size
()
# network
net_with_loss
=
NASNetAMobileWithLoss
(
cfg
)
if
args_opt
.
resume
:
ckpt
=
load_checkpoint
(
args_opt
.
resume
)
load_param_into_net
(
net_with_loss
,
ckpt
)
# learning rate schedule
lr
=
get_lr
(
lr_init
=
cfg
.
lr_init
,
lr_decay_rate
=
cfg
.
lr_decay_rate
,
num_epoch_per_decay
=
cfg
.
num_epoch_per_decay
,
total_epochs
=
cfg
.
epoch_size
,
steps_per_epoch
=
batches_per_epoch
,
is_stair
=
True
)
lr
=
Tensor
(
lr
)
# optimizer
decayed_params
=
[]
no_decayed_params
=
[]
for
param
in
net_with_loss
.
trainable_params
():
if
'beta'
not
in
param
.
name
and
'gamma'
not
in
param
.
name
and
'bias'
not
in
param
.
name
:
decayed_params
.
append
(
param
)
else
:
no_decayed_params
.
append
(
param
)
group_params
=
[{
'params'
:
decayed_params
,
'weight_decay'
:
cfg
.
weight_decay
},
{
'params'
:
no_decayed_params
},
{
'order_params'
:
net_with_loss
.
trainable_params
()}]
optimizer
=
RMSProp
(
group_params
,
lr
,
decay
=
cfg
.
rmsprop_decay
,
weight_decay
=
cfg
.
weight_decay
,
momentum
=
cfg
.
momentum
,
epsilon
=
cfg
.
opt_eps
,
loss_scale
=
cfg
.
loss_scale
)
net_with_grads
=
NASNetAMobileTrainOneStepWithClipGradient
(
net_with_loss
,
optimizer
)
net_with_grads
.
set_train
()
model
=
Model
(
net_with_grads
)
print
(
"============== Starting Training =============="
)
loss_cb
=
LossMonitor
(
per_print_times
=
batches_per_epoch
)
time_cb
=
TimeMonitor
(
data_size
=
batches_per_epoch
)
callbacks
=
[
loss_cb
,
time_cb
]
config_ck
=
CheckpointConfig
(
save_checkpoint_steps
=
batches_per_epoch
,
keep_checkpoint_max
=
cfg
.
keep_checkpoint_max
)
ckpoint_cb
=
ModelCheckpoint
(
prefix
=
f
"nasnet-a-mobile-rank
{
cfg
.
rank
}
"
,
directory
=
cfg
.
ckpt_path
,
config
=
config_ck
)
if
args_opt
.
is_distributed
&
cfg
.
is_save_on_master
:
if
cfg
.
rank
==
0
:
callbacks
.
append
(
ckpoint_cb
)
model
.
train
(
cfg
.
epoch_size
,
dataset
,
callbacks
=
callbacks
,
dataset_sink_mode
=
True
)
else
:
callbacks
.
append
(
ckpoint_cb
)
model
.
train
(
cfg
.
epoch_size
,
dataset
,
callbacks
=
callbacks
,
dataset_sink_mode
=
True
)
print
(
"train success"
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录