Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MindSpore
book
提交
17f96e95
B
book
项目概览
MindSpore
/
book
通知
3
Star
1
Fork
1
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
B
book
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
前往新版Gitcode,体验更适合开发者的 AI 搜索 >>
提交
17f96e95
编写于
8月 19, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
8月 19, 2020
浏览文件
操作
浏览文件
下载
差异文件
!17 change bool argument in argparse
Merge pull request !17 from panbingao/master
上级
1e07ba05
6bbe8821
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
9 addition
and
7 deletion
+9
-7
chapter05/resnet/resnet_cifar.py
chapter05/resnet/resnet_cifar.py
+9
-7
未找到文件。
chapter05/resnet/resnet_cifar.py
浏览文件 @
17f96e95
...
@@ -21,6 +21,7 @@ import os
...
@@ -21,6 +21,7 @@ import os
import
random
import
random
import
argparse
import
argparse
import
numpy
as
np
import
numpy
as
np
from
resnet
import
resnet50
import
mindspore.nn
as
nn
import
mindspore.nn
as
nn
import
mindspore.common.dtype
as
mstype
import
mindspore.common.dtype
as
mstype
import
mindspore.ops.functional
as
F
import
mindspore.ops.functional
as
F
...
@@ -37,7 +38,6 @@ from mindspore.train.serialization import load_checkpoint, load_param_into_net
...
@@ -37,7 +38,6 @@ from mindspore.train.serialization import load_checkpoint, load_param_into_net
from
mindspore.communication.management
import
init
from
mindspore.communication.management
import
init
from
mindspore.parallel._auto_parallel_context
import
auto_parallel_context
from
mindspore.parallel._auto_parallel_context
import
auto_parallel_context
from
mindspore.nn.loss
import
SoftmaxCrossEntropyWithLogits
from
mindspore.nn.loss
import
SoftmaxCrossEntropyWithLogits
from
resnet
import
resnet50
random
.
seed
(
1
)
random
.
seed
(
1
)
np
.
random
.
seed
(
1
)
np
.
random
.
seed
(
1
)
de
.
config
.
set_seed
(
1
)
de
.
config
.
set_seed
(
1
)
...
@@ -45,10 +45,10 @@ de.config.set_seed(1)
...
@@ -45,10 +45,10 @@ de.config.set_seed(1)
parser
=
argparse
.
ArgumentParser
(
description
=
'Image classification.'
)
parser
=
argparse
.
ArgumentParser
(
description
=
'Image classification.'
)
parser
.
add_argument
(
'--device_target'
,
type
=
str
,
default
=
"Ascend"
,
choices
=
[
'Ascend'
,
'GPU'
],
parser
.
add_argument
(
'--device_target'
,
type
=
str
,
default
=
"Ascend"
,
choices
=
[
'Ascend'
,
'GPU'
],
help
=
'device where the code will be implemented (default: Ascend)'
)
help
=
'device where the code will be implemented (default: Ascend)'
)
parser
.
add_argument
(
'--run_distribute'
,
type
=
bool
,
default
=
False
,
help
=
'Run distribute
i
.'
)
parser
.
add_argument
(
'--run_distribute'
,
type
=
bool
,
default
=
False
,
help
=
'Run distribute.'
)
parser
.
add_argument
(
'--device_num'
,
type
=
int
,
default
=
1
,
help
=
'Device num.'
)
parser
.
add_argument
(
'--device_num'
,
type
=
int
,
default
=
1
,
help
=
'Device num.'
)
parser
.
add_argument
(
'--
do_train'
,
type
=
bool
,
default
=
True
,
help
=
'Do train or not.'
)
parser
.
add_argument
(
'--
mode'
,
type
=
str
,
default
=
"train"
,
choices
=
[
'train'
,
'test'
],
parser
.
add_argument
(
'--do_eval'
,
type
=
bool
,
default
=
False
,
help
=
'Do eval or not.
'
)
help
=
'implement phase, set to train or test
'
)
parser
.
add_argument
(
'--epoch_size'
,
type
=
int
,
default
=
1
,
help
=
'Epoch size.'
)
parser
.
add_argument
(
'--epoch_size'
,
type
=
int
,
default
=
1
,
help
=
'Epoch size.'
)
parser
.
add_argument
(
'--batch_size'
,
type
=
int
,
default
=
32
,
help
=
'Batch size.'
)
parser
.
add_argument
(
'--batch_size'
,
type
=
int
,
default
=
32
,
help
=
'Batch size.'
)
parser
.
add_argument
(
'--num_classes'
,
type
=
int
,
default
=
10
,
help
=
'Num classes.'
)
parser
.
add_argument
(
'--num_classes'
,
type
=
int
,
default
=
10
,
help
=
'Num classes.'
)
...
@@ -112,7 +112,7 @@ def create_dataset(repeat_num=1, training=True):
...
@@ -112,7 +112,7 @@ def create_dataset(repeat_num=1, training=True):
return
ds
return
ds
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
if
not
args_opt
.
do_eval
and
args_opt
.
run_distribute
:
if
args_opt
.
mode
==
'train'
and
args_opt
.
run_distribute
:
context
.
set_auto_parallel_context
(
device_num
=
args_opt
.
device_num
,
parallel_mode
=
ParallelMode
.
DATA_PARALLEL
,
mirror_mean
=
True
)
context
.
set_auto_parallel_context
(
device_num
=
args_opt
.
device_num
,
parallel_mode
=
ParallelMode
.
DATA_PARALLEL
,
mirror_mean
=
True
)
auto_parallel_context
().
set_all_reduce_fusion_split_indices
([
140
])
auto_parallel_context
().
set_all_reduce_fusion_split_indices
([
140
])
init
()
init
()
...
@@ -124,7 +124,8 @@ if __name__ == '__main__':
...
@@ -124,7 +124,8 @@ if __name__ == '__main__':
model
=
Model
(
net
,
loss_fn
=
ls
,
optimizer
=
opt
,
metrics
=
{
'acc'
})
model
=
Model
(
net
,
loss_fn
=
ls
,
optimizer
=
opt
,
metrics
=
{
'acc'
})
if
args_opt
.
do_train
:
if
args_opt
.
mode
==
'train'
:
# train
print
(
"============== Starting Training =============="
)
dataset
=
create_dataset
()
dataset
=
create_dataset
()
batch_num
=
dataset
.
get_dataset_size
()
batch_num
=
dataset
.
get_dataset_size
()
config_ck
=
CheckpointConfig
(
save_checkpoint_steps
=
batch_num
,
keep_checkpoint_max
=
10
)
config_ck
=
CheckpointConfig
(
save_checkpoint_steps
=
batch_num
,
keep_checkpoint_max
=
10
)
...
@@ -133,7 +134,8 @@ if __name__ == '__main__':
...
@@ -133,7 +134,8 @@ if __name__ == '__main__':
loss_cb
=
LossMonitor
()
loss_cb
=
LossMonitor
()
model
.
train
(
epoch_size
,
dataset
,
callbacks
=
[
ckpoint_cb
,
loss_cb
])
model
.
train
(
epoch_size
,
dataset
,
callbacks
=
[
ckpoint_cb
,
loss_cb
])
if
args_opt
.
do_eval
:
if
args_opt
.
mode
==
'test'
:
# test
print
(
"============== Starting Testing =============="
)
if
args_opt
.
checkpoint_path
:
if
args_opt
.
checkpoint_path
:
param_dict
=
load_checkpoint
(
args_opt
.
checkpoint_path
)
param_dict
=
load_checkpoint
(
args_opt
.
checkpoint_path
)
load_param_into_net
(
net
,
param_dict
)
load_param_into_net
(
net
,
param_dict
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录