提交 25a528ae 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!5736 Add device specific config key checking

Merge pull request !5736 from fary86/add_device_specific_config_check
...@@ -120,7 +120,8 @@ EvalResultPtr BaseFuncGraphEvaluator::Eval(AnalysisEnginePtr engine, const Abstr ...@@ -120,7 +120,8 @@ EvalResultPtr BaseFuncGraphEvaluator::Eval(AnalysisEnginePtr engine, const Abstr
engine->IncreaseFunctionCallDepth(); engine->IncreaseFunctionCallDepth();
if (engine->function_call_depth() > MsContext::GetInstance()->get_param<uint32_t>(MS_CTX_MAX_CALL_DEPTH)) { if (engine->function_call_depth() > MsContext::GetInstance()->get_param<uint32_t>(MS_CTX_MAX_CALL_DEPTH)) {
MS_LOG(EXCEPTION) << "Exceed function call depth limit " MS_LOG(EXCEPTION) << "Exceed function call depth limit "
<< MsContext::GetInstance()->get_param<uint32_t>(MS_CTX_MAX_CALL_DEPTH) << "."; << MsContext::GetInstance()->get_param<uint32_t>(MS_CTX_MAX_CALL_DEPTH)
<< ", please call 'context.set_context(max_call_depth=value)' to adjust this value.";
} }
std::vector<AnfNodePtr> nodes = FastShadowSort(func_node); std::vector<AnfNodePtr> nodes = FastShadowSort(func_node);
for (auto it = nodes.crbegin(); it != nodes.crend(); it++) { for (auto it = nodes.crbegin(); it != nodes.crend(); it++) {
......
...@@ -71,40 +71,29 @@ py::object MsCtxGetParameter(const std::shared_ptr<MsContext> &ctx, MsCtxParam p ...@@ -71,40 +71,29 @@ py::object MsCtxGetParameter(const std::shared_ptr<MsContext> &ctx, MsCtxParam p
} }
} // namespace } // namespace
// Note: exported python enum variables begining with '_' are for internal use
REGISTER_PYBIND_DEFINE(MsContextPy, ([](const py::module *m) { REGISTER_PYBIND_DEFINE(MsContextPy, ([](const py::module *m) {
(void)py::enum_<MsCtxParam>(*m, "ms_ctx_param", py::arithmetic()) (void)py::enum_<MsCtxParam>(*m, "ms_ctx_param", py::arithmetic())
.value("enable_auto_mixed_precision", MsCtxParam::MS_CTX_ENABLE_AUTO_MIXED_PRECISION) .value("enable_auto_mixed_precision", MsCtxParam::MS_CTX_ENABLE_AUTO_MIXED_PRECISION)
.value("check_bprop", MsCtxParam::MS_CTX_CHECK_BPROP_FLAG) .value("check_bprop", MsCtxParam::MS_CTX_CHECK_BPROP_FLAG)
.value("enable_dump", MsCtxParam::MS_CTX_ENABLE_DUMP) .value("enable_dump", MsCtxParam::MS_CTX_ENABLE_DUMP)
.value("enable_dynamic_mem_pool", MsCtxParam::MS_CTX_ENABLE_DYNAMIC_MEM_POOL)
.value("enable_gpu_summary", MsCtxParam::MS_CTX_ENABLE_GPU_SUMMARY)
.value("enable_graph_kernel", MsCtxParam::MS_CTX_ENABLE_GRAPH_KERNEL) .value("enable_graph_kernel", MsCtxParam::MS_CTX_ENABLE_GRAPH_KERNEL)
.value("enable_hccl", MsCtxParam::MS_CTX_ENABLE_HCCL)
.value("enable_mem_reuse", MsCtxParam::MS_CTX_ENABLE_MEM_REUSE)
.value("enable_pynative_hook", MsCtxParam::MS_CTX_ENABLE_PYNATIVE_HOOK)
.value("enable_pynative_infer", MsCtxParam::MS_CTX_ENABLE_PYNATIVE_INFER)
.value("enable_reduce_precision", MsCtxParam::MS_CTX_ENABLE_REDUCE_PRECISION) .value("enable_reduce_precision", MsCtxParam::MS_CTX_ENABLE_REDUCE_PRECISION)
.value("enable_sparse", MsCtxParam::MS_CTX_ENABLE_SPARSE) .value("enable_sparse", MsCtxParam::MS_CTX_ENABLE_SPARSE)
.value("enable_task_sink", MsCtxParam::MS_CTX_ENABLE_TASK_SINK)
.value("ir_fusion_flag", MsCtxParam::MS_CTX_IR_FUSION_FLAG)
.value("is_multi_graph_sink", MsCtxParam::MS_CTX_IS_MULTI_GRAPH_SINK)
.value("is_pynative_ge_init", MsCtxParam::MS_CTX_IS_PYNATIVE_GE_INIT)
.value("precompile_only", MsCtxParam::MS_CTX_PRECOMPILE_ONLY) .value("precompile_only", MsCtxParam::MS_CTX_PRECOMPILE_ONLY)
.value("enable_profiling", MsCtxParam::MS_CTX_ENABLE_PROFILING) .value("enable_profiling", MsCtxParam::MS_CTX_ENABLE_PROFILING)
.value("save_graphs", MsCtxParam::MS_CTX_SAVE_GRAPHS_FLAG) .value("save_graphs", MsCtxParam::MS_CTX_SAVE_GRAPHS_FLAG)
.value("max_device_memory", MsCtxParam::MS_CTX_MAX_DEVICE_MEMORY) .value("max_device_memory", MsCtxParam::MS_CTX_MAX_DEVICE_MEMORY)
.value("mode", MsCtxParam::MS_CTX_EXECUTION_MODE) .value("mode", MsCtxParam::MS_CTX_EXECUTION_MODE)
.value("device_target", MsCtxParam::MS_CTX_DEVICE_TARGET) .value("device_target", MsCtxParam::MS_CTX_DEVICE_TARGET)
.value("graph_memory_max_size", MsCtxParam::MS_CTX_GRAPH_MEMORY_MAX_SIZE) .value("_graph_memory_max_size", MsCtxParam::MS_CTX_GRAPH_MEMORY_MAX_SIZE)
.value("print_file_path", MsCtxParam::MS_CTX_PRINT_FILE_PATH) .value("print_file_path", MsCtxParam::MS_CTX_PRINT_FILE_PATH)
.value("profiling_options", MsCtxParam::MS_CTX_PROFILING_OPTIONS) .value("profiling_options", MsCtxParam::MS_CTX_PROFILING_OPTIONS)
.value("save_dump_path", MsCtxParam::MS_CTX_SAVE_DUMP_PATH) .value("save_dump_path", MsCtxParam::MS_CTX_SAVE_DUMP_PATH)
.value("save_graphs_path", MsCtxParam::MS_CTX_SAVE_GRAPHS_PATH) .value("save_graphs_path", MsCtxParam::MS_CTX_SAVE_GRAPHS_PATH)
.value("variable_memory_max_size", MsCtxParam::MS_CTX_VARIABLE_MEMORY_MAX_SIZE) .value("variable_memory_max_size", MsCtxParam::MS_CTX_VARIABLE_MEMORY_MAX_SIZE)
.value("device_id", MsCtxParam::MS_CTX_DEVICE_ID) .value("device_id", MsCtxParam::MS_CTX_DEVICE_ID)
.value("ge_ref", MsCtxParam::MS_CTX_GE_REF) .value("max_call_depth", MsCtxParam::MS_CTX_MAX_CALL_DEPTH);
.value("max_call_depth", MsCtxParam::MS_CTX_MAX_CALL_DEPTH)
.value("tsd_ref", MsCtxParam::MS_CTX_TSD_REF);
(void)py::class_<mindspore::MsContext, std::shared_ptr<mindspore::MsContext>>(*m, "MSContext") (void)py::class_<mindspore::MsContext, std::shared_ptr<mindspore::MsContext>>(*m, "MSContext")
.def_static("get_instance", &mindspore::MsContext::GetInstance, "Get ms context instance.") .def_static("get_instance", &mindspore::MsContext::GetInstance, "Get ms context instance.")
......
...@@ -221,6 +221,7 @@ class _Context: ...@@ -221,6 +221,7 @@ class _Context:
self.set_param(ms_ctx_param.profiling_options, option) self.set_param(ms_ctx_param.profiling_options, option)
def set_variable_memory_max_size(self, variable_memory_max_size): def set_variable_memory_max_size(self, variable_memory_max_size):
"""set values of variable_memory_max_size and graph_memory_max_size"""
if not _check_input_format(variable_memory_max_size): if not _check_input_format(variable_memory_max_size):
raise ValueError("Context param variable_memory_max_size should be in correct format! Such as \"5GB\"") raise ValueError("Context param variable_memory_max_size should be in correct format! Such as \"5GB\"")
if int(variable_memory_max_size[:-2]) >= _DEVICE_APP_MEMORY_SIZE: if int(variable_memory_max_size[:-2]) >= _DEVICE_APP_MEMORY_SIZE:
...@@ -229,7 +230,8 @@ class _Context: ...@@ -229,7 +230,8 @@ class _Context:
graph_memory_max_size = _DEVICE_APP_MEMORY_SIZE - int(variable_memory_max_size[:-2]) graph_memory_max_size = _DEVICE_APP_MEMORY_SIZE - int(variable_memory_max_size[:-2])
graph_memory_max_size_ = str(graph_memory_max_size) + " * 1024 * 1024 * 1024" graph_memory_max_size_ = str(graph_memory_max_size) + " * 1024 * 1024 * 1024"
self.set_param(ms_ctx_param.variable_memory_max_size, variable_memory_max_size_) self.set_param(ms_ctx_param.variable_memory_max_size, variable_memory_max_size_)
self.set_param(ms_ctx_param.graph_memory_max_size, graph_memory_max_size_) # pylint: disable=protected-access
self.set_param(ms_ctx_param._graph_memory_max_size, graph_memory_max_size_)
def set_max_device_memory(self, max_device_memory): def set_max_device_memory(self, max_device_memory):
if not _check_input_format(max_device_memory): if not _check_input_format(max_device_memory):
...@@ -427,6 +429,26 @@ def reset_auto_parallel_context(): ...@@ -427,6 +429,26 @@ def reset_auto_parallel_context():
_reset_auto_parallel_context() _reset_auto_parallel_context()
def _check_target_specific_cfgs(device, arg_key):
"""Checking whether a config is sutable for a specified device"""
device_cfgs = {
'enable_auto_mixed_precision': ['Ascend'],
'enable_dump': ['Ascend'],
'enable_profiling': ['Ascend'],
'variable_memory_max_size': ['Ascend'],
'max_device_memory': ['GPU']
}
# configs not in map device_cfgs are supposed to be suitable for all devices
if not arg_key in device_cfgs:
return True
supported_devices = device_cfgs[arg_key]
if device in supported_devices:
return True
logger.warning(f"Config '{arg_key}' only supports devices in {supported_devices}, current device is '{device}'"
", ignore it.")
return False
@args_type_check(mode=int, precompile_only=bool, device_target=str, device_id=int, save_graphs=bool, @args_type_check(mode=int, precompile_only=bool, device_target=str, device_id=int, save_graphs=bool,
save_graphs_path=str, enable_dump=bool, save_graphs_path=str, enable_dump=bool,
save_dump_path=str, enable_reduce_precision=bool, variable_memory_max_size=str, save_dump_path=str, enable_reduce_precision=bool, variable_memory_max_size=str,
...@@ -452,6 +474,26 @@ def set_context(**kwargs): ...@@ -452,6 +474,26 @@ def set_context(**kwargs):
The mode is not recommended to be changed after net was initilized because the implementations of some The mode is not recommended to be changed after net was initilized because the implementations of some
operations are different in graph mode and pynative mode. Default: PYNATIVE_MODE. operations are different in graph mode and pynative mode. Default: PYNATIVE_MODE.
Some configurations are device specific, see the bellow table for details:
=========================== =========================== =================
Common(CPU/GPU/Asecend) Ascend GPU
=========================== =========================== =================
check_bprop enable_auto_mixed_precision max_device_memory
device_id enable_dump
device_target enable_profiling
enable_graph_kernel variable_memory_max_size
enable_reduce_precision
enable_sparse
mode
print_file_path
profiling_options
reserve_class_name_in_scope
save_dump_path
save_graphs
save_graphs_path
=========================== =========================== =================
Args: Args:
mode (int): Running in GRAPH_MODE(0) or PYNATIVE_MODE(1). mode (int): Running in GRAPH_MODE(0) or PYNATIVE_MODE(1).
device_target (str): The target device to run, support "Ascend", "GPU", "CPU". Default: "Ascend". device_target (str): The target device to run, support "Ascend", "GPU", "CPU". Default: "Ascend".
...@@ -515,14 +557,21 @@ def set_context(**kwargs): ...@@ -515,14 +557,21 @@ def set_context(**kwargs):
>>> context.set_context(max_call_depth=80) >>> context.set_context(max_call_depth=80)
""" """
ctx = _context() ctx = _context()
# set device target first
if 'device_target' in kwargs:
ctx.set_device_target(kwargs['device_target'])
device = ctx.get_param(ms_ctx_param.device_target)
for key, value in kwargs.items(): for key, value in kwargs.items():
if not _check_target_specific_cfgs(device, key):
continue
if hasattr(ctx, key): if hasattr(ctx, key):
setattr(ctx, key, value) setattr(ctx, key, value)
continue continue
if key in ctx.setters: if key in ctx.setters:
ctx.setters[key](ctx, value) ctx.setters[key](ctx, value)
continue continue
if key in ms_ctx_param.__members__: # enum variables begining with '_' are for internal use
if key in ms_ctx_param.__members__ and key[0] != '_':
ctx.set_param(ms_ctx_param.__members__[key], value) ctx.set_param(ms_ctx_param.__members__[key], value)
continue continue
raise ValueError("Set context keyword %s is not recognized!" % key) raise ValueError("Set context keyword %s is not recognized!" % key)
...@@ -542,9 +591,12 @@ def get_context(attr_key): ...@@ -542,9 +591,12 @@ def get_context(attr_key):
ValueError: If input key is not an attribute in context. ValueError: If input key is not an attribute in context.
""" """
ctx = _context() ctx = _context()
device = ctx.get_param(ms_ctx_param.device_target)
_ = _check_target_specific_cfgs(device, attr_key)
if hasattr(ctx, attr_key): if hasattr(ctx, attr_key):
return getattr(ctx, attr_key) return getattr(ctx, attr_key)
if attr_key in ms_ctx_param.__members__: # enum variables begining with '_' are for internal use
if attr_key in ms_ctx_param.__members__ and attr_key[0] != '_':
return ctx.get_param(ms_ctx_param.__members__[attr_key]) return ctx.get_param(ms_ctx_param.__members__[attr_key])
raise ValueError("Get context keyword %s is not recognized!" % attr_key) raise ValueError("Get context keyword %s is not recognized!" % attr_key)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册