Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
weixin_44025039
mindspore
提交
73806e0c
M
mindspore
项目概览
weixin_44025039
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
73806e0c
编写于
8月 01, 2020
作者:
V
VectorSL
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
gpu update resnet in modelzoo
上级
6c4ee3f3
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
18 addition
and
8 deletion
+18
-8
model_zoo/official/cv/resnet/README.md
model_zoo/official/cv/resnet/README.md
+1
-1
model_zoo/official/cv/resnet/eval.py
model_zoo/official/cv/resnet/eval.py
+4
-2
model_zoo/official/cv/resnet/train.py
model_zoo/official/cv/resnet/train.py
+13
-5
未找到文件。
model_zoo/official/cv/resnet/README.md
浏览文件 @
73806e0c
...
...
@@ -241,7 +241,7 @@ result: {'top_5_accuracy': 0.9429417413572343, 'top_1_accuracy': 0.7853513124199
### Running on GPU
```
# distributed training example
mpirun -n 8 python train.py --
-net=resnet50 --dataset=cifar10
-dataset_path=~/cifar-10-batches-bin --device_target="GPU" --run_distribute=True
mpirun -n 8 python train.py --
net=resnet50 --dataset=cifar10 -
-dataset_path=~/cifar-10-batches-bin --device_target="GPU" --run_distribute=True
# standalone training example
python train.py --net=resnet50 --dataset=cifar10 --dataset_path=~/cifar-10-batches-bin --device_target="GPU"
...
...
model_zoo/official/cv/resnet/eval.py
浏览文件 @
73806e0c
...
...
@@ -54,8 +54,10 @@ if __name__ == '__main__':
target
=
args_opt
.
device_target
# init context
device_id
=
int
(
os
.
getenv
(
'DEVICE_ID'
))
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
device_target
=
target
,
save_graphs
=
False
,
device_id
=
device_id
)
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
device_target
=
target
,
save_graphs
=
False
)
if
target
!=
"GPU"
:
device_id
=
int
(
os
.
getenv
(
'DEVICE_ID'
))
context
.
set_context
(
device_id
=
device_id
)
# create dataset
dataset
=
create_dataset
(
dataset_path
=
args_opt
.
dataset_path
,
do_train
=
False
,
batch_size
=
config
.
batch_size
,
...
...
model_zoo/official/cv/resnet/train.py
浏览文件 @
73806e0c
...
...
@@ -143,13 +143,21 @@ if __name__ == '__main__':
amp_level
=
"O2"
,
keep_batchnorm_fp32
=
False
)
else
:
# GPU target
loss
=
SoftmaxCrossEntropyWithLogits
(
sparse
=
True
,
reduction
=
"mean"
,
is_grad
=
False
,
smooth_factor
=
config
.
label_smooth_factor
,
num_classes
=
config
.
class_num
)
if
args_opt
.
dataset
==
"imagenet2012"
:
if
not
config
.
use_label_smooth
:
config
.
label_smooth_factor
=
0.0
loss
=
SoftmaxCrossEntropyWithLogits
(
sparse
=
True
,
reduction
=
"mean"
,
is_grad
=
False
,
smooth_factor
=
config
.
label_smooth_factor
,
num_classes
=
config
.
class_num
)
else
:
loss
=
SoftmaxCrossEntropyWithLogits
(
sparse
=
True
,
reduction
=
"mean"
,
is_grad
=
False
,
num_classes
=
config
.
class_num
)
## fp32 training
opt
=
Momentum
(
filter
(
lambda
x
:
x
.
requires_grad
,
net
.
get_parameters
()),
lr
,
config
.
momentum
,
config
.
weight_decay
)
model
=
Model
(
net
,
loss_fn
=
loss
,
optimizer
=
opt
,
metrics
=
{
'acc'
})
##Mixed precision
#model = Model(net, loss_fn=loss, optimizer=opt, metrics={'acc'},
# amp_level="O2", keep_batchnorm_fp32=True)
# # Mixed precision
# loss_scale = FixedLossScaleManager(config.loss_scale, drop_overflow_update=False)
# opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), lr, config.momentum, config.weight_decay, config.loss_scale)
# model = Model(net, loss_fn=loss, optimizer=opt, loss_scale_manager=loss_scale, metrics={'acc'}, amp_level="O2")
# define callbacks
time_cb
=
TimeMonitor
(
data_size
=
step_size
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录