Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
e08e4088
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
e08e4088
编写于
9月 08, 2020
作者:
C
chujinjin
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix model zoo error for pynative
上级
8a71db07
变更
7
隐藏空白更改
内联
并排
Showing
7 changed file
with
45 addition
and
22 deletion
+45
-22
model_zoo/official/cv/mobilenetv2/src/args.py
model_zoo/official/cv/mobilenetv2/src/args.py
+2
-1
model_zoo/official/cv/mobilenetv2/src/config.py
model_zoo/official/cv/mobilenetv2/src/config.py
+1
-0
model_zoo/official/cv/mobilenetv2/src/dataset.py
model_zoo/official/cv/mobilenetv2/src/dataset.py
+6
-3
model_zoo/official/cv/mobilenetv2/src/utils.py
model_zoo/official/cv/mobilenetv2/src/utils.py
+13
-6
model_zoo/official/cv/mobilenetv3/src/dataset.py
model_zoo/official/cv/mobilenetv3/src/dataset.py
+9
-5
model_zoo/official/cv/mobilenetv3/train.py
model_zoo/official/cv/mobilenetv3/train.py
+13
-6
model_zoo/official/recommend/wide_and_deep/src/wide_and_deep.py
...zoo/official/recommend/wide_and_deep/src/wide_and_deep.py
+1
-1
未找到文件。
model_zoo/official/cv/mobilenetv2/src/args.py
浏览文件 @
e08e4088
...
...
@@ -14,7 +14,7 @@
# ============================================================================
import
argparse
import
ast
def
launch_parse_args
():
...
...
@@ -43,6 +43,7 @@ def train_parse_args():
help
=
'run platform, only support CPU, GPU and Ascend'
)
train_parser
.
add_argument
(
'--pretrain_ckpt'
,
type
=
str
,
default
=
None
,
help
=
'Pretrained checkpoint path
\
for fine tune or incremental learning'
)
train_parser
.
add_argument
(
'--run_distribute'
,
type
=
ast
.
literal_eval
,
default
=
True
,
help
=
'Run distribute'
)
train_parser
.
add_argument
(
'--train_method'
,
type
=
str
,
choices
=
(
"train"
,
"fine_tune"
,
"incremental_learn"
),
\
help
=
"
\"
fine_tune
\"
or
\"
incremental_learn
\"
if to fine tune the net after loading the ckpt,
\"
train
\"
to
\
train from initialization model"
)
...
...
model_zoo/official/cv/mobilenetv2/src/config.py
浏览文件 @
e08e4088
...
...
@@ -59,6 +59,7 @@ def set_config(args):
"save_checkpoint_path"
:
"./checkpoint"
,
"platform"
:
args
.
platform
,
"ccl"
:
"nccl"
,
"run_distribute"
:
args
.
run_distribute
})
config_ascend
=
ed
({
"num_classes"
:
1000
,
...
...
model_zoo/official/cv/mobilenetv2/src/dataset.py
浏览文件 @
e08e4088
...
...
@@ -51,9 +51,12 @@ def create_dataset(dataset_path, do_train, config, repeat_num=1):
num_shards
=
rank_size
,
shard_id
=
rank_id
)
elif
config
.
platform
==
"GPU"
:
if
do_train
:
from
mindspore.communication.management
import
get_rank
,
get_group_size
ds
=
de
.
ImageFolderDataset
(
dataset_path
,
num_parallel_workers
=
8
,
shuffle
=
True
,
num_shards
=
get_group_size
(),
shard_id
=
get_rank
())
if
config
.
run_distribute
:
from
mindspore.communication.management
import
get_rank
,
get_group_size
ds
=
de
.
ImageFolderDataset
(
dataset_path
,
num_parallel_workers
=
8
,
shuffle
=
True
,
num_shards
=
get_group_size
(),
shard_id
=
get_rank
())
else
:
ds
=
de
.
ImageFolderDataset
(
dataset_path
,
num_parallel_workers
=
8
,
shuffle
=
True
)
else
:
ds
=
de
.
ImageFolderDataset
(
dataset_path
,
num_parallel_workers
=
8
,
shuffle
=
True
)
elif
config
.
platform
==
"CPU"
:
...
...
model_zoo/official/cv/mobilenetv2/src/utils.py
浏览文件 @
e08e4088
...
...
@@ -22,6 +22,7 @@ from mindspore.communication.management import get_rank, init, get_group_size
from
src.models
import
Monitor
def
switch_precision
(
net
,
data_type
,
config
):
if
config
.
platform
==
"Ascend"
:
net
.
to_float
(
data_type
)
...
...
@@ -29,17 +30,18 @@ def switch_precision(net, data_type, config):
if
isinstance
(
cell
,
nn
.
Dense
):
cell
.
to_float
(
mstype
.
float32
)
def
context_device_init
(
config
):
def
context_device_init
(
config
):
if
config
.
platform
==
"CPU"
:
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
device_target
=
config
.
platform
,
save_graphs
=
False
)
elif
config
.
platform
==
"GPU"
:
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
device_target
=
config
.
platform
,
save_graphs
=
False
)
init
(
"nccl"
)
context
.
set_auto_parallel_context
(
device_num
=
get_group_size
(),
parallel_mode
=
ParallelMode
.
DATA_PARALLEL
,
gradients_mean
=
True
)
if
config
.
run_distribute
:
init
(
"nccl"
)
context
.
set_auto_parallel_context
(
device_num
=
get_group_size
(),
parallel_mode
=
ParallelMode
.
DATA_PARALLEL
,
gradients_mean
=
True
)
elif
config
.
platform
==
"Ascend"
:
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
device_target
=
config
.
platform
,
device_id
=
config
.
device_id
,
...
...
@@ -53,6 +55,7 @@ def context_device_init(config):
else
:
raise
ValueError
(
"Only support CPU, GPU and Ascend."
)
def
set_context
(
config
):
if
config
.
platform
==
"CPU"
:
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
device_target
=
config
.
platform
,
...
...
@@ -64,6 +67,7 @@ def set_context(config):
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
device_target
=
config
.
platform
,
save_graphs
=
False
)
def
config_ckpoint
(
config
,
lr
,
step_size
):
cb
=
None
if
config
.
platform
in
(
"CPU"
,
"GPU"
)
or
config
.
rank_id
==
0
:
...
...
@@ -75,7 +79,10 @@ def config_ckpoint(config, lr, step_size):
ckpt_save_dir
=
config
.
save_checkpoint_path
if
config
.
platform
==
"GPU"
:
ckpt_save_dir
+=
"ckpt_"
+
str
(
get_rank
())
+
"/"
if
config
.
run_distribute
:
ckpt_save_dir
+=
"ckpt_"
+
str
(
get_rank
())
+
"/"
else
:
ckpt_save_dir
+=
"ckpt_"
+
"/"
ckpt_cb
=
ModelCheckpoint
(
prefix
=
"mobilenetV2"
,
directory
=
ckpt_save_dir
,
config
=
config_ck
)
cb
+=
[
ckpt_cb
]
...
...
model_zoo/official/cv/mobilenetv3/src/dataset.py
浏览文件 @
e08e4088
...
...
@@ -21,7 +21,7 @@ import mindspore.dataset.vision.c_transforms as C
import
mindspore.dataset.transforms.c_transforms
as
C2
def
create_dataset
(
dataset_path
,
do_train
,
config
,
device_target
,
repeat_num
=
1
,
batch_size
=
32
):
def
create_dataset
(
dataset_path
,
do_train
,
config
,
device_target
,
repeat_num
=
1
,
batch_size
=
32
,
run_distribute
=
False
):
"""
create a train or eval dataset
...
...
@@ -36,9 +36,12 @@ def create_dataset(dataset_path, do_train, config, device_target, repeat_num=1,
"""
if
device_target
==
"GPU"
:
if
do_train
:
from
mindspore.communication.management
import
get_rank
,
get_group_size
ds
=
de
.
ImageFolderDataset
(
dataset_path
,
num_parallel_workers
=
8
,
shuffle
=
True
,
num_shards
=
get_group_size
(),
shard_id
=
get_rank
())
if
run_distribute
:
from
mindspore.communication.management
import
get_rank
,
get_group_size
ds
=
de
.
ImageFolderDataset
(
dataset_path
,
num_parallel_workers
=
8
,
shuffle
=
True
,
num_shards
=
get_group_size
(),
shard_id
=
get_rank
())
else
:
ds
=
de
.
ImageFolderDataset
(
dataset_path
,
num_parallel_workers
=
8
,
shuffle
=
True
)
else
:
ds
=
de
.
ImageFolderDataset
(
dataset_path
,
num_parallel_workers
=
8
,
shuffle
=
True
)
else
:
...
...
@@ -56,7 +59,8 @@ def create_dataset(dataset_path, do_train, config, device_target, repeat_num=1,
resize_op
=
C
.
Resize
(
256
)
center_crop
=
C
.
CenterCrop
(
resize_width
)
rescale_op
=
C
.
RandomColorAdjust
(
brightness
=
0.4
,
contrast
=
0.4
,
saturation
=
0.4
)
normalize_op
=
C
.
Normalize
(
mean
=
[
0.485
*
255
,
0.456
*
255
,
0.406
*
255
],
std
=
[
0.229
*
255
,
0.224
*
255
,
0.225
*
255
])
normalize_op
=
C
.
Normalize
(
mean
=
[
0.485
*
255
,
0.456
*
255
,
0.406
*
255
],
std
=
[
0.229
*
255
,
0.224
*
255
,
0.225
*
255
])
change_swap_op
=
C
.
HWC2CHW
()
if
do_train
:
...
...
model_zoo/official/cv/mobilenetv3/train.py
浏览文件 @
e08e4088
...
...
@@ -16,6 +16,7 @@
import
time
import
argparse
import
ast
import
numpy
as
np
from
mindspore
import
context
...
...
@@ -46,16 +47,18 @@ parser = argparse.ArgumentParser(description='Image classification')
parser
.
add_argument
(
'--dataset_path'
,
type
=
str
,
default
=
None
,
help
=
'Dataset path'
)
parser
.
add_argument
(
'--pre_trained'
,
type
=
str
,
default
=
None
,
help
=
'Pretrained checkpoint path'
)
parser
.
add_argument
(
'--device_target'
,
type
=
str
,
default
=
"GPU"
,
help
=
'run device_target'
)
parser
.
add_argument
(
'--run_distribute'
,
type
=
ast
.
literal_eval
,
default
=
True
,
help
=
'Run distribute'
)
args_opt
=
parser
.
parse_args
()
if
args_opt
.
device_target
==
"GPU"
:
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
device_target
=
"GPU"
,
save_graphs
=
False
)
init
()
context
.
set_auto_parallel_context
(
device_num
=
get_group_size
(),
parallel_mode
=
ParallelMode
.
DATA_PARALLEL
,
gradients_mean
=
True
)
if
args_opt
.
run_distribute
:
init
()
context
.
set_auto_parallel_context
(
device_num
=
get_group_size
(),
parallel_mode
=
ParallelMode
.
DATA_PARALLEL
,
gradients_mean
=
True
)
else
:
raise
ValueError
(
"Unsupported device_target."
)
...
...
@@ -168,7 +171,8 @@ if __name__ == '__main__':
config
=
config_gpu
,
device_target
=
args_opt
.
device_target
,
repeat_num
=
1
,
batch_size
=
config_gpu
.
batch_size
)
batch_size
=
config_gpu
.
batch_size
,
run_distribute
=
args_opt
.
run_distribute
)
step_size
=
dataset
.
get_dataset_size
()
# resume
if
args_opt
.
pre_trained
:
...
...
@@ -191,7 +195,10 @@ if __name__ == '__main__':
loss_scale_manager
=
loss_scale
)
cb
=
[
Monitor
(
lr_init
=
lr
.
asnumpy
())]
ckpt_save_dir
=
config_gpu
.
save_checkpoint_path
+
"ckpt_"
+
str
(
get_rank
())
+
"/"
if
args_opt
.
run_distribute
:
ckpt_save_dir
=
config_gpu
.
save_checkpoint_path
+
"ckpt_"
+
str
(
get_rank
())
+
"/"
else
:
ckpt_save_dir
=
config_gpu
.
save_checkpoint_path
+
"ckpt_"
+
"/"
if
config_gpu
.
save_checkpoint
:
config_ck
=
CheckpointConfig
(
save_checkpoint_steps
=
config_gpu
.
save_checkpoint_epochs
*
step_size
,
keep_checkpoint_max
=
config_gpu
.
keep_checkpoint_max
)
...
...
model_zoo/official/recommend/wide_and_deep/src/wide_and_deep.py
浏览文件 @
e08e4088
...
...
@@ -399,6 +399,6 @@ class PredictWithSigmoid(nn.Cell):
self
.
sigmoid
=
P
.
Sigmoid
()
def
construct
(
self
,
batch_ids
,
batch_wts
,
labels
):
logits
,
_
,
_
,
=
self
.
network
(
batch_ids
,
batch_wts
)
logits
,
_
,
=
self
.
network
(
batch_ids
,
batch_wts
)
pred_probs
=
self
.
sigmoid
(
logits
)
return
logits
,
pred_probs
,
labels
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录