diff --git a/chapter03/lenet/main.py b/chapter03/lenet/main.py index e39a37d3979111520c672131a82a463c1e2ccae5..87813265d1fe9f4c2eb698c6f4034379bdf29a4f 100644 --- a/chapter03/lenet/main.py +++ b/chapter03/lenet/main.py @@ -88,7 +88,7 @@ if __name__ == "__main__": args = parser.parse_args() - context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, enable_mem_reuse=False) + context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target) network = LeNet5(cfg.num_classes) net_loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction="mean") diff --git a/chapter04/alexnet/main.py b/chapter04/alexnet/main.py index 8ac57657adc52d1488506a7e6b7c2c4b4978cbdc..769fa2cf2fa9c623249a619b2d566d71c8120cd5 100644 --- a/chapter04/alexnet/main.py +++ b/chapter04/alexnet/main.py @@ -78,7 +78,7 @@ if __name__ == "__main__": parser.add_argument('--dataset_sink_mode', type=bool, default=False, help='dataset_sink_mode is False or True') args = parser.parse_args() - context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, enable_mem_reuse=False) + context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target) network = AlexNet(cfg.num_classes) loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction="mean") diff --git a/chapter05/resnet/resnet_cifar.py b/chapter05/resnet/resnet_cifar.py index 0e616c71d3a84d6b8264a17b9b156e5fcaac3a12..4bf56be66fea681ed25a8ab20add4046ae960219 100644 --- a/chapter05/resnet/resnet_cifar.py +++ b/chapter05/resnet/resnet_cifar.py @@ -66,8 +66,6 @@ if args_opt.device_target == "Ascend": #Choose one availabe Device to use on users' env. device_id = int(os.getenv('DEVICE_ID')) context.set_context(device_id=device_id) - context.set_context(enable_loop_sink=True) - context.set_context(enable_mem_reuse=False) def create_dataset(repeat_num=1, training=True): """create the dataset of cifar10""" diff --git a/chapter07/Bert_NEZHA_cnwiki/train.py b/chapter07/Bert_NEZHA_cnwiki/train.py index 76e481ba204758a667d89060c4a84430462aa339..ee278a488c774044622e9a8f7c75e6ee0f720409 100644 --- a/chapter07/Bert_NEZHA_cnwiki/train.py +++ b/chapter07/Bert_NEZHA_cnwiki/train.py @@ -74,8 +74,6 @@ def train_bert(): """train bert""" context.set_context(mode=context.GRAPH_MODE) context.set_context(device_target="Ascend") - context.set_context(enable_loop_sink=True) - context.set_context(enable_mem_reuse=True) ds = create_train_dataset(bert_net_cfg.batch_size) netwithloss = BertNetworkWithLoss(bert_net_cfg, True) optimizer = Lamb(netwithloss.trainable_params(), decay_steps=bert_train_cfg.decay_steps,