提交 ab5c5ae6 编写于 作者: W wangnan39@huawei.com

fix import error in docs that use on the cloud

上级 af9f6931
......@@ -89,7 +89,7 @@ ModelArts使用对象存储服务(Object Storage Service,简称OBS)进行
1. 在ModelArts运行的脚本必须配置`data_url``train_url`,分别对应数据存储路径(OBS路径)和训练输出路径(OBS路径)。
``` python
import parser
import argparse
parser = argparse.ArgumentParser(description='ResNet-50 train.')
parser.add_argument('--data_url', required=True, default=None, help='Location of data.')
......@@ -160,6 +160,8 @@ MindSpore暂时没有提供直接访问OBS数据的接口,需要通过MoXing
```python
import os
from mindspore import context
from mindspore.train.model import ParallelMode
device_num = int(os.getenv('RANK_SIZE'))
if device_num > 1:
......@@ -176,6 +178,7 @@ MindSpore暂时没有提供直接访问OBS数据的接口,需要通过MoXing
``` python
import os
import argparse
from mindspore import context
from mindspore.train.model import ParallelMode
import mindspore.dataset.engine as de
......@@ -194,8 +197,8 @@ def create_dataset(dataset_path):
def resnet50_train(args_opt):
if device_num > 1:
context.set_auto_parallel_context(device_num=device_num,
parallel_mode=ParallelMode.DATA_PARALLEL,
mirror_mean=True)
parallel_mode=ParallelMode.DATA_PARALLEL,
mirror_mean=True)
train_dataset = create_dataset(local_data_path)
if __name__ == '__main__':
......@@ -212,10 +215,12 @@ if __name__ == '__main__':
``` python
import os
import argparse
from mindspore import context
from mindspore.train.model import ParallelMode
import mindspore.dataset.engine as de
# adapt to cloud: used for downloading data
import moxing as mox
device_id = int(os.getenv('DEVICE_ID'))
......@@ -230,19 +235,17 @@ def create_dataset(dataset_path):
return ds
def resnet50_train(args_opt):
epoch_size = args_opt.epoch_size
# define local data path
# adapt to cloud: define local data path
local_data_path = '/cache/data'
context.set_context(mode=context.GRAPH_MODE)
if device_num > 1:
context.set_auto_parallel_context(device_num=device_num,
parallel_mode=ParallelMode.DATA_PARALLEL,
mirror_mean=True)
# define distributed local data path
parallel_mode=ParallelMode.DATA_PARALLEL,
mirror_mean=True)
# adapt to cloud: define distributed local data path
local_data_path = os.path.join(local_data_path, str(device_id))
# data download
# adapt to cloud: download data from obs to local location
print('Download data.')
mox.file.copy_parallel(src_url=args_opt.data_url, dst_url=local_data_path)
......@@ -250,7 +253,9 @@ def resnet50_train(args_opt):
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='ResNet-50 train.')
# adapt to cloud: get obs data path
parser.add_argument('--data_url', required=True, default=None, help='Location of data.')
# adapt to cloud: get obs output path
parser.add_argument('--train_url', required=True, default=None, help='Location of training outputs.')
parser.add_argument('--epoch_size', type=int, default=90, help='Train epoch size.')
args_opt, unknown = parser.parse_known_args()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册