Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
cf3fff89
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
cf3fff89
编写于
7月 25, 2020
作者:
M
ms_yan
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
init add vgg16 gpu version
merge the script optimize the script
上级
79225e04
变更
13
隐藏空白更改
内联
并排
Showing
13 changed file
with
1208 addition
and
123 deletion
+1208
-123
model_zoo/official/cv/vgg16/eval.py
model_zoo/official/cv/vgg16/eval.py
+191
-29
model_zoo/official/cv/vgg16/src/config.py
model_zoo/official/cv/vgg16/src/config.py
+39
-4
model_zoo/official/cv/vgg16/src/crossentropy.py
model_zoo/official/cv/vgg16/src/crossentropy.py
+39
-0
model_zoo/official/cv/vgg16/src/dataset.py
model_zoo/official/cv/vgg16/src/dataset.py
+140
-14
model_zoo/official/cv/vgg16/src/linear_warmup.py
model_zoo/official/cv/vgg16/src/linear_warmup.py
+7
-13
model_zoo/official/cv/vgg16/src/utils/logging.py
model_zoo/official/cv/vgg16/src/utils/logging.py
+82
-0
model_zoo/official/cv/vgg16/src/utils/sampler.py
model_zoo/official/cv/vgg16/src/utils/sampler.py
+53
-0
model_zoo/official/cv/vgg16/src/utils/util.py
model_zoo/official/cv/vgg16/src/utils/util.py
+36
-0
model_zoo/official/cv/vgg16/src/utils/var_init.py
model_zoo/official/cv/vgg16/src/utils/var_init.py
+213
-0
model_zoo/official/cv/vgg16/src/vgg.py
model_zoo/official/cv/vgg16/src/vgg.py
+47
-9
model_zoo/official/cv/vgg16/src/warmup_cosine_annealing_lr.py
...l_zoo/official/cv/vgg16/src/warmup_cosine_annealing_lr.py
+40
-0
model_zoo/official/cv/vgg16/src/warmup_step_lr.py
model_zoo/official/cv/vgg16/src/warmup_step_lr.py
+84
-0
model_zoo/official/cv/vgg16/train.py
model_zoo/official/cv/vgg16/train.py
+237
-54
未找到文件。
model_zoo/official/cv/vgg16/eval.py
浏览文件 @
cf3fff89
...
...
@@ -12,42 +12,204 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
##############test vgg16 example on cifar10#################
python eval.py --data_path=$DATA_HOME --device_id=$DEVICE_ID
"""
"""Eval"""
import
os
import
time
import
argparse
import
datetime
import
glob
import
numpy
as
np
import
mindspore.nn
as
nn
from
mindspore
import
context
from
mindspore
import
Tensor
,
context
from
mindspore.nn.optim.momentum
import
Momentum
from
mindspore.train.model
import
Model
from
mindspore.train.serialization
import
load_checkpoint
,
load_param_into_net
from
src.config
import
cifar_cfg
as
cfg
from
src.dataset
import
vgg_create_dataset
from
mindspore.ops
import
operations
as
P
from
mindspore.ops
import
functional
as
F
from
mindspore.common
import
dtype
as
mstype
from
src.utils.logging
import
get_logger
from
src.vgg
import
vgg16
from
src.dataset
import
vgg_create_dataset
from
src.dataset
import
classification_dataset
class
ParameterReduce
(
nn
.
Cell
):
"""ParameterReduce"""
def
__init__
(
self
):
super
(
ParameterReduce
,
self
).
__init__
()
self
.
cast
=
P
.
Cast
()
self
.
reduce
=
P
.
AllReduce
()
def
construct
(
self
,
x
):
one
=
self
.
cast
(
F
.
scalar_to_array
(
1.0
),
mstype
.
float32
)
out
=
x
*
one
ret
=
self
.
reduce
(
out
)
return
ret
if
__name__
==
'__main__'
:
parser
=
argparse
.
ArgumentParser
(
description
=
'Cifar10 classification'
)
parser
.
add_argument
(
'--device_target'
,
type
=
str
,
default
=
'Ascend'
,
choices
=
[
'Ascend'
,
'GPU'
],
def
parse_args
(
cloud_args
=
None
):
"""parse_args"""
parser
=
argparse
.
ArgumentParser
(
'mindspore classification test'
)
parser
.
add_argument
(
'--device_target'
,
type
=
str
,
default
=
'GPU'
,
choices
=
[
'Ascend'
,
'GPU'
],
help
=
'device where the code will be implemented. (Default: Ascend)'
)
parser
.
add_argument
(
'--data_path'
,
type
=
str
,
default
=
'./cifar'
,
help
=
'path where the dataset is saved'
)
parser
.
add_argument
(
'--checkpoint_path'
,
type
=
str
,
default
=
None
,
help
=
'checkpoint file path.'
)
parser
.
add_argument
(
'--device_id'
,
type
=
int
,
default
=
None
,
help
=
'device id of GPU or Ascend. (Default: None)'
)
# dataset related
parser
.
add_argument
(
'--dataset'
,
type
=
str
,
choices
=
[
"cifar10"
,
"imagenet2012"
],
default
=
"imagenet2012"
)
parser
.
add_argument
(
'--data_path'
,
type
=
str
,
default
=
''
,
help
=
'eval data dir'
)
parser
.
add_argument
(
'--per_batch_size'
,
default
=
32
,
type
=
int
,
help
=
'batch size for per npu'
)
# network related
parser
.
add_argument
(
'--graph_ckpt'
,
type
=
int
,
default
=
1
,
help
=
'graph ckpt or feed ckpt'
)
parser
.
add_argument
(
'--pretrained'
,
default
=
''
,
type
=
str
,
help
=
'fully path of pretrained model to load. '
'If it is a direction, it will test all ckpt'
)
# logging related
parser
.
add_argument
(
'--log_path'
,
type
=
str
,
default
=
'outputs/'
,
help
=
'path to save log'
)
parser
.
add_argument
(
'--rank'
,
type
=
int
,
default
=
0
,
help
=
'local rank of distributed'
)
parser
.
add_argument
(
'--group_size'
,
type
=
int
,
default
=
1
,
help
=
'world size of distributed'
)
# roma obs
parser
.
add_argument
(
'--train_url'
,
type
=
str
,
default
=
""
,
help
=
'train url'
)
args_opt
=
parser
.
parse_args
()
args_opt
=
merge_args
(
args_opt
,
cloud_args
)
if
args_opt
.
dataset
==
"cifar10"
:
from
src.config
import
cifar_cfg
as
cfg
else
:
from
src.config
import
imagenet_cfg
as
cfg
args_opt
.
image_size
=
cfg
.
image_size
args_opt
.
num_classes
=
cfg
.
num_classes
args_opt
.
per_batch_size
=
cfg
.
batch_size
args_opt
.
buffer_size
=
cfg
.
buffer_size
args_opt
.
pad_mode
=
cfg
.
pad_mode
args_opt
.
padding
=
cfg
.
padding
args_opt
.
has_bias
=
cfg
.
has_bias
args_opt
.
batch_norm
=
cfg
.
batch_norm
args_opt
.
image_size
=
list
(
map
(
int
,
args_opt
.
image_size
.
split
(
','
)))
return
args_opt
def
get_top5_acc
(
top5_arg
,
gt_class
):
sub_count
=
0
for
top5
,
gt
in
zip
(
top5_arg
,
gt_class
):
if
gt
in
top5
:
sub_count
+=
1
return
sub_count
def
merge_args
(
args
,
cloud_args
):
"""merge_args"""
args_dict
=
vars
(
args
)
if
isinstance
(
cloud_args
,
dict
):
for
key
in
cloud_args
.
keys
():
val
=
cloud_args
[
key
]
if
key
in
args_dict
and
val
:
arg_type
=
type
(
args_dict
[
key
])
if
arg_type
is
not
type
(
None
):
val
=
arg_type
(
val
)
args_dict
[
key
]
=
val
return
args
def
test
(
cloud_args
=
None
):
"""test"""
args
=
parse_args
(
cloud_args
)
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
enable_auto_mixed_precision
=
True
,
device_target
=
args
.
device_target
,
save_graphs
=
False
)
if
os
.
getenv
(
'DEVICE_ID'
,
"not_set"
).
isdigit
():
context
.
set_context
(
device_id
=
int
(
os
.
getenv
(
'DEVICE_ID'
)))
args
.
outputs_dir
=
os
.
path
.
join
(
args
.
log_path
,
datetime
.
datetime
.
now
().
strftime
(
'%Y-%m-%d_time_%H_%M_%S'
))
args
.
logger
=
get_logger
(
args
.
outputs_dir
,
args
.
rank
)
args
.
logger
.
save_args
(
args
)
if
args
.
dataset
==
"cifar10"
:
net
=
vgg16
(
num_classes
=
args
.
num_classes
)
opt
=
Momentum
(
filter
(
lambda
x
:
x
.
requires_grad
,
net
.
get_parameters
()),
0.01
,
cfg
.
momentum
,
weight_decay
=
args
.
weight_decay
)
loss
=
nn
.
SoftmaxCrossEntropyWithLogits
(
sparse
=
True
,
reduction
=
'mean'
,
is_grad
=
False
)
model
=
Model
(
net
,
loss_fn
=
loss
,
optimizer
=
opt
,
metrics
=
{
'acc'
})
param_dict
=
load_checkpoint
(
args
.
checkpoint_path
)
load_param_into_net
(
net
,
param_dict
)
net
.
set_train
(
False
)
dataset
=
vgg_create_dataset
(
args
.
data_path
,
1
,
False
)
res
=
model
.
eval
(
dataset
)
print
(
"result: "
,
res
)
else
:
# network
args
.
logger
.
important_info
(
'start create network'
)
if
os
.
path
.
isdir
(
args
.
pretrained
):
models
=
list
(
glob
.
glob
(
os
.
path
.
join
(
args
.
pretrained
,
'*.ckpt'
)))
print
(
models
)
if
args
.
graph_ckpt
:
f
=
lambda
x
:
-
1
*
int
(
os
.
path
.
splitext
(
os
.
path
.
split
(
x
)[
-
1
])[
0
].
split
(
'-'
)[
-
1
].
split
(
'_'
)[
0
])
else
:
f
=
lambda
x
:
-
1
*
int
(
os
.
path
.
splitext
(
os
.
path
.
split
(
x
)[
-
1
])[
0
].
split
(
'_'
)[
-
1
])
args
.
models
=
sorted
(
models
,
key
=
f
)
else
:
args
.
models
=
[
args
.
pretrained
,]
for
model
in
args
.
models
:
if
args
.
dataset
==
"cifar10"
:
dataset
=
vgg_create_dataset
(
args
.
data_path
,
args
.
image_size
,
args
.
per_batch_size
,
training
=
False
)
else
:
dataset
=
classification_dataset
(
args
.
data_path
,
args
.
image_size
,
args
.
per_batch_size
)
eval_dataloader
=
dataset
.
create_tuple_iterator
()
network
=
vgg16
(
args
.
num_classes
,
args
,
phase
=
"test"
)
# pre_trained
load_param_into_net
(
network
,
load_checkpoint
(
model
))
network
.
add_flags_recursive
(
fp16
=
True
)
img_tot
=
0
top1_correct
=
0
top5_correct
=
0
network
.
set_train
(
False
)
t_end
=
time
.
time
()
it
=
0
for
data
,
gt_classes
in
eval_dataloader
:
output
=
network
(
Tensor
(
data
,
mstype
.
float32
))
output
=
output
.
asnumpy
()
top1_output
=
np
.
argmax
(
output
,
(
-
1
))
top5_output
=
np
.
argsort
(
output
)[:,
-
5
:]
t1_correct
=
np
.
equal
(
top1_output
,
gt_classes
).
sum
()
top1_correct
+=
t1_correct
top5_correct
+=
get_top5_acc
(
top5_output
,
gt_classes
)
img_tot
+=
args
.
per_batch_size
if
args
.
rank
==
0
and
it
==
0
:
t_end
=
time
.
time
()
it
=
1
if
args
.
rank
==
0
:
time_used
=
time
.
time
()
-
t_end
fps
=
(
img_tot
-
args
.
per_batch_size
)
*
args
.
group_size
/
time_used
args
.
logger
.
info
(
'Inference Performance: {:.2f} img/sec'
.
format
(
fps
))
results
=
[[
top1_correct
],
[
top5_correct
],
[
img_tot
]]
args
.
logger
.
info
(
'before results={}'
.
format
(
results
))
results
=
np
.
array
(
results
)
args
.
logger
.
info
(
'after results={}'
.
format
(
results
))
top1_correct
=
results
[
0
,
0
]
top5_correct
=
results
[
1
,
0
]
img_tot
=
results
[
2
,
0
]
acc1
=
100.0
*
top1_correct
/
img_tot
acc5
=
100.0
*
top5_correct
/
img_tot
args
.
logger
.
info
(
'after allreduce eval: top1_correct={}, tot={},'
'acc={:.2f}%(TOP1)'
.
format
(
top1_correct
,
img_tot
,
acc1
))
args
.
logger
.
info
(
'after allreduce eval: top5_correct={}, tot={},'
'acc={:.2f}%(TOP5)'
.
format
(
top5_correct
,
img_tot
,
acc5
))
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
device_target
=
args_opt
.
device_target
)
context
.
set_context
(
device_id
=
args_opt
.
device_id
)
net
=
vgg16
(
num_classes
=
cfg
.
num_classes
)
opt
=
Momentum
(
filter
(
lambda
x
:
x
.
requires_grad
,
net
.
get_parameters
()),
0.01
,
cfg
.
momentum
,
weight_decay
=
cfg
.
weight_decay
)
loss
=
nn
.
SoftmaxCrossEntropyWithLogits
(
sparse
=
True
,
reduction
=
'mean'
,
is_grad
=
False
)
model
=
Model
(
net
,
loss_fn
=
loss
,
optimizer
=
opt
,
metrics
=
{
'acc'
})
param_dict
=
load_checkpoint
(
args_opt
.
checkpoint_path
)
load_param_into_net
(
net
,
param_dict
)
net
.
set_train
(
False
)
dataset
=
vgg_create_dataset
(
args_opt
.
data_path
,
1
,
False
)
res
=
model
.
eval
(
dataset
)
print
(
"result: "
,
res
)
if
__name__
==
"__main__"
:
test
()
model_zoo/official/cv/vgg16/src/config.py
100644 → 100755
浏览文件 @
cf3fff89
...
...
@@ -13,21 +13,56 @@
# limitations under the License.
# ============================================================================
"""
network config setting, will be used in
main
.py
network config setting, will be used in
train.py and eval
.py
"""
from
easydict
import
EasyDict
as
edict
# config for vgg16, cifar10
cifar_cfg
=
edict
({
'num_classes'
:
10
,
"lr"
:
0.01
,
'lr_init'
:
0.01
,
'lr_max'
:
0.1
,
"lr_epochs"
:
'30,60,90,120'
,
"lr_scheduler"
:
"step"
,
'warmup_epochs'
:
5
,
'batch_size'
:
64
,
'
epoch_size
'
:
70
,
'
max_epoch
'
:
70
,
'momentum'
:
0.9
,
'weight_decay'
:
5e-4
,
"loss_scale"
:
1.0
,
"label_smooth"
:
0
,
"label_smooth_factor"
:
0
,
'buffer_size'
:
10
,
'image_height'
:
224
,
'image_width'
:
224
,
"image_size"
:
'224,224'
,
'pad_mode'
:
'same'
,
'padding'
:
0
,
'has_bias'
:
False
,
"batch_norm"
:
True
,
'keep_checkpoint_max'
:
10
})
# config for vgg16, imagenet2012
imagenet_cfg
=
edict
({
'num_classes'
:
1000
,
"lr"
:
0.01
,
'lr_init'
:
0.01
,
'lr_max'
:
0.1
,
"lr_epochs"
:
'30,60,90,120'
,
"lr_scheduler"
:
'cosine_annealing'
,
'warmup_epochs'
:
0
,
'batch_size'
:
32
,
'max_epoch'
:
150
,
'momentum'
:
0.9
,
'weight_decay'
:
1e-4
,
"loss_scale"
:
1024
,
"label_smooth"
:
1
,
"label_smooth_factor"
:
0.1
,
'buffer_size'
:
10
,
"image_size"
:
'224,224'
,
'pad_mode'
:
'pad'
,
'padding'
:
1
,
'has_bias'
:
True
,
"batch_norm"
:
False
,
'keep_checkpoint_max'
:
10
})
model_zoo/official/cv/vgg16/src/crossentropy.py
0 → 100755
浏览文件 @
cf3fff89
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""define loss function for network"""
from
mindspore.nn.loss.loss
import
_Loss
from
mindspore.ops
import
operations
as
P
from
mindspore.ops
import
functional
as
F
from
mindspore
import
Tensor
from
mindspore.common
import
dtype
as
mstype
import
mindspore.nn
as
nn
class
CrossEntropy
(
_Loss
):
"""the redefined loss function with SoftmaxCrossEntropyWithLogits"""
def
__init__
(
self
,
smooth_factor
=
0.
,
num_classes
=
1001
):
super
(
CrossEntropy
,
self
).
__init__
()
self
.
onehot
=
P
.
OneHot
()
self
.
on_value
=
Tensor
(
1.0
-
smooth_factor
,
mstype
.
float32
)
self
.
off_value
=
Tensor
(
1.0
*
smooth_factor
/
(
num_classes
-
1
),
mstype
.
float32
)
self
.
ce
=
nn
.
SoftmaxCrossEntropyWithLogits
()
self
.
mean
=
P
.
ReduceMean
(
False
)
def
construct
(
self
,
logit
,
label
):
one_hot_label
=
self
.
onehot
(
label
,
F
.
shape
(
logit
)[
1
],
self
.
on_value
,
self
.
off_value
)
loss
=
self
.
ce
(
logit
,
one_hot_label
)
loss
=
self
.
mean
(
loss
,
0
)
return
loss
model_zoo/official/cv/vgg16/src/dataset.py
浏览文件 @
cf3fff89
...
...
@@ -13,37 +13,35 @@
# limitations under the License.
# ============================================================================
"""
Data operations, will be used in train.py and eval.py
dataset processing.
"""
import
os
import
mindspore.common.dtype
as
mstype
import
mindspore.dataset
as
ds
from
mindspore.common
import
dtype
as
mstype
import
mindspore.dataset
as
de
import
mindspore.dataset.transforms.c_transforms
as
C
import
mindspore.dataset.transforms.vision.c_transforms
as
vision
from
.config
import
cifar_cfg
as
cfg
from
PIL
import
Image
,
ImageFile
from
src.utils.sampler
import
DistributedSampler
ImageFile
.
LOAD_TRUNCATED_IMAGES
=
True
def
vgg_create_dataset
(
data_home
,
repeat_num
=
1
,
training
=
True
):
def
vgg_create_dataset
(
data_home
,
image_size
,
batch_size
,
rank_id
=
0
,
rank_size
=
1
,
repeat_num
=
1
,
training
=
True
):
"""Data operations."""
d
s
.
config
.
set_seed
(
1
)
d
e
.
config
.
set_seed
(
1
)
data_dir
=
os
.
path
.
join
(
data_home
,
"cifar-10-batches-bin"
)
if
not
training
:
data_dir
=
os
.
path
.
join
(
data_home
,
"cifar-10-verify-bin"
)
rank_size
=
int
(
os
.
environ
.
get
(
"RANK_SIZE"
))
if
os
.
environ
.
get
(
"RANK_SIZE"
)
else
None
rank_id
=
int
(
os
.
environ
.
get
(
"RANK_ID"
))
if
os
.
environ
.
get
(
"RANK_ID"
)
else
None
data_set
=
ds
.
Cifar10Dataset
(
data_dir
,
num_shards
=
rank_size
,
shard_id
=
rank_id
)
data_set
=
de
.
Cifar10Dataset
(
data_dir
,
num_shards
=
rank_size
,
shard_id
=
rank_id
)
resize_height
=
cfg
.
image_height
resize_width
=
cfg
.
image_width
rescale
=
1.0
/
255.0
shift
=
0.0
# define map operations
random_crop_op
=
vision
.
RandomCrop
((
32
,
32
),
(
4
,
4
,
4
,
4
))
# padding_mode default CONSTANT
random_horizontal_op
=
vision
.
RandomHorizontalFlip
()
resize_op
=
vision
.
Resize
(
(
resize_height
,
resize_width
)
)
# interpolation default BILINEAR
resize_op
=
vision
.
Resize
(
image_size
)
# interpolation default BILINEAR
rescale_op
=
vision
.
Rescale
(
rescale
,
shift
)
normalize_op
=
vision
.
Normalize
((
0.4465
,
0.4822
,
0.4914
),
(
0.2010
,
0.1994
,
0.2023
))
changeswap_op
=
vision
.
HWC2CHW
()
...
...
@@ -66,6 +64,134 @@ def vgg_create_dataset(data_home, repeat_num=1, training=True):
data_set
=
data_set
.
shuffle
(
buffer_size
=
10
)
# apply batch operations
data_set
=
data_set
.
batch
(
batch_size
=
cfg
.
batch_size
,
drop_remainder
=
True
)
data_set
=
data_set
.
batch
(
batch_size
=
batch_size
,
drop_remainder
=
True
)
return
data_set
def
classification_dataset
(
data_dir
,
image_size
,
per_batch_size
,
rank
=
0
,
group_size
=
1
,
mode
=
'train'
,
input_mode
=
'folder'
,
root
=
''
,
num_parallel_workers
=
None
,
shuffle
=
None
,
sampler
=
None
,
repeat_num
=
1
,
class_indexing
=
None
,
drop_remainder
=
True
,
transform
=
None
,
target_transform
=
None
):
"""
A function that returns a dataset for classification. The mode of input dataset could be "folder" or "txt".
If it is "folder", all images within one folder have the same label. If it is "txt", all paths of images
are written into a textfile.
Args:
data_dir (str): Path to the root directory that contains the dataset for "input_mode="folder"".
Or path of the textfile that contains every image's path of the dataset.
image_size (str): Size of the input images.
per_batch_size (int): the batch size of evey step during training.
rank (int): The shard ID within num_shards (default=None).
group_size (int): Number of shards that the dataset should be divided
into (default=None).
mode (str): "train" or others. Default: " train".
input_mode (str): The form of the input dataset. "folder" or "txt". Default: "folder".
root (str): the images path for "input_mode="txt"". Default: " ".
num_parallel_workers (int): Number of workers to read the data. Default: None.
shuffle (bool): Whether or not to perform shuffle on the dataset
(default=None, performs shuffle).
sampler (Sampler): Object used to choose samples from the dataset. Default: None.
repeat_num (int): the num of repeat dataset.
class_indexing (dict): A str-to-int mapping from folder name to index
(default=None, the folder names will be sorted
alphabetically and each class will be given a
unique index starting from 0).
Examples:
>>> from mindvision.common.datasets.classification import classification_dataset
>>> # path to imagefolder directory. This directory needs to contain sub-directories which contain the images
>>> dataset_dir = "/path/to/imagefolder_directory"
>>> de_dataset = classification_dataset(train_data_dir, image_size=[224, 244],
>>> per_batch_size=64, rank=0, group_size=4)
>>> # Path of the textfile that contains every image's path of the dataset.
>>> dataset_dir = "/path/to/dataset/images/train.txt"
>>> images_dir = "/path/to/dataset/images"
>>> de_dataset = classification_dataset(train_data_dir, image_size=[224, 244],
>>> per_batch_size=64, rank=0, group_size=4,
>>> input_mode="txt", root=images_dir)
"""
mean
=
[
0.485
*
255
,
0.456
*
255
,
0.406
*
255
]
std
=
[
0.229
*
255
,
0.224
*
255
,
0.225
*
255
]
if
transform
is
None
:
if
mode
==
'train'
:
transform_img
=
[
vision
.
RandomCropDecodeResize
(
image_size
,
scale
=
(
0.08
,
1.0
)),
vision
.
RandomHorizontalFlip
(
prob
=
0.5
),
vision
.
Normalize
(
mean
=
mean
,
std
=
std
),
vision
.
HWC2CHW
()
]
else
:
transform_img
=
[
vision
.
Decode
(),
vision
.
Resize
((
256
,
256
)),
vision
.
CenterCrop
(
image_size
),
vision
.
Normalize
(
mean
=
mean
,
std
=
std
),
vision
.
HWC2CHW
()
]
else
:
transform_img
=
transform
if
target_transform
is
None
:
transform_label
=
[
C
.
TypeCast
(
mstype
.
int32
)]
else
:
transform_label
=
target_transform
if
input_mode
==
'folder'
:
de_dataset
=
de
.
ImageFolderDatasetV2
(
data_dir
,
num_parallel_workers
=
num_parallel_workers
,
shuffle
=
shuffle
,
sampler
=
sampler
,
class_indexing
=
class_indexing
,
num_shards
=
group_size
,
shard_id
=
rank
)
else
:
dataset
=
TxtDataset
(
root
,
data_dir
)
sampler
=
DistributedSampler
(
dataset
,
rank
,
group_size
,
shuffle
=
shuffle
)
de_dataset
=
de
.
GeneratorDataset
(
dataset
,
[
"image"
,
"label"
],
sampler
=
sampler
)
de_dataset
.
set_dataset_size
(
len
(
sampler
))
de_dataset
=
de_dataset
.
map
(
input_columns
=
"image"
,
num_parallel_workers
=
8
,
operations
=
transform_img
)
de_dataset
=
de_dataset
.
map
(
input_columns
=
"label"
,
num_parallel_workers
=
8
,
operations
=
transform_label
)
columns_to_project
=
[
"image"
,
"label"
]
de_dataset
=
de_dataset
.
project
(
columns
=
columns_to_project
)
de_dataset
=
de_dataset
.
batch
(
per_batch_size
,
drop_remainder
=
drop_remainder
)
de_dataset
=
de_dataset
.
repeat
(
repeat_num
)
return
de_dataset
class
TxtDataset
:
"""
create txt dataset.
Args:
Returns:
de_dataset.
"""
def
__init__
(
self
,
root
,
txt_name
):
super
(
TxtDataset
,
self
).
__init__
()
self
.
imgs
=
[]
self
.
labels
=
[]
fin
=
open
(
txt_name
,
"r"
)
for
line
in
fin
:
img_name
,
label
=
line
.
strip
().
split
(
' '
)
self
.
imgs
.
append
(
os
.
path
.
join
(
root
,
img_name
))
self
.
labels
.
append
(
int
(
label
))
fin
.
close
()
def
__getitem__
(
self
,
index
):
img
=
Image
.
open
(
self
.
imgs
[
index
]).
convert
(
'RGB'
)
return
img
,
self
.
labels
[
index
]
def
__len__
(
self
):
return
len
(
self
.
imgs
)
tests/ut/python/model/test_vgg
.py
→
model_zoo/official/cv/vgg16/src/linear_warmup
.py
浏览文件 @
cf3fff89
...
...
@@ -12,18 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
test_vgg"""
import
numpy
as
np
import
pytest
"""
linear warm up learning rate.
"""
from
mindspore
import
Tensor
from
model_zoo.official.cv.vgg16.src.vgg
import
vgg16
from
..ut_filter
import
non_graph_engine
@
non_graph_engine
def
test_vgg16
():
inputs
=
Tensor
(
np
.
random
.
rand
(
1
,
3
,
112
,
112
).
astype
(
np
.
float32
))
net
=
vgg16
()
with
pytest
.
raises
(
ValueError
):
print
(
net
.
construct
(
inputs
))
def
linear_warmup_lr
(
current_step
,
warmup_steps
,
base_lr
,
init_lr
):
lr_inc
=
(
float
(
base_lr
)
-
float
(
init_lr
))
/
float
(
warmup_steps
)
lr
=
float
(
init_lr
)
+
lr_inc
*
current_step
return
lr
model_zoo/official/cv/vgg16/src/utils/logging.py
0 → 100644
浏览文件 @
cf3fff89
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
get logger.
"""
import
logging
import
os
import
sys
from
datetime
import
datetime
class
LOGGER
(
logging
.
Logger
):
"""
set up logging file.
Args:
logger_name (string): logger name.
log_dir (string): path of logger.
Returns:
string, logger path
"""
def
__init__
(
self
,
logger_name
,
rank
=
0
):
super
(
LOGGER
,
self
).
__init__
(
logger_name
)
if
rank
%
8
==
0
:
console
=
logging
.
StreamHandler
(
sys
.
stdout
)
console
.
setLevel
(
logging
.
INFO
)
formatter
=
logging
.
Formatter
(
'%(asctime)s:%(levelname)s:%(message)s'
)
console
.
setFormatter
(
formatter
)
self
.
addHandler
(
console
)
def
setup_logging_file
(
self
,
log_dir
,
rank
=
0
):
"""set up log file"""
self
.
rank
=
rank
if
not
os
.
path
.
exists
(
log_dir
):
os
.
makedirs
(
log_dir
,
exist_ok
=
True
)
log_name
=
datetime
.
now
().
strftime
(
'%Y-%m-%d_time_%H_%M_%S'
)
+
'_rank_{}.log'
.
format
(
rank
)
self
.
log_fn
=
os
.
path
.
join
(
log_dir
,
log_name
)
fh
=
logging
.
FileHandler
(
self
.
log_fn
)
fh
.
setLevel
(
logging
.
INFO
)
formatter
=
logging
.
Formatter
(
'%(asctime)s:%(levelname)s:%(message)s'
)
fh
.
setFormatter
(
formatter
)
self
.
addHandler
(
fh
)
def
info
(
self
,
msg
,
*
args
,
**
kwargs
):
if
self
.
isEnabledFor
(
logging
.
INFO
):
self
.
_log
(
logging
.
INFO
,
msg
,
args
,
**
kwargs
)
def
save_args
(
self
,
args
):
self
.
info
(
'Args:'
)
args_dict
=
vars
(
args
)
for
key
in
args_dict
.
keys
():
self
.
info
(
'--> %s: %s'
,
key
,
args_dict
[
key
])
self
.
info
(
''
)
def
important_info
(
self
,
msg
,
*
args
,
**
kwargs
):
if
self
.
isEnabledFor
(
logging
.
INFO
)
and
self
.
rank
==
0
:
line_width
=
2
important_msg
=
'
\n
'
important_msg
+=
(
'*'
*
70
+
'
\n
'
)
*
line_width
important_msg
+=
(
'*'
*
line_width
+
'
\n
'
)
*
2
important_msg
+=
'*'
*
line_width
+
' '
*
8
+
msg
+
'
\n
'
important_msg
+=
(
'*'
*
line_width
+
'
\n
'
)
*
2
important_msg
+=
(
'*'
*
70
+
'
\n
'
)
*
line_width
self
.
info
(
important_msg
,
*
args
,
**
kwargs
)
def
get_logger
(
path
,
rank
):
logger
=
LOGGER
(
"mindversion"
,
rank
)
logger
.
setup_logging_file
(
path
,
rank
)
return
logger
model_zoo/official/cv/vgg16/src/utils/sampler.py
0 → 100644
浏览文件 @
cf3fff89
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
choose samples from the dataset
"""
import
math
import
numpy
as
np
class
DistributedSampler
():
"""
sampling the dataset.
Args:
Returns:
num_samples, number of samples.
"""
def
__init__
(
self
,
dataset
,
rank
,
group_size
,
shuffle
=
True
,
seed
=
0
):
self
.
dataset
=
dataset
self
.
rank
=
rank
self
.
group_size
=
group_size
self
.
dataset_length
=
len
(
self
.
dataset
)
self
.
num_samples
=
int
(
math
.
ceil
(
self
.
dataset_length
*
1.0
/
self
.
group_size
))
self
.
total_size
=
self
.
num_samples
*
self
.
group_size
self
.
shuffle
=
shuffle
self
.
seed
=
seed
def
__iter__
(
self
):
if
self
.
shuffle
:
self
.
seed
=
(
self
.
seed
+
1
)
&
0xffffffff
np
.
random
.
seed
(
self
.
seed
)
indices
=
np
.
random
.
permutation
(
self
.
dataset_length
).
tolist
()
else
:
indices
=
list
(
range
(
len
(
self
.
dataset_length
)))
indices
+=
indices
[:(
self
.
total_size
-
len
(
indices
))]
indices
=
indices
[
self
.
rank
::
self
.
group_size
]
return
iter
(
indices
)
def
__len__
(
self
):
return
self
.
num_samples
\ No newline at end of file
model_zoo/official/cv/vgg16/src/utils/util.py
0 → 100644
浏览文件 @
cf3fff89
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Util class or function."""
def
get_param_groups
(
network
):
"""Param groups for optimizer."""
decay_params
=
[]
no_decay_params
=
[]
for
x
in
network
.
trainable_params
():
parameter_name
=
x
.
name
if
parameter_name
.
endswith
(
'.bias'
):
# all bias not using weight decay
no_decay_params
.
append
(
x
)
elif
parameter_name
.
endswith
(
'.gamma'
):
# bn weight bias not using weight decay, be carefully for now x not include BN
no_decay_params
.
append
(
x
)
elif
parameter_name
.
endswith
(
'.beta'
):
# bn weight bias not using weight decay, be carefully for now x not include BN
no_decay_params
.
append
(
x
)
else
:
decay_params
.
append
(
x
)
return
[{
'params'
:
no_decay_params
,
'weight_decay'
:
0.0
},
{
'params'
:
decay_params
}]
model_zoo/official/cv/vgg16/src/utils/var_init.py
0 → 100644
浏览文件 @
cf3fff89
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
Initialize.
"""
import
math
from
functools
import
reduce
import
numpy
as
np
import
mindspore.nn
as
nn
from
mindspore
import
Tensor
from
mindspore.common
import
initializer
as
init
def
_calculate_gain
(
nonlinearity
,
param
=
None
):
r
"""
Return the recommended gain value for the given nonlinearity function.
The values are as follows:
================= ====================================================
nonlinearity gain
================= ====================================================
Linear / Identity :math:`1`
Conv{1,2,3}D :math:`1`
Sigmoid :math:`1`
Tanh :math:`\frac{5}{3}`
ReLU :math:`\sqrt{2}`
Leaky Relu :math:`\sqrt{\frac{2}{1 + \text{negative\_slope}^2}}`
================= ====================================================
Args:
nonlinearity: the non-linear function
param: optional parameter for the non-linear function
Examples:
>>> gain = calculate_gain('leaky_relu', 0.2) # leaky_relu with negative_slope=0.2
"""
linear_fns
=
[
'linear'
,
'conv1d'
,
'conv2d'
,
'conv3d'
,
'conv_transpose1d'
,
'conv_transpose2d'
,
'conv_transpose3d'
]
if
nonlinearity
in
linear_fns
or
nonlinearity
==
'sigmoid'
:
return
1
if
nonlinearity
==
'tanh'
:
return
5.0
/
3
if
nonlinearity
==
'relu'
:
return
math
.
sqrt
(
2.0
)
if
nonlinearity
==
'leaky_relu'
:
if
param
is
None
:
negative_slope
=
0.01
elif
not
isinstance
(
param
,
bool
)
and
isinstance
(
param
,
int
)
or
isinstance
(
param
,
float
):
negative_slope
=
param
else
:
raise
ValueError
(
"negative_slope {} not a valid number"
.
format
(
param
))
return
math
.
sqrt
(
2.0
/
(
1
+
negative_slope
**
2
))
raise
ValueError
(
"Unsupported nonlinearity {}"
.
format
(
nonlinearity
))
def
_assignment
(
arr
,
num
):
"""Assign the value of `num` to `arr`."""
if
arr
.
shape
==
():
arr
=
arr
.
reshape
((
1
))
arr
[:]
=
num
arr
=
arr
.
reshape
(())
else
:
if
isinstance
(
num
,
np
.
ndarray
):
arr
[:]
=
num
[:]
else
:
arr
[:]
=
num
return
arr
def
_calculate_in_and_out
(
arr
):
"""
Calculate n_in and n_out.
Args:
arr (Array): Input array.
Returns:
Tuple, a tuple with two elements, the first element is `n_in` and the second element is `n_out`.
"""
dim
=
len
(
arr
.
shape
)
if
dim
<
2
:
raise
ValueError
(
"If initialize data with xavier uniform, the dimension of data must greater than 1."
)
n_in
=
arr
.
shape
[
1
]
n_out
=
arr
.
shape
[
0
]
if
dim
>
2
:
counter
=
reduce
(
lambda
x
,
y
:
x
*
y
,
arr
.
shape
[
2
:])
n_in
*=
counter
n_out
*=
counter
return
n_in
,
n_out
def
_select_fan
(
array
,
mode
):
mode
=
mode
.
lower
()
valid_modes
=
[
'fan_in'
,
'fan_out'
]
if
mode
not
in
valid_modes
:
raise
ValueError
(
"Mode {} not supported, please use one of {}"
.
format
(
mode
,
valid_modes
))
fan_in
,
fan_out
=
_calculate_in_and_out
(
array
)
return
fan_in
if
mode
==
'fan_in'
else
fan_out
class
KaimingInit
(
init
.
Initializer
):
r
"""
Base Class. Initialize the array with He kaiming algorithm.
Args:
a: the negative slope of the rectifier used after this layer (only
used with ``'leaky_relu'``)
mode: either ``'fan_in'`` (default) or ``'fan_out'``. Choosing ``'fan_in'``
preserves the magnitude of the variance of the weights in the
forward pass. Choosing ``'fan_out'`` preserves the magnitudes in the
backwards pass.
nonlinearity: the non-linear function, recommended to use only with
``'relu'`` or ``'leaky_relu'`` (default).
"""
def
__init__
(
self
,
a
=
0
,
mode
=
'fan_in'
,
nonlinearity
=
'leaky_relu'
):
super
(
KaimingInit
,
self
).
__init__
()
self
.
mode
=
mode
self
.
gain
=
_calculate_gain
(
nonlinearity
,
a
)
def
_initialize
(
self
,
arr
):
pass
class
KaimingUniform
(
KaimingInit
):
r
"""
Initialize the array with He kaiming uniform algorithm. The resulting tensor will
have values sampled from :math:`\mathcal{U}(-\text{bound}, \text{bound})` where
.. math::
\text{bound} = \text{gain} \times \sqrt{\frac{3}{\text{fan\_mode}}}
Input:
arr (Array): The array to be assigned.
Returns:
Array, assigned array.
Examples:
>>> w = np.empty(3, 5)
>>> KaimingUniform(w, mode='fan_in', nonlinearity='relu')
"""
def
_initialize
(
self
,
arr
):
fan
=
_select_fan
(
arr
,
self
.
mode
)
bound
=
math
.
sqrt
(
3.0
)
*
self
.
gain
/
math
.
sqrt
(
fan
)
np
.
random
.
seed
(
0
)
data
=
np
.
random
.
uniform
(
-
bound
,
bound
,
arr
.
shape
)
_assignment
(
arr
,
data
)
class
KaimingNormal
(
KaimingInit
):
r
"""
Initialize the array with He kaiming normal algorithm. The resulting tensor will
have values sampled from :math:`\mathcal{N}(0, \text{std}^2)` where
.. math::
\text{std} = \frac{\text{gain}}{\sqrt{\text{fan\_mode}}}
Input:
arr (Array): The array to be assigned.
Returns:
Array, assigned array.
Examples:
>>> w = np.empty(3, 5)
>>> KaimingNormal(w, mode='fan_out', nonlinearity='relu')
"""
def
_initialize
(
self
,
arr
):
fan
=
_select_fan
(
arr
,
self
.
mode
)
std
=
self
.
gain
/
math
.
sqrt
(
fan
)
np
.
random
.
seed
(
0
)
data
=
np
.
random
.
normal
(
0
,
std
,
arr
.
shape
)
_assignment
(
arr
,
data
)
def
default_recurisive_init
(
custom_cell
):
"""default_recurisive_init"""
for
_
,
cell
in
custom_cell
.
cells_and_names
():
if
isinstance
(
cell
,
nn
.
Conv2d
):
cell
.
weight
.
default_input
=
init
.
initializer
(
KaimingUniform
(
a
=
math
.
sqrt
(
5
)),
cell
.
weight
.
default_input
.
shape
,
cell
.
weight
.
default_input
.
dtype
).
to_tensor
()
if
cell
.
bias
is
not
None
:
fan_in
,
_
=
_calculate_in_and_out
(
cell
.
weight
.
default_input
.
asnumpy
())
bound
=
1
/
math
.
sqrt
(
fan_in
)
np
.
random
.
seed
(
0
)
cell
.
bias
.
default_input
=
Tensor
(
np
.
random
.
uniform
(
-
bound
,
bound
,
cell
.
bias
.
default_input
.
shape
),
cell
.
bias
.
default_input
.
dtype
)
elif
isinstance
(
cell
,
nn
.
Dense
):
cell
.
weight
.
default_input
=
init
.
initializer
(
KaimingUniform
(
a
=
math
.
sqrt
(
5
)),
cell
.
weight
.
default_input
.
shape
,
cell
.
weight
.
default_input
.
dtype
).
to_tensor
()
if
cell
.
bias
is
not
None
:
fan_in
,
_
=
_calculate_in_and_out
(
cell
.
weight
.
default_input
.
asnumpy
())
bound
=
1
/
math
.
sqrt
(
fan_in
)
np
.
random
.
seed
(
0
)
cell
.
bias
.
default_input
=
Tensor
(
np
.
random
.
uniform
(
-
bound
,
bound
,
cell
.
bias
.
default_input
.
shape
),
cell
.
bias
.
default_input
.
dtype
)
elif
isinstance
(
cell
,
(
nn
.
BatchNorm2d
,
nn
.
BatchNorm1d
)):
pass
model_zoo/official/cv/vgg16/src/vgg.py
浏览文件 @
cf3fff89
...
...
@@ -12,12 +12,18 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""VGG."""
"""
Image classifiation.
"""
import
math
import
mindspore.nn
as
nn
from
mindspore.common.initializer
import
initializer
import
mindspore.common.dtype
as
mstype
from
mindspore.common
import
initializer
as
init
from
mindspore.common.initializer
import
initializer
from
.utils.var_init
import
default_recurisive_init
,
KaimingNormal
def
_make_layer
(
base
,
batch_norm
):
def
_make_layer
(
base
,
args
,
batch_norm
):
"""Make stage network of VGG."""
layers
=
[]
in_channels
=
3
...
...
@@ -27,11 +33,14 @@ def _make_layer(base, batch_norm):
else
:
weight_shape
=
(
v
,
in_channels
,
3
,
3
)
weight
=
initializer
(
'XavierUniform'
,
shape
=
weight_shape
,
dtype
=
mstype
.
float32
).
to_tensor
()
if
args
.
dataset
==
"imagenet2012"
:
weight
=
'normal'
conv2d
=
nn
.
Conv2d
(
in_channels
=
in_channels
,
out_channels
=
v
,
kernel_size
=
3
,
padding
=
0
,
pad_mode
=
'same'
,
padding
=
args
.
padding
,
pad_mode
=
args
.
pad_mode
,
has_bias
=
args
.
has_bias
,
weight_init
=
weight
)
if
batch_norm
:
layers
+=
[
conv2d
,
nn
.
BatchNorm2d
(
v
),
nn
.
ReLU
()]
...
...
@@ -59,17 +68,25 @@ class Vgg(nn.Cell):
>>> num_classes=1000, batch_norm=False, batch_size=1)
"""
def
__init__
(
self
,
base
,
num_classes
=
1000
,
batch_norm
=
False
,
batch_size
=
1
):
def
__init__
(
self
,
base
,
num_classes
=
1000
,
batch_norm
=
False
,
batch_size
=
1
,
args
=
None
,
phase
=
"train"
):
super
(
Vgg
,
self
).
__init__
()
_
=
batch_size
self
.
layers
=
_make_layer
(
base
,
batch_norm
=
batch_norm
)
self
.
layers
=
_make_layer
(
base
,
args
,
batch_norm
=
batch_norm
)
self
.
flatten
=
nn
.
Flatten
()
dropout_ratio
=
0.5
if
args
.
dataset
==
"cifar10"
or
phase
==
"test"
:
dropout_ratio
=
1.0
self
.
classifier
=
nn
.
SequentialCell
([
nn
.
Dense
(
512
*
7
*
7
,
4096
),
nn
.
ReLU
(),
nn
.
Dropout
(
dropout_ratio
),
nn
.
Dense
(
4096
,
4096
),
nn
.
ReLU
(),
nn
.
Dropout
(
dropout_ratio
),
nn
.
Dense
(
4096
,
num_classes
)])
if
args
.
dataset
==
"imagenet2012"
:
default_recurisive_init
(
self
)
self
.
custom_init_weight
()
def
construct
(
self
,
x
):
x
=
self
.
layers
(
x
)
...
...
@@ -77,6 +94,25 @@ class Vgg(nn.Cell):
x
=
self
.
classifier
(
x
)
return
x
def
custom_init_weight
(
self
):
"""
Init the weight of Conv2d and Dense in the net.
"""
for
_
,
cell
in
self
.
cells_and_names
():
if
isinstance
(
cell
,
nn
.
Conv2d
):
cell
.
weight
.
default_input
=
init
.
initializer
(
KaimingNormal
(
a
=
math
.
sqrt
(
5
),
mode
=
'fan_out'
,
nonlinearity
=
'relu'
),
cell
.
weight
.
default_input
.
shape
,
cell
.
weight
.
default_input
.
dtype
).
to_tensor
()
if
cell
.
bias
is
not
None
:
cell
.
bias
.
default_input
=
init
.
initializer
(
'zeros'
,
cell
.
bias
.
default_input
.
shape
,
cell
.
bias
.
default_input
.
dtype
).
to_tensor
()
elif
isinstance
(
cell
,
nn
.
Dense
):
cell
.
weight
.
default_input
=
init
.
initializer
(
init
.
Normal
(
0.01
),
cell
.
weight
.
default_input
.
shape
,
cell
.
weight
.
default_input
.
dtype
).
to_tensor
()
if
cell
.
bias
is
not
None
:
cell
.
bias
.
default_input
=
init
.
initializer
(
'zeros'
,
cell
.
bias
.
default_input
.
shape
,
cell
.
bias
.
default_input
.
dtype
).
to_tensor
()
cfg
=
{
'11'
:
[
64
,
'M'
,
128
,
'M'
,
256
,
256
,
'M'
,
512
,
512
,
'M'
,
512
,
512
,
'M'
],
...
...
@@ -86,12 +122,14 @@ cfg = {
}
def
vgg16
(
num_classes
=
1000
):
def
vgg16
(
num_classes
=
1000
,
args
=
None
,
phase
=
"train"
):
"""
Get Vgg16 neural network with batch normalization.
Args:
num_classes (int): Class numbers. Default: 1000.
args(dict): param for net init.
phase(str): train or test mode.
Returns:
Cell, cell instance of Vgg16 neural network with batch normalization.
...
...
@@ -100,5 +138,5 @@ def vgg16(num_classes=1000):
>>> vgg16(num_classes=1000)
"""
net
=
Vgg
(
cfg
[
'16'
],
num_classes
=
num_classes
,
batch_norm
=
Tru
e
)
net
=
Vgg
(
cfg
[
'16'
],
num_classes
=
num_classes
,
args
=
args
,
batch_norm
=
True
,
phase
=
phas
e
)
return
net
model_zoo/official/cv/vgg16/src/warmup_cosine_annealing_lr.py
0 → 100644
浏览文件 @
cf3fff89
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
warm up cosine annealing learning rate.
"""
import
math
import
numpy
as
np
from
.linear_warmup
import
linear_warmup_lr
def
warmup_cosine_annealing_lr
(
lr
,
steps_per_epoch
,
warmup_epochs
,
max_epoch
,
T_max
,
eta_min
=
0
):
"""warm up cosine annealing learning rate."""
base_lr
=
lr
warmup_init_lr
=
0
total_steps
=
int
(
max_epoch
*
steps_per_epoch
)
warmup_steps
=
int
(
warmup_epochs
*
steps_per_epoch
)
lr_each_step
=
[]
for
i
in
range
(
total_steps
):
last_epoch
=
i
//
steps_per_epoch
if
i
<
warmup_steps
:
lr
=
linear_warmup_lr
(
i
+
1
,
warmup_steps
,
base_lr
,
warmup_init_lr
)
else
:
lr
=
eta_min
+
(
base_lr
-
eta_min
)
*
(
1.
+
math
.
cos
(
math
.
pi
*
last_epoch
/
T_max
))
/
2
lr_each_step
.
append
(
lr
)
return
np
.
array
(
lr_each_step
).
astype
(
np
.
float32
)
model_zoo/official/cv/vgg16/src/warmup_step_lr.py
0 → 100644
浏览文件 @
cf3fff89
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
warm up step learning rate.
"""
from
collections
import
Counter
import
numpy
as
np
from
.linear_warmup
import
linear_warmup_lr
def
lr_steps
(
global_step
,
lr_init
,
lr_max
,
warmup_epochs
,
total_epochs
,
steps_per_epoch
):
"""Set learning rate."""
lr_each_step
=
[]
total_steps
=
steps_per_epoch
*
total_epochs
warmup_steps
=
steps_per_epoch
*
warmup_epochs
if
warmup_steps
!=
0
:
inc_each_step
=
(
float
(
lr_max
)
-
float
(
lr_init
))
/
float
(
warmup_steps
)
else
:
inc_each_step
=
0
for
i
in
range
(
total_steps
):
if
i
<
warmup_steps
:
lr_value
=
float
(
lr_init
)
+
inc_each_step
*
float
(
i
)
else
:
base
=
(
1.0
-
(
float
(
i
)
-
float
(
warmup_steps
))
/
(
float
(
total_steps
)
-
float
(
warmup_steps
)))
lr_value
=
float
(
lr_max
)
*
base
*
base
if
lr_value
<
0.0
:
lr_value
=
0.0
lr_each_step
.
append
(
lr_value
)
current_step
=
global_step
lr_each_step
=
np
.
array
(
lr_each_step
).
astype
(
np
.
float32
)
learning_rate
=
lr_each_step
[
current_step
:]
return
learning_rate
def
warmup_step_lr
(
lr
,
lr_epochs
,
steps_per_epoch
,
warmup_epochs
,
max_epoch
,
gamma
=
0.1
):
"""warmup_step_lr"""
base_lr
=
lr
warmup_init_lr
=
0
total_steps
=
int
(
max_epoch
*
steps_per_epoch
)
warmup_steps
=
int
(
warmup_epochs
*
steps_per_epoch
)
milestones
=
lr_epochs
milestones_steps
=
[]
for
milestone
in
milestones
:
milestones_step
=
milestone
*
steps_per_epoch
milestones_steps
.
append
(
milestones_step
)
lr_each_step
=
[]
lr
=
base_lr
milestones_steps_counter
=
Counter
(
milestones_steps
)
for
i
in
range
(
total_steps
):
if
i
<
warmup_steps
:
lr
=
linear_warmup_lr
(
i
+
1
,
warmup_steps
,
base_lr
,
warmup_init_lr
)
else
:
lr
=
lr
*
gamma
**
milestones_steps_counter
[
i
]
lr_each_step
.
append
(
lr
)
return
np
.
array
(
lr_each_step
).
astype
(
np
.
float32
)
def
multi_step_lr
(
lr
,
milestones
,
steps_per_epoch
,
max_epoch
,
gamma
=
0.1
):
return
warmup_step_lr
(
lr
,
milestones
,
steps_per_epoch
,
0
,
max_epoch
,
gamma
=
gamma
)
def
step_lr
(
lr
,
epoch_size
,
steps_per_epoch
,
max_epoch
,
gamma
=
0.1
):
lr_epochs
=
[]
for
i
in
range
(
1
,
max_epoch
):
if
i
%
epoch_size
==
0
:
lr_epochs
.
append
(
i
)
return
multi_step_lr
(
lr
,
lr_epochs
,
steps_per_epoch
,
max_epoch
,
gamma
=
gamma
)
model_zoo/official/cv/vgg16/train.py
浏览文件 @
cf3fff89
...
...
@@ -17,6 +17,8 @@
python train.py --data_path=$DATA_HOME --device_id=$DEVICE_ID
"""
import
argparse
import
datetime
import
time
import
os
import
random
...
...
@@ -25,83 +27,264 @@ import numpy as np
import
mindspore.nn
as
nn
from
mindspore
import
Tensor
from
mindspore
import
context
from
mindspore.communication.management
import
init
from
mindspore.communication.management
import
init
,
get_rank
,
get_group_size
from
mindspore.nn.optim.momentum
import
Momentum
from
mindspore.train.callback
import
ModelCheckpoint
,
CheckpointConfig
,
LossMonitor
,
TimeMonitor
from
mindspore.train.callback
import
Callback
,
ModelCheckpoint
,
CheckpointConfig
from
mindspore.train.model
import
Model
,
ParallelMode
from
mindspore.train.serialization
import
load_param_into_net
,
load_checkpoint
from
src.config
import
cifar_cfg
as
cfg
from
mindspore.train.loss_scale_manager
import
FixedLossScaleManager
from
src.dataset
import
vgg_create_dataset
from
src.dataset
import
classification_dataset
from
src.crossentropy
import
CrossEntropy
from
src.warmup_step_lr
import
warmup_step_lr
from
src.warmup_cosine_annealing_lr
import
warmup_cosine_annealing_lr
from
src.warmup_step_lr
import
lr_steps
from
src.utils.logging
import
get_logger
from
src.utils.util
import
get_param_groups
from
src.vgg
import
vgg16
random
.
seed
(
1
)
np
.
random
.
seed
(
1
)
def
lr_steps
(
global_step
,
lr_init
,
lr_max
,
warmup_epochs
,
total_epochs
,
steps_per_epoch
):
"""Set learning rate."""
lr_each_step
=
[]
total_steps
=
steps_per_epoch
*
total_epochs
warmup_steps
=
steps_per_epoch
*
warmup_epochs
if
warmup_steps
!=
0
:
inc_each_step
=
(
float
(
lr_max
)
-
float
(
lr_init
))
/
float
(
warmup_steps
)
else
:
inc_each_step
=
0
for
i
in
range
(
total_steps
):
if
i
<
warmup_steps
:
lr_value
=
float
(
lr_init
)
+
inc_each_step
*
float
(
i
)
else
:
base
=
(
1.0
-
(
float
(
i
)
-
float
(
warmup_steps
))
/
(
float
(
total_steps
)
-
float
(
warmup_steps
)))
lr_value
=
float
(
lr_max
)
*
base
*
base
if
lr_value
<
0.0
:
lr_value
=
0.0
lr_each_step
.
append
(
lr_value
)
class
ProgressMonitor
(
Callback
):
"""monitor loss and time"""
def
__init__
(
self
,
args_param
):
super
(
ProgressMonitor
,
self
).
__init__
()
self
.
me_epoch_start_time
=
0
self
.
me_epoch_start_step_num
=
0
self
.
args
=
args_param
self
.
ckpt_history
=
[]
current_step
=
global_step
lr_each_step
=
np
.
array
(
lr_each_step
).
astype
(
np
.
float32
)
learning_rate
=
lr_each_step
[
current_step
:]
def
begin
(
self
,
run_context
):
self
.
args
.
logger
.
info
(
'start network train...'
)
return
learning_rate
def
epoch_begin
(
self
,
run_context
):
pass
def
epoch_end
(
self
,
run_context
):
"""
Called after each epoch finished.
if
__name__
==
'__main__'
:
parser
=
argparse
.
ArgumentParser
(
description
=
'Cifar10 classification'
)
parser
.
add_argument
(
'--device_target'
,
type
=
str
,
default
=
'Ascend'
,
choices
=
[
'Ascend'
,
'GPU'
],
Args:
run_context (RunContext): Include some information of the model.
"""
cb_params
=
run_context
.
original_args
()
me_step
=
cb_params
.
cur_step_num
-
1
real_epoch
=
me_step
//
self
.
args
.
steps_per_epoch
time_used
=
time
.
time
()
-
self
.
me_epoch_start_time
fps_mean
=
self
.
args
.
per_batch_size
*
(
me_step
-
self
.
me_epoch_start_step_num
)
*
self
.
args
.
group_size
/
time_used
self
.
args
.
logger
.
info
(
'epoch[{}], iter[{}], loss:{}, mean_fps:{:.2f}'
'imgs/sec'
.
format
(
real_epoch
,
me_step
,
cb_params
.
net_outputs
,
fps_mean
))
if
self
.
args
.
rank_save_ckpt_flag
:
import
glob
ckpts
=
glob
.
glob
(
os
.
path
.
join
(
self
.
args
.
outputs_dir
,
'*.ckpt'
))
for
ckpt
in
ckpts
:
ckpt_fn
=
os
.
path
.
basename
(
ckpt
)
if
not
ckpt_fn
.
startswith
(
'{}-'
.
format
(
self
.
args
.
rank
)):
continue
if
ckpt
in
self
.
ckpt_history
:
continue
self
.
ckpt_history
.
append
(
ckpt
)
self
.
args
.
logger
.
info
(
'epoch[{}], iter[{}], loss:{}, ckpt:{},'
'ckpt_fn:{}'
.
format
(
real_epoch
,
me_step
,
cb_params
.
net_outputs
,
ckpt
,
ckpt_fn
))
self
.
me_epoch_start_step_num
=
me_step
self
.
me_epoch_start_time
=
time
.
time
()
def
step_begin
(
self
,
run_context
):
pass
def
step_end
(
self
,
run_context
,
*
me_args
):
pass
def
end
(
self
,
run_context
):
self
.
args
.
logger
.
info
(
'end network train...'
)
def
parse_args
(
cloud_args
=
None
):
"""parameters"""
parser
=
argparse
.
ArgumentParser
(
'mindspore classification training'
)
parser
.
add_argument
(
'--device_target'
,
type
=
str
,
default
=
'GPU'
,
choices
=
[
'Ascend'
,
'GPU'
],
help
=
'device where the code will be implemented. (Default: Ascend)'
)
parser
.
add_argument
(
'--data_path'
,
type
=
str
,
default
=
'./cifar'
,
help
=
'path where the dataset is saved'
)
parser
.
add_argument
(
'--device_id'
,
type
=
int
,
default
=
None
,
help
=
'device id of GPU or Ascend. (Default: None)'
)
parser
.
add_argument
(
'--pre_trained'
,
type
=
str
,
default
=
None
,
help
=
'the pretrained checkpoint file path.'
)
parser
.
add_argument
(
'--device_id'
,
type
=
int
,
default
=
1
,
help
=
'device id of GPU or Ascend. (Default: None)'
)
# dataset related
parser
.
add_argument
(
'--dataset'
,
type
=
str
,
choices
=
[
"cifar10"
,
"imagenet2012"
],
default
=
"cifar10"
)
parser
.
add_argument
(
'--data_path'
,
type
=
str
,
default
=
''
,
help
=
'train data dir'
)
# network related
parser
.
add_argument
(
'--pre_trained'
,
default
=
''
,
type
=
str
,
help
=
'model_path, local pretrained model to load'
)
parser
.
add_argument
(
'--lr_gamma'
,
type
=
float
,
default
=
0.1
,
help
=
'decrease lr by a factor of exponential lr_scheduler'
)
parser
.
add_argument
(
'--eta_min'
,
type
=
float
,
default
=
0.
,
help
=
'eta_min in cosine_annealing scheduler'
)
parser
.
add_argument
(
'--T_max'
,
type
=
int
,
default
=
150
,
help
=
'T-max in cosine_annealing scheduler'
)
# logging and checkpoint related
parser
.
add_argument
(
'--log_interval'
,
type
=
int
,
default
=
100
,
help
=
'logging interval'
)
parser
.
add_argument
(
'--ckpt_path'
,
type
=
str
,
default
=
'outputs/'
,
help
=
'checkpoint save location'
)
parser
.
add_argument
(
'--ckpt_interval'
,
type
=
int
,
default
=
5000
,
help
=
'ckpt_interval'
)
parser
.
add_argument
(
'--is_save_on_master'
,
type
=
int
,
default
=
1
,
help
=
'save ckpt on master or all rank'
)
# distributed related
parser
.
add_argument
(
'--is_distributed'
,
type
=
int
,
default
=
0
,
help
=
'if multi device'
)
parser
.
add_argument
(
'--rank'
,
type
=
int
,
default
=
0
,
help
=
'local rank of distributed'
)
parser
.
add_argument
(
'--group_size'
,
type
=
int
,
default
=
1
,
help
=
'world size of distributed'
)
args_opt
=
parser
.
parse_args
()
args_opt
=
merge_args
(
args_opt
,
cloud_args
)
if
args_opt
.
dataset
==
"cifar10"
:
from
src.config
import
cifar_cfg
as
cfg
else
:
from
src.config
import
imagenet_cfg
as
cfg
args_opt
.
label_smooth
=
cfg
.
label_smooth
args_opt
.
label_smooth_factor
=
cfg
.
label_smooth_factor
args_opt
.
lr_scheduler
=
cfg
.
lr_scheduler
args_opt
.
loss_scale
=
cfg
.
loss_scale
args_opt
.
max_epoch
=
cfg
.
max_epoch
args_opt
.
warmup_epochs
=
cfg
.
warmup_epochs
args_opt
.
lr
=
cfg
.
lr
args_opt
.
lr_init
=
cfg
.
lr_init
args_opt
.
lr_max
=
cfg
.
lr_max
args_opt
.
momentum
=
cfg
.
momentum
args_opt
.
weight_decay
=
cfg
.
weight_decay
args_opt
.
per_batch_size
=
cfg
.
batch_size
args_opt
.
num_classes
=
cfg
.
num_classes
args_opt
.
buffer_size
=
cfg
.
buffer_size
args_opt
.
ckpt_save_max
=
cfg
.
keep_checkpoint_max
args_opt
.
pad_mode
=
cfg
.
pad_mode
args_opt
.
padding
=
cfg
.
padding
args_opt
.
has_bias
=
cfg
.
has_bias
args_opt
.
batch_norm
=
cfg
.
batch_norm
args_opt
.
lr_epochs
=
list
(
map
(
int
,
cfg
.
lr_epochs
.
split
(
','
)))
args_opt
.
image_size
=
list
(
map
(
int
,
cfg
.
image_size
.
split
(
','
)))
return
args_opt
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
device_target
=
args_opt
.
device_target
)
context
.
set_context
(
device_id
=
args_opt
.
device_id
)
def
merge_args
(
args_opt
,
cloud_args
):
"""dictionary"""
args_dict
=
vars
(
args_opt
)
if
isinstance
(
cloud_args
,
dict
):
for
key_arg
in
cloud_args
.
keys
():
val
=
cloud_args
[
key_arg
]
if
key_arg
in
args_dict
and
val
:
arg_type
=
type
(
args_dict
[
key_arg
])
if
arg_type
is
not
None
:
val
=
arg_type
(
val
)
args_dict
[
key_arg
]
=
val
return
args_opt
if
__name__
==
'__main__'
:
args
=
parse_args
()
device_num
=
int
(
os
.
environ
.
get
(
"DEVICE_NUM"
,
1
))
if
device_num
>
1
:
if
args
.
is_distributed
:
if
args
.
device_target
==
"Ascend"
:
init
()
elif
args
.
device_target
==
"GPU"
:
init
(
"nccl"
)
args
.
rank
=
get_rank
()
args
.
group_size
=
get_group_size
()
device_num
=
args
.
group_size
context
.
reset_auto_parallel_context
()
context
.
set_auto_parallel_context
(
device_num
=
device_num
,
parallel_mode
=
ParallelMode
.
DATA_PARALLEL
,
mirror_mean
=
True
)
init
()
else
:
context
.
set_context
(
device_id
=
args
.
device_id
)
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
device_target
=
args
.
device_target
)
# select for master rank save ckpt or all rank save, compatiable for model parallel
args
.
rank_save_ckpt_flag
=
0
if
args
.
is_save_on_master
:
if
args
.
rank
==
0
:
args
.
rank_save_ckpt_flag
=
1
else
:
args
.
rank_save_ckpt_flag
=
1
# logger
args
.
outputs_dir
=
os
.
path
.
join
(
args
.
ckpt_path
,
datetime
.
datetime
.
now
().
strftime
(
'%Y-%m-%d_time_%H_%M_%S'
))
args
.
logger
=
get_logger
(
args
.
outputs_dir
,
args
.
rank
)
if
args
.
dataset
==
"cifar10"
:
dataset
=
vgg_create_dataset
(
args
.
data_path
,
args
.
image_size
,
args
.
per_batch_size
,
args
.
rank
,
args
.
group_size
)
else
:
dataset
=
classification_dataset
(
args
.
data_path
,
args
.
image_size
,
args
.
per_batch_size
,
args
.
rank
,
args
.
group_size
)
dataset
=
vgg_create_dataset
(
args_opt
.
data_path
,
1
)
batch_num
=
dataset
.
get_dataset_size
()
args
.
steps_per_epoch
=
dataset
.
get_dataset_size
()
args
.
logger
.
save_args
(
args
)
# network
args
.
logger
.
important_info
(
'start create network'
)
# get network and init
network
=
vgg16
(
args
.
num_classes
,
args
)
net
=
vgg16
(
num_classes
=
cfg
.
num_classes
)
# pre_trained
if
args_opt
.
pre_trained
:
load_param_into_net
(
net
,
load_checkpoint
(
args_opt
.
pre_trained
))
lr
=
lr_steps
(
0
,
lr_init
=
cfg
.
lr_init
,
lr_max
=
cfg
.
lr_max
,
warmup_epochs
=
cfg
.
warmup_epochs
,
total_epochs
=
cfg
.
epoch_size
,
steps_per_epoch
=
batch_num
)
opt
=
Momentum
(
filter
(
lambda
x
:
x
.
requires_grad
,
net
.
get_parameters
()),
Tensor
(
lr
),
cfg
.
momentum
,
weight_decay
=
cfg
.
weight_decay
)
loss
=
nn
.
SoftmaxCrossEntropyWithLogits
(
sparse
=
True
,
reduction
=
'mean'
,
is_grad
=
False
)
model
=
Model
(
net
,
loss_fn
=
loss
,
optimizer
=
opt
,
metrics
=
{
'acc'
},
amp_level
=
"O2"
,
keep_batchnorm_fp32
=
False
,
loss_scale_manager
=
None
)
config_ck
=
CheckpointConfig
(
save_checkpoint_steps
=
batch_num
*
5
,
keep_checkpoint_max
=
cfg
.
keep_checkpoint_max
)
time_cb
=
TimeMonitor
(
data_size
=
batch_num
)
ckpoint_cb
=
ModelCheckpoint
(
prefix
=
"train_vgg_cifar10"
,
directory
=
"./"
,
config
=
config_ck
)
loss_cb
=
LossMonitor
()
model
.
train
(
cfg
.
epoch_size
,
dataset
,
callbacks
=
[
time_cb
,
ckpoint_cb
,
loss_cb
])
print
(
"train success"
)
if
args
.
pre_trained
:
load_param_into_net
(
network
,
load_checkpoint
(
args
.
pre_trained
))
# lr scheduler
if
args
.
lr_scheduler
==
'exponential'
:
lr
=
warmup_step_lr
(
args
.
lr
,
args
.
lr_epochs
,
args
.
steps_per_epoch
,
args
.
warmup_epochs
,
args
.
max_epoch
,
gamma
=
args
.
lr_gamma
,
)
elif
args
.
lr_scheduler
==
'cosine_annealing'
:
lr
=
warmup_cosine_annealing_lr
(
args
.
lr
,
args
.
steps_per_epoch
,
args
.
warmup_epochs
,
args
.
max_epoch
,
args
.
T_max
,
args
.
eta_min
)
elif
args
.
lr_scheduler
==
'step'
:
lr
=
lr_steps
(
0
,
lr_init
=
args
.
lr_init
,
lr_max
=
args
.
lr_max
,
warmup_epochs
=
args
.
warmup_epochs
,
total_epochs
=
args
.
max_epoch
,
steps_per_epoch
=
batch_num
)
else
:
raise
NotImplementedError
(
args
.
lr_scheduler
)
# optimizer
opt
=
Momentum
(
params
=
get_param_groups
(
network
),
learning_rate
=
Tensor
(
lr
),
momentum
=
args
.
momentum
,
weight_decay
=
args
.
weight_decay
,
loss_scale
=
args
.
loss_scale
)
if
args
.
dataset
==
"cifar10"
:
loss
=
nn
.
SoftmaxCrossEntropyWithLogits
(
sparse
=
True
,
reduction
=
'mean'
,
is_grad
=
False
)
model
=
Model
(
network
,
loss_fn
=
loss
,
optimizer
=
opt
,
metrics
=
{
'acc'
},
amp_level
=
"O2"
,
keep_batchnorm_fp32
=
False
,
loss_scale_manager
=
None
)
else
:
if
not
args
.
label_smooth
:
args
.
label_smooth_factor
=
0.0
loss
=
CrossEntropy
(
smooth_factor
=
args
.
label_smooth_factor
,
num_classes
=
args
.
num_classes
)
loss_scale_manager
=
FixedLossScaleManager
(
args
.
loss_scale
,
drop_overflow_update
=
False
)
model
=
Model
(
network
,
loss_fn
=
loss
,
optimizer
=
opt
,
loss_scale_manager
=
loss_scale_manager
,
amp_level
=
"O2"
)
# checkpoint save
progress_cb
=
ProgressMonitor
(
args
)
callbacks
=
[
progress_cb
,]
if
args
.
rank_save_ckpt_flag
:
ckpt_config
=
CheckpointConfig
(
save_checkpoint_steps
=
args
.
ckpt_interval
*
args
.
steps_per_epoch
,
keep_checkpoint_max
=
args
.
ckpt_save_max
)
ckpt_cb
=
ModelCheckpoint
(
config
=
ckpt_config
,
directory
=
args
.
outputs_dir
,
prefix
=
'{}'
.
format
(
args
.
rank
))
callbacks
.
append
(
ckpt_cb
)
model
.
train
(
args
.
max_epoch
,
dataset
,
callbacks
=
callbacks
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录