提交 5331a61e 编写于 作者: G gengdongjie

remove enable_task-sink param

上级 69fed9c3
...@@ -249,7 +249,7 @@ from resnet import resnet50 ...@@ -249,7 +249,7 @@ from resnet import resnet50
device_id = int(os.getenv('DEVICE_ID')) device_id = int(os.getenv('DEVICE_ID'))
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
context.set_context(enable_task_sink=True, device_id=device_id) # set task_sink and device_id context.set_context(device_id=device_id) # set device_id
context.set_context(enable_loop_sink=True) context.set_context(enable_loop_sink=True)
def test_train_cifar(num_classes=10, epoch_size=10): def test_train_cifar(num_classes=10, epoch_size=10):
...@@ -263,7 +263,7 @@ def test_train_cifar(num_classes=10, epoch_size=10): ...@@ -263,7 +263,7 @@ def test_train_cifar(num_classes=10, epoch_size=10):
model.train(epoch_size, dataset, callbacks=[loss_cb], dataset_sink_mode=True) model.train(epoch_size, dataset, callbacks=[loss_cb], dataset_sink_mode=True)
``` ```
其中, 其中,
- `dataset_sink_mode=True``enable_task_sink=True`,`enable_loop_sink=True`:表示采用数据集和任务的下沉模式,即训练的计算下沉到硬件平台中执行。 - `dataset_sink_mode=True``enable_loop_sink=True`:表示采用数据集的下沉模式,即训练的计算下沉到硬件平台中执行。
- `LossMonitor`:能够通过回调函数返回Loss值,用于监控损失函数。 - `LossMonitor`:能够通过回调函数返回Loss值,用于监控损失函数。
## 运行脚本 ## 运行脚本
......
...@@ -35,7 +35,7 @@ from resnet import resnet50 ...@@ -35,7 +35,7 @@ from resnet import resnet50
device_id = int(os.getenv('DEVICE_ID')) device_id = int(os.getenv('DEVICE_ID'))
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
context.set_context(enable_task_sink=True, device_id=device_id) # set task_sink and device_id context.set_context(device_id=device_id) # set device_id
context.set_context(enable_loop_sink=True) context.set_context(enable_loop_sink=True)
context.set_context(enable_mem_reuse=False) context.set_context(enable_mem_reuse=False)
init() init()
......
...@@ -54,7 +54,7 @@ device_id = int(os.getenv('DEVICE_ID')) ...@@ -54,7 +54,7 @@ device_id = int(os.getenv('DEVICE_ID'))
data_home = args_opt.dataset_path data_home = args_opt.dataset_path
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
context.set_context(enable_task_sink=True, device_id=device_id) context.set_context(device_id=device_id)
context.set_context(enable_loop_sink=False) context.set_context(enable_loop_sink=False)
context.set_context(enable_mem_reuse=False) context.set_context(enable_mem_reuse=False)
......
...@@ -116,7 +116,7 @@ def resnet50_train(args_opt): ...@@ -116,7 +116,7 @@ def resnet50_train(args_opt):
# set graph mode and parallel mode # set graph mode and parallel mode
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False) context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False)
context.set_context(enable_task_sink=True, device_id=device_id) context.set_context(device_id=device_id)
context.set_context(enable_loop_sink=True) context.set_context(enable_loop_sink=True)
context.set_context(enable_mem_reuse=True) context.set_context(enable_mem_reuse=True)
if device_num > 1: if device_num > 1:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册