diff --git a/tutorials/source_zh_cn/advanced_use/distributed_training.md b/tutorials/source_zh_cn/advanced_use/distributed_training.md index 1582d6ea252d73d8919ae3676e00e09f584fc295..45fd0488d9cdc763e4d8217a6d295f9cfe9ac3a0 100644 --- a/tutorials/source_zh_cn/advanced_use/distributed_training.md +++ b/tutorials/source_zh_cn/advanced_use/distributed_training.md @@ -249,7 +249,7 @@ from resnet import resnet50 device_id = int(os.getenv('DEVICE_ID')) context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") -context.set_context(enable_task_sink=True, device_id=device_id) # set task_sink and device_id +context.set_context(device_id=device_id) # set device_id context.set_context(enable_loop_sink=True) def test_train_cifar(num_classes=10, epoch_size=10): @@ -263,7 +263,7 @@ def test_train_cifar(num_classes=10, epoch_size=10): model.train(epoch_size, dataset, callbacks=[loss_cb], dataset_sink_mode=True) ``` 其中, -- `dataset_sink_mode=True`,`enable_task_sink=True`,`enable_loop_sink=True`:表示采用数据集和任务的下沉模式,即训练的计算下沉到硬件平台中执行。 +- `dataset_sink_mode=True`,`enable_loop_sink=True`:表示采用数据集的下沉模式,即训练的计算下沉到硬件平台中执行。 - `LossMonitor`:能够通过回调函数返回Loss值,用于监控损失函数。 ## 运行脚本 diff --git a/tutorials/tutorial_code/distributed_training/resnet50_distributed_training.py b/tutorials/tutorial_code/distributed_training/resnet50_distributed_training.py index 5d84a7afbb5f64941460b538080bfd5b8bc42da3..6937cb1e9b72271ec18de159d22a3130b124be29 100644 --- a/tutorials/tutorial_code/distributed_training/resnet50_distributed_training.py +++ b/tutorials/tutorial_code/distributed_training/resnet50_distributed_training.py @@ -35,7 +35,7 @@ from resnet import resnet50 device_id = int(os.getenv('DEVICE_ID')) context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") -context.set_context(enable_task_sink=True, device_id=device_id) # set task_sink and device_id +context.set_context(device_id=device_id) # set device_id context.set_context(enable_loop_sink=True) context.set_context(enable_mem_reuse=False) init() diff --git a/tutorials/tutorial_code/resnet/cifar_resnet50.py b/tutorials/tutorial_code/resnet/cifar_resnet50.py index df890186b2f843e27e54609bb67d850192953cd3..40e66165ce640239113f2285fdaa49a08855de50 100644 --- a/tutorials/tutorial_code/resnet/cifar_resnet50.py +++ b/tutorials/tutorial_code/resnet/cifar_resnet50.py @@ -54,7 +54,7 @@ device_id = int(os.getenv('DEVICE_ID')) data_home = args_opt.dataset_path context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") -context.set_context(enable_task_sink=True, device_id=device_id) +context.set_context(device_id=device_id) context.set_context(enable_loop_sink=False) context.set_context(enable_mem_reuse=False) diff --git a/tutorials/tutorial_code/sample_for_cloud/resnet50_train.py b/tutorials/tutorial_code/sample_for_cloud/resnet50_train.py index 0b422a250b4b33588ec396b048e445e838d7e90c..af5ee4fa09eae8255b319936e181f9d9d875436d 100644 --- a/tutorials/tutorial_code/sample_for_cloud/resnet50_train.py +++ b/tutorials/tutorial_code/sample_for_cloud/resnet50_train.py @@ -116,7 +116,7 @@ def resnet50_train(args_opt): # set graph mode and parallel mode context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False) - context.set_context(enable_task_sink=True, device_id=device_id) + context.set_context(device_id=device_id) context.set_context(enable_loop_sink=True) context.set_context(enable_mem_reuse=True) if device_num > 1: