提交 2907cf44 编写于 作者: J jinyaohui

remove some context param

上级 dbac31e7
......@@ -39,7 +39,7 @@ if __name__ == "__main__":
parser.add_argument('--dataset_sink_mode', type=bool, default=False, help='dataset_sink_mode is False or True')
args = parser.parse_args()
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, enable_mem_reuse=False)
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
network = AlexNet(cfg.num_classes)
loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction="mean")
......
......@@ -39,7 +39,7 @@ if __name__ == "__main__":
parser.add_argument('--dataset_sink_mode', type=bool, default=False, help='dataset_sink_mode is False or True')
args = parser.parse_args()
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, enable_mem_reuse=False)
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
network = AlexNet(cfg.num_classes)
loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction="mean")
......
......@@ -46,8 +46,7 @@ This example implements pre-training, fine-tuning and evaluation of [BERT-base](
### Pre-Training
```
usage: run_pretrain.py [--distribute DISTRIBUTE] [--epoch_size N] [----device_num N] [--device_id N]
[--enable_task_sink ENABLE_TASK_SINK] [--enable_loop_sink ENABLE_LOOP_SINK]
[--enable_mem_reuse ENABLE_MEM_REUSE] [--enable_save_ckpt ENABLE_SAVE_CKPT]
[--enable_save_ckpt ENABLE_SAVE_CKPT]
[--enable_lossscale ENABLE_LOSSSCALE] [--do_shuffle DO_SHUFFLE]
[--enable_data_sink ENABLE_DATA_SINK] [--data_sink_steps N] [--checkpoint_path CHECKPOINT_PATH]
[--save_checkpoint_steps N] [--save_checkpoint_num N]
......@@ -58,8 +57,6 @@ options:
--epoch_size epoch size: N, default is 1
--device_num number of used devices: N, default is 1
--device_id device id: N, default is 0
--enable_loop_sink enable loop sink: "true" | "false", default is "true"
--enable_mem_reuse enable memory reuse: "true" | "false", default is "true"
--enable_save_ckpt enable save checkpoint: "true" | "false", default is "true"
--enable_lossscale enable lossscale: "true" | "false", default is "true"
--do_shuffle enable shuffle: "true" | "false", default is "true"
......
......@@ -83,8 +83,7 @@ def test_train():
pytest -s finetune.py::test_train
'''
devid = int(os.getenv('DEVICE_ID'))
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=devid,
enable_mem_reuse=True, enable_task_sink=True)
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=devid)
#BertCLSTrain for classification
#BertNERTrain for sequence labeling
if cfg.task == 'NER':
......
......@@ -50,8 +50,6 @@ do
--epoch_size=$EPOCH_SIZE \
--device_id=$DEVICE_ID \
--device_num=$RANK_SIZE \
--enable_loop_sink="true" \
--enable_mem_reuse="true" \
--enable_save_ckpt="true" \
--enable_lossscale="true" \
--do_shuffle="true" \
......
......@@ -59,8 +59,6 @@ def run_pretrain():
parser.add_argument("--epoch_size", type=int, default="1", help="Epoch size, default is 1.")
parser.add_argument("--device_id", type=int, default=0, help="Device id, default is 0.")
parser.add_argument("--device_num", type=int, default=1, help="Use device nums, default is 1.")
parser.add_argument("--enable_loop_sink", type=str, default="true", help="Enable loop sink, default is true.")
parser.add_argument("--enable_mem_reuse", type=str, default="true", help="Enable mem reuse, default is true.")
parser.add_argument("--enable_save_ckpt", type=str, default="true", help="Enable save checkpoint, default is true.")
parser.add_argument("--enable_lossscale", type=str, default="true", help="Use lossscale or not, default is not.")
parser.add_argument("--do_shuffle", type=str, default="true", help="Enable shuffle for dataset, default is true.")
......@@ -75,8 +73,6 @@ def run_pretrain():
args_opt = parser.parse_args()
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=args_opt.device_id)
context.set_context(enable_loop_sink=(args_opt.enable_loop_sink == "true"),
enable_mem_reuse=(args_opt.enable_mem_reuse == "true"))
context.set_context(reserve_class_name_in_scope=False)
if args_opt.distribute == "true":
......
......@@ -29,8 +29,6 @@ python run_pretrain.py \
--distribute="false" \
--epoch_size=$EPOCH_SIZE \
--device_id=$DEVICE_ID \
--enable_loop_sink="true" \
--enable_mem_reuse="true" \
--enable_save_ckpt="true" \
--enable_lossscale="true" \
--do_shuffle="true" \
......
......@@ -40,7 +40,6 @@ if __name__ == '__main__':
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target)
context.set_context(device_id=args_opt.device_id)
context.set_context(enable_mem_reuse=True)
net = GooGLeNet(num_classes=cfg.num_classes)
opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), 0.01, cfg.momentum,
......
......@@ -70,8 +70,6 @@ if __name__ == '__main__':
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target)
context.set_context(device_id=args_opt.device_id)
context.set_context(enable_loop_sink=True)
context.set_context(enable_mem_reuse=True)
device_num = int(os.environ.get("DEVICE_NUM", 1))
if device_num > 1:
......
......@@ -43,7 +43,7 @@ if __name__ == "__main__":
args = parser.parse_args()
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, enable_mem_reuse=False)
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
network = LeNet5(cfg.num_classes)
net_loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction="mean")
......
......@@ -40,7 +40,7 @@ if __name__ == "__main__":
args = parser.parse_args()
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, enable_mem_reuse=False)
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
network = LeNet5(cfg.num_classes)
net_loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction="mean")
......
......@@ -34,8 +34,6 @@ args_opt = parser.parse_args()
device_id = int(os.getenv('DEVICE_ID'))
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=device_id, save_graphs=False)
context.set_context(enable_loop_sink=True)
context.set_context(enable_mem_reuse=True)
if __name__ == '__main__':
loss = SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction='mean')
......
......@@ -56,8 +56,6 @@ rank_size = int(os.getenv('RANK_SIZE'))
run_distribute = rank_size > 1
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=device_id, save_graphs=False)
context.set_context(enable_loop_sink=True)
context.set_context(enable_mem_reuse=True)
class CrossEntropyWithLabelSmooth(_Loss):
"""
......
......@@ -46,8 +46,6 @@ args_opt = parser.parse_args()
device_id = int(os.getenv('DEVICE_ID'))
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False, device_id=device_id)
context.set_context(enable_loop_sink=True)
context.set_context(enable_mem_reuse=True)
if __name__ == '__main__':
if not args_opt.do_eval and args_opt.run_distribute:
......
......@@ -49,8 +49,6 @@ args_opt = parser.parse_args()
device_id = int(os.getenv('DEVICE_ID'))
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False, device_id=device_id)
context.set_context(enable_loop_sink=True)
context.set_context(enable_mem_reuse=True)
if __name__ == '__main__':
if not args_opt.do_eval and args_opt.run_distribute:
......
......@@ -40,8 +40,6 @@ device_id = int(os.getenv('DEVICE_ID'))
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False)
context.set_context(device_id=device_id)
context.set_context(enable_loop_sink=True)
context.set_context(enable_mem_reuse=True)
if __name__ == '__main__':
if not args_opt.do_eval and args_opt.run_distribute:
......
......@@ -43,8 +43,6 @@ device_id = int(os.getenv('DEVICE_ID'))
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False)
context.set_context(device_id=device_id)
context.set_context(enable_loop_sink=True)
context.set_context(enable_mem_reuse=True)
if __name__ == '__main__':
if not args_opt.do_eval and args_opt.run_distribute:
......
......@@ -37,9 +37,7 @@ args_opt = parser.parse_args()
device_id = int(os.getenv('DEVICE_ID'))
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False)
context.set_context(enable_task_sink=True, device_id=device_id)
context.set_context(enable_loop_sink=True)
context.set_context(enable_mem_reuse=True)
context.set_context(device_id=device_id)
if __name__ == '__main__':
......
......@@ -44,9 +44,7 @@ args_opt = parser.parse_args()
device_id = int(os.getenv('DEVICE_ID'))
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False)
context.set_context(enable_task_sink=True, device_id=device_id)
context.set_context(enable_loop_sink=True)
context.set_context(enable_mem_reuse=True)
context.set_context(device_id=device_id)
if __name__ == '__main__':
if not args_opt.do_eval and args_opt.run_distribute:
......
......@@ -71,7 +71,6 @@ if __name__ == '__main__':
args_opt = parser.parse_args()
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=args_opt.device_id)
context.set_context(enable_loop_sink=True, enable_mem_reuse=True)
config = ConfigSSD()
prefix = "ssd_eval.mindrecord"
......
......@@ -93,7 +93,6 @@ def main():
args_opt = parser.parse_args()
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=args_opt.device_id)
context.set_context(enable_loop_sink=True, enable_mem_reuse=True)
if args_opt.distribute:
device_num = args_opt.device_num
......
......@@ -37,7 +37,6 @@ if __name__ == '__main__':
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target)
context.set_context(device_id=args_opt.device_id)
context.set_context(enable_mem_reuse=True)
net = vgg16(num_classes=cfg.num_classes)
opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), 0.01, cfg.momentum,
......
......@@ -64,8 +64,6 @@ if __name__ == '__main__':
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target)
context.set_context(device_id=args_opt.device_id)
context.set_context(enable_loop_sink=True)
context.set_context(enable_mem_reuse=True)
device_num = int(os.environ.get("DEVICE_NUM", 1))
if device_num > 1:
......
......@@ -82,7 +82,6 @@ if __name__ == '__main__':
args_opt = parser.parse_args()
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=args_opt.device_id)
context.set_context(enable_loop_sink=True, enable_mem_reuse=True)
# It will generate mindrecord file in args_opt.mindrecord_dir,
# and the file name is yolo.mindrecord0, 1, ... file_num.
......
......@@ -84,7 +84,6 @@ def main():
args_opt = parser.parse_args()
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=args_opt.device_id)
context.set_context(enable_loop_sink=True, enable_mem_reuse=True)
if args_opt.distribute:
device_num = args_opt.device_num
context.reset_auto_parallel_context()
......
......@@ -107,6 +107,10 @@ bool Dump::IsConfigValid(const nlohmann::json &dumpSettings) {
}
dump_enable_ = enable;
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
// dump_enable_ is true, close mem reuse
context_ptr->set_enable_mem_reuse(!dump_enable_);
trans_flag_ = trans_flag;
dump_mode_ = mode;
dump_path_ = path;
......
......@@ -117,20 +117,12 @@ PYBIND11_MODULE(_c_expression, m) {
.def("close_tsd", &mindspore::MsContext::CloseTsd, "Close tdt dataset client.")
.def("get_save_graphs_flag", &mindspore::MsContext::save_graphs_flag, "Get whether to save graphs.")
.def("set_save_graphs_flag", &mindspore::MsContext::set_save_graphs_flag, "Set whether to save graphs.")
.def("get_auto_mixed_precision_flag", &mindspore::MsContext::auto_mixed_precision_flag,
"Get whether to enable auto mixed precision.")
.def("set_auto_mixed_precision_flag", &mindspore::MsContext::set_auto_mixed_precision_flag,
"Set whether to enable auto mixed precision.")
.def("get_enable_reduce_precision_flag", &mindspore::MsContext::enable_reduce_precision,
"Get whether to enable reduce precision.")
.def("set_enable_reduce_precision_flag", &mindspore::MsContext::set_enable_reduce_precision,
"Set whether to enable reduce precision.")
.def("get_save_graphs_path", &mindspore::MsContext::save_graphs_path, "Get save graphs path.")
.def("set_save_graphs_path", &mindspore::MsContext::set_save_graphs_path, "Set save graphs path.")
.def("get_loop_sink_flag", &mindspore::MsContext::loop_sink_flag, "Get whether to enable loop sink.")
.def("set_loop_sink_flag", &mindspore::MsContext::set_loop_sink_flag, "Set whether to enable loop sink.")
.def("get_enable_mem_reuse", &mindspore::MsContext::enable_mem_reuse, "Get whether to enable mem reuse.")
.def("set_enable_mem_reuse", &mindspore::MsContext::set_enable_mem_reuse, "Set whether to enable mem reuse.")
.def("get_save_ms_model_flag", &mindspore::MsContext::save_ms_model_flag, "Get whether to save ms model.")
.def("set_save_ms_model_flag", &mindspore::MsContext::set_save_ms_model_flag, "Set whether to save ms model.")
.def("get_save_ms_model_path", &mindspore::MsContext::save_ms_model_path, "Get path to save ms model.")
......
......@@ -91,7 +91,6 @@ class MsContext {
bool ir_fusion_flag() const { return ir_fusion_flag_; }
void set_loop_sink_flag(bool loop_sink_flag) { enable_loop_sink_ = loop_sink_flag; }
bool loop_sink_flag() const { return enable_loop_sink_; }
void set_enable_mem_reuse(bool enable_mem_reuse) { enable_mem_reuse_ = enable_mem_reuse; }
......@@ -106,11 +105,6 @@ class MsContext {
void set_enable_gpu_summary(bool enable_gpu_summary) { enable_gpu_summary_ = enable_gpu_summary; }
bool enable_gpu_summary() const { return enable_gpu_summary_; }
void set_auto_mixed_precision_flag(bool auto_mixed_precision_flag) {
auto_mixed_precision_flag_ = auto_mixed_precision_flag;
}
bool auto_mixed_precision_flag() const { return auto_mixed_precision_flag_; }
void set_enable_reduce_precision(bool flag) { enable_reduce_precision_ = flag; }
bool enable_reduce_precision() const { return enable_reduce_precision_; }
......
......@@ -31,6 +31,8 @@ __all__ = ['GRAPH_MODE', 'PYNATIVE_MODE', 'set_context', 'get_context', 'set_aut
GRAPH_MODE = 0
PYNATIVE_MODE = 1
# The max memory size of graph plus variable.
_DEVICE_APP_MEMORY_SIZE = 31
def _make_directory(path):
......@@ -215,22 +217,6 @@ class _Context:
if not success:
raise RuntimeError("Device id set failed!!!")
@property
def enable_loop_sink(self):
return self._context_handle.get_loop_sink_flag()
@enable_loop_sink.setter
def enable_loop_sink(self, enable_loop_sink):
self._context_handle.set_loop_sink_flag(enable_loop_sink)
@property
def enable_mem_reuse(self):
return self._context_handle.get_enable_mem_reuse()
@enable_mem_reuse.setter
def enable_mem_reuse(self, enable_mem_reuse):
self._context_handle.set_enable_mem_reuse(enable_mem_reuse)
@property
def save_ms_model(self):
return self._context_handle.get_save_ms_model_flag()
......@@ -247,14 +233,6 @@ class _Context:
def save_ms_model_path(self, save_ms_model_path):
self._context_handle.set_save_ms_model_path(save_ms_model_path)
@property
def enable_auto_mixed_precision(self):
return self._context_handle.get_auto_mixed_precision_flag()
@enable_auto_mixed_precision.setter
def enable_auto_mixed_precision(self, enable_auto_mixed_precision):
self._context_handle.set_auto_mixed_precision_flag(enable_auto_mixed_precision)
@property
def enable_reduce_precision(self):
return self._context_handle.get_enable_reduce_precision_flag()
......@@ -309,29 +287,21 @@ class _Context:
"""Sets whether to save the network class name in the scope."""
self._thread_local_info.reserve_class_name_in_scope = reserve_class_name_in_scope
@property
def graph_memory_max_size(self):
return None
@graph_memory_max_size.setter
def graph_memory_max_size(self, graph_memory_max_size):
if check_input_format(graph_memory_max_size):
graph_memory_max_size_ = graph_memory_max_size[:-2] + " * 1024 * 1024 * 1024"
self._context_handle.set_graph_memory_max_size(graph_memory_max_size_)
else:
raise ValueError("Context param graph_memory_max_size should be in correct format! Such as \"26GB\"")
@property
def variable_memory_max_size(self):
return None
@variable_memory_max_size.setter
def variable_memory_max_size(self, variable_memory_max_size):
if check_input_format(variable_memory_max_size):
variable_memory_max_size_ = variable_memory_max_size[:-2] + " * 1024 * 1024 * 1024"
self._context_handle.set_variable_memory_max_size(variable_memory_max_size_)
else:
if not check_input_format(variable_memory_max_size):
raise ValueError("Context param variable_memory_max_size should be in correct format! Such as \"5GB\"")
if int(variable_memory_max_size[:-2]) >= _DEVICE_APP_MEMORY_SIZE:
raise ValueError("Context param variable_memory_max_size should be less than 31GB.")
variable_memory_max_size_ = variable_memory_max_size[:-2] + " * 1024 * 1024 * 1024"
graph_memory_max_size = _DEVICE_APP_MEMORY_SIZE - int(variable_memory_max_size[:-2])
graph_memory_max_size_ = str(graph_memory_max_size) + " * 1024 * 1024 * 1024"
self._context_handle.set_variable_memory_max_size(variable_memory_max_size_)
self._context_handle.set_graph_memory_max_size(graph_memory_max_size_)
@property
def enable_ge(self):
......@@ -469,10 +439,9 @@ def reset_auto_parallel_context():
@args_type_check(mode=int, precompile_only=bool, device_target=str, device_id=int, save_graphs=bool,
save_graphs_path=str, enable_loop_sink=bool, enable_mem_reuse=bool, save_ms_model=bool,
save_ms_model_path=str, enable_auto_mixed_precision=bool, enable_dump=bool, save_dump_path=str,
enable_reduce_precision=bool, graph_memory_max_size=str,
variable_memory_max_size=str, enable_profiling=bool, profiling_options=str)
save_graphs_path=str, save_ms_model=bool, save_ms_model_path=str, enable_dump=bool,
save_dump_path=str, enable_reduce_precision=bool, variable_memory_max_size=str,
enable_profiling=bool, profiling_options=str)
def set_context(**kwargs):
"""
Sets context for running environment.
......@@ -490,8 +459,6 @@ def set_context(**kwargs):
Note:
Attribute name is required for setting attributes.
If need to config graph max memory size and variable max memory size, one must make sure:
The sum of graph_memory_max_size and variable_memory_max_size should be less than total memory size of
a device, while the total memory is supposed to be no more than 256GB.
Args:
mode (int): Running in GRAPH_MODE(0) or PYNATIVE_MODE(1). Default: PYNATIVE_MODE.
......@@ -499,19 +466,15 @@ def set_context(**kwargs):
device_id (int): Id of target device, the value must be in [0, device_num_per_host-1],
while device_num_per_host should no more than 4096. Default: 0.
save_graphs (bool): Whether to save graphs. Default: False.
enable_loop_sink (bool): Whether to enable loop sink. Default: True.
enable_mem_reuse (bool): Whether to enable memory reuse. Default: True.
save_ms_model (bool): Whether to save lite model converted by graph. Default: False.
save_ms_model_path (str): Path to save converted lite model. Default: "."
save_graphs_path (str): Path to save graphs. Default: "."
enable_auto_mixed_precision (bool): Whether to enable auto mixed precision. Default: True.
reserve_class_name_in_scope (bool) : Whether to save the network class name in the scope. Default: True.
enable_reduce_precision (bool): Whether to enable precision reduction. Default: True.
enable_dump (bool): Whether to enable dump. Default: False.
save_dump_path (str): When the program is executed on Ascend, operators can dump data here.
The root dump path is configured in /home/HwHiAiUser/ide_daemon/ide_daemon.cfg.
So the real dump path is "{configured root dump path}/{`save_dump_path`}". Default: ".".
graph_memory_max_size (str): Sets graph memory max size. Default: "26GB".
variable_memory_max_size (str): Sets variable memory max size. Default: "5GB".
enable_profiling (bool): Whether to open profiling. Default: False.
profiling_options (str): Sets profiling collection options, operators can profiling data here.
......@@ -538,12 +501,10 @@ def set_context(**kwargs):
>>> context.set_context(device_target="Ascend")
>>> context.set_context(device_id=0)
>>> context.set_context(save_graphs=True, save_graphs_path="./model.ms")
>>> context.set_context(enable_mem_reuse=True)
>>> context.set_context(enable_reduce_precision=True)
>>> context.set_context(save_ms_model=True, save_ms_model_path=".")
>>> context.set_context(enable_dump=True, save_dump_path=".")
>>> context.set_context(reserve_class_name_in_scope=True)
>>> context.set_context(graph_memory_max_size="25GB")
>>> context.set_context(variable_memory_max_size="6GB")
>>> context.set_context(mode=context.GRAPH_MODE,
>>> device_target="Ascend",device_id=0, save_graphs=True,
......
......@@ -44,15 +44,18 @@ class DatasetHelper:
def __init__(self, dataset, dataset_sink_mode=True):
check_bool(dataset_sink_mode)
iterclass = _DatasetIterGE
if not dataset_sink_mode:
iterclass = _DatasetIterFeed
elif not context.get_context("enable_ge"):
if context.get_context("enable_loop_sink"):
iterclass = _DatasetIterMSLoopSink
if dataset_sink_mode:
if context.get_context("enable_ge"):
iterclass = _DatasetIterGE
else:
iterclass = _DatasetIterMS
if context.get_context("device_target") == "Ascend":
iterclass = _DatasetIterMSLoopSink
elif context.get_context("device_target") == "GPU":
iterclass = _DatasetIterMS
elif context.get_context("device_target") == "CPU":
raise RuntimeError("Currently dataset sink mode is not supported when the device target is CPU.")
else:
iterclass = _DatasetIterFeed
self.iter = iterclass(dataset)
def __iter__(self):
......@@ -104,12 +107,12 @@ class _DatasetIter:
if dataset.get_dataset_size() % loop_size != 0:
raise ValueError(f'Dataset size {dataset.get_dataset_size()} and '
f'loop_size {loop_size} are not matched.')
loop_count = int(dataset.get_dataset_size()/loop_size)
loop_count = int(dataset.get_dataset_size() / loop_size)
return loop_count
class _DatasetIterMSLoopSink(_DatasetIter):
"""Iter for context (enable_loop_sink=True)"""
"""Iter for context (device_target=Ascend)"""
def __init__(self, dataset):
super(_DatasetIterMSLoopSink, self).__init__(dataset)
self.loop_count = self.get_loop_count(dataset)
......@@ -122,11 +125,12 @@ class _DatasetIterMSLoopSink(_DatasetIter):
def op():
return tuple()
self.op = op
class _DatasetIterMS(_DatasetIter):
"""Iter for context (enable_loop_sink=False)"""
"""Iter for context (device_target=GPU)"""
def __init__(self, dataset):
super(_DatasetIterMS, self).__init__(dataset)
self.loop_count = dataset.get_dataset_size()
......@@ -149,11 +153,12 @@ class _DatasetIterGE(_DatasetIter):
def op():
return tensor_list_run
self.op = op
class _DatasetIterFeed:
"""Iter for feed data"""
"""Iter for normal(non sink) mode, feed the data from host."""
def __init__(self, dataset):
self.dataset = dataset
self.device_num = _get_device_num()
......
......@@ -279,7 +279,7 @@ class Model:
"""
# remove later to deal with loop sink
need_wrap = False
if not hasattr(train_dataset, '__ME_INITED__') and context.get_context("enable_loop_sink") \
if not hasattr(train_dataset, '__ME_INITED__') and context.get_context("device_target") == "Ascend" \
and not context.get_context("enable_ge"):
need_wrap = True
......@@ -420,9 +420,6 @@ class Model:
_device_number_check(self._parallel_mode, self._device_number)
_parameter_broadcast_check(self._parallel_mode, self._parameter_broadcast)
if context.get_context("device_target") in ["CPU", "GPU"] and context.get_context("enable_loop_sink"):
raise ValueError("CPU and GPU can't support loop sink, please set enable_loop_sink=False.")
self._train(epoch,
train_dataset,
callbacks=callbacks,
......@@ -446,7 +443,7 @@ class Model:
# remove later to deal with loop sink
need_wrap = False
if not hasattr(valid_dataset, '__ME_INITED__') and context.get_context("enable_loop_sink") \
if not hasattr(valid_dataset, '__ME_INITED__') and context.get_context("device_target") == "Ascend" \
and not context.get_context("enable_ge"):
need_wrap = True
......
......@@ -34,7 +34,6 @@ def setup_module():
np.random.seed(0)
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
context.set_context(device_id=device_id)
context.set_context(enable_loop_sink=False)
distributedTool.init()
device_num = distributedTool.get_group_size()
rank_id = distributedTool.get_rank()
......
......@@ -47,7 +47,6 @@ def setup_module():
np.random.seed(0)
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
context.set_context(device_id=device_id)
context.set_context(enable_loop_sink=False)
distributedTool.init()
rank_id = distributedTool.get_rank()
device_num = distributedTool.get_group_size()
......
......@@ -32,7 +32,6 @@ from mindspore.parallel import set_algo_parameters
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
context.set_context(device_id=int(os.getenv('DEVICE_ID')))
context.set_context(enable_loop_sink=False)
init()
context.set_auto_parallel_context(mirror_mean=True, parallel_mode=ParallelMode.AUTO_PARALLEL)
......
......@@ -54,8 +54,6 @@ data_home = args_opt.dataset_path
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
context.set_context(device_id=device_id)
context.set_context(enable_loop_sink=True)
context.set_context(enable_mem_reuse=True)
def create_dataset(repeat_num=1, training=True):
......
......@@ -54,8 +54,6 @@ data_home = args_opt.dataset_path
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
context.set_context(device_id=device_id)
context.set_context(enable_loop_sink=True)
context.set_context(enable_mem_reuse=False)
def create_dataset(repeat_num=1, training=True):
......
......@@ -127,8 +127,6 @@ class ModelCallback(Callback):
def test_bert_tdt():
"""test bert tdt"""
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", reserve_class_name_in_scope=False)
context.set_context(enable_loop_sink=True)
context.set_context(enable_mem_reuse=True)
ds = me_de_train_dataset()
version = os.getenv('VERSION', 'large')
batch_size = int(os.getenv('BATCH_SIZE', '16'))
......
......@@ -141,7 +141,7 @@ def create_dataset(data_path, batch_size=32, repeat_size=1,
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_train_and_eval_lenet():
context.set_context(mode=context.GRAPH_MODE, device_target="GPU", enable_mem_reuse=False)
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
network = LeNet5(10)
net_loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction="mean")
net_opt = nn.Momentum(network.trainable_params(), 0.01, 0.9)
......
......@@ -20,7 +20,7 @@ from mindspore import Tensor
from mindspore.train.serialization import save, load, _check_filedir_or_create, _chg_model_file_name_if_same_exist, \
_read_file_last_line, context, export
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", enable_loop_sink=True)
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
def test_resnet50_export(batch_size=1, num_classes=5):
......
......@@ -55,8 +55,6 @@ data_home = args_opt.dataset_path
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
context.set_context(device_id=device_id)
context.set_context(enable_loop_sink=True)
context.set_context(enable_mem_reuse=True)
def create_dataset(repeat_num=1, training=True):
......
......@@ -138,8 +138,6 @@ def train_process(device_id, epoch_size, num_classes, device_num, batch_size):
os.chdir(str(device_id))
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
context.set_context(device_id=device_id)
context.set_context(enable_loop_sink=True)
context.set_context(enable_mem_reuse=True)
context.set_context(mode=context.GRAPH_MODE)
net = resnet50(batch_size, num_classes)
loss = CrossEntropyLoss()
......@@ -160,8 +158,6 @@ def train_process(device_id, epoch_size, num_classes, device_num, batch_size):
def eval(batch_size, num_classes):
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
context.set_context(device_id=0)
context.set_context(enable_loop_sink=True)
context.set_context(enable_mem_reuse=True)
net = resnet50(batch_size, num_classes)
loss = CrossEntropyLoss()
......
......@@ -148,8 +148,6 @@ def train_process(q, device_id, epoch_size, num_classes, device_num, batch_size,
context.set_context(mode=context.GRAPH_MODE,
device_target="Ascend", save_graphs=False)
context.set_context(device_id=device_id)
context.set_context(enable_loop_sink=True)
context.set_context(enable_mem_reuse=True)
os.environ['MINDSPORE_HCCL_CONFIG_PATH'] = MINDSPORE_HCCL_CONFIG_PATH
os.environ['RANK_ID'] = str(device_id)
os.environ['RANK_SIZE'] = str(device_num)
......
......@@ -32,7 +32,6 @@ from mindspore.parallel import _cost_model_context as cost_model_context
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
context.set_context(device_id=0)
context.set_context(enable_loop_sink=False)
init()
......
......@@ -102,6 +102,21 @@ def test_profiling_options():
assert context.get_context("profiling_options") == "training_trace:task_trace"
def test_variable_memory_max_size():
"""test_variable_memory_max_size"""
with pytest.raises(TypeError):
context.set_context(variable_memory_max_size=True)
with pytest.raises(TypeError):
context.set_context(variable_memory_max_size=1)
with pytest.raises(ValueError):
context.set_context(variable_memory_max_size="")
with pytest.raises(ValueError):
context.set_context(variable_memory_max_size="1G")
with pytest.raises(ValueError):
context.set_context(variable_memory_max_size="31GB")
context.set_context(variable_memory_max_size="3GB")
def test_set_context():
""" test_set_context """
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend",
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册