diff --git a/example/alexnet_cifar10/eval.py b/example/alexnet_cifar10/eval.py
index be71e339950c68e34d622475fadc7f5d263e3b4c..2efc6d15f69b693297872fbbc84ce0f568499ff8 100644
--- a/example/alexnet_cifar10/eval.py
+++ b/example/alexnet_cifar10/eval.py
@@ -39,7 +39,7 @@ if __name__ == "__main__":
     parser.add_argument('--dataset_sink_mode', type=bool, default=False, help='dataset_sink_mode is False or True')
     args = parser.parse_args()
 
-    context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, enable_mem_reuse=False)
+    context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
 
     network = AlexNet(cfg.num_classes)
     loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction="mean")
diff --git a/example/alexnet_cifar10/train.py b/example/alexnet_cifar10/train.py
index b97843902dd6afaf8cf24a128449efe429cf1c65..622df2d40420d209022a88afa20585b7bcba49e3 100644
--- a/example/alexnet_cifar10/train.py
+++ b/example/alexnet_cifar10/train.py
@@ -39,7 +39,7 @@ if __name__ == "__main__":
     parser.add_argument('--dataset_sink_mode', type=bool, default=False, help='dataset_sink_mode is False or True')
     args = parser.parse_args()
 
-    context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, enable_mem_reuse=False)
+    context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
 
     network = AlexNet(cfg.num_classes)
     loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction="mean")
diff --git a/example/bert_clue/README.md b/example/bert_clue/README.md
index f61cb5ddc550f02b2adb3155aa88ad40fcf0443c..01e0913411babf58343cd3a13312792f8dd961c5 100644
--- a/example/bert_clue/README.md
+++ b/example/bert_clue/README.md
@@ -46,8 +46,7 @@ This example implements pre-training, fine-tuning and evaluation of [BERT-base](
 ### Pre-Training
 ``` 
 usage: run_pretrain.py  [--distribute DISTRIBUTE] [--epoch_size N] [----device_num N] [--device_id N] 
-                        [--enable_task_sink ENABLE_TASK_SINK] [--enable_loop_sink ENABLE_LOOP_SINK]
-                        [--enable_mem_reuse ENABLE_MEM_REUSE] [--enable_save_ckpt ENABLE_SAVE_CKPT]
+                        [--enable_save_ckpt ENABLE_SAVE_CKPT]
                         [--enable_lossscale ENABLE_LOSSSCALE] [--do_shuffle DO_SHUFFLE]
                         [--enable_data_sink ENABLE_DATA_SINK] [--data_sink_steps N] [--checkpoint_path CHECKPOINT_PATH]
                         [--save_checkpoint_steps N] [--save_checkpoint_num N] 
@@ -58,8 +57,6 @@ options:
     --epoch_size               epoch size: N, default is 1
     --device_num               number of used devices: N, default is 1
     --device_id                device id: N, default is 0
-    --enable_loop_sink         enable loop sink: "true" | "false", default is "true"
-    --enable_mem_reuse         enable memory reuse: "true" | "false", default is "true"
     --enable_save_ckpt         enable save checkpoint: "true" | "false", default is "true"
     --enable_lossscale         enable lossscale: "true" | "false", default is "true"
     --do_shuffle               enable shuffle: "true" | "false", default is "true"
diff --git a/example/bert_clue/finetune.py b/example/bert_clue/finetune.py
index d3cd22a3bd6816dc57f0e27d9eab8b2edeef224a..ee62d940b57dfb48224f7046cb8e3f61d0297ec2 100644
--- a/example/bert_clue/finetune.py
+++ b/example/bert_clue/finetune.py
@@ -83,8 +83,7 @@ def test_train():
     pytest -s finetune.py::test_train
     '''
     devid = int(os.getenv('DEVICE_ID'))
-    context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=devid,
-                        enable_mem_reuse=True, enable_task_sink=True)
+    context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=devid)
     #BertCLSTrain for classification
     #BertNERTrain for sequence labeling
     if cfg.task == 'NER':
diff --git a/example/bert_clue/run_distribute_pretrain.sh b/example/bert_clue/run_distribute_pretrain.sh
index aeef7b04d67578a1f0f3e8a73db3362bc901e44d..6c726027d7910f5de96d53586f4650e620ef22b7 100644
--- a/example/bert_clue/run_distribute_pretrain.sh
+++ b/example/bert_clue/run_distribute_pretrain.sh
@@ -50,8 +50,6 @@ do
     --epoch_size=$EPOCH_SIZE \
     --device_id=$DEVICE_ID \
     --device_num=$RANK_SIZE \
-    --enable_loop_sink="true" \
-    --enable_mem_reuse="true" \
     --enable_save_ckpt="true" \
     --enable_lossscale="true" \
     --do_shuffle="true" \
diff --git a/example/bert_clue/run_pretrain.py b/example/bert_clue/run_pretrain.py
index 4fa09347f9bfedba00db80eceedad0708831720a..2209176d6b568297001ad0c66c0c5034378c42df 100644
--- a/example/bert_clue/run_pretrain.py
+++ b/example/bert_clue/run_pretrain.py
@@ -59,8 +59,6 @@ def run_pretrain():
     parser.add_argument("--epoch_size", type=int, default="1", help="Epoch size, default is 1.")
     parser.add_argument("--device_id", type=int, default=0, help="Device id, default is 0.")
     parser.add_argument("--device_num", type=int, default=1, help="Use device nums, default is 1.")
-    parser.add_argument("--enable_loop_sink", type=str, default="true", help="Enable loop sink, default is true.")
-    parser.add_argument("--enable_mem_reuse", type=str, default="true", help="Enable mem reuse, default is true.")
     parser.add_argument("--enable_save_ckpt", type=str, default="true", help="Enable save checkpoint, default is true.")
     parser.add_argument("--enable_lossscale", type=str, default="true", help="Use lossscale or not, default is not.")
     parser.add_argument("--do_shuffle", type=str, default="true", help="Enable shuffle for dataset, default is true.")
@@ -75,8 +73,6 @@ def run_pretrain():
 
     args_opt = parser.parse_args()
     context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=args_opt.device_id)
-    context.set_context(enable_loop_sink=(args_opt.enable_loop_sink == "true"),
-                        enable_mem_reuse=(args_opt.enable_mem_reuse == "true"))
     context.set_context(reserve_class_name_in_scope=False)
 
     if args_opt.distribute == "true":
diff --git a/example/bert_clue/run_standalone_pretrain.sh b/example/bert_clue/run_standalone_pretrain.sh
index 94d769fc1100911b46c963d2a7b32da1f7868d5d..7795a4e46df63e0b04236e525e9678788c2b961c 100644
--- a/example/bert_clue/run_standalone_pretrain.sh
+++ b/example/bert_clue/run_standalone_pretrain.sh
@@ -29,8 +29,6 @@ python run_pretrain.py  \
     --distribute="false" \
     --epoch_size=$EPOCH_SIZE \
     --device_id=$DEVICE_ID \
-    --enable_loop_sink="true" \
-    --enable_mem_reuse="true" \
     --enable_save_ckpt="true" \
     --enable_lossscale="true" \
     --do_shuffle="true" \
diff --git a/example/googlenet_cifar10/eval.py b/example/googlenet_cifar10/eval.py
index 8674e97e1e9c7903d3b283a8530df3b1246af3c7..cc09539aa782311a3679a7bda0953b9ffa2de3b2 100644
--- a/example/googlenet_cifar10/eval.py
+++ b/example/googlenet_cifar10/eval.py
@@ -40,7 +40,6 @@ if __name__ == '__main__':
 
     context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target)
     context.set_context(device_id=args_opt.device_id)
-    context.set_context(enable_mem_reuse=True)
 
     net = GooGLeNet(num_classes=cfg.num_classes)
     opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), 0.01, cfg.momentum,
diff --git a/example/googlenet_cifar10/train.py b/example/googlenet_cifar10/train.py
index 6f98013251588cc155cff1edc4ebfffa39f6c703..bee0297bb3d584b3b5166fa64fbb0c938d1a0603 100644
--- a/example/googlenet_cifar10/train.py
+++ b/example/googlenet_cifar10/train.py
@@ -70,8 +70,6 @@ if __name__ == '__main__':
 
     context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target)
     context.set_context(device_id=args_opt.device_id)
-    context.set_context(enable_loop_sink=True)
-    context.set_context(enable_mem_reuse=True)
 
     device_num = int(os.environ.get("DEVICE_NUM", 1))
     if device_num > 1:
diff --git a/example/lenet_mnist/eval.py b/example/lenet_mnist/eval.py
index 3473a995328519245a0832590fd5d281c075cd8a..8317785a66138059995912fa2ddb3d5a05f2411c 100644
--- a/example/lenet_mnist/eval.py
+++ b/example/lenet_mnist/eval.py
@@ -43,7 +43,7 @@ if __name__ == "__main__":
 
     args = parser.parse_args()
 
-    context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, enable_mem_reuse=False)
+    context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
 
     network = LeNet5(cfg.num_classes)
     net_loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction="mean")
diff --git a/example/lenet_mnist/train.py b/example/lenet_mnist/train.py
index 2fa8d3c27f4c944f179a235324dbf8d27f8b7373..d58d1a101b415ab99ccc83c9f3309322c7f29f6a 100644
--- a/example/lenet_mnist/train.py
+++ b/example/lenet_mnist/train.py
@@ -40,7 +40,7 @@ if __name__ == "__main__":
 
     args = parser.parse_args()
 
-    context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, enable_mem_reuse=False)
+    context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
 
     network = LeNet5(cfg.num_classes)
     net_loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction="mean")
diff --git a/example/mobilenetv2_imagenet2012/eval.py b/example/mobilenetv2_imagenet2012/eval.py
index 71986c1b22bbbb7c81522f5665b2b51892fc446f..79df8ea8f2e4058ff01b17bedc4dced5be1374ca 100644
--- a/example/mobilenetv2_imagenet2012/eval.py
+++ b/example/mobilenetv2_imagenet2012/eval.py
@@ -34,8 +34,6 @@ args_opt = parser.parse_args()
 device_id = int(os.getenv('DEVICE_ID'))
 
 context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=device_id, save_graphs=False)
-context.set_context(enable_loop_sink=True)
-context.set_context(enable_mem_reuse=True)
 
 if __name__ == '__main__':
     loss = SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction='mean')
diff --git a/example/mobilenetv2_imagenet2012/train.py b/example/mobilenetv2_imagenet2012/train.py
index 1bdd2992cb5ca213306555baef84ce6ff99a8a8b..72dfe788578f94ab3975fe57fe1142f11d1c6b60 100644
--- a/example/mobilenetv2_imagenet2012/train.py
+++ b/example/mobilenetv2_imagenet2012/train.py
@@ -56,8 +56,6 @@ rank_size = int(os.getenv('RANK_SIZE'))
 run_distribute = rank_size > 1
 
 context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=device_id, save_graphs=False)
-context.set_context(enable_loop_sink=True)
-context.set_context(enable_mem_reuse=True)
 
 class CrossEntropyWithLabelSmooth(_Loss):
     """
diff --git a/example/resnet101_imagenet2012/eval.py b/example/resnet101_imagenet2012/eval.py
index 5bc651a969bfaa641fa66d1bc878debf6772d497..88d942866be7c5a5d488107fc819f225446ca4cb 100755
--- a/example/resnet101_imagenet2012/eval.py
+++ b/example/resnet101_imagenet2012/eval.py
@@ -46,8 +46,6 @@ args_opt = parser.parse_args()
 device_id = int(os.getenv('DEVICE_ID'))
 
 context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False, device_id=device_id)
-context.set_context(enable_loop_sink=True)
-context.set_context(enable_mem_reuse=True)
 
 if __name__ == '__main__':
     if not args_opt.do_eval and args_opt.run_distribute:
diff --git a/example/resnet101_imagenet2012/train.py b/example/resnet101_imagenet2012/train.py
index 2a049db425c7f96e3a58642653c0f69c1948f784..cfe87d16a6d1209938888442c1bad134d6435dd8 100755
--- a/example/resnet101_imagenet2012/train.py
+++ b/example/resnet101_imagenet2012/train.py
@@ -49,8 +49,6 @@ args_opt = parser.parse_args()
 device_id = int(os.getenv('DEVICE_ID'))
 
 context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False, device_id=device_id)
-context.set_context(enable_loop_sink=True)
-context.set_context(enable_mem_reuse=True)
 
 if __name__ == '__main__':
     if not args_opt.do_eval and args_opt.run_distribute:
diff --git a/example/resnet50_cifar10/eval.py b/example/resnet50_cifar10/eval.py
index 872f27d7285ba04b39b28439a756391df8377510..e6f02360d81ff96c88b6974e347e5f50b6403a26 100755
--- a/example/resnet50_cifar10/eval.py
+++ b/example/resnet50_cifar10/eval.py
@@ -40,8 +40,6 @@ device_id = int(os.getenv('DEVICE_ID'))
 
 context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False)
 context.set_context(device_id=device_id)
-context.set_context(enable_loop_sink=True)
-context.set_context(enable_mem_reuse=True)
 
 if __name__ == '__main__':
     if not args_opt.do_eval and args_opt.run_distribute:
diff --git a/example/resnet50_cifar10/train.py b/example/resnet50_cifar10/train.py
index 448ea9b05906e10bbc24c92a1954ce4c43be0be3..86a373c2dc8dabc0709c6df041b51a8e9e2f50ce 100755
--- a/example/resnet50_cifar10/train.py
+++ b/example/resnet50_cifar10/train.py
@@ -43,8 +43,6 @@ device_id = int(os.getenv('DEVICE_ID'))
 
 context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False)
 context.set_context(device_id=device_id)
-context.set_context(enable_loop_sink=True)
-context.set_context(enable_mem_reuse=True)
 
 if __name__ == '__main__':
     if not args_opt.do_eval and args_opt.run_distribute:
diff --git a/example/resnet50_imagenet2012/eval.py b/example/resnet50_imagenet2012/eval.py
index 1db83a4715b5807b9c48d725ba85df9ab50f5035..a19807ee9cf0a680a196fdf41a452dc3f1b3f5a3 100755
--- a/example/resnet50_imagenet2012/eval.py
+++ b/example/resnet50_imagenet2012/eval.py
@@ -37,9 +37,7 @@ args_opt = parser.parse_args()
 device_id = int(os.getenv('DEVICE_ID'))
 
 context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False)
-context.set_context(enable_task_sink=True, device_id=device_id)
-context.set_context(enable_loop_sink=True)
-context.set_context(enable_mem_reuse=True)
+context.set_context(device_id=device_id)
 
 if __name__ == '__main__':
 
diff --git a/example/resnet50_imagenet2012/train.py b/example/resnet50_imagenet2012/train.py
index d87189f8a7f4994cdb5e8f54a3f44a69b35693c4..d050d96ec999e2cdf0f105a937cea3111fcee905 100755
--- a/example/resnet50_imagenet2012/train.py
+++ b/example/resnet50_imagenet2012/train.py
@@ -44,9 +44,7 @@ args_opt = parser.parse_args()
 device_id = int(os.getenv('DEVICE_ID'))
 
 context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False)
-context.set_context(enable_task_sink=True, device_id=device_id)
-context.set_context(enable_loop_sink=True)
-context.set_context(enable_mem_reuse=True)
+context.set_context(device_id=device_id)
 
 if __name__ == '__main__':
     if not args_opt.do_eval and args_opt.run_distribute:
diff --git a/example/ssd_coco2017/eval.py b/example/ssd_coco2017/eval.py
index 8612a3779821d66c7a6efdcb2ccb1b5bd10e88a4..c0af504de22d0d430e4670796420dc10af318ab7 100644
--- a/example/ssd_coco2017/eval.py
+++ b/example/ssd_coco2017/eval.py
@@ -71,7 +71,6 @@ if __name__ == '__main__':
     args_opt = parser.parse_args()
 
     context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=args_opt.device_id)
-    context.set_context(enable_loop_sink=True, enable_mem_reuse=True)
 
     config = ConfigSSD()
     prefix = "ssd_eval.mindrecord"
diff --git a/example/ssd_coco2017/train.py b/example/ssd_coco2017/train.py
index 9a6a4ece7075db9e960d0a32675859706a727445..a89d558c65084e4c5e609289474c0fbaba8b9adb 100644
--- a/example/ssd_coco2017/train.py
+++ b/example/ssd_coco2017/train.py
@@ -93,7 +93,6 @@ def main():
     args_opt = parser.parse_args()
 
     context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=args_opt.device_id)
-    context.set_context(enable_loop_sink=True, enable_mem_reuse=True)
 
     if args_opt.distribute:
         device_num = args_opt.device_num
diff --git a/example/vgg16_cifar10/eval.py b/example/vgg16_cifar10/eval.py
index 68c23d250ff4063ba19ec875a3ff21310b209524..ec9fc607c2d1590976af3ab97db81c12e940ac2b 100644
--- a/example/vgg16_cifar10/eval.py
+++ b/example/vgg16_cifar10/eval.py
@@ -37,7 +37,6 @@ if __name__ == '__main__':
 
     context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target)
     context.set_context(device_id=args_opt.device_id)
-    context.set_context(enable_mem_reuse=True)
 
     net = vgg16(num_classes=cfg.num_classes)
     opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), 0.01, cfg.momentum,
diff --git a/example/vgg16_cifar10/train.py b/example/vgg16_cifar10/train.py
index fcf5ea701037b602485ef114e1450c844f2fd922..9993db706a42493679e82f233a98bc05c7b080a3 100644
--- a/example/vgg16_cifar10/train.py
+++ b/example/vgg16_cifar10/train.py
@@ -64,8 +64,6 @@ if __name__ == '__main__':
 
     context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target)
     context.set_context(device_id=args_opt.device_id)
-    context.set_context(enable_loop_sink=True)
-    context.set_context(enable_mem_reuse=True)
 
     device_num = int(os.environ.get("DEVICE_NUM", 1))
     if device_num > 1:
diff --git a/example/yolov3_coco2017/eval.py b/example/yolov3_coco2017/eval.py
index 3bc3027260f0dc23fa066b93cc3bc751a7755af2..6e6d35824820362773cf9f69417f38e996372630 100644
--- a/example/yolov3_coco2017/eval.py
+++ b/example/yolov3_coco2017/eval.py
@@ -82,7 +82,6 @@ if __name__ == '__main__':
     args_opt = parser.parse_args()
 
     context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=args_opt.device_id)
-    context.set_context(enable_loop_sink=True, enable_mem_reuse=True)
 
     # It will generate mindrecord file in args_opt.mindrecord_dir,
     # and the file name is yolo.mindrecord0, 1, ... file_num.
diff --git a/example/yolov3_coco2017/train.py b/example/yolov3_coco2017/train.py
index 1aa72c4de18f21f87c8076d92431483fef62d912..cfa3580b86e8a0456388c7b801651b0cbe1ec45f 100644
--- a/example/yolov3_coco2017/train.py
+++ b/example/yolov3_coco2017/train.py
@@ -84,7 +84,6 @@ def main():
     args_opt = parser.parse_args()
 
     context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=args_opt.device_id)
-    context.set_context(enable_loop_sink=True, enable_mem_reuse=True)
     if args_opt.distribute:
         device_num = args_opt.device_num
         context.reset_auto_parallel_context()
diff --git a/mindspore/ccsrc/debug/e2e_dump.cc b/mindspore/ccsrc/debug/e2e_dump.cc
index 34d401191a2909a7f2062bd84a240da724ffd1fb..78a331fc278d4988f67021efacff3513aa640ae7 100644
--- a/mindspore/ccsrc/debug/e2e_dump.cc
+++ b/mindspore/ccsrc/debug/e2e_dump.cc
@@ -107,6 +107,10 @@ bool Dump::IsConfigValid(const nlohmann::json &dumpSettings) {
   }
 
   dump_enable_ = enable;
+  auto context_ptr = MsContext::GetInstance();
+  MS_EXCEPTION_IF_NULL(context_ptr);
+  // dump_enable_ is true, close mem reuse
+  context_ptr->set_enable_mem_reuse(!dump_enable_);
   trans_flag_ = trans_flag;
   dump_mode_ = mode;
   dump_path_ = path;
diff --git a/mindspore/ccsrc/pipeline/init.cc b/mindspore/ccsrc/pipeline/init.cc
index 90bae0a3893dbd51242909669d30d627ddc6ff8c..7c663291c043dafe8bd9a4d11fc63a5b5404304d 100644
--- a/mindspore/ccsrc/pipeline/init.cc
+++ b/mindspore/ccsrc/pipeline/init.cc
@@ -117,20 +117,12 @@ PYBIND11_MODULE(_c_expression, m) {
     .def("close_tsd", &mindspore::MsContext::CloseTsd, "Close tdt dataset client.")
     .def("get_save_graphs_flag", &mindspore::MsContext::save_graphs_flag, "Get whether to save graphs.")
     .def("set_save_graphs_flag", &mindspore::MsContext::set_save_graphs_flag, "Set whether to save graphs.")
-    .def("get_auto_mixed_precision_flag", &mindspore::MsContext::auto_mixed_precision_flag,
-         "Get whether to enable auto mixed precision.")
-    .def("set_auto_mixed_precision_flag", &mindspore::MsContext::set_auto_mixed_precision_flag,
-         "Set whether to enable auto mixed precision.")
     .def("get_enable_reduce_precision_flag", &mindspore::MsContext::enable_reduce_precision,
          "Get whether to enable reduce precision.")
     .def("set_enable_reduce_precision_flag", &mindspore::MsContext::set_enable_reduce_precision,
          "Set whether to enable reduce precision.")
     .def("get_save_graphs_path", &mindspore::MsContext::save_graphs_path, "Get save graphs path.")
     .def("set_save_graphs_path", &mindspore::MsContext::set_save_graphs_path, "Set save graphs path.")
-    .def("get_loop_sink_flag", &mindspore::MsContext::loop_sink_flag, "Get whether to enable loop sink.")
-    .def("set_loop_sink_flag", &mindspore::MsContext::set_loop_sink_flag, "Set whether to enable loop sink.")
-    .def("get_enable_mem_reuse", &mindspore::MsContext::enable_mem_reuse, "Get whether to enable mem reuse.")
-    .def("set_enable_mem_reuse", &mindspore::MsContext::set_enable_mem_reuse, "Set whether to enable mem reuse.")
     .def("get_save_ms_model_flag", &mindspore::MsContext::save_ms_model_flag, "Get whether to save ms model.")
     .def("set_save_ms_model_flag", &mindspore::MsContext::set_save_ms_model_flag, "Set whether to save ms model.")
     .def("get_save_ms_model_path", &mindspore::MsContext::save_ms_model_path, "Get path to save ms model.")
diff --git a/mindspore/ccsrc/utils/context/ms_context.h b/mindspore/ccsrc/utils/context/ms_context.h
index ef13d2a1bd2ee9633b3c55a9632592f926845fde..f4e52950cd6cc7401cfb971e35c8c7f8506eccf4 100644
--- a/mindspore/ccsrc/utils/context/ms_context.h
+++ b/mindspore/ccsrc/utils/context/ms_context.h
@@ -91,7 +91,6 @@ class MsContext {
 
   bool ir_fusion_flag() const { return ir_fusion_flag_; }
 
-  void set_loop_sink_flag(bool loop_sink_flag) { enable_loop_sink_ = loop_sink_flag; }
   bool loop_sink_flag() const { return enable_loop_sink_; }
 
   void set_enable_mem_reuse(bool enable_mem_reuse) { enable_mem_reuse_ = enable_mem_reuse; }
@@ -106,11 +105,6 @@ class MsContext {
   void set_enable_gpu_summary(bool enable_gpu_summary) { enable_gpu_summary_ = enable_gpu_summary; }
   bool enable_gpu_summary() const { return enable_gpu_summary_; }
 
-  void set_auto_mixed_precision_flag(bool auto_mixed_precision_flag) {
-    auto_mixed_precision_flag_ = auto_mixed_precision_flag;
-  }
-  bool auto_mixed_precision_flag() const { return auto_mixed_precision_flag_; }
-
   void set_enable_reduce_precision(bool flag) { enable_reduce_precision_ = flag; }
   bool enable_reduce_precision() const { return enable_reduce_precision_; }
 
diff --git a/mindspore/context.py b/mindspore/context.py
index 99307a7ac26c7d8173c42fdae2167e86c8f6f5dd..0c56a069416ec437292384b20c75fb8072a7a14e 100644
--- a/mindspore/context.py
+++ b/mindspore/context.py
@@ -31,6 +31,8 @@ __all__ = ['GRAPH_MODE', 'PYNATIVE_MODE', 'set_context', 'get_context', 'set_aut
 
 GRAPH_MODE = 0
 PYNATIVE_MODE = 1
+# The max memory size of graph plus variable.
+_DEVICE_APP_MEMORY_SIZE = 31
 
 
 def _make_directory(path):
@@ -215,22 +217,6 @@ class _Context:
         if not success:
             raise RuntimeError("Device id set failed!!!")
 
-    @property
-    def enable_loop_sink(self):
-        return self._context_handle.get_loop_sink_flag()
-
-    @enable_loop_sink.setter
-    def enable_loop_sink(self, enable_loop_sink):
-        self._context_handle.set_loop_sink_flag(enable_loop_sink)
-
-    @property
-    def enable_mem_reuse(self):
-        return self._context_handle.get_enable_mem_reuse()
-
-    @enable_mem_reuse.setter
-    def enable_mem_reuse(self, enable_mem_reuse):
-        self._context_handle.set_enable_mem_reuse(enable_mem_reuse)
-
     @property
     def save_ms_model(self):
         return self._context_handle.get_save_ms_model_flag()
@@ -247,14 +233,6 @@ class _Context:
     def save_ms_model_path(self, save_ms_model_path):
         self._context_handle.set_save_ms_model_path(save_ms_model_path)
 
-    @property
-    def enable_auto_mixed_precision(self):
-        return self._context_handle.get_auto_mixed_precision_flag()
-
-    @enable_auto_mixed_precision.setter
-    def enable_auto_mixed_precision(self, enable_auto_mixed_precision):
-        self._context_handle.set_auto_mixed_precision_flag(enable_auto_mixed_precision)
-
     @property
     def enable_reduce_precision(self):
         return self._context_handle.get_enable_reduce_precision_flag()
@@ -309,29 +287,21 @@ class _Context:
         """Sets whether to save the network class name in the scope."""
         self._thread_local_info.reserve_class_name_in_scope = reserve_class_name_in_scope
 
-    @property
-    def graph_memory_max_size(self):
-        return None
-
-    @graph_memory_max_size.setter
-    def graph_memory_max_size(self, graph_memory_max_size):
-        if check_input_format(graph_memory_max_size):
-            graph_memory_max_size_ = graph_memory_max_size[:-2] + " * 1024 * 1024 * 1024"
-            self._context_handle.set_graph_memory_max_size(graph_memory_max_size_)
-        else:
-            raise ValueError("Context param graph_memory_max_size should be in correct format! Such as \"26GB\"")
-
     @property
     def variable_memory_max_size(self):
         return None
 
     @variable_memory_max_size.setter
     def variable_memory_max_size(self, variable_memory_max_size):
-        if check_input_format(variable_memory_max_size):
-            variable_memory_max_size_ = variable_memory_max_size[:-2] + " * 1024 * 1024 * 1024"
-            self._context_handle.set_variable_memory_max_size(variable_memory_max_size_)
-        else:
+        if not check_input_format(variable_memory_max_size):
             raise ValueError("Context param variable_memory_max_size should be in correct format! Such as \"5GB\"")
+        if int(variable_memory_max_size[:-2]) >= _DEVICE_APP_MEMORY_SIZE:
+            raise ValueError("Context param variable_memory_max_size should be less than 31GB.")
+        variable_memory_max_size_ = variable_memory_max_size[:-2] + " * 1024 * 1024 * 1024"
+        graph_memory_max_size = _DEVICE_APP_MEMORY_SIZE - int(variable_memory_max_size[:-2])
+        graph_memory_max_size_ = str(graph_memory_max_size) + " * 1024 * 1024 * 1024"
+        self._context_handle.set_variable_memory_max_size(variable_memory_max_size_)
+        self._context_handle.set_graph_memory_max_size(graph_memory_max_size_)
 
     @property
     def enable_ge(self):
@@ -469,10 +439,9 @@ def reset_auto_parallel_context():
 
 
 @args_type_check(mode=int, precompile_only=bool, device_target=str, device_id=int, save_graphs=bool,
-                 save_graphs_path=str, enable_loop_sink=bool, enable_mem_reuse=bool, save_ms_model=bool,
-                 save_ms_model_path=str, enable_auto_mixed_precision=bool, enable_dump=bool, save_dump_path=str,
-                 enable_reduce_precision=bool, graph_memory_max_size=str,
-                 variable_memory_max_size=str, enable_profiling=bool, profiling_options=str)
+                 save_graphs_path=str, save_ms_model=bool, save_ms_model_path=str, enable_dump=bool,
+                 save_dump_path=str, enable_reduce_precision=bool, variable_memory_max_size=str,
+                 enable_profiling=bool, profiling_options=str)
 def set_context(**kwargs):
     """
     Sets context for running environment.
@@ -490,8 +459,6 @@ def set_context(**kwargs):
     Note:
         Attribute name is required for setting attributes.
         If need to config graph max memory size and variable max memory size, one must make sure:
-            The sum of graph_memory_max_size and variable_memory_max_size should be less than total memory size of
-            a device, while the total memory is supposed to be no more than 256GB.
 
     Args:
         mode (int): Running in GRAPH_MODE(0) or PYNATIVE_MODE(1). Default: PYNATIVE_MODE.
@@ -499,19 +466,15 @@ def set_context(**kwargs):
         device_id (int): Id of target device, the value must be in [0, device_num_per_host-1],
                     while device_num_per_host should no more than 4096. Default: 0.
         save_graphs (bool): Whether to save graphs. Default: False.
-        enable_loop_sink (bool): Whether to enable loop sink. Default: True.
-        enable_mem_reuse (bool): Whether to enable memory reuse. Default: True.
         save_ms_model (bool): Whether to save lite model converted by graph. Default: False.
         save_ms_model_path (str): Path to save converted lite model. Default: "."
         save_graphs_path (str): Path to save graphs. Default: "."
-        enable_auto_mixed_precision (bool): Whether to enable auto mixed precision. Default: True.
         reserve_class_name_in_scope (bool) : Whether to save the network class name in the scope. Default: True.
         enable_reduce_precision (bool): Whether to enable precision reduction. Default: True.
         enable_dump (bool): Whether to enable dump. Default: False.
         save_dump_path (str): When the program is executed on Ascend, operators can dump data here.
             The root dump path is configured in /home/HwHiAiUser/ide_daemon/ide_daemon.cfg.
             So the real dump path is "{configured root dump path}/{`save_dump_path`}". Default: ".".
-        graph_memory_max_size (str): Sets graph memory max size. Default: "26GB".
         variable_memory_max_size (str): Sets variable memory max size. Default: "5GB".
         enable_profiling (bool): Whether to open profiling. Default: False.
         profiling_options (str): Sets profiling collection options, operators can profiling data here.
@@ -538,12 +501,10 @@ def set_context(**kwargs):
         >>> context.set_context(device_target="Ascend")
         >>> context.set_context(device_id=0)
         >>> context.set_context(save_graphs=True, save_graphs_path="./model.ms")
-        >>> context.set_context(enable_mem_reuse=True)
         >>> context.set_context(enable_reduce_precision=True)
         >>> context.set_context(save_ms_model=True, save_ms_model_path=".")
         >>> context.set_context(enable_dump=True, save_dump_path=".")
         >>> context.set_context(reserve_class_name_in_scope=True)
-        >>> context.set_context(graph_memory_max_size="25GB")
         >>> context.set_context(variable_memory_max_size="6GB")
         >>> context.set_context(mode=context.GRAPH_MODE,
         >>>                     device_target="Ascend",device_id=0, save_graphs=True,
diff --git a/mindspore/train/dataset_helper.py b/mindspore/train/dataset_helper.py
index 0a327fd0e8fe6bd328d66c1395627aef8b1e38f7..52797b631c091e93610ff40a469286b026037851 100644
--- a/mindspore/train/dataset_helper.py
+++ b/mindspore/train/dataset_helper.py
@@ -44,15 +44,18 @@ class DatasetHelper:
     def __init__(self, dataset, dataset_sink_mode=True):
         check_bool(dataset_sink_mode)
 
-        iterclass = _DatasetIterGE
-        if not dataset_sink_mode:
-            iterclass = _DatasetIterFeed
-        elif not context.get_context("enable_ge"):
-            if context.get_context("enable_loop_sink"):
-                iterclass = _DatasetIterMSLoopSink
+        if dataset_sink_mode:
+            if context.get_context("enable_ge"):
+                iterclass = _DatasetIterGE
             else:
-                iterclass = _DatasetIterMS
-
+                if context.get_context("device_target") == "Ascend":
+                    iterclass = _DatasetIterMSLoopSink
+                elif context.get_context("device_target") == "GPU":
+                    iterclass = _DatasetIterMS
+                elif context.get_context("device_target") == "CPU":
+                    raise RuntimeError("Currently dataset sink mode is not supported when the device target is CPU.")
+        else:
+            iterclass = _DatasetIterFeed
         self.iter = iterclass(dataset)
 
     def __iter__(self):
@@ -104,12 +107,12 @@ class _DatasetIter:
             if dataset.get_dataset_size() % loop_size != 0:
                 raise ValueError(f'Dataset size {dataset.get_dataset_size()} and '
                                  f'loop_size {loop_size} are not matched.')
-            loop_count = int(dataset.get_dataset_size()/loop_size)
+            loop_count = int(dataset.get_dataset_size() / loop_size)
         return loop_count
 
 
 class _DatasetIterMSLoopSink(_DatasetIter):
-    """Iter for context (enable_loop_sink=True)"""
+    """Iter for context (device_target=Ascend)"""
     def __init__(self, dataset):
         super(_DatasetIterMSLoopSink, self).__init__(dataset)
         self.loop_count = self.get_loop_count(dataset)
@@ -122,11 +125,12 @@ class _DatasetIterMSLoopSink(_DatasetIter):
 
         def op():
             return tuple()
+
         self.op = op
 
 
 class _DatasetIterMS(_DatasetIter):
-    """Iter for context (enable_loop_sink=False)"""
+    """Iter for context (device_target=GPU)"""
     def __init__(self, dataset):
         super(_DatasetIterMS, self).__init__(dataset)
         self.loop_count = dataset.get_dataset_size()
@@ -149,11 +153,12 @@ class _DatasetIterGE(_DatasetIter):
 
         def op():
             return tensor_list_run
+
         self.op = op
 
 
 class _DatasetIterFeed:
-    """Iter for feed data"""
+    """Iter for normal(non sink) mode, feed the data from host."""
     def __init__(self, dataset):
         self.dataset = dataset
         self.device_num = _get_device_num()
diff --git a/mindspore/train/model.py b/mindspore/train/model.py
index 1017b1daa18e945a58e6a16f9f911621ae6796c3..b4faecbe46e0f1f82a32fd8af5a4ddab9accfde7 100755
--- a/mindspore/train/model.py
+++ b/mindspore/train/model.py
@@ -279,7 +279,7 @@ class Model:
         """
         # remove later to deal with loop sink
         need_wrap = False
-        if not hasattr(train_dataset, '__ME_INITED__') and context.get_context("enable_loop_sink") \
+        if not hasattr(train_dataset, '__ME_INITED__') and context.get_context("device_target") == "Ascend" \
                 and not context.get_context("enable_ge"):
             need_wrap = True
 
@@ -420,9 +420,6 @@ class Model:
         _device_number_check(self._parallel_mode, self._device_number)
         _parameter_broadcast_check(self._parallel_mode, self._parameter_broadcast)
 
-        if context.get_context("device_target") in ["CPU", "GPU"] and context.get_context("enable_loop_sink"):
-            raise ValueError("CPU and GPU can't support loop sink, please set enable_loop_sink=False.")
-
         self._train(epoch,
                     train_dataset,
                     callbacks=callbacks,
@@ -446,7 +443,7 @@ class Model:
 
         # remove later to deal with loop sink
         need_wrap = False
-        if not hasattr(valid_dataset, '__ME_INITED__') and context.get_context("enable_loop_sink") \
+        if not hasattr(valid_dataset, '__ME_INITED__') and context.get_context("device_target") == "Ascend" \
                 and not context.get_context("enable_ge"):
             need_wrap = True
 
diff --git a/tests/st/auto_parallel/onehot_model_parallel.py b/tests/st/auto_parallel/onehot_model_parallel.py
index e0ec25bd29106a4740a9bd163a9fc7bb15eb919d..2931d6f0f89956fe9855826a2d82896188bc9a81 100644
--- a/tests/st/auto_parallel/onehot_model_parallel.py
+++ b/tests/st/auto_parallel/onehot_model_parallel.py
@@ -34,7 +34,6 @@ def setup_module():
     np.random.seed(0)
     context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
     context.set_context(device_id=device_id)
-    context.set_context(enable_loop_sink=False)
     distributedTool.init()
     device_num = distributedTool.get_group_size()
     rank_id = distributedTool.get_rank()
diff --git a/tests/st/auto_parallel/soft_entropy_loss_expand_parallel.py b/tests/st/auto_parallel/soft_entropy_loss_expand_parallel.py
index 73a681d912d62f65bb69fd87b086eb044b4058b6..ee37fc921d6181824c1abf76720f46ed2f01924a 100644
--- a/tests/st/auto_parallel/soft_entropy_loss_expand_parallel.py
+++ b/tests/st/auto_parallel/soft_entropy_loss_expand_parallel.py
@@ -47,7 +47,6 @@ def setup_module():
     np.random.seed(0)
     context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
     context.set_context(device_id=device_id)
-    context.set_context(enable_loop_sink=False)
     distributedTool.init()
     rank_id = distributedTool.get_rank()
     device_num = distributedTool.get_group_size()
diff --git a/tests/st/auto_parallel/test_resnet50_expand_loss_2p.py b/tests/st/auto_parallel/test_resnet50_expand_loss_2p.py
index c168ea5b6d9c2a82e930bf95fba1b185cd029a66..c5c1f8c1fb96577c93892714c0a25abe37e4bf9e 100644
--- a/tests/st/auto_parallel/test_resnet50_expand_loss_2p.py
+++ b/tests/st/auto_parallel/test_resnet50_expand_loss_2p.py
@@ -32,7 +32,6 @@ from mindspore.parallel import set_algo_parameters
 
 context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
 context.set_context(device_id=int(os.getenv('DEVICE_ID')))
-context.set_context(enable_loop_sink=False)
 init()
 context.set_auto_parallel_context(mirror_mean=True, parallel_mode=ParallelMode.AUTO_PARALLEL)
 
diff --git a/tests/st/mem_reuse/resnet_cifar_memreuse.py b/tests/st/mem_reuse/resnet_cifar_memreuse.py
index bfa03524bde4cfeac903baeaa21b46a24fc26c98..c25fb957e2cd1ed95ea7d4d1daaeae0c0ec57222 100644
--- a/tests/st/mem_reuse/resnet_cifar_memreuse.py
+++ b/tests/st/mem_reuse/resnet_cifar_memreuse.py
@@ -54,8 +54,6 @@ data_home = args_opt.dataset_path
 
 context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
 context.set_context(device_id=device_id)
-context.set_context(enable_loop_sink=True)
-context.set_context(enable_mem_reuse=True)
 
 
 def create_dataset(repeat_num=1, training=True):
diff --git a/tests/st/mem_reuse/resnet_cifar_normal.py b/tests/st/mem_reuse/resnet_cifar_normal.py
index 1bdef4c59bf1420199da03d18c072854688711e2..96af53f22531baa16cc9479e9ab450b073006122 100644
--- a/tests/st/mem_reuse/resnet_cifar_normal.py
+++ b/tests/st/mem_reuse/resnet_cifar_normal.py
@@ -54,8 +54,6 @@ data_home = args_opt.dataset_path
 
 context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
 context.set_context(device_id=device_id)
-context.set_context(enable_loop_sink=True)
-context.set_context(enable_mem_reuse=False)
 
 
 def create_dataset(repeat_num=1, training=True):
diff --git a/tests/st/networks/models/bert/bert_tdt_lossscale.py b/tests/st/networks/models/bert/bert_tdt_lossscale.py
index f0d79846110eb0d8bd2dd37163706e0b695cd855..1bd72d0221746514561bbdacc4e1abf7614841b3 100644
--- a/tests/st/networks/models/bert/bert_tdt_lossscale.py
+++ b/tests/st/networks/models/bert/bert_tdt_lossscale.py
@@ -127,8 +127,6 @@ class ModelCallback(Callback):
 def test_bert_tdt():
     """test bert tdt"""
     context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", reserve_class_name_in_scope=False)
-    context.set_context(enable_loop_sink=True)
-    context.set_context(enable_mem_reuse=True)
     ds = me_de_train_dataset()
     version = os.getenv('VERSION', 'large')
     batch_size = int(os.getenv('BATCH_SIZE', '16'))
diff --git a/tests/st/networks/test_gpu_lenet.py b/tests/st/networks/test_gpu_lenet.py
index aaba8e6f93e7343ad1896c382df3d08d9460ec9d..723a45cbff385c029c092b75b20a283867c06308 100644
--- a/tests/st/networks/test_gpu_lenet.py
+++ b/tests/st/networks/test_gpu_lenet.py
@@ -141,7 +141,7 @@ def create_dataset(data_path, batch_size=32, repeat_size=1,
 @pytest.mark.platform_x86_gpu_training
 @pytest.mark.env_onecard
 def test_train_and_eval_lenet():
-    context.set_context(mode=context.GRAPH_MODE, device_target="GPU", enable_mem_reuse=False)
+    context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
     network = LeNet5(10)
     net_loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction="mean")
     net_opt = nn.Momentum(network.trainable_params(), 0.01, 0.9)
diff --git a/tests/st/tbe_networks/export_geir.py b/tests/st/tbe_networks/export_geir.py
index 531c6a95371e9b6dca82713bcc544db6b0a3db02..e305476e390ff6d76367246343bf65df4b4209a4 100644
--- a/tests/st/tbe_networks/export_geir.py
+++ b/tests/st/tbe_networks/export_geir.py
@@ -20,7 +20,7 @@ from mindspore import Tensor
 from mindspore.train.serialization import save, load, _check_filedir_or_create, _chg_model_file_name_if_same_exist, \
     _read_file_last_line, context, export
 
-context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", enable_loop_sink=True)
+context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
 
 
 def test_resnet50_export(batch_size=1, num_classes=5):
diff --git a/tests/st/tbe_networks/resnet_cifar.py b/tests/st/tbe_networks/resnet_cifar.py
index 38fcf42e9e6dee72238d9ee0380b426b85ba6b90..be9e01bd2da8eab4cde1647d57d24dea08beba9c 100644
--- a/tests/st/tbe_networks/resnet_cifar.py
+++ b/tests/st/tbe_networks/resnet_cifar.py
@@ -55,8 +55,6 @@ data_home = args_opt.dataset_path
 
 context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
 context.set_context(device_id=device_id)
-context.set_context(enable_loop_sink=True)
-context.set_context(enable_mem_reuse=True)
 
 
 def create_dataset(repeat_num=1, training=True):
diff --git a/tests/st/tbe_networks/test_resnet_cifar_1p.py b/tests/st/tbe_networks/test_resnet_cifar_1p.py
index b8c86932f6ce11bdb0f2d4cdc509da2ef373daaf..af9f866209ffccebbc76905b0923add7d0167c09 100644
--- a/tests/st/tbe_networks/test_resnet_cifar_1p.py
+++ b/tests/st/tbe_networks/test_resnet_cifar_1p.py
@@ -138,8 +138,6 @@ def train_process(device_id, epoch_size, num_classes, device_num, batch_size):
     os.chdir(str(device_id))
     context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
     context.set_context(device_id=device_id)
-    context.set_context(enable_loop_sink=True)
-    context.set_context(enable_mem_reuse=True)
     context.set_context(mode=context.GRAPH_MODE)
     net = resnet50(batch_size, num_classes)
     loss = CrossEntropyLoss()
@@ -160,8 +158,6 @@ def train_process(device_id, epoch_size, num_classes, device_num, batch_size):
 def eval(batch_size, num_classes):
     context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
     context.set_context(device_id=0)
-    context.set_context(enable_loop_sink=True)
-    context.set_context(enable_mem_reuse=True)
 
     net = resnet50(batch_size, num_classes)
     loss = CrossEntropyLoss()
diff --git a/tests/st/tbe_networks/test_resnet_cifar_8p.py b/tests/st/tbe_networks/test_resnet_cifar_8p.py
index 2bfe7863190f0c848c32e2966f95ef0cd3ec0cea..0033c944061925a84714867dcb32b5deb9a3280a 100644
--- a/tests/st/tbe_networks/test_resnet_cifar_8p.py
+++ b/tests/st/tbe_networks/test_resnet_cifar_8p.py
@@ -148,8 +148,6 @@ def train_process(q, device_id, epoch_size, num_classes, device_num, batch_size,
     context.set_context(mode=context.GRAPH_MODE,
                         device_target="Ascend", save_graphs=False)
     context.set_context(device_id=device_id)
-    context.set_context(enable_loop_sink=True)
-    context.set_context(enable_mem_reuse=True)
     os.environ['MINDSPORE_HCCL_CONFIG_PATH'] = MINDSPORE_HCCL_CONFIG_PATH
     os.environ['RANK_ID'] = str(device_id)
     os.environ['RANK_SIZE'] = str(device_num)
diff --git a/tests/ut/python/parallel/test_auto_parallel_resnet.py b/tests/ut/python/parallel/test_auto_parallel_resnet.py
index 0a05dc6dee78194c0fc9a8230cc174a9ae09fea1..07ad2016a6de8647c4c611718db9c3e0443a80f5 100644
--- a/tests/ut/python/parallel/test_auto_parallel_resnet.py
+++ b/tests/ut/python/parallel/test_auto_parallel_resnet.py
@@ -32,7 +32,6 @@ from mindspore.parallel import _cost_model_context as cost_model_context
 
 context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
 context.set_context(device_id=0)
-context.set_context(enable_loop_sink=False)
 init()
 
 
diff --git a/tests/ut/python/pynative_mode/test_context.py b/tests/ut/python/pynative_mode/test_context.py
index cece0c14514604d376313cfc448ba1a21b2aa533..f02b017e84905c1c3364c6422695e1c50f0bad58 100644
--- a/tests/ut/python/pynative_mode/test_context.py
+++ b/tests/ut/python/pynative_mode/test_context.py
@@ -102,6 +102,21 @@ def test_profiling_options():
     assert context.get_context("profiling_options") == "training_trace:task_trace"
 
 
+def test_variable_memory_max_size():
+    """test_variable_memory_max_size"""
+    with pytest.raises(TypeError):
+        context.set_context(variable_memory_max_size=True)
+    with pytest.raises(TypeError):
+        context.set_context(variable_memory_max_size=1)
+    with pytest.raises(ValueError):
+        context.set_context(variable_memory_max_size="")
+    with pytest.raises(ValueError):
+        context.set_context(variable_memory_max_size="1G")
+    with pytest.raises(ValueError):
+        context.set_context(variable_memory_max_size="31GB")
+    context.set_context(variable_memory_max_size="3GB")
+
+
 def test_set_context():
     """ test_set_context """
     context.set_context(mode=context.GRAPH_MODE, device_target="Ascend",