提交 2907cf44 编写于 作者: J jinyaohui

remove some context param

上级 dbac31e7
...@@ -39,7 +39,7 @@ if __name__ == "__main__": ...@@ -39,7 +39,7 @@ if __name__ == "__main__":
parser.add_argument('--dataset_sink_mode', type=bool, default=False, help='dataset_sink_mode is False or True') parser.add_argument('--dataset_sink_mode', type=bool, default=False, help='dataset_sink_mode is False or True')
args = parser.parse_args() args = parser.parse_args()
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, enable_mem_reuse=False) context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
network = AlexNet(cfg.num_classes) network = AlexNet(cfg.num_classes)
loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction="mean") loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction="mean")
......
...@@ -39,7 +39,7 @@ if __name__ == "__main__": ...@@ -39,7 +39,7 @@ if __name__ == "__main__":
parser.add_argument('--dataset_sink_mode', type=bool, default=False, help='dataset_sink_mode is False or True') parser.add_argument('--dataset_sink_mode', type=bool, default=False, help='dataset_sink_mode is False or True')
args = parser.parse_args() args = parser.parse_args()
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, enable_mem_reuse=False) context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
network = AlexNet(cfg.num_classes) network = AlexNet(cfg.num_classes)
loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction="mean") loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction="mean")
......
...@@ -46,8 +46,7 @@ This example implements pre-training, fine-tuning and evaluation of [BERT-base]( ...@@ -46,8 +46,7 @@ This example implements pre-training, fine-tuning and evaluation of [BERT-base](
### Pre-Training ### Pre-Training
``` ```
usage: run_pretrain.py [--distribute DISTRIBUTE] [--epoch_size N] [----device_num N] [--device_id N] usage: run_pretrain.py [--distribute DISTRIBUTE] [--epoch_size N] [----device_num N] [--device_id N]
[--enable_task_sink ENABLE_TASK_SINK] [--enable_loop_sink ENABLE_LOOP_SINK] [--enable_save_ckpt ENABLE_SAVE_CKPT]
[--enable_mem_reuse ENABLE_MEM_REUSE] [--enable_save_ckpt ENABLE_SAVE_CKPT]
[--enable_lossscale ENABLE_LOSSSCALE] [--do_shuffle DO_SHUFFLE] [--enable_lossscale ENABLE_LOSSSCALE] [--do_shuffle DO_SHUFFLE]
[--enable_data_sink ENABLE_DATA_SINK] [--data_sink_steps N] [--checkpoint_path CHECKPOINT_PATH] [--enable_data_sink ENABLE_DATA_SINK] [--data_sink_steps N] [--checkpoint_path CHECKPOINT_PATH]
[--save_checkpoint_steps N] [--save_checkpoint_num N] [--save_checkpoint_steps N] [--save_checkpoint_num N]
...@@ -58,8 +57,6 @@ options: ...@@ -58,8 +57,6 @@ options:
--epoch_size epoch size: N, default is 1 --epoch_size epoch size: N, default is 1
--device_num number of used devices: N, default is 1 --device_num number of used devices: N, default is 1
--device_id device id: N, default is 0 --device_id device id: N, default is 0
--enable_loop_sink enable loop sink: "true" | "false", default is "true"
--enable_mem_reuse enable memory reuse: "true" | "false", default is "true"
--enable_save_ckpt enable save checkpoint: "true" | "false", default is "true" --enable_save_ckpt enable save checkpoint: "true" | "false", default is "true"
--enable_lossscale enable lossscale: "true" | "false", default is "true" --enable_lossscale enable lossscale: "true" | "false", default is "true"
--do_shuffle enable shuffle: "true" | "false", default is "true" --do_shuffle enable shuffle: "true" | "false", default is "true"
......
...@@ -83,8 +83,7 @@ def test_train(): ...@@ -83,8 +83,7 @@ def test_train():
pytest -s finetune.py::test_train pytest -s finetune.py::test_train
''' '''
devid = int(os.getenv('DEVICE_ID')) devid = int(os.getenv('DEVICE_ID'))
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=devid, context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=devid)
enable_mem_reuse=True, enable_task_sink=True)
#BertCLSTrain for classification #BertCLSTrain for classification
#BertNERTrain for sequence labeling #BertNERTrain for sequence labeling
if cfg.task == 'NER': if cfg.task == 'NER':
......
...@@ -50,8 +50,6 @@ do ...@@ -50,8 +50,6 @@ do
--epoch_size=$EPOCH_SIZE \ --epoch_size=$EPOCH_SIZE \
--device_id=$DEVICE_ID \ --device_id=$DEVICE_ID \
--device_num=$RANK_SIZE \ --device_num=$RANK_SIZE \
--enable_loop_sink="true" \
--enable_mem_reuse="true" \
--enable_save_ckpt="true" \ --enable_save_ckpt="true" \
--enable_lossscale="true" \ --enable_lossscale="true" \
--do_shuffle="true" \ --do_shuffle="true" \
......
...@@ -59,8 +59,6 @@ def run_pretrain(): ...@@ -59,8 +59,6 @@ def run_pretrain():
parser.add_argument("--epoch_size", type=int, default="1", help="Epoch size, default is 1.") parser.add_argument("--epoch_size", type=int, default="1", help="Epoch size, default is 1.")
parser.add_argument("--device_id", type=int, default=0, help="Device id, default is 0.") parser.add_argument("--device_id", type=int, default=0, help="Device id, default is 0.")
parser.add_argument("--device_num", type=int, default=1, help="Use device nums, default is 1.") parser.add_argument("--device_num", type=int, default=1, help="Use device nums, default is 1.")
parser.add_argument("--enable_loop_sink", type=str, default="true", help="Enable loop sink, default is true.")
parser.add_argument("--enable_mem_reuse", type=str, default="true", help="Enable mem reuse, default is true.")
parser.add_argument("--enable_save_ckpt", type=str, default="true", help="Enable save checkpoint, default is true.") parser.add_argument("--enable_save_ckpt", type=str, default="true", help="Enable save checkpoint, default is true.")
parser.add_argument("--enable_lossscale", type=str, default="true", help="Use lossscale or not, default is not.") parser.add_argument("--enable_lossscale", type=str, default="true", help="Use lossscale or not, default is not.")
parser.add_argument("--do_shuffle", type=str, default="true", help="Enable shuffle for dataset, default is true.") parser.add_argument("--do_shuffle", type=str, default="true", help="Enable shuffle for dataset, default is true.")
...@@ -75,8 +73,6 @@ def run_pretrain(): ...@@ -75,8 +73,6 @@ def run_pretrain():
args_opt = parser.parse_args() args_opt = parser.parse_args()
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=args_opt.device_id) context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=args_opt.device_id)
context.set_context(enable_loop_sink=(args_opt.enable_loop_sink == "true"),
enable_mem_reuse=(args_opt.enable_mem_reuse == "true"))
context.set_context(reserve_class_name_in_scope=False) context.set_context(reserve_class_name_in_scope=False)
if args_opt.distribute == "true": if args_opt.distribute == "true":
......
...@@ -29,8 +29,6 @@ python run_pretrain.py \ ...@@ -29,8 +29,6 @@ python run_pretrain.py \
--distribute="false" \ --distribute="false" \
--epoch_size=$EPOCH_SIZE \ --epoch_size=$EPOCH_SIZE \
--device_id=$DEVICE_ID \ --device_id=$DEVICE_ID \
--enable_loop_sink="true" \
--enable_mem_reuse="true" \
--enable_save_ckpt="true" \ --enable_save_ckpt="true" \
--enable_lossscale="true" \ --enable_lossscale="true" \
--do_shuffle="true" \ --do_shuffle="true" \
......
...@@ -40,7 +40,6 @@ if __name__ == '__main__': ...@@ -40,7 +40,6 @@ if __name__ == '__main__':
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target) context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target)
context.set_context(device_id=args_opt.device_id) context.set_context(device_id=args_opt.device_id)
context.set_context(enable_mem_reuse=True)
net = GooGLeNet(num_classes=cfg.num_classes) net = GooGLeNet(num_classes=cfg.num_classes)
opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), 0.01, cfg.momentum, opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), 0.01, cfg.momentum,
......
...@@ -70,8 +70,6 @@ if __name__ == '__main__': ...@@ -70,8 +70,6 @@ if __name__ == '__main__':
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target) context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target)
context.set_context(device_id=args_opt.device_id) context.set_context(device_id=args_opt.device_id)
context.set_context(enable_loop_sink=True)
context.set_context(enable_mem_reuse=True)
device_num = int(os.environ.get("DEVICE_NUM", 1)) device_num = int(os.environ.get("DEVICE_NUM", 1))
if device_num > 1: if device_num > 1:
......
...@@ -43,7 +43,7 @@ if __name__ == "__main__": ...@@ -43,7 +43,7 @@ if __name__ == "__main__":
args = parser.parse_args() args = parser.parse_args()
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, enable_mem_reuse=False) context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
network = LeNet5(cfg.num_classes) network = LeNet5(cfg.num_classes)
net_loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction="mean") net_loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction="mean")
......
...@@ -40,7 +40,7 @@ if __name__ == "__main__": ...@@ -40,7 +40,7 @@ if __name__ == "__main__":
args = parser.parse_args() args = parser.parse_args()
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, enable_mem_reuse=False) context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
network = LeNet5(cfg.num_classes) network = LeNet5(cfg.num_classes)
net_loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction="mean") net_loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction="mean")
......
...@@ -34,8 +34,6 @@ args_opt = parser.parse_args() ...@@ -34,8 +34,6 @@ args_opt = parser.parse_args()
device_id = int(os.getenv('DEVICE_ID')) device_id = int(os.getenv('DEVICE_ID'))
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=device_id, save_graphs=False) context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=device_id, save_graphs=False)
context.set_context(enable_loop_sink=True)
context.set_context(enable_mem_reuse=True)
if __name__ == '__main__': if __name__ == '__main__':
loss = SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction='mean') loss = SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction='mean')
......
...@@ -56,8 +56,6 @@ rank_size = int(os.getenv('RANK_SIZE')) ...@@ -56,8 +56,6 @@ rank_size = int(os.getenv('RANK_SIZE'))
run_distribute = rank_size > 1 run_distribute = rank_size > 1
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=device_id, save_graphs=False) context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=device_id, save_graphs=False)
context.set_context(enable_loop_sink=True)
context.set_context(enable_mem_reuse=True)
class CrossEntropyWithLabelSmooth(_Loss): class CrossEntropyWithLabelSmooth(_Loss):
""" """
......
...@@ -46,8 +46,6 @@ args_opt = parser.parse_args() ...@@ -46,8 +46,6 @@ args_opt = parser.parse_args()
device_id = int(os.getenv('DEVICE_ID')) device_id = int(os.getenv('DEVICE_ID'))
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False, device_id=device_id) context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False, device_id=device_id)
context.set_context(enable_loop_sink=True)
context.set_context(enable_mem_reuse=True)
if __name__ == '__main__': if __name__ == '__main__':
if not args_opt.do_eval and args_opt.run_distribute: if not args_opt.do_eval and args_opt.run_distribute:
......
...@@ -49,8 +49,6 @@ args_opt = parser.parse_args() ...@@ -49,8 +49,6 @@ args_opt = parser.parse_args()
device_id = int(os.getenv('DEVICE_ID')) device_id = int(os.getenv('DEVICE_ID'))
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False, device_id=device_id) context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False, device_id=device_id)
context.set_context(enable_loop_sink=True)
context.set_context(enable_mem_reuse=True)
if __name__ == '__main__': if __name__ == '__main__':
if not args_opt.do_eval and args_opt.run_distribute: if not args_opt.do_eval and args_opt.run_distribute:
......
...@@ -40,8 +40,6 @@ device_id = int(os.getenv('DEVICE_ID')) ...@@ -40,8 +40,6 @@ device_id = int(os.getenv('DEVICE_ID'))
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False) context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False)
context.set_context(device_id=device_id) context.set_context(device_id=device_id)
context.set_context(enable_loop_sink=True)
context.set_context(enable_mem_reuse=True)
if __name__ == '__main__': if __name__ == '__main__':
if not args_opt.do_eval and args_opt.run_distribute: if not args_opt.do_eval and args_opt.run_distribute:
......
...@@ -43,8 +43,6 @@ device_id = int(os.getenv('DEVICE_ID')) ...@@ -43,8 +43,6 @@ device_id = int(os.getenv('DEVICE_ID'))
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False) context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False)
context.set_context(device_id=device_id) context.set_context(device_id=device_id)
context.set_context(enable_loop_sink=True)
context.set_context(enable_mem_reuse=True)
if __name__ == '__main__': if __name__ == '__main__':
if not args_opt.do_eval and args_opt.run_distribute: if not args_opt.do_eval and args_opt.run_distribute:
......
...@@ -37,9 +37,7 @@ args_opt = parser.parse_args() ...@@ -37,9 +37,7 @@ args_opt = parser.parse_args()
device_id = int(os.getenv('DEVICE_ID')) device_id = int(os.getenv('DEVICE_ID'))
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False) context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False)
context.set_context(enable_task_sink=True, device_id=device_id) context.set_context(device_id=device_id)
context.set_context(enable_loop_sink=True)
context.set_context(enable_mem_reuse=True)
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -44,9 +44,7 @@ args_opt = parser.parse_args() ...@@ -44,9 +44,7 @@ args_opt = parser.parse_args()
device_id = int(os.getenv('DEVICE_ID')) device_id = int(os.getenv('DEVICE_ID'))
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False) context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False)
context.set_context(enable_task_sink=True, device_id=device_id) context.set_context(device_id=device_id)
context.set_context(enable_loop_sink=True)
context.set_context(enable_mem_reuse=True)
if __name__ == '__main__': if __name__ == '__main__':
if not args_opt.do_eval and args_opt.run_distribute: if not args_opt.do_eval and args_opt.run_distribute:
......
...@@ -71,7 +71,6 @@ if __name__ == '__main__': ...@@ -71,7 +71,6 @@ if __name__ == '__main__':
args_opt = parser.parse_args() args_opt = parser.parse_args()
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=args_opt.device_id) context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=args_opt.device_id)
context.set_context(enable_loop_sink=True, enable_mem_reuse=True)
config = ConfigSSD() config = ConfigSSD()
prefix = "ssd_eval.mindrecord" prefix = "ssd_eval.mindrecord"
......
...@@ -93,7 +93,6 @@ def main(): ...@@ -93,7 +93,6 @@ def main():
args_opt = parser.parse_args() args_opt = parser.parse_args()
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=args_opt.device_id) context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=args_opt.device_id)
context.set_context(enable_loop_sink=True, enable_mem_reuse=True)
if args_opt.distribute: if args_opt.distribute:
device_num = args_opt.device_num device_num = args_opt.device_num
......
...@@ -37,7 +37,6 @@ if __name__ == '__main__': ...@@ -37,7 +37,6 @@ if __name__ == '__main__':
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target) context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target)
context.set_context(device_id=args_opt.device_id) context.set_context(device_id=args_opt.device_id)
context.set_context(enable_mem_reuse=True)
net = vgg16(num_classes=cfg.num_classes) net = vgg16(num_classes=cfg.num_classes)
opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), 0.01, cfg.momentum, opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), 0.01, cfg.momentum,
......
...@@ -64,8 +64,6 @@ if __name__ == '__main__': ...@@ -64,8 +64,6 @@ if __name__ == '__main__':
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target) context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target)
context.set_context(device_id=args_opt.device_id) context.set_context(device_id=args_opt.device_id)
context.set_context(enable_loop_sink=True)
context.set_context(enable_mem_reuse=True)
device_num = int(os.environ.get("DEVICE_NUM", 1)) device_num = int(os.environ.get("DEVICE_NUM", 1))
if device_num > 1: if device_num > 1:
......
...@@ -82,7 +82,6 @@ if __name__ == '__main__': ...@@ -82,7 +82,6 @@ if __name__ == '__main__':
args_opt = parser.parse_args() args_opt = parser.parse_args()
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=args_opt.device_id) context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=args_opt.device_id)
context.set_context(enable_loop_sink=True, enable_mem_reuse=True)
# It will generate mindrecord file in args_opt.mindrecord_dir, # It will generate mindrecord file in args_opt.mindrecord_dir,
# and the file name is yolo.mindrecord0, 1, ... file_num. # and the file name is yolo.mindrecord0, 1, ... file_num.
......
...@@ -84,7 +84,6 @@ def main(): ...@@ -84,7 +84,6 @@ def main():
args_opt = parser.parse_args() args_opt = parser.parse_args()
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=args_opt.device_id) context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=args_opt.device_id)
context.set_context(enable_loop_sink=True, enable_mem_reuse=True)
if args_opt.distribute: if args_opt.distribute:
device_num = args_opt.device_num device_num = args_opt.device_num
context.reset_auto_parallel_context() context.reset_auto_parallel_context()
......
...@@ -107,6 +107,10 @@ bool Dump::IsConfigValid(const nlohmann::json &dumpSettings) { ...@@ -107,6 +107,10 @@ bool Dump::IsConfigValid(const nlohmann::json &dumpSettings) {
} }
dump_enable_ = enable; dump_enable_ = enable;
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
// dump_enable_ is true, close mem reuse
context_ptr->set_enable_mem_reuse(!dump_enable_);
trans_flag_ = trans_flag; trans_flag_ = trans_flag;
dump_mode_ = mode; dump_mode_ = mode;
dump_path_ = path; dump_path_ = path;
......
...@@ -117,20 +117,12 @@ PYBIND11_MODULE(_c_expression, m) { ...@@ -117,20 +117,12 @@ PYBIND11_MODULE(_c_expression, m) {
.def("close_tsd", &mindspore::MsContext::CloseTsd, "Close tdt dataset client.") .def("close_tsd", &mindspore::MsContext::CloseTsd, "Close tdt dataset client.")
.def("get_save_graphs_flag", &mindspore::MsContext::save_graphs_flag, "Get whether to save graphs.") .def("get_save_graphs_flag", &mindspore::MsContext::save_graphs_flag, "Get whether to save graphs.")
.def("set_save_graphs_flag", &mindspore::MsContext::set_save_graphs_flag, "Set whether to save graphs.") .def("set_save_graphs_flag", &mindspore::MsContext::set_save_graphs_flag, "Set whether to save graphs.")
.def("get_auto_mixed_precision_flag", &mindspore::MsContext::auto_mixed_precision_flag,
"Get whether to enable auto mixed precision.")
.def("set_auto_mixed_precision_flag", &mindspore::MsContext::set_auto_mixed_precision_flag,
"Set whether to enable auto mixed precision.")
.def("get_enable_reduce_precision_flag", &mindspore::MsContext::enable_reduce_precision, .def("get_enable_reduce_precision_flag", &mindspore::MsContext::enable_reduce_precision,
"Get whether to enable reduce precision.") "Get whether to enable reduce precision.")
.def("set_enable_reduce_precision_flag", &mindspore::MsContext::set_enable_reduce_precision, .def("set_enable_reduce_precision_flag", &mindspore::MsContext::set_enable_reduce_precision,
"Set whether to enable reduce precision.") "Set whether to enable reduce precision.")
.def("get_save_graphs_path", &mindspore::MsContext::save_graphs_path, "Get save graphs path.") .def("get_save_graphs_path", &mindspore::MsContext::save_graphs_path, "Get save graphs path.")
.def("set_save_graphs_path", &mindspore::MsContext::set_save_graphs_path, "Set save graphs path.") .def("set_save_graphs_path", &mindspore::MsContext::set_save_graphs_path, "Set save graphs path.")
.def("get_loop_sink_flag", &mindspore::MsContext::loop_sink_flag, "Get whether to enable loop sink.")
.def("set_loop_sink_flag", &mindspore::MsContext::set_loop_sink_flag, "Set whether to enable loop sink.")
.def("get_enable_mem_reuse", &mindspore::MsContext::enable_mem_reuse, "Get whether to enable mem reuse.")
.def("set_enable_mem_reuse", &mindspore::MsContext::set_enable_mem_reuse, "Set whether to enable mem reuse.")
.def("get_save_ms_model_flag", &mindspore::MsContext::save_ms_model_flag, "Get whether to save ms model.") .def("get_save_ms_model_flag", &mindspore::MsContext::save_ms_model_flag, "Get whether to save ms model.")
.def("set_save_ms_model_flag", &mindspore::MsContext::set_save_ms_model_flag, "Set whether to save ms model.") .def("set_save_ms_model_flag", &mindspore::MsContext::set_save_ms_model_flag, "Set whether to save ms model.")
.def("get_save_ms_model_path", &mindspore::MsContext::save_ms_model_path, "Get path to save ms model.") .def("get_save_ms_model_path", &mindspore::MsContext::save_ms_model_path, "Get path to save ms model.")
......
...@@ -91,7 +91,6 @@ class MsContext { ...@@ -91,7 +91,6 @@ class MsContext {
bool ir_fusion_flag() const { return ir_fusion_flag_; } bool ir_fusion_flag() const { return ir_fusion_flag_; }
void set_loop_sink_flag(bool loop_sink_flag) { enable_loop_sink_ = loop_sink_flag; }
bool loop_sink_flag() const { return enable_loop_sink_; } bool loop_sink_flag() const { return enable_loop_sink_; }
void set_enable_mem_reuse(bool enable_mem_reuse) { enable_mem_reuse_ = enable_mem_reuse; } void set_enable_mem_reuse(bool enable_mem_reuse) { enable_mem_reuse_ = enable_mem_reuse; }
...@@ -106,11 +105,6 @@ class MsContext { ...@@ -106,11 +105,6 @@ class MsContext {
void set_enable_gpu_summary(bool enable_gpu_summary) { enable_gpu_summary_ = enable_gpu_summary; } void set_enable_gpu_summary(bool enable_gpu_summary) { enable_gpu_summary_ = enable_gpu_summary; }
bool enable_gpu_summary() const { return enable_gpu_summary_; } bool enable_gpu_summary() const { return enable_gpu_summary_; }
void set_auto_mixed_precision_flag(bool auto_mixed_precision_flag) {
auto_mixed_precision_flag_ = auto_mixed_precision_flag;
}
bool auto_mixed_precision_flag() const { return auto_mixed_precision_flag_; }
void set_enable_reduce_precision(bool flag) { enable_reduce_precision_ = flag; } void set_enable_reduce_precision(bool flag) { enable_reduce_precision_ = flag; }
bool enable_reduce_precision() const { return enable_reduce_precision_; } bool enable_reduce_precision() const { return enable_reduce_precision_; }
......
...@@ -31,6 +31,8 @@ __all__ = ['GRAPH_MODE', 'PYNATIVE_MODE', 'set_context', 'get_context', 'set_aut ...@@ -31,6 +31,8 @@ __all__ = ['GRAPH_MODE', 'PYNATIVE_MODE', 'set_context', 'get_context', 'set_aut
GRAPH_MODE = 0 GRAPH_MODE = 0
PYNATIVE_MODE = 1 PYNATIVE_MODE = 1
# The max memory size of graph plus variable.
_DEVICE_APP_MEMORY_SIZE = 31
def _make_directory(path): def _make_directory(path):
...@@ -215,22 +217,6 @@ class _Context: ...@@ -215,22 +217,6 @@ class _Context:
if not success: if not success:
raise RuntimeError("Device id set failed!!!") raise RuntimeError("Device id set failed!!!")
@property
def enable_loop_sink(self):
return self._context_handle.get_loop_sink_flag()
@enable_loop_sink.setter
def enable_loop_sink(self, enable_loop_sink):
self._context_handle.set_loop_sink_flag(enable_loop_sink)
@property
def enable_mem_reuse(self):
return self._context_handle.get_enable_mem_reuse()
@enable_mem_reuse.setter
def enable_mem_reuse(self, enable_mem_reuse):
self._context_handle.set_enable_mem_reuse(enable_mem_reuse)
@property @property
def save_ms_model(self): def save_ms_model(self):
return self._context_handle.get_save_ms_model_flag() return self._context_handle.get_save_ms_model_flag()
...@@ -247,14 +233,6 @@ class _Context: ...@@ -247,14 +233,6 @@ class _Context:
def save_ms_model_path(self, save_ms_model_path): def save_ms_model_path(self, save_ms_model_path):
self._context_handle.set_save_ms_model_path(save_ms_model_path) self._context_handle.set_save_ms_model_path(save_ms_model_path)
@property
def enable_auto_mixed_precision(self):
return self._context_handle.get_auto_mixed_precision_flag()
@enable_auto_mixed_precision.setter
def enable_auto_mixed_precision(self, enable_auto_mixed_precision):
self._context_handle.set_auto_mixed_precision_flag(enable_auto_mixed_precision)
@property @property
def enable_reduce_precision(self): def enable_reduce_precision(self):
return self._context_handle.get_enable_reduce_precision_flag() return self._context_handle.get_enable_reduce_precision_flag()
...@@ -309,29 +287,21 @@ class _Context: ...@@ -309,29 +287,21 @@ class _Context:
"""Sets whether to save the network class name in the scope.""" """Sets whether to save the network class name in the scope."""
self._thread_local_info.reserve_class_name_in_scope = reserve_class_name_in_scope self._thread_local_info.reserve_class_name_in_scope = reserve_class_name_in_scope
@property
def graph_memory_max_size(self):
return None
@graph_memory_max_size.setter
def graph_memory_max_size(self, graph_memory_max_size):
if check_input_format(graph_memory_max_size):
graph_memory_max_size_ = graph_memory_max_size[:-2] + " * 1024 * 1024 * 1024"
self._context_handle.set_graph_memory_max_size(graph_memory_max_size_)
else:
raise ValueError("Context param graph_memory_max_size should be in correct format! Such as \"26GB\"")
@property @property
def variable_memory_max_size(self): def variable_memory_max_size(self):
return None return None
@variable_memory_max_size.setter @variable_memory_max_size.setter
def variable_memory_max_size(self, variable_memory_max_size): def variable_memory_max_size(self, variable_memory_max_size):
if check_input_format(variable_memory_max_size): if not check_input_format(variable_memory_max_size):
variable_memory_max_size_ = variable_memory_max_size[:-2] + " * 1024 * 1024 * 1024"
self._context_handle.set_variable_memory_max_size(variable_memory_max_size_)
else:
raise ValueError("Context param variable_memory_max_size should be in correct format! Such as \"5GB\"") raise ValueError("Context param variable_memory_max_size should be in correct format! Such as \"5GB\"")
if int(variable_memory_max_size[:-2]) >= _DEVICE_APP_MEMORY_SIZE:
raise ValueError("Context param variable_memory_max_size should be less than 31GB.")
variable_memory_max_size_ = variable_memory_max_size[:-2] + " * 1024 * 1024 * 1024"
graph_memory_max_size = _DEVICE_APP_MEMORY_SIZE - int(variable_memory_max_size[:-2])
graph_memory_max_size_ = str(graph_memory_max_size) + " * 1024 * 1024 * 1024"
self._context_handle.set_variable_memory_max_size(variable_memory_max_size_)
self._context_handle.set_graph_memory_max_size(graph_memory_max_size_)
@property @property
def enable_ge(self): def enable_ge(self):
...@@ -469,10 +439,9 @@ def reset_auto_parallel_context(): ...@@ -469,10 +439,9 @@ def reset_auto_parallel_context():
@args_type_check(mode=int, precompile_only=bool, device_target=str, device_id=int, save_graphs=bool, @args_type_check(mode=int, precompile_only=bool, device_target=str, device_id=int, save_graphs=bool,
save_graphs_path=str, enable_loop_sink=bool, enable_mem_reuse=bool, save_ms_model=bool, save_graphs_path=str, save_ms_model=bool, save_ms_model_path=str, enable_dump=bool,
save_ms_model_path=str, enable_auto_mixed_precision=bool, enable_dump=bool, save_dump_path=str, save_dump_path=str, enable_reduce_precision=bool, variable_memory_max_size=str,
enable_reduce_precision=bool, graph_memory_max_size=str, enable_profiling=bool, profiling_options=str)
variable_memory_max_size=str, enable_profiling=bool, profiling_options=str)
def set_context(**kwargs): def set_context(**kwargs):
""" """
Sets context for running environment. Sets context for running environment.
...@@ -490,8 +459,6 @@ def set_context(**kwargs): ...@@ -490,8 +459,6 @@ def set_context(**kwargs):
Note: Note:
Attribute name is required for setting attributes. Attribute name is required for setting attributes.
If need to config graph max memory size and variable max memory size, one must make sure: If need to config graph max memory size and variable max memory size, one must make sure:
The sum of graph_memory_max_size and variable_memory_max_size should be less than total memory size of
a device, while the total memory is supposed to be no more than 256GB.
Args: Args:
mode (int): Running in GRAPH_MODE(0) or PYNATIVE_MODE(1). Default: PYNATIVE_MODE. mode (int): Running in GRAPH_MODE(0) or PYNATIVE_MODE(1). Default: PYNATIVE_MODE.
...@@ -499,19 +466,15 @@ def set_context(**kwargs): ...@@ -499,19 +466,15 @@ def set_context(**kwargs):
device_id (int): Id of target device, the value must be in [0, device_num_per_host-1], device_id (int): Id of target device, the value must be in [0, device_num_per_host-1],
while device_num_per_host should no more than 4096. Default: 0. while device_num_per_host should no more than 4096. Default: 0.
save_graphs (bool): Whether to save graphs. Default: False. save_graphs (bool): Whether to save graphs. Default: False.
enable_loop_sink (bool): Whether to enable loop sink. Default: True.
enable_mem_reuse (bool): Whether to enable memory reuse. Default: True.
save_ms_model (bool): Whether to save lite model converted by graph. Default: False. save_ms_model (bool): Whether to save lite model converted by graph. Default: False.
save_ms_model_path (str): Path to save converted lite model. Default: "." save_ms_model_path (str): Path to save converted lite model. Default: "."
save_graphs_path (str): Path to save graphs. Default: "." save_graphs_path (str): Path to save graphs. Default: "."
enable_auto_mixed_precision (bool): Whether to enable auto mixed precision. Default: True.
reserve_class_name_in_scope (bool) : Whether to save the network class name in the scope. Default: True. reserve_class_name_in_scope (bool) : Whether to save the network class name in the scope. Default: True.
enable_reduce_precision (bool): Whether to enable precision reduction. Default: True. enable_reduce_precision (bool): Whether to enable precision reduction. Default: True.
enable_dump (bool): Whether to enable dump. Default: False. enable_dump (bool): Whether to enable dump. Default: False.
save_dump_path (str): When the program is executed on Ascend, operators can dump data here. save_dump_path (str): When the program is executed on Ascend, operators can dump data here.
The root dump path is configured in /home/HwHiAiUser/ide_daemon/ide_daemon.cfg. The root dump path is configured in /home/HwHiAiUser/ide_daemon/ide_daemon.cfg.
So the real dump path is "{configured root dump path}/{`save_dump_path`}". Default: ".". So the real dump path is "{configured root dump path}/{`save_dump_path`}". Default: ".".
graph_memory_max_size (str): Sets graph memory max size. Default: "26GB".
variable_memory_max_size (str): Sets variable memory max size. Default: "5GB". variable_memory_max_size (str): Sets variable memory max size. Default: "5GB".
enable_profiling (bool): Whether to open profiling. Default: False. enable_profiling (bool): Whether to open profiling. Default: False.
profiling_options (str): Sets profiling collection options, operators can profiling data here. profiling_options (str): Sets profiling collection options, operators can profiling data here.
...@@ -538,12 +501,10 @@ def set_context(**kwargs): ...@@ -538,12 +501,10 @@ def set_context(**kwargs):
>>> context.set_context(device_target="Ascend") >>> context.set_context(device_target="Ascend")
>>> context.set_context(device_id=0) >>> context.set_context(device_id=0)
>>> context.set_context(save_graphs=True, save_graphs_path="./model.ms") >>> context.set_context(save_graphs=True, save_graphs_path="./model.ms")
>>> context.set_context(enable_mem_reuse=True)
>>> context.set_context(enable_reduce_precision=True) >>> context.set_context(enable_reduce_precision=True)
>>> context.set_context(save_ms_model=True, save_ms_model_path=".") >>> context.set_context(save_ms_model=True, save_ms_model_path=".")
>>> context.set_context(enable_dump=True, save_dump_path=".") >>> context.set_context(enable_dump=True, save_dump_path=".")
>>> context.set_context(reserve_class_name_in_scope=True) >>> context.set_context(reserve_class_name_in_scope=True)
>>> context.set_context(graph_memory_max_size="25GB")
>>> context.set_context(variable_memory_max_size="6GB") >>> context.set_context(variable_memory_max_size="6GB")
>>> context.set_context(mode=context.GRAPH_MODE, >>> context.set_context(mode=context.GRAPH_MODE,
>>> device_target="Ascend",device_id=0, save_graphs=True, >>> device_target="Ascend",device_id=0, save_graphs=True,
......
...@@ -44,15 +44,18 @@ class DatasetHelper: ...@@ -44,15 +44,18 @@ class DatasetHelper:
def __init__(self, dataset, dataset_sink_mode=True): def __init__(self, dataset, dataset_sink_mode=True):
check_bool(dataset_sink_mode) check_bool(dataset_sink_mode)
iterclass = _DatasetIterGE if dataset_sink_mode:
if not dataset_sink_mode: if context.get_context("enable_ge"):
iterclass = _DatasetIterFeed iterclass = _DatasetIterGE
elif not context.get_context("enable_ge"):
if context.get_context("enable_loop_sink"):
iterclass = _DatasetIterMSLoopSink
else: else:
iterclass = _DatasetIterMS if context.get_context("device_target") == "Ascend":
iterclass = _DatasetIterMSLoopSink
elif context.get_context("device_target") == "GPU":
iterclass = _DatasetIterMS
elif context.get_context("device_target") == "CPU":
raise RuntimeError("Currently dataset sink mode is not supported when the device target is CPU.")
else:
iterclass = _DatasetIterFeed
self.iter = iterclass(dataset) self.iter = iterclass(dataset)
def __iter__(self): def __iter__(self):
...@@ -104,12 +107,12 @@ class _DatasetIter: ...@@ -104,12 +107,12 @@ class _DatasetIter:
if dataset.get_dataset_size() % loop_size != 0: if dataset.get_dataset_size() % loop_size != 0:
raise ValueError(f'Dataset size {dataset.get_dataset_size()} and ' raise ValueError(f'Dataset size {dataset.get_dataset_size()} and '
f'loop_size {loop_size} are not matched.') f'loop_size {loop_size} are not matched.')
loop_count = int(dataset.get_dataset_size()/loop_size) loop_count = int(dataset.get_dataset_size() / loop_size)
return loop_count return loop_count
class _DatasetIterMSLoopSink(_DatasetIter): class _DatasetIterMSLoopSink(_DatasetIter):
"""Iter for context (enable_loop_sink=True)""" """Iter for context (device_target=Ascend)"""
def __init__(self, dataset): def __init__(self, dataset):
super(_DatasetIterMSLoopSink, self).__init__(dataset) super(_DatasetIterMSLoopSink, self).__init__(dataset)
self.loop_count = self.get_loop_count(dataset) self.loop_count = self.get_loop_count(dataset)
...@@ -122,11 +125,12 @@ class _DatasetIterMSLoopSink(_DatasetIter): ...@@ -122,11 +125,12 @@ class _DatasetIterMSLoopSink(_DatasetIter):
def op(): def op():
return tuple() return tuple()
self.op = op self.op = op
class _DatasetIterMS(_DatasetIter): class _DatasetIterMS(_DatasetIter):
"""Iter for context (enable_loop_sink=False)""" """Iter for context (device_target=GPU)"""
def __init__(self, dataset): def __init__(self, dataset):
super(_DatasetIterMS, self).__init__(dataset) super(_DatasetIterMS, self).__init__(dataset)
self.loop_count = dataset.get_dataset_size() self.loop_count = dataset.get_dataset_size()
...@@ -149,11 +153,12 @@ class _DatasetIterGE(_DatasetIter): ...@@ -149,11 +153,12 @@ class _DatasetIterGE(_DatasetIter):
def op(): def op():
return tensor_list_run return tensor_list_run
self.op = op self.op = op
class _DatasetIterFeed: class _DatasetIterFeed:
"""Iter for feed data""" """Iter for normal(non sink) mode, feed the data from host."""
def __init__(self, dataset): def __init__(self, dataset):
self.dataset = dataset self.dataset = dataset
self.device_num = _get_device_num() self.device_num = _get_device_num()
......
...@@ -279,7 +279,7 @@ class Model: ...@@ -279,7 +279,7 @@ class Model:
""" """
# remove later to deal with loop sink # remove later to deal with loop sink
need_wrap = False need_wrap = False
if not hasattr(train_dataset, '__ME_INITED__') and context.get_context("enable_loop_sink") \ if not hasattr(train_dataset, '__ME_INITED__') and context.get_context("device_target") == "Ascend" \
and not context.get_context("enable_ge"): and not context.get_context("enable_ge"):
need_wrap = True need_wrap = True
...@@ -420,9 +420,6 @@ class Model: ...@@ -420,9 +420,6 @@ class Model:
_device_number_check(self._parallel_mode, self._device_number) _device_number_check(self._parallel_mode, self._device_number)
_parameter_broadcast_check(self._parallel_mode, self._parameter_broadcast) _parameter_broadcast_check(self._parallel_mode, self._parameter_broadcast)
if context.get_context("device_target") in ["CPU", "GPU"] and context.get_context("enable_loop_sink"):
raise ValueError("CPU and GPU can't support loop sink, please set enable_loop_sink=False.")
self._train(epoch, self._train(epoch,
train_dataset, train_dataset,
callbacks=callbacks, callbacks=callbacks,
...@@ -446,7 +443,7 @@ class Model: ...@@ -446,7 +443,7 @@ class Model:
# remove later to deal with loop sink # remove later to deal with loop sink
need_wrap = False need_wrap = False
if not hasattr(valid_dataset, '__ME_INITED__') and context.get_context("enable_loop_sink") \ if not hasattr(valid_dataset, '__ME_INITED__') and context.get_context("device_target") == "Ascend" \
and not context.get_context("enable_ge"): and not context.get_context("enable_ge"):
need_wrap = True need_wrap = True
......
...@@ -34,7 +34,6 @@ def setup_module(): ...@@ -34,7 +34,6 @@ def setup_module():
np.random.seed(0) np.random.seed(0)
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
context.set_context(device_id=device_id) context.set_context(device_id=device_id)
context.set_context(enable_loop_sink=False)
distributedTool.init() distributedTool.init()
device_num = distributedTool.get_group_size() device_num = distributedTool.get_group_size()
rank_id = distributedTool.get_rank() rank_id = distributedTool.get_rank()
......
...@@ -47,7 +47,6 @@ def setup_module(): ...@@ -47,7 +47,6 @@ def setup_module():
np.random.seed(0) np.random.seed(0)
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
context.set_context(device_id=device_id) context.set_context(device_id=device_id)
context.set_context(enable_loop_sink=False)
distributedTool.init() distributedTool.init()
rank_id = distributedTool.get_rank() rank_id = distributedTool.get_rank()
device_num = distributedTool.get_group_size() device_num = distributedTool.get_group_size()
......
...@@ -32,7 +32,6 @@ from mindspore.parallel import set_algo_parameters ...@@ -32,7 +32,6 @@ from mindspore.parallel import set_algo_parameters
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
context.set_context(device_id=int(os.getenv('DEVICE_ID'))) context.set_context(device_id=int(os.getenv('DEVICE_ID')))
context.set_context(enable_loop_sink=False)
init() init()
context.set_auto_parallel_context(mirror_mean=True, parallel_mode=ParallelMode.AUTO_PARALLEL) context.set_auto_parallel_context(mirror_mean=True, parallel_mode=ParallelMode.AUTO_PARALLEL)
......
...@@ -54,8 +54,6 @@ data_home = args_opt.dataset_path ...@@ -54,8 +54,6 @@ data_home = args_opt.dataset_path
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
context.set_context(device_id=device_id) context.set_context(device_id=device_id)
context.set_context(enable_loop_sink=True)
context.set_context(enable_mem_reuse=True)
def create_dataset(repeat_num=1, training=True): def create_dataset(repeat_num=1, training=True):
......
...@@ -54,8 +54,6 @@ data_home = args_opt.dataset_path ...@@ -54,8 +54,6 @@ data_home = args_opt.dataset_path
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
context.set_context(device_id=device_id) context.set_context(device_id=device_id)
context.set_context(enable_loop_sink=True)
context.set_context(enable_mem_reuse=False)
def create_dataset(repeat_num=1, training=True): def create_dataset(repeat_num=1, training=True):
......
...@@ -127,8 +127,6 @@ class ModelCallback(Callback): ...@@ -127,8 +127,6 @@ class ModelCallback(Callback):
def test_bert_tdt(): def test_bert_tdt():
"""test bert tdt""" """test bert tdt"""
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", reserve_class_name_in_scope=False) context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", reserve_class_name_in_scope=False)
context.set_context(enable_loop_sink=True)
context.set_context(enable_mem_reuse=True)
ds = me_de_train_dataset() ds = me_de_train_dataset()
version = os.getenv('VERSION', 'large') version = os.getenv('VERSION', 'large')
batch_size = int(os.getenv('BATCH_SIZE', '16')) batch_size = int(os.getenv('BATCH_SIZE', '16'))
......
...@@ -141,7 +141,7 @@ def create_dataset(data_path, batch_size=32, repeat_size=1, ...@@ -141,7 +141,7 @@ def create_dataset(data_path, batch_size=32, repeat_size=1,
@pytest.mark.platform_x86_gpu_training @pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard @pytest.mark.env_onecard
def test_train_and_eval_lenet(): def test_train_and_eval_lenet():
context.set_context(mode=context.GRAPH_MODE, device_target="GPU", enable_mem_reuse=False) context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
network = LeNet5(10) network = LeNet5(10)
net_loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction="mean") net_loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction="mean")
net_opt = nn.Momentum(network.trainable_params(), 0.01, 0.9) net_opt = nn.Momentum(network.trainable_params(), 0.01, 0.9)
......
...@@ -20,7 +20,7 @@ from mindspore import Tensor ...@@ -20,7 +20,7 @@ from mindspore import Tensor
from mindspore.train.serialization import save, load, _check_filedir_or_create, _chg_model_file_name_if_same_exist, \ from mindspore.train.serialization import save, load, _check_filedir_or_create, _chg_model_file_name_if_same_exist, \
_read_file_last_line, context, export _read_file_last_line, context, export
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", enable_loop_sink=True) context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
def test_resnet50_export(batch_size=1, num_classes=5): def test_resnet50_export(batch_size=1, num_classes=5):
......
...@@ -55,8 +55,6 @@ data_home = args_opt.dataset_path ...@@ -55,8 +55,6 @@ data_home = args_opt.dataset_path
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
context.set_context(device_id=device_id) context.set_context(device_id=device_id)
context.set_context(enable_loop_sink=True)
context.set_context(enable_mem_reuse=True)
def create_dataset(repeat_num=1, training=True): def create_dataset(repeat_num=1, training=True):
......
...@@ -138,8 +138,6 @@ def train_process(device_id, epoch_size, num_classes, device_num, batch_size): ...@@ -138,8 +138,6 @@ def train_process(device_id, epoch_size, num_classes, device_num, batch_size):
os.chdir(str(device_id)) os.chdir(str(device_id))
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
context.set_context(device_id=device_id) context.set_context(device_id=device_id)
context.set_context(enable_loop_sink=True)
context.set_context(enable_mem_reuse=True)
context.set_context(mode=context.GRAPH_MODE) context.set_context(mode=context.GRAPH_MODE)
net = resnet50(batch_size, num_classes) net = resnet50(batch_size, num_classes)
loss = CrossEntropyLoss() loss = CrossEntropyLoss()
...@@ -160,8 +158,6 @@ def train_process(device_id, epoch_size, num_classes, device_num, batch_size): ...@@ -160,8 +158,6 @@ def train_process(device_id, epoch_size, num_classes, device_num, batch_size):
def eval(batch_size, num_classes): def eval(batch_size, num_classes):
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
context.set_context(device_id=0) context.set_context(device_id=0)
context.set_context(enable_loop_sink=True)
context.set_context(enable_mem_reuse=True)
net = resnet50(batch_size, num_classes) net = resnet50(batch_size, num_classes)
loss = CrossEntropyLoss() loss = CrossEntropyLoss()
......
...@@ -148,8 +148,6 @@ def train_process(q, device_id, epoch_size, num_classes, device_num, batch_size, ...@@ -148,8 +148,6 @@ def train_process(q, device_id, epoch_size, num_classes, device_num, batch_size,
context.set_context(mode=context.GRAPH_MODE, context.set_context(mode=context.GRAPH_MODE,
device_target="Ascend", save_graphs=False) device_target="Ascend", save_graphs=False)
context.set_context(device_id=device_id) context.set_context(device_id=device_id)
context.set_context(enable_loop_sink=True)
context.set_context(enable_mem_reuse=True)
os.environ['MINDSPORE_HCCL_CONFIG_PATH'] = MINDSPORE_HCCL_CONFIG_PATH os.environ['MINDSPORE_HCCL_CONFIG_PATH'] = MINDSPORE_HCCL_CONFIG_PATH
os.environ['RANK_ID'] = str(device_id) os.environ['RANK_ID'] = str(device_id)
os.environ['RANK_SIZE'] = str(device_num) os.environ['RANK_SIZE'] = str(device_num)
......
...@@ -32,7 +32,6 @@ from mindspore.parallel import _cost_model_context as cost_model_context ...@@ -32,7 +32,6 @@ from mindspore.parallel import _cost_model_context as cost_model_context
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
context.set_context(device_id=0) context.set_context(device_id=0)
context.set_context(enable_loop_sink=False)
init() init()
......
...@@ -102,6 +102,21 @@ def test_profiling_options(): ...@@ -102,6 +102,21 @@ def test_profiling_options():
assert context.get_context("profiling_options") == "training_trace:task_trace" assert context.get_context("profiling_options") == "training_trace:task_trace"
def test_variable_memory_max_size():
"""test_variable_memory_max_size"""
with pytest.raises(TypeError):
context.set_context(variable_memory_max_size=True)
with pytest.raises(TypeError):
context.set_context(variable_memory_max_size=1)
with pytest.raises(ValueError):
context.set_context(variable_memory_max_size="")
with pytest.raises(ValueError):
context.set_context(variable_memory_max_size="1G")
with pytest.raises(ValueError):
context.set_context(variable_memory_max_size="31GB")
context.set_context(variable_memory_max_size="3GB")
def test_set_context(): def test_set_context():
""" test_set_context """ """ test_set_context """
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", context.set_context(mode=context.GRAPH_MODE, device_target="Ascend",
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册