Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MindSpore
docs
提交
ab5c5ae6
D
docs
项目概览
MindSpore
/
docs
通知
5
Star
3
Fork
2
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
docs
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
ab5c5ae6
编写于
5月 12, 2020
作者:
W
wangnan39@huawei.com
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix import error in docs that use on the cloud
上级
af9f6931
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
15 addition
and
10 deletion
+15
-10
tutorials/source_zh_cn/advanced_use/use_on_the_cloud.md
tutorials/source_zh_cn/advanced_use/use_on_the_cloud.md
+15
-10
未找到文件。
tutorials/source_zh_cn/advanced_use/use_on_the_cloud.md
浏览文件 @
ab5c5ae6
...
...
@@ -89,7 +89,7 @@ ModelArts使用对象存储服务(Object Storage Service,简称OBS)进行
1.
在ModelArts运行的脚本必须配置
`data_url`
和
`train_url`
,分别对应数据存储路径(OBS路径)和训练输出路径(OBS路径)。
``` python
import
parser
import
argparse
parser = argparse.ArgumentParser(description='ResNet-50 train.')
parser.add_argument('--data_url', required=True, default=None, help='Location of data.')
...
...
@@ -160,6 +160,8 @@ MindSpore暂时没有提供直接访问OBS数据的接口,需要通过MoXing
```python
import os
from mindspore import context
from mindspore.train.model import ParallelMode
device_num = int(os.getenv('RANK_SIZE'))
if device_num > 1:
...
...
@@ -176,6 +178,7 @@ MindSpore暂时没有提供直接访问OBS数据的接口,需要通过MoXing
```
python
import
os
import
argparse
from
mindspore
import
context
from
mindspore.train.model
import
ParallelMode
import
mindspore.dataset.engine
as
de
...
...
@@ -194,8 +197,8 @@ def create_dataset(dataset_path):
def
resnet50_train
(
args_opt
):
if
device_num
>
1
:
context
.
set_auto_parallel_context
(
device_num
=
device_num
,
parallel_mode
=
ParallelMode
.
DATA_PARALLEL
,
mirror_mean
=
True
)
parallel_mode
=
ParallelMode
.
DATA_PARALLEL
,
mirror_mean
=
True
)
train_dataset
=
create_dataset
(
local_data_path
)
if
__name__
==
'__main__'
:
...
...
@@ -212,10 +215,12 @@ if __name__ == '__main__':
```
python
import
os
import
argparse
from
mindspore
import
context
from
mindspore.train.model
import
ParallelMode
import
mindspore.dataset.engine
as
de
# adapt to cloud: used for downloading data
import
moxing
as
mox
device_id
=
int
(
os
.
getenv
(
'DEVICE_ID'
))
...
...
@@ -230,19 +235,17 @@ def create_dataset(dataset_path):
return
ds
def
resnet50_train
(
args_opt
):
epoch_size
=
args_opt
.
epoch_size
# define local data path
# adapt to cloud: define local data path
local_data_path
=
'/cache/data'
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
)
if
device_num
>
1
:
context
.
set_auto_parallel_context
(
device_num
=
device_num
,
parallel_mode
=
ParallelMode
.
DATA_PARALLEL
,
mirror_mean
=
True
)
# define distributed local data path
parallel_mode
=
ParallelMode
.
DATA_PARALLEL
,
mirror_mean
=
True
)
#
adapt to cloud:
define distributed local data path
local_data_path
=
os
.
path
.
join
(
local_data_path
,
str
(
device_id
))
#
data download
#
adapt to cloud: download data from obs to local location
print
(
'Download data.'
)
mox
.
file
.
copy_parallel
(
src_url
=
args_opt
.
data_url
,
dst_url
=
local_data_path
)
...
...
@@ -250,7 +253,9 @@ def resnet50_train(args_opt):
if
__name__
==
'__main__'
:
parser
=
argparse
.
ArgumentParser
(
description
=
'ResNet-50 train.'
)
# adapt to cloud: get obs data path
parser
.
add_argument
(
'--data_url'
,
required
=
True
,
default
=
None
,
help
=
'Location of data.'
)
# adapt to cloud: get obs output path
parser
.
add_argument
(
'--train_url'
,
required
=
True
,
default
=
None
,
help
=
'Location of training outputs.'
)
parser
.
add_argument
(
'--epoch_size'
,
type
=
int
,
default
=
90
,
help
=
'Train epoch size.'
)
args_opt
,
unknown
=
parser
.
parse_known_args
()
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录