提交 2aa4235f 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!96 Remove enable_hccl api

Merge pull request !96 from zjun/remove_enbale_hccl
......@@ -84,7 +84,7 @@ export DEVICE_ID=0
### Invoking the Collective Communication Library
You need to enable the distributed API `enable_hccl` in the `context.set_context()` API, set the `device_id` parameter, and invoke `init()` to complete the initialization operation.
You need to set the `device_id` parameter, and invoke `init()` to complete the initialization operation.
In the sample, the graph mode is used during runtime. On the Ascend AI processor, Huawei Collective Communication Library (HCCL) is used.
......@@ -94,7 +94,7 @@ from mindspore import context
from mindspore.communication.management import init
if __name__ == "__main__":
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", enable_hccl=True, device_id=int(os.environ["DEVICE_ID"]))
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=int(os.environ["DEVICE_ID"]))
init()
...
```
......
......@@ -97,16 +97,15 @@ from mindspore import context
from mindspore.communication.management import init
if __name__ == "__main__":
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", enable_hccl=True, device_id=int(os.environ["DEVICE_ID"]))
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=int(os.environ["DEVICE_ID"]))
init()
...
```
其中,
- `mode=context.GRAPH_MODE`:使用分布式训练需要指定运行模式为图模式(PyNative模式不支持并行)。
- `enable_hccl=True`:使能HCCL通信。
- `device_id`:卡的物理序号,即卡所在机器中的实际序号。
- `init()`:完成分布式训练初始化操作。
- `init()`使能HCCL通信,并完成分布式训练初始化操作。
## 数据并行模式加载数据集
......@@ -240,7 +239,6 @@ from resnet import resnet50
device_id = int(os.getenv('DEVICE_ID'))
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
context.set_context(enable_task_sink=True, device_id=device_id) # set task_sink and device_id
context.set_context(enable_hccl=True) # set enable_hccl
context.set_context(enable_loop_sink=True)
def test_train_cifar(num_classes=10, epoch_size=10):
......
......@@ -36,7 +36,6 @@ from resnet import resnet50
device_id = int(os.getenv('DEVICE_ID'))
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
context.set_context(enable_task_sink=True, device_id=device_id) # set task_sink and device_id
context.set_context(enable_hccl=True) # set enable_hccl
context.set_context(enable_loop_sink=True)
context.set_context(enable_mem_reuse=False)
init()
......
......@@ -106,16 +106,10 @@ def create_dataset(repeat_num=1, training=True):
if __name__ == '__main__':
# in this way by judging the mark of args, users will decide which function to use
if args_opt.do_eval:
context.set_context(enable_hccl=False)
else:
if args_opt.run_distribute:
context.set_context(enable_hccl=True)
context.set_auto_parallel_context(device_num=args_opt.device_num, parallel_mode=ParallelMode.DATA_PARALLEL)
auto_parallel_context().set_all_reduce_fusion_split_indices([140])
init()
else:
context.set_context(enable_hccl=False)
if not args_opt.do_eval and args_opt.run_distribute:
context.set_auto_parallel_context(device_num=args_opt.device_num, parallel_mode=ParallelMode.DATA_PARALLEL)
auto_parallel_context().set_all_reduce_fusion_split_indices([140])
init()
epoch_size = args_opt.epoch_size
net = resnet50(args_opt.batch_size, args_opt.num_classes)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册