提交 219c1314 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!16 add tutorial docs for using MindSpore on the cloud

Merge pull request !16 from WangNan39/add_resnet50_docs
# 在云上使用MindSpore
<!-- TOC -->
- [在云上使用MindSpore](#在云上使用mindspore)
- [概述](#概述)
- [准备工作](#准备工作)
- [ModelArts使用准备](#modelarts使用准备)
- [拥有云上昇腾AI处理器资源](#拥有云上昇腾ai处理器资源)
- [数据准备](#数据准备)
- [执行脚本准备](#执行脚本准备)
- [通过简单适配将MindSpore脚本运行在ModelArts](#通过简单适配将mindspore脚本运行在modelarts)
- [脚本参数](#脚本参数)
- [适配OBS数据](#适配obs数据)
- [获取环境变量](#获取环境变量)
- [示例代码](#示例代码)
- [创建训练任务](#创建训练任务)
- [进入ModelArts控制台](#进入modelarts控制台)
- [使用常用框架创建训练作业](#使用常用框架创建训练作业)
- [使用MindSpore作为常用框架创建训练作业](#使用mindspore作为常用框架创建训练作业)
- [查看运行结果](#查看运行结果)
<!-- /TOC -->
## 概述
ModelArts是华为云提供的面向开发者的一站式AI开发平台,集成了昇腾AI处理器资源池,用户可以在该平台下体验MindSpore。本教程以ResNet-50为例,简要介绍如何在ModelArts使用MindSpore完成训练任务。
## 准备工作
### ModelArts使用准备
参考ModelArts教程“准备工作”一栏,完成账号注册、ModelArts配置和创建桶的准备工作。
> ModelArts教程链接:<https://support.huaweicloud.com/wtsnew-modelarts/index.html>。页面提供了较丰富的ModelArts教程,参考“准备工作”部分完成ModelArts准备工作。
### 拥有云上昇腾AI处理器资源
确保你已经拥有ModelArts昇腾AI处理器体验资格,拥有申请的云上体验账号。如果你还没有体验资格,可以按照指引<https://www.mindspore.cn/install> 申请云上体验资格。
### 数据准备
ModelArts使用对象存储服务(Object Storage Service,简称OBS)进行数据存储,因此,在开始训练任务之前,需要将数据上传至OBS。本示例使用CIFAR-10二进制格式数据集。
1. 下载CIFAR-10数据集,解压,并按照如下目录结构上传至OBS桶中。
```
└─对象存储/ms-dataset/cifar-10
├─train
│ data_batch_1.bin
│ data_batch_2.bin
│ data_batch_3.bin
│ data_batch_4.bin
│ data_batch_5.bin
└─eval
test_batch.bin
```
> CIFAR-10数据集下载页面:<http://www.cs.toronto.edu/~kriz/cifar.html>。页面提供3个数据集下载链接,本示例使用CIFAR-10 binary version。
2. 为了方便用户快速体验MindSpore,OBS公共目录(对象存储/ms-dataset/cifar-10)预置了CIFAR-10数据供用户直接使用。
### 执行脚本准备
创建属于自己的OBS桶,在桶中创建代码目录,并将以下目录中的所有脚本上传至代码目录:
> <https://gitee.com/mindspore/docs/tree/master/tutorials/tutorial_code/sample_for_cloud/> 脚本使用ResNet-50网络在CIFAR-10数据集上进行训练,并在训练结束后验证精度。脚本可以在ModelArts采用`1*Ascend`或`8*Ascend`两种不同规格进行训练任务。
为了方便后续创建训练作业,先创建训练输出目录和日志输出目录,本示例创建的目录结构如下:
```
└─对象存储/resnet50-train
├─resnet50_cifar10_train
│ dataset.py
│ resnet50_train.py
├─output
└─log
```
## 通过简单适配将MindSpore脚本运行在ModelArts
如果需要将自定义MindSpore脚本或更多MindSpore示例代码在ModelArts运行起来,可以参考本章节对MindSpore代码进行简单适配。想要快速体验ResNet-50训练CIFAR-10可以跳过本章节。
### 脚本参数
1. 两个固定参数
``` python
import parser
parser = argparse.ArgumentParser(description='ResNet-50 train.')
parser.add_argument('--data_url', required=True, default=None, help='Location of data.')
parser.add_argument('--train_url', required=True, default=None, help='Location of training outputs.')
```
`data_url`和`train_url`是在ModelArts执行训练任务时两个必传参数,分别对应数据存储路径(OBS路径)和训练输出路径(OBS路径)。
2. ModelArts界面支持向脚本中其他参数传值,在下一章节“创建训练作业”中将会详细介绍。
``` python
parser.add_argument('--epoch_size', type=int, default=90, help='Train epoch size.')
```
### 适配OBS数据
MindSpore暂时没有提供直接访问OBS数据的接口,需要通过MoXing提供的API与OBS交互。ModelArts训练脚本在容器中执行,通常选用`/cache`目录作为容器数据存储路径。
> 华为云MoXing提供了丰富的API供用户使用 <https://github.com/huaweicloud/ModelArts-Lab/tree/master/docs/moxing_api_doc>,本示例中仅需要使用`copy_parallel`接口。
1. 将OBS中存储的数据下载至执行容器。
```python
import moxing as mox
mox.file.copy_parallel(src_path='s3://dataset_url/', dst_path='/cache/data_path')
```
2. 将训练输出从容器中上传至OBS。
```python
import moxing as mox
mox.file.copy_parallel(src_path='/cache/output_path', dst_path='s3://output_url/')
```
### 获取环境变量
MindSpore创建数据集和配置分布式策略与运行环境有关,通过获取`DEVICE_ID``RANK_SIZE`两个环境变量,用户可以构建适用于`1*Ascend``8*Ascend`两种不同规格的训练脚本。
1. 创建数据集。
```python
import os
import mindspore.dataset.engine as de
device_id = int(os.getenv('DEVICE_ID'))
device_num = int(os.getenv('RANK_SIZE'))
if device_num == 1:
ds = de.Cifar10Dataset(dataset_path, num_parallel_workers=4, shuffle=True)
else:
# split train data for 8 Ascend situation
ds = de.Cifar10Dataset(dataset_path, num_parallel_workers=4, shuffle=True,
num_shards=device_num, shard_id=device_id)
```
2. 配置分布式策略。
```python
import os
device_id = int(os.getenv('DEVICE_ID'))
device_num = int(os.getenv('RANK_SIZE'))
context.set_context(mode=context.GRAPH_MODE)
if device_num > 1:
context.set_auto_parallel_context(device_num=device_num,
parallel_mode=ParallelMode.DATA_PARALLEL,
mirror_mean=True)
```
### 示例代码
结合以上三点对MindSpore脚本进行简单适配,以下述伪代码为例:
原始MindSpore脚本:
``` python
import os
from mindspore import context
from mindspore.train.model import ParallelMode
import mindspore.dataset.engine as de
device_id = int(os.getenv('DEVICE_ID'))
device_num = int(os.getenv('RANK_SIZE'))
def create_dataset(dataset_path):
if device_num == 1:
ds = de.Cifar10Dataset(dataset_path, num_parallel_workers=4, shuffle=True)
else:
ds = de.Cifar10Dataset(dataset_path, num_parallel_workers=4, shuffle=True,
num_shards=device_num, shard_id=device_id)
return ds
def resnet50_train(args_opt):
epoch_size = args_opt.epoch_size
local_data_path = args_opt.local_data_path
context.set_context(mode=context.GRAPH_MODE)
if device_num > 1:
context.set_auto_parallel_context(device_num=device_num,
parallel_mode=ParallelMode.DATA_PARALLEL,
mirror_mean=True)
train_dataset = create_dataset(local_data_path)
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='ResNet-50 train.')
parser.add_argument('--local_data_path', required=True, default=None, help='Location of data.')
parser.add_argument('--epoch_size', type=int, default=90, help='Train epoch size.')
args_opt, unknown = parser.parse_known_args()
resnet50_train(args_opt)
```
适配后的MindSpore脚本:
``` python
import os
from mindspore import context
from mindspore.train.model import ParallelMode
import mindspore.dataset.engine as de
import moxing as mox
device_id = int(os.getenv('DEVICE_ID'))
device_num = int(os.getenv('RANK_SIZE'))
def create_dataset(dataset_path):
if device_num == 1:
ds = de.Cifar10Dataset(dataset_path, num_parallel_workers=4, shuffle=True)
else:
ds = de.Cifar10Dataset(dataset_path, num_parallel_workers=4, shuffle=True,
num_shards=device_num, shard_id=device_id)
return ds
def resnet50_train(args_opt):
epoch_size = args_opt.epoch_size
# define local data path
local_data_path = '/cache/data'
context.set_context(mode=context.GRAPH_MODE)
if device_num > 1:
context.set_auto_parallel_context(device_num=device_num,
parallel_mode=ParallelMode.DATA_PARALLEL,
mirror_mean=True)
# define distributed local data path
local_data_path = os.path.join(local_data_path, str(device_id))
# data download
print('Download data.')
mox.file.copy_parallel(src_url=args_opt.data_url, dst_url=local_data_path)
train_dataset = create_dataset(local_data_path)
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='ResNet-50 train.')
parser.add_argument('--data_url', required=True, default=None, help='Location of data.')
parser.add_argument('--train_url', required=True, default=None, help='Location of training outputs.')
parser.add_argument('--epoch_size', type=int, default=90, help='Train epoch size.')
args_opt, unknown = parser.parse_known_args()
resnet50_train(args_opt)
```
## 创建训练任务
准备好数据和执行脚本以后,需要创建训练任务将MindSpore脚本真正运行起来。首次使用ModelArts的用户可以根据本章节了解ModelArts创建训练作业的流程。
### 进入ModelArts控制台
打开华为云ModelArts主页<https://www.huaweicloud.com/product/modelarts.html>,点击该页面的“进入控制台”。
### 使用常用框架创建训练作业
ModelArts教程 <https://support.huaweicloud.com/engineers-modelarts/modelarts_23_0238.html> 展示了如何使用常用框架创建训练作业。
### 使用MindSpore作为常用框架创建训练作业
以本教程使用的训练脚本和数据为例,详细列出在创建训练作业界面如何进行配置:
1. `算法来源`选择`常用框架 > Ascend-Powered-Engine > MindSpore-0.1-arrch64-cp37`
2. `代码目录`选择预先在OBS桶中创建代码目录,`启动文件`选择代码目录下的启动脚本。
3. `数据来源`选择`数据存储位置`,并填入OBS中CIFAR-10数据集的位置。
4. `运行参数``数据存储位置``训练输出位置`分别对应运行参数`data_url``train_url`,选择`增加运行参数`可以向脚本中其他参数传值,如`epoch_size`
5. `资源池`选择`公共资源池 > Ascend`
6. `资源池 > 规格`选择`Ascend: 1 * Ascend 910 CPU:24 核 96GiB``Ascend: 8 * Ascend 910 CPU:192 核 768GiB`,分别表示单机单卡和单机8卡规格。
使用MindSpore作为常用框架创建训练作业,如下图所示:
![训练作业参数](./images/cloud_train_job1.png)
![训练作业规格](./images/cloud_train_job2.png)
## 查看运行结果
1. 在训练作业界面可以查看运行日志
下图是采用`8*Ascend`规格执行ResNet-50训练的日志。epoch总数为90,训练任务总时长约9分钟,精度约为91.5%,每秒训练图片张数约12300。
![8*Ascend训练执行结果](./images/train_log_8_Ascend.png)
下图是采用`1*Ascend`规格执行ResNet-50训练的日志。epoch总数为90,训练任务总时长约50分钟,精度约为90.8%,每秒训练图片张数约1600。
![1*Ascend训练执行结果](./images/train_log_1_Ascend.png)
2. 如果创建训练作业时指定了日志路径,可以从OBS下载日志文件并查看。
......@@ -33,6 +33,7 @@ MindSpore教程
advanced_use/computer_vision_application
advanced_use/nlp_application
advanced_use/customized_debugging_information
advanced_use/use_on_the_cloud
advanced_use/on_device_inference
advanced_use/model_security
advanced_use/community
......@@ -43,4 +44,4 @@ MindSpore教程
:caption: 声明
statement/legal_statement
statement/privacy_policy
\ No newline at end of file
statement/privacy_policy
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Create train or eval dataset."""
import os
import mindspore.common.dtype as mstype
import mindspore.dataset.engine as de
import mindspore.dataset.transforms.vision.c_transforms as C
import mindspore.dataset.transforms.c_transforms as C2
device_id = int(os.getenv('DEVICE_ID'))
device_num = int(os.getenv('RANK_SIZE'))
def create_dataset(dataset_path, do_train, repeat_num=1, batch_size=32):
"""
Create a train or eval dataset.
Args:
dataset_path (str): The path of dataset.
do_train (bool): Whether dataset is used for train or eval.
repeat_num (int): The repeat times of dataset. Default: 1.
batch_size (int): The batch size of dataset. Default: 32.
Returns:
Dataset.
"""
if do_train:
dataset_path = os.path.join(dataset_path, 'train')
do_shuffle = True
else:
dataset_path = os.path.join(dataset_path, 'eval')
do_shuffle = False
if device_num == 1 or not do_train:
ds = de.Cifar10Dataset(dataset_path, num_parallel_workers=4, shuffle=do_shuffle)
else:
ds = de.Cifar10Dataset(dataset_path, num_parallel_workers=4, shuffle=do_shuffle,
num_shards=device_num, shard_id=device_id)
resize_height = 224
resize_width = 224
buffer_size = 100
rescale = 1.0 / 255.0
shift = 0.0
# define map operations
random_crop_op = C.RandomCrop((32, 32), (4, 4, 4, 4))
random_horizontal_flip_op = C.RandomHorizontalFlip(device_id / (device_id + 1))
resize_op = C.Resize((resize_height, resize_width))
rescale_op = C.Rescale(rescale, shift)
normalize_op = C.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010])
change_swap_op = C.HWC2CHW()
trans = []
if do_train:
trans += [random_crop_op, random_horizontal_flip_op]
trans += [resize_op, rescale_op, normalize_op, change_swap_op]
type_cast_op = C2.TypeCast(mstype.int32)
ds = ds.map(input_columns="label", operations=type_cast_op)
ds = ds.map(input_columns="image", operations=trans)
# apply shuffle operations
ds = ds.shuffle(buffer_size=buffer_size)
# apply batch operations
ds = ds.batch(batch_size, drop_remainder=True)
# apply dataset repeat operation
ds = ds.repeat(repeat_num)
return ds
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""ResNet50 model train with MindSpore"""
import os
import argparse
import random
import time
import numpy as np
import moxing as mox
from mindspore import context
from mindspore import Tensor
from mindspore.nn.optim.momentum import Momentum
from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits
from mindspore.train.model import Model, ParallelMode
from mindspore.train.callback import Callback, LossMonitor
from mindspore.train.loss_scale_manager import FixedLossScaleManager
from dataset import create_dataset, device_id, device_num
from mindspore.model_zoo.resnet import resnet50
random.seed(1)
np.random.seed(1)
class PerformanceCallback(Callback):
"""
Training performance callback.
Args:
batch_size (int): Batch number for one step.
"""
def __init__(self, batch_size):
super(PerformanceCallback, self).__init__()
self.batch_size = batch_size
self.last_step = 0
self.epoch_begin_time = 0
def step_begin(self, run_context):
self.epoch_begin_time = time.time()
def step_end(self, run_context):
params = run_context.original_args()
cost_time = time.time() - self.epoch_begin_time
train_steps = params.cur_step_num -self.last_step
print(f'epoch {params.cur_epoch_num} cost time = {cost_time}, train step num: {train_steps}, '
f'one step time: {1000*cost_time/train_steps} ms, '
f'train samples per second of cluster: {device_num*train_steps*self.batch_size/cost_time:.0f}')
self.last_step = run_context.original_args().cur_step_num
def get_lr(global_step,
total_epochs,
steps_per_epoch,
lr_init=0.01,
lr_max=0.1,
warmup_epochs=5):
"""
Generate learning rate array.
Args:
global_step (int): Initial step of training.
total_epochs (int): Total epoch of training.
steps_per_epoch (float): Steps of one epoch.
lr_init (float): Initial learning rate. Default: 0.01.
lr_max (float): Maximum learning rate. Default: 0.1.
warmup_epochs (int): The number of warming up epochs. Default: 5.
Returns:
np.array, learning rate array.
"""
lr_each_step = []
total_steps = steps_per_epoch * total_epochs
warmup_steps = steps_per_epoch * warmup_epochs
if warmup_steps != 0:
inc_each_step = (float(lr_max) - float(lr_init)) / float(warmup_steps)
else:
inc_each_step = 0
for i in range(int(total_steps)):
if i < warmup_steps:
lr = float(lr_init) + inc_each_step * float(i)
else:
base = ( 1.0 - (float(i) - float(warmup_steps)) / (float(total_steps) - float(warmup_steps)) )
lr = float(lr_max) * base * base
if lr < 0.0:
lr = 0.0
lr_each_step.append(lr)
current_step = global_step
lr_each_step = np.array(lr_each_step).astype(np.float32)
learning_rate = lr_each_step[current_step:]
return learning_rate
def resnet50_train(args_opt):
epoch_size = args_opt.epoch_size
batch_size = 32
class_num = 10
loss_scale_num = 1024
local_data_path = '/cache/data'
# set graph mode and parallel mode
context.set_context(mode=context.GRAPH_MODE)
if device_num > 1:
context.set_auto_parallel_context(device_num=device_num,
parallel_mode=ParallelMode.DATA_PARALLEL,
mirror_mean=True)
local_data_path = os.path.join(local_data_path, str(device_id))
# data download
print('Download data.')
mox.file.copy_parallel(src_url=args_opt.data_url, dst_url=local_data_path)
# create dataset
print('Create train and evaluate dataset.')
train_dataset = create_dataset(dataset_path=local_data_path, do_train=True,
repeat_num=epoch_size, batch_size=batch_size)
eval_dataset = create_dataset(dataset_path=local_data_path, do_train=False,
repeat_num=1, batch_size=batch_size)
train_step_size = train_dataset.get_dataset_size()
print('Create dataset success.')
# create model
net = resnet50(class_num = class_num)
loss = SoftmaxCrossEntropyWithLogits(sparse=True)
lr = Tensor(get_lr(global_step=0, total_epochs=epoch_size, steps_per_epoch=train_step_size))
opt = Momentum(net.trainable_params(), lr, momentum=0.9, weight_decay=1e-4, loss_scale=loss_scale_num)
loss_scale = FixedLossScaleManager(loss_scale_num, False)
model = Model(net, loss_fn=loss, optimizer=opt, loss_scale_manager=loss_scale, metrics={'acc'})
# define performance callback to show ips and loss callback to show loss for every epoch
performance_cb = PerformanceCallback(batch_size)
loss_cb = LossMonitor()
cb = [performance_cb, loss_cb]
print(f'Start run training, total epoch: {epoch_size}.')
model.train(epoch_size, train_dataset, callbacks=cb)
if device_num == 1 or device_id == 0:
print(f'Start run evaluation.')
output = model.eval(eval_dataset)
print(f'Evaluation result: {output}.')
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='ResNet50 train.')
parser.add_argument('--data_url', required=True, default=None, help='Location of data.')
parser.add_argument('--train_url', required=True, default=None, help='Location of training outputs.')
parser.add_argument('--epoch_size', type=int, default=90, help='Train epoch size.')
args_opt, unknown = parser.parse_known_args()
resnet50_train(args_opt)
print('ResNet50 training success!')
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册