提交 f8d61850 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!125 optimize distributed training tutorial

Merge pull request !125 from gziyan/optimize_distributed_training
......@@ -41,13 +41,13 @@
> `CIFAR-10`数据集下载链接:<http://www.cs.toronto.edu/~kriz/cifar-10-binary.tar.gz>。
将数据集下载并解压到本地路径下,这里将数据集解压存放到样例代码同级目录下的`./dataset`路径下
将数据集下载并解压到本地路径下,解压后的文件夹为`cifar-10-batches-bin`
### 配置分布式环境变量
在裸机环境(对比云上环境,即本地有Ascend 910 AI 处理器)进行分布式训练时,需要配置当前多卡环境的组网信息文件。如果使用华为云环境,因为云服务本身已经做好了配置,可以跳过本小节。
以Ascend 910 AI处理器为例,1个8卡环境的json配置文件示例如下,本样例将该配置文件命名为rank_table.json
以Ascend 910 AI处理器为例,1个8卡环境的json配置文件示例如下,本样例将该配置文件命名为`rank_table_8pcs.json`。2卡环境配置可以参考样例代码中的`rank_table_2pcs.json`文件
```json
{
......@@ -118,7 +118,7 @@ if __name__ == "__main__":
## 数据并行模式加载数据集
分布式训练时,数据是以数据并行的方式导入的。下面我们以CIFAR-10数据集为例,介绍以数据并行方式导入CIFAR-10数据集的方法,`data_path`是指数据集的路径,在样例代码中采用样例代码同级目录下`dataset/cifar-10-batches-bin`文件夹的路径。
分布式训练时,数据是以数据并行的方式导入的。下面我们以CIFAR-10数据集为例,介绍以数据并行方式导入CIFAR-10数据集的方法,`data_path`是指数据集的路径,`cifar-10-batches-bin`文件夹的路径。
```python
......@@ -274,11 +274,27 @@ def test_train_cifar(num_classes=10, epoch_size=10):
```bash
#!/bin/bash
DATD_PATH=$1
export DATA_PATH=${DATA_PATH}
RANK_SIZE=$2
EXEC_PATH=$(pwd)
export MINDSPORE_HCCL_CONFIG_PATH=${EXEC_PATH}/rank_table.json
export RANK_SIZE=8
for((i=0;i<$RANK_SIZE;i++))
test_dist_8pcs()
{
export MINDSPORE_HCCL_CONFIG_PATH=${EXEC_PATH}/rank_table_8pcs.json
export RANK_SIZE=8
}
test_dist_2pcs()
{
export MINDSPORE_HCCL_CONFIG_PATH=${EXEC_PATH}/rank_table_2pcs.json
export RANK_SIZE=2
}
test_dist_${RANK_SIZE}pcs
for((i=0;i<${RANK_SIZE};i++))
do
rm -rf device$i
mkdir device$i
......@@ -293,6 +309,8 @@ do
done
```
脚本需要传入变量`DATA_PATH``RANK_SIZE`,分别表示数据集的路径和卡的数量。
其中必要的环境变量有,
- `MINDSPORE_HCCL_CONFIG_PATH`:组网信息文件的路径。
- `DEVICE_ID`:当前网卡在机器上的实际序号。
......@@ -304,7 +322,7 @@ done
日志文件保存device目录下,env.log中记录了环境变量的相关信息,关于Loss部分结果保存在train.log中,示例如下:
```
test_resnet50_expand_loss_8p.py::test_train_feed ===============ds_num 195
resnet50_distributed_training.py::test_train_feed ===============ds_num 195
global_step: 194, loss: 1.997
global_step: 389, loss: 1.655
global_step: 584, loss: 1.723
......
{
"board_id": "0x0000",
"chip_info": "910",
"deploy_mode": "lab",
"group_count": "1",
"group_list": [
{
"device_num": "2",
"server_num": "1",
"group_name": "",
"instance_count": "2",
"instance_list": [
{
"devices": [
{
"device_id": "0",
"device_ip": "192.1.27.6"
}
],
"rank_id": "0",
"server_id": "10.155.111.140"
},
{
"devices": [
{
"device_id": "1",
"device_ip": "192.2.27.6"
}
],
"rank_id": "1",
"server_id": "10.155.111.140"
}
]
}
],
"para_plane_nic_location": "device",
"para_plane_nic_name": [
"eth0",
"eth1"
],
"para_plane_nic_num": "2",
"status": "completed"
}
......@@ -43,7 +43,6 @@ init()
rank_id = get_rank()
rank_size = get_group_size()
EXEC_PATH=os.getcwd()
def create_dataset(data_path, repeat_num=1, batch_size=32, rank_id=0, rank_size=1):
resize_height = 224
......@@ -119,11 +118,14 @@ class SoftmaxCrossEntropyExpand(nn.Cell):
return loss
def test_train_cifar(num_classes=10, epoch_size=10):
def test_train_cifar(epoch_size=10):
context.set_auto_parallel_context(parallel_mode=ParallelMode.AUTO_PARALLEL, mirror_mean=True)
loss_cb = LossMonitor()
dataset = create_dataset(os.path.join(EXEC_PATH, '../dataset/cifar-10-batches-bin/'), epoch_size)
net = resnet50(32, num_classes)
data_path = os.getenv('DATA_PATH')
dataset = create_dataset(data_path, epoch_size)
batch_size = 32
num_classes = 10
net = resnet50(batch_size, num_classes)
loss = SoftmaxCrossEntropyExpand(sparse=True)
opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), 0.01, 0.9)
model = Model(net, loss_fn=loss, optimizer=opt)
......
#!/bin/bash
DATD_PATH=$1
export DATA_PATH=${DATA_PATH}
RANK_SIZE=$2
EXEC_PATH=$(pwd)
export MINDSPORE_HCCL_CONFIG_PATH=${EXEC_PATH}/rank_table.json
export RANK_SIZE=8
for((i=0;i<$RANK_SIZE;i++))
test_dist_8p()
{
export MINDSPORE_HCCL_CONFIG_PATH=${EXEC_PATH}/rank_table_8p.json
export RANK_SIZE=8
}
test_dist_2p()
{
export MINDSPORE_HCCL_CONFIG_PATH=${EXEC_PATH}/rank_table_2p.json
export RANK_SIZE=2
}
test_dist_${RANK_SIZE}p
for((i=0;i<${RANK_SIZE};i++))
do
rm -rf device$i
mkdir device$i
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册