提交 ad9b0f74 编写于 作者: Y yao_yf

parallel interface change

上级 9fab1755
......@@ -345,7 +345,8 @@
"outputs": [],
"source": [
"from mindspore.communication.management import init\n",
"from mindspore.train.model import Model, ParallelMode\n",
"from mindspore.train.model import Model\n",
"from mindspore.context import ParallelMode\n",
"from resnet import resnet50\n",
"from mindspore.parallel._auto_parallel_context import auto_parallel_context\n",
"\n",
......
......@@ -215,7 +215,7 @@ The `Momentum` optimizer is used as the parameter update tool. The definition is
`context.set_auto_parallel_context` is an API for users to set parallel training parameters and must be called before the initialization of networks. The related parameters are as follows:
- `parallel_mode`: parallel distributed mode. The default value is `ParallelMode.STAND_ALONE`. The options are `ParallelMode.DATA_PARALLEL` and `ParallelMode.AUTO_PARALLEL`.
- `mirror_mean`: During backward computation, the framework collects gradients of parameters in data parallel mode across multiple hosts, obtains the global gradient value, and transfers the global gradient value to the optimizer for update. The default value is `False`, which indicates that the `allreduce_sum` operation is applied. The value `True` indicates that the `allreduce_mean` operation is applied.
- `gradients_mean`: During backward computation, the framework collects gradients of parameters in data parallel mode across multiple hosts, obtains the global gradient value, and transfers the global gradient value to the optimizer for update. The default value is `False`, which indicates that the `allreduce_sum` operation is applied. The value `True` indicates that the `allreduce_mean` operation is applied.
- `enable_parallel_optimizer`: a developing feature. Whether to use optimizer model parallel, which improves performance by distributing the parameters to be updated to each worker, and applying Broadcast among workers to share updated parameters. This feature can be used only in data parallel mode and when the number of parameters is larger than the number of devices.
> You are advised to set `device_num` and `global_rank` to their default values. The framework calls the HCCL API to obtain the values.
......@@ -228,7 +228,8 @@ In the following sample code, the automatic parallel mode is specified. To switc
from mindspore import context
from mindspore.nn.optim.momentum import Momentum
from mindspore.train.callback import LossMonitor
from mindspore.train.model import Model, ParallelMode
from mindspore.train.model import Model
from mindspore.context import ParallelMode
from resnet import resnet50
device_id = int(os.getenv('DEVICE_ID'))
......@@ -236,7 +237,7 @@ context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
context.set_context(device_id=device_id) # set device_id
def test_train_cifar(epoch_size=10):
context.set_auto_parallel_context(parallel_mode=ParallelMode.AUTO_PARALLEL, mirror_mean=True)
context.set_auto_parallel_context(parallel_mode=ParallelMode.AUTO_PARALLEL, gradients_mean=True)
loss_cb = LossMonitor()
dataset = create_dataset(data_path)
batch_size = 32
......
......@@ -218,7 +218,7 @@ class SoftmaxCrossEntropyExpand(nn.Cell):
`context.set_auto_parallel_context`是配置并行训练参数的接口,必须在初始化网络之前调用。主要参数包括:
- `parallel_mode`:分布式并行模式,默认为单机模式`ParallelMode.STAND_ALONE`。可选数据并行`ParallelMode.DATA_PARALLEL`及自动并行`ParallelMode.AUTO_PARALLEL`
- `mirror_mean`:反向计算时,框架内部会将数据并行参数分散在多台机器的梯度值进行收集,得到全局梯度值后再传入优化器中更新。默认值为`False`,设置为True对应`allreduce_mean`操作,False对应`allreduce_sum`操作。
- `gradients_mean`:反向计算时,框架内部会将数据并行参数分散在多台机器的梯度值进行收集,得到全局梯度值后再传入优化器中更新。默认值为`False`,设置为True对应`allreduce_mean`操作,False对应`allreduce_sum`操作。
- `enable_parallel_optimizer`:开发中特性。打开优化器模型并行开关,通过拆分权重到各卡分别进行更新再同步的方式以提升性能。该参数目前只在数据并行模式和参数量大于机器数时有效,支持`Lamb``Adam`优化器。
> `device_num`和`global_rank`建议采用默认值,框架内会调用HCCL接口获取。
......@@ -231,7 +231,8 @@ class SoftmaxCrossEntropyExpand(nn.Cell):
from mindspore import context
from mindspore.nn.optim.momentum import Momentum
from mindspore.train.callback import LossMonitor
from mindspore.train.model import Model, ParallelMode
from mindspore.train.model import Model
from mindspore.context import ParallelMode
from resnet import resnet50
device_id = int(os.getenv('DEVICE_ID'))
......@@ -239,7 +240,7 @@ context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
context.set_context(device_id=device_id) # set device_id
def test_train_cifar(epoch_size=10):
context.set_auto_parallel_context(parallel_mode=ParallelMode.AUTO_PARALLEL, mirror_mean=True)
context.set_auto_parallel_context(parallel_mode=ParallelMode.AUTO_PARALLEL, gradients_mean=True)
loss_cb = LossMonitor()
dataset = create_dataset(data_path)
batch_size = 32
......
......@@ -164,13 +164,13 @@ MindSpore暂时没有提供直接访问OBS数据的接口,需要通过MoXing
```python
import os
from mindspore import context
from mindspore.train.model import ParallelMode
from mindspore.context import ParallelMode
device_num = int(os.getenv('RANK_SIZE'))
if device_num > 1:
context.set_auto_parallel_context(device_num=device_num,
parallel_mode=ParallelMode.DATA_PARALLEL,
mirror_mean=True)
gradients_mean=True)
```
### 示例代码
......@@ -183,7 +183,7 @@ MindSpore暂时没有提供直接访问OBS数据的接口,需要通过MoXing
import os
import argparse
from mindspore import context
from mindspore.train.model import ParallelMode
from mindspore.context import ParallelMode
import mindspore.dataset.engine as de
device_id = int(os.getenv('DEVICE_ID'))
......@@ -201,7 +201,7 @@ def resnet50_train(args_opt):
if device_num > 1:
context.set_auto_parallel_context(device_num=device_num,
parallel_mode=ParallelMode.DATA_PARALLEL,
mirror_mean=True)
gradients_mean=True)
train_dataset = create_dataset(local_data_path)
if __name__ == '__main__':
......@@ -220,7 +220,7 @@ if __name__ == '__main__':
import os
import argparse
from mindspore import context
from mindspore.train.model import ParallelMode
from mindspore.context import ParallelMode
import mindspore.dataset.engine as de
# adapt to cloud: used for downloading data
......@@ -244,7 +244,7 @@ def resnet50_train(args_opt):
if device_num > 1:
context.set_auto_parallel_context(device_num=device_num,
parallel_mode=ParallelMode.DATA_PARALLEL,
mirror_mean=True)
gradients_mean=True)
# adapt to cloud: define distributed local data path
local_data_path = os.path.join(local_data_path, str(device_id))
......
......@@ -28,7 +28,8 @@ from mindspore.communication.management import init, get_rank, get_group_size
from mindspore import Tensor
from mindspore.ops import operations as P
from mindspore.nn.optim.momentum import Momentum
from mindspore.train.model import Model, ParallelMode
from mindspore.train.model import Model
from mindspore.context import ParallelMode
from mindspore import context
from mindspore.train.callback import LossMonitor
from resnet import resnet50
......@@ -117,7 +118,7 @@ class SoftmaxCrossEntropyExpand(nn.Cell):
def test_train_cifar(epoch_size=10):
context.set_auto_parallel_context(parallel_mode=ParallelMode.AUTO_PARALLEL, mirror_mean=True)
context.set_auto_parallel_context(parallel_mode=ParallelMode.AUTO_PARALLEL, gradients_mean=True)
loss_cb = LossMonitor()
data_path = os.getenv('DATA_PATH')
dataset = create_dataset(data_path)
......
......@@ -29,7 +29,8 @@ from mindspore.communication.management import init
from mindspore import Tensor
from mindspore.ops import operations as P
from mindspore.nn.optim.momentum import Momentum
from mindspore.train.model import Model, ParallelMode
from mindspore.train.model import Model
from mindspore.context import ParallelMode
from mindspore import context
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor
from mindspore.train.serialization import load_checkpoint, load_param_into_net
......
......@@ -24,7 +24,8 @@ from mindspore import context
from mindspore import Tensor
from mindspore.nn.optim.momentum import Momentum
from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits
from mindspore.train.model import Model, ParallelMode
from mindspore.train.model import Model
from mindspore.context import ParallelMode
from mindspore.train.callback import Callback, LossMonitor
from mindspore.train.loss_scale_manager import FixedLossScaleManager
from mindspore.communication.management import init
......@@ -121,7 +122,7 @@ def resnet50_train(args_opt):
if device_num > 1:
context.set_auto_parallel_context(device_num=device_num,
parallel_mode=ParallelMode.DATA_PARALLEL,
mirror_mean=True)
gradients_mean=True)
init()
local_data_path = os.path.join(local_data_path, str(device_id))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册