提交 2cebcf25 编写于 作者: J jinyaohui

remove some context params

上级 5b316eb7
...@@ -99,8 +99,7 @@ if __name__ == "__main__": ...@@ -99,8 +99,7 @@ if __name__ == "__main__":
parser.add_argument('--device_target', type=str, default="CPU", choices=['Ascend', 'GPU', 'CPU'], parser.add_argument('--device_target', type=str, default="CPU", choices=['Ascend', 'GPU', 'CPU'],
help='device where the code will be implemented (default: CPU)') help='device where the code will be implemented (default: CPU)')
args = parser.parse_args() args = parser.parse_args()
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
enable_mem_reuse=False)
... ...
``` ```
......
...@@ -250,7 +250,6 @@ from resnet import resnet50 ...@@ -250,7 +250,6 @@ from resnet import resnet50
device_id = int(os.getenv('DEVICE_ID')) device_id = int(os.getenv('DEVICE_ID'))
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
context.set_context(device_id=device_id) # set device_id context.set_context(device_id=device_id) # set device_id
context.set_context(enable_loop_sink=True)
def test_train_cifar(num_classes=10, epoch_size=10): def test_train_cifar(num_classes=10, epoch_size=10):
context.set_auto_parallel_context(parallel_mode=ParallelMode.AUTO_PARALLEL, mirror_mean=True) context.set_auto_parallel_context(parallel_mode=ParallelMode.AUTO_PARALLEL, mirror_mean=True)
...@@ -263,7 +262,7 @@ def test_train_cifar(num_classes=10, epoch_size=10): ...@@ -263,7 +262,7 @@ def test_train_cifar(num_classes=10, epoch_size=10):
model.train(epoch_size, dataset, callbacks=[loss_cb], dataset_sink_mode=True) model.train(epoch_size, dataset, callbacks=[loss_cb], dataset_sink_mode=True)
``` ```
其中, 其中,
- `dataset_sink_mode=True``enable_loop_sink=True`:表示采用数据集的下沉模式,即训练的计算下沉到硬件平台中执行。 - `dataset_sink_mode=True`:表示采用数据集的下沉模式,即训练的计算下沉到硬件平台中执行。
- `LossMonitor`:能够通过回调函数返回Loss值,用于监控损失函数。 - `LossMonitor`:能够通过回调函数返回Loss值,用于监控损失函数。
## 运行脚本 ## 运行脚本
......
...@@ -101,8 +101,7 @@ if __name__ == "__main__": ...@@ -101,8 +101,7 @@ if __name__ == "__main__":
parser.add_argument('--device_target', type=str, default="CPU", choices=['Ascend', 'GPU', 'CPU'], parser.add_argument('--device_target', type=str, default="CPU", choices=['Ascend', 'GPU', 'CPU'],
help='device where the code will be implemented (default: CPU)') help='device where the code will be implemented (default: CPU)')
args = parser.parse_args() args = parser.parse_args()
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
enable_mem_reuse=False)
... ...
``` ```
......
...@@ -36,8 +36,6 @@ from resnet import resnet50 ...@@ -36,8 +36,6 @@ from resnet import resnet50
device_id = int(os.getenv('DEVICE_ID')) device_id = int(os.getenv('DEVICE_ID'))
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
context.set_context(device_id=device_id) # set device_id context.set_context(device_id=device_id) # set device_id
context.set_context(enable_loop_sink=True)
context.set_context(enable_mem_reuse=False)
init() init()
rank_id = get_rank() rank_id = get_rank()
......
...@@ -195,8 +195,7 @@ if __name__ == "__main__": ...@@ -195,8 +195,7 @@ if __name__ == "__main__":
parser.add_argument('--device_target', type=str, default="CPU", choices=['Ascend', 'GPU', 'CPU'], parser.add_argument('--device_target', type=str, default="CPU", choices=['Ascend', 'GPU', 'CPU'],
help='device where the code will be implemented (default: CPU)') help='device where the code will be implemented (default: CPU)')
args = parser.parse_args() args = parser.parse_args()
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
enable_mem_reuse=False)
# download mnist dataset # download mnist dataset
download_dataset() download_dataset()
# learning rate setting # learning rate setting
......
...@@ -55,8 +55,6 @@ data_home = args_opt.dataset_path ...@@ -55,8 +55,6 @@ data_home = args_opt.dataset_path
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
context.set_context(device_id=device_id) context.set_context(device_id=device_id)
context.set_context(enable_loop_sink=False)
context.set_context(enable_mem_reuse=False)
def create_dataset(repeat_num=1, training=True): def create_dataset(repeat_num=1, training=True):
""" """
......
...@@ -117,8 +117,6 @@ def resnet50_train(args_opt): ...@@ -117,8 +117,6 @@ def resnet50_train(args_opt):
# set graph mode and parallel mode # set graph mode and parallel mode
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False) context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False)
context.set_context(device_id=device_id) context.set_context(device_id=device_id)
context.set_context(enable_loop_sink=True)
context.set_context(enable_mem_reuse=True)
if device_num > 1: if device_num > 1:
context.set_auto_parallel_context(device_num=device_num, context.set_auto_parallel_context(device_num=device_num,
parallel_mode=ParallelMode.DATA_PARALLEL, parallel_mode=ParallelMode.DATA_PARALLEL,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册