提交 2cebcf25 编写于 作者: J jinyaohui

remove some context params

上级 5b316eb7
......@@ -99,8 +99,7 @@ if __name__ == "__main__":
parser.add_argument('--device_target', type=str, default="CPU", choices=['Ascend', 'GPU', 'CPU'],
help='device where the code will be implemented (default: CPU)')
args = parser.parse_args()
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target,
enable_mem_reuse=False)
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
...
```
......
......@@ -250,7 +250,6 @@ from resnet import resnet50
device_id = int(os.getenv('DEVICE_ID'))
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
context.set_context(device_id=device_id) # set device_id
context.set_context(enable_loop_sink=True)
def test_train_cifar(num_classes=10, epoch_size=10):
context.set_auto_parallel_context(parallel_mode=ParallelMode.AUTO_PARALLEL, mirror_mean=True)
......@@ -263,7 +262,7 @@ def test_train_cifar(num_classes=10, epoch_size=10):
model.train(epoch_size, dataset, callbacks=[loss_cb], dataset_sink_mode=True)
```
其中,
- `dataset_sink_mode=True``enable_loop_sink=True`:表示采用数据集的下沉模式,即训练的计算下沉到硬件平台中执行。
- `dataset_sink_mode=True`:表示采用数据集的下沉模式,即训练的计算下沉到硬件平台中执行。
- `LossMonitor`:能够通过回调函数返回Loss值,用于监控损失函数。
## 运行脚本
......
......@@ -101,8 +101,7 @@ if __name__ == "__main__":
parser.add_argument('--device_target', type=str, default="CPU", choices=['Ascend', 'GPU', 'CPU'],
help='device where the code will be implemented (default: CPU)')
args = parser.parse_args()
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target,
enable_mem_reuse=False)
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
...
```
......
......@@ -36,8 +36,6 @@ from resnet import resnet50
device_id = int(os.getenv('DEVICE_ID'))
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
context.set_context(device_id=device_id) # set device_id
context.set_context(enable_loop_sink=True)
context.set_context(enable_mem_reuse=False)
init()
rank_id = get_rank()
......
......@@ -195,8 +195,7 @@ if __name__ == "__main__":
parser.add_argument('--device_target', type=str, default="CPU", choices=['Ascend', 'GPU', 'CPU'],
help='device where the code will be implemented (default: CPU)')
args = parser.parse_args()
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target,
enable_mem_reuse=False)
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
# download mnist dataset
download_dataset()
# learning rate setting
......
......@@ -55,8 +55,6 @@ data_home = args_opt.dataset_path
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
context.set_context(device_id=device_id)
context.set_context(enable_loop_sink=False)
context.set_context(enable_mem_reuse=False)
def create_dataset(repeat_num=1, training=True):
"""
......
......@@ -117,8 +117,6 @@ def resnet50_train(args_opt):
# set graph mode and parallel mode
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False)
context.set_context(device_id=device_id)
context.set_context(enable_loop_sink=True)
context.set_context(enable_mem_reuse=True)
if device_num > 1:
context.set_auto_parallel_context(device_num=device_num,
parallel_mode=ParallelMode.DATA_PARALLEL,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册