diff --git a/mindspore/ccsrc/pipeline/jit/static_analysis/evaluator.cc b/mindspore/ccsrc/pipeline/jit/static_analysis/evaluator.cc index 2584d169589b821a02005db53b7ec57ff664d951..978e4dc5d37c32d14b3aa2b65dde1cba96f35aaf 100644 --- a/mindspore/ccsrc/pipeline/jit/static_analysis/evaluator.cc +++ b/mindspore/ccsrc/pipeline/jit/static_analysis/evaluator.cc @@ -120,7 +120,8 @@ EvalResultPtr BaseFuncGraphEvaluator::Eval(AnalysisEnginePtr engine, const Abstr engine->IncreaseFunctionCallDepth(); if (engine->function_call_depth() > MsContext::GetInstance()->get_param(MS_CTX_MAX_CALL_DEPTH)) { MS_LOG(EXCEPTION) << "Exceed function call depth limit " - << MsContext::GetInstance()->get_param(MS_CTX_MAX_CALL_DEPTH) << "."; + << MsContext::GetInstance()->get_param(MS_CTX_MAX_CALL_DEPTH) + << ", please call 'context.set_context(max_call_depth=value)' to adjust this value."; } std::vector nodes = FastShadowSort(func_node); for (auto it = nodes.crbegin(); it != nodes.crend(); it++) { diff --git a/mindspore/ccsrc/pybind_api/utils/ms_context_py.cc b/mindspore/ccsrc/pybind_api/utils/ms_context_py.cc index fcd9781faccf56a76799ceb7d182dee3e2872547..140569415ea8c4be4cd2c1adf19b258f3ad3fb2f 100644 --- a/mindspore/ccsrc/pybind_api/utils/ms_context_py.cc +++ b/mindspore/ccsrc/pybind_api/utils/ms_context_py.cc @@ -71,40 +71,29 @@ py::object MsCtxGetParameter(const std::shared_ptr &ctx, MsCtxParam p } } // namespace +// Note: exported python enum variables begining with '_' are for internal use REGISTER_PYBIND_DEFINE(MsContextPy, ([](const py::module *m) { (void)py::enum_(*m, "ms_ctx_param", py::arithmetic()) .value("enable_auto_mixed_precision", MsCtxParam::MS_CTX_ENABLE_AUTO_MIXED_PRECISION) .value("check_bprop", MsCtxParam::MS_CTX_CHECK_BPROP_FLAG) .value("enable_dump", MsCtxParam::MS_CTX_ENABLE_DUMP) - .value("enable_dynamic_mem_pool", MsCtxParam::MS_CTX_ENABLE_DYNAMIC_MEM_POOL) - .value("enable_gpu_summary", MsCtxParam::MS_CTX_ENABLE_GPU_SUMMARY) .value("enable_graph_kernel", MsCtxParam::MS_CTX_ENABLE_GRAPH_KERNEL) - .value("enable_hccl", MsCtxParam::MS_CTX_ENABLE_HCCL) - .value("enable_mem_reuse", MsCtxParam::MS_CTX_ENABLE_MEM_REUSE) - .value("enable_pynative_hook", MsCtxParam::MS_CTX_ENABLE_PYNATIVE_HOOK) - .value("enable_pynative_infer", MsCtxParam::MS_CTX_ENABLE_PYNATIVE_INFER) .value("enable_reduce_precision", MsCtxParam::MS_CTX_ENABLE_REDUCE_PRECISION) .value("enable_sparse", MsCtxParam::MS_CTX_ENABLE_SPARSE) - .value("enable_task_sink", MsCtxParam::MS_CTX_ENABLE_TASK_SINK) - .value("ir_fusion_flag", MsCtxParam::MS_CTX_IR_FUSION_FLAG) - .value("is_multi_graph_sink", MsCtxParam::MS_CTX_IS_MULTI_GRAPH_SINK) - .value("is_pynative_ge_init", MsCtxParam::MS_CTX_IS_PYNATIVE_GE_INIT) .value("precompile_only", MsCtxParam::MS_CTX_PRECOMPILE_ONLY) .value("enable_profiling", MsCtxParam::MS_CTX_ENABLE_PROFILING) .value("save_graphs", MsCtxParam::MS_CTX_SAVE_GRAPHS_FLAG) .value("max_device_memory", MsCtxParam::MS_CTX_MAX_DEVICE_MEMORY) .value("mode", MsCtxParam::MS_CTX_EXECUTION_MODE) .value("device_target", MsCtxParam::MS_CTX_DEVICE_TARGET) - .value("graph_memory_max_size", MsCtxParam::MS_CTX_GRAPH_MEMORY_MAX_SIZE) + .value("_graph_memory_max_size", MsCtxParam::MS_CTX_GRAPH_MEMORY_MAX_SIZE) .value("print_file_path", MsCtxParam::MS_CTX_PRINT_FILE_PATH) .value("profiling_options", MsCtxParam::MS_CTX_PROFILING_OPTIONS) .value("save_dump_path", MsCtxParam::MS_CTX_SAVE_DUMP_PATH) .value("save_graphs_path", MsCtxParam::MS_CTX_SAVE_GRAPHS_PATH) .value("variable_memory_max_size", MsCtxParam::MS_CTX_VARIABLE_MEMORY_MAX_SIZE) .value("device_id", MsCtxParam::MS_CTX_DEVICE_ID) - .value("ge_ref", MsCtxParam::MS_CTX_GE_REF) - .value("max_call_depth", MsCtxParam::MS_CTX_MAX_CALL_DEPTH) - .value("tsd_ref", MsCtxParam::MS_CTX_TSD_REF); + .value("max_call_depth", MsCtxParam::MS_CTX_MAX_CALL_DEPTH); (void)py::class_>(*m, "MSContext") .def_static("get_instance", &mindspore::MsContext::GetInstance, "Get ms context instance.") diff --git a/mindspore/context.py b/mindspore/context.py index 9661102d15ef06d1d952172ad71ce2c810f543d3..ebdd79f565789b16d3971e60d8b0ee0dc3e75088 100644 --- a/mindspore/context.py +++ b/mindspore/context.py @@ -221,6 +221,7 @@ class _Context: self.set_param(ms_ctx_param.profiling_options, option) def set_variable_memory_max_size(self, variable_memory_max_size): + """set values of variable_memory_max_size and graph_memory_max_size""" if not _check_input_format(variable_memory_max_size): raise ValueError("Context param variable_memory_max_size should be in correct format! Such as \"5GB\"") if int(variable_memory_max_size[:-2]) >= _DEVICE_APP_MEMORY_SIZE: @@ -229,7 +230,8 @@ class _Context: graph_memory_max_size = _DEVICE_APP_MEMORY_SIZE - int(variable_memory_max_size[:-2]) graph_memory_max_size_ = str(graph_memory_max_size) + " * 1024 * 1024 * 1024" self.set_param(ms_ctx_param.variable_memory_max_size, variable_memory_max_size_) - self.set_param(ms_ctx_param.graph_memory_max_size, graph_memory_max_size_) + # pylint: disable=protected-access + self.set_param(ms_ctx_param._graph_memory_max_size, graph_memory_max_size_) def set_max_device_memory(self, max_device_memory): if not _check_input_format(max_device_memory): @@ -427,6 +429,26 @@ def reset_auto_parallel_context(): _reset_auto_parallel_context() +def _check_target_specific_cfgs(device, arg_key): + """Checking whether a config is sutable for a specified device""" + device_cfgs = { + 'enable_auto_mixed_precision': ['Ascend'], + 'enable_dump': ['Ascend'], + 'enable_profiling': ['Ascend'], + 'variable_memory_max_size': ['Ascend'], + 'max_device_memory': ['GPU'] + } + # configs not in map device_cfgs are supposed to be suitable for all devices + if not arg_key in device_cfgs: + return True + supported_devices = device_cfgs[arg_key] + if device in supported_devices: + return True + logger.warning(f"Config '{arg_key}' only supports devices in {supported_devices}, current device is '{device}'" + ", ignore it.") + return False + + @args_type_check(mode=int, precompile_only=bool, device_target=str, device_id=int, save_graphs=bool, save_graphs_path=str, enable_dump=bool, save_dump_path=str, enable_reduce_precision=bool, variable_memory_max_size=str, @@ -452,6 +474,26 @@ def set_context(**kwargs): The mode is not recommended to be changed after net was initilized because the implementations of some operations are different in graph mode and pynative mode. Default: PYNATIVE_MODE. + Some configurations are device specific, see the bellow table for details: + + =========================== =========================== ================= + Common(CPU/GPU/Asecend) Ascend GPU + =========================== =========================== ================= + check_bprop enable_auto_mixed_precision max_device_memory + device_id enable_dump + device_target enable_profiling + enable_graph_kernel variable_memory_max_size + enable_reduce_precision + enable_sparse + mode + print_file_path + profiling_options + reserve_class_name_in_scope + save_dump_path + save_graphs + save_graphs_path + =========================== =========================== ================= + Args: mode (int): Running in GRAPH_MODE(0) or PYNATIVE_MODE(1). device_target (str): The target device to run, support "Ascend", "GPU", "CPU". Default: "Ascend". @@ -515,14 +557,21 @@ def set_context(**kwargs): >>> context.set_context(max_call_depth=80) """ ctx = _context() + # set device target first + if 'device_target' in kwargs: + ctx.set_device_target(kwargs['device_target']) + device = ctx.get_param(ms_ctx_param.device_target) for key, value in kwargs.items(): + if not _check_target_specific_cfgs(device, key): + continue if hasattr(ctx, key): setattr(ctx, key, value) continue if key in ctx.setters: ctx.setters[key](ctx, value) continue - if key in ms_ctx_param.__members__: + # enum variables begining with '_' are for internal use + if key in ms_ctx_param.__members__ and key[0] != '_': ctx.set_param(ms_ctx_param.__members__[key], value) continue raise ValueError("Set context keyword %s is not recognized!" % key) @@ -542,9 +591,12 @@ def get_context(attr_key): ValueError: If input key is not an attribute in context. """ ctx = _context() + device = ctx.get_param(ms_ctx_param.device_target) + _ = _check_target_specific_cfgs(device, attr_key) if hasattr(ctx, attr_key): return getattr(ctx, attr_key) - if attr_key in ms_ctx_param.__members__: + # enum variables begining with '_' are for internal use + if attr_key in ms_ctx_param.__members__ and attr_key[0] != '_': return ctx.get_param(ms_ctx_param.__members__[attr_key]) raise ValueError("Get context keyword %s is not recognized!" % attr_key)