提交 b0f89685 编写于 作者: F fary86

Add device specific config check

上级 5a63dac0
......@@ -120,7 +120,8 @@ EvalResultPtr BaseFuncGraphEvaluator::Eval(AnalysisEnginePtr engine, const Abstr
engine->IncreaseFunctionCallDepth();
if (engine->function_call_depth() > MsContext::GetInstance()->get_param<uint32_t>(MS_CTX_MAX_CALL_DEPTH)) {
MS_LOG(EXCEPTION) << "Exceed function call depth limit "
<< MsContext::GetInstance()->get_param<uint32_t>(MS_CTX_MAX_CALL_DEPTH) << ".";
<< MsContext::GetInstance()->get_param<uint32_t>(MS_CTX_MAX_CALL_DEPTH)
<< ", please call 'context.set_context(max_call_depth=value)' to adjust this value.";
}
std::vector<AnfNodePtr> nodes = FastShadowSort(func_node);
for (auto it = nodes.crbegin(); it != nodes.crend(); it++) {
......
......@@ -71,40 +71,29 @@ py::object MsCtxGetParameter(const std::shared_ptr<MsContext> &ctx, MsCtxParam p
}
} // namespace
// Note: exported python enum variables begining with '_' are for internal use
REGISTER_PYBIND_DEFINE(MsContextPy, ([](const py::module *m) {
(void)py::enum_<MsCtxParam>(*m, "ms_ctx_param", py::arithmetic())
.value("enable_auto_mixed_precision", MsCtxParam::MS_CTX_ENABLE_AUTO_MIXED_PRECISION)
.value("check_bprop", MsCtxParam::MS_CTX_CHECK_BPROP_FLAG)
.value("enable_dump", MsCtxParam::MS_CTX_ENABLE_DUMP)
.value("enable_dynamic_mem_pool", MsCtxParam::MS_CTX_ENABLE_DYNAMIC_MEM_POOL)
.value("enable_gpu_summary", MsCtxParam::MS_CTX_ENABLE_GPU_SUMMARY)
.value("enable_graph_kernel", MsCtxParam::MS_CTX_ENABLE_GRAPH_KERNEL)
.value("enable_hccl", MsCtxParam::MS_CTX_ENABLE_HCCL)
.value("enable_mem_reuse", MsCtxParam::MS_CTX_ENABLE_MEM_REUSE)
.value("enable_pynative_hook", MsCtxParam::MS_CTX_ENABLE_PYNATIVE_HOOK)
.value("enable_pynative_infer", MsCtxParam::MS_CTX_ENABLE_PYNATIVE_INFER)
.value("enable_reduce_precision", MsCtxParam::MS_CTX_ENABLE_REDUCE_PRECISION)
.value("enable_sparse", MsCtxParam::MS_CTX_ENABLE_SPARSE)
.value("enable_task_sink", MsCtxParam::MS_CTX_ENABLE_TASK_SINK)
.value("ir_fusion_flag", MsCtxParam::MS_CTX_IR_FUSION_FLAG)
.value("is_multi_graph_sink", MsCtxParam::MS_CTX_IS_MULTI_GRAPH_SINK)
.value("is_pynative_ge_init", MsCtxParam::MS_CTX_IS_PYNATIVE_GE_INIT)
.value("precompile_only", MsCtxParam::MS_CTX_PRECOMPILE_ONLY)
.value("enable_profiling", MsCtxParam::MS_CTX_ENABLE_PROFILING)
.value("save_graphs", MsCtxParam::MS_CTX_SAVE_GRAPHS_FLAG)
.value("max_device_memory", MsCtxParam::MS_CTX_MAX_DEVICE_MEMORY)
.value("mode", MsCtxParam::MS_CTX_EXECUTION_MODE)
.value("device_target", MsCtxParam::MS_CTX_DEVICE_TARGET)
.value("graph_memory_max_size", MsCtxParam::MS_CTX_GRAPH_MEMORY_MAX_SIZE)
.value("_graph_memory_max_size", MsCtxParam::MS_CTX_GRAPH_MEMORY_MAX_SIZE)
.value("print_file_path", MsCtxParam::MS_CTX_PRINT_FILE_PATH)
.value("profiling_options", MsCtxParam::MS_CTX_PROFILING_OPTIONS)
.value("save_dump_path", MsCtxParam::MS_CTX_SAVE_DUMP_PATH)
.value("save_graphs_path", MsCtxParam::MS_CTX_SAVE_GRAPHS_PATH)
.value("variable_memory_max_size", MsCtxParam::MS_CTX_VARIABLE_MEMORY_MAX_SIZE)
.value("device_id", MsCtxParam::MS_CTX_DEVICE_ID)
.value("ge_ref", MsCtxParam::MS_CTX_GE_REF)
.value("max_call_depth", MsCtxParam::MS_CTX_MAX_CALL_DEPTH)
.value("tsd_ref", MsCtxParam::MS_CTX_TSD_REF);
.value("max_call_depth", MsCtxParam::MS_CTX_MAX_CALL_DEPTH);
(void)py::class_<mindspore::MsContext, std::shared_ptr<mindspore::MsContext>>(*m, "MSContext")
.def_static("get_instance", &mindspore::MsContext::GetInstance, "Get ms context instance.")
......
......@@ -219,6 +219,7 @@ class _Context:
self.set_param(ms_ctx_param.profiling_options, option)
def set_variable_memory_max_size(self, variable_memory_max_size):
"""set values of variable_memory_max_size and graph_memory_max_size"""
if not _check_input_format(variable_memory_max_size):
raise ValueError("Context param variable_memory_max_size should be in correct format! Such as \"5GB\"")
if int(variable_memory_max_size[:-2]) >= _DEVICE_APP_MEMORY_SIZE:
......@@ -227,7 +228,8 @@ class _Context:
graph_memory_max_size = _DEVICE_APP_MEMORY_SIZE - int(variable_memory_max_size[:-2])
graph_memory_max_size_ = str(graph_memory_max_size) + " * 1024 * 1024 * 1024"
self.set_param(ms_ctx_param.variable_memory_max_size, variable_memory_max_size_)
self.set_param(ms_ctx_param.graph_memory_max_size, graph_memory_max_size_)
# pylint: disable=protected-access
self.set_param(ms_ctx_param._graph_memory_max_size, graph_memory_max_size_)
def set_max_device_memory(self, max_device_memory):
if not _check_input_format(max_device_memory):
......@@ -425,6 +427,26 @@ def reset_auto_parallel_context():
_reset_auto_parallel_context()
def _check_target_specific_cfgs(device, arg_key):
"""Checking whether a config is sutable for a specified device"""
device_cfgs = {
'enable_auto_mixed_precision': ['Ascend'],
'enable_dump': ['Ascend'],
'enable_profiling': ['Ascend'],
'variable_memory_max_size': ['Ascend'],
'max_device_memory': ['GPU']
}
# configs not in map device_cfgs are supposed to be suitable for all devices
if not arg_key in device_cfgs:
return True
supported_devices = device_cfgs[arg_key]
if device in supported_devices:
return True
logger.warning(f"Config '{arg_key}' only supports devices in {supported_devices}, current device is '{device}'"
", ignore it.")
return False
@args_type_check(mode=int, precompile_only=bool, device_target=str, device_id=int, save_graphs=bool,
save_graphs_path=str, enable_dump=bool,
save_dump_path=str, enable_reduce_precision=bool, variable_memory_max_size=str,
......@@ -450,6 +472,26 @@ def set_context(**kwargs):
The mode is not recommended to be changed after net was initilized because the implementations of some
operations are different in graph mode and pynative mode. Default: PYNATIVE_MODE.
Some configurations are device specific, see the bellow table for details:
=========================== =========================== =================
Common(CPU/GPU/Asecend) Ascend GPU
=========================== =========================== =================
check_bprop enable_auto_mixed_precision max_device_memory
device_id enable_dump
device_target enable_profiling
enable_graph_kernel variable_memory_max_size
enable_reduce_precision
enable_sparse
mode
print_file_path
profiling_options
reserve_class_name_in_scope
save_dump_path
save_graphs
save_graphs_path
=========================== =========================== =================
Args:
mode (int): Running in GRAPH_MODE(0) or PYNATIVE_MODE(1).
device_target (str): The target device to run, support "Ascend", "GPU", "CPU". Default: "Ascend".
......@@ -513,14 +555,21 @@ def set_context(**kwargs):
>>> context.set_context(max_call_depth=80)
"""
ctx = _context()
# set device target first
if 'device_target' in kwargs:
ctx.set_device_target(kwargs['device_target'])
device = ctx.get_param(ms_ctx_param.device_target)
for key, value in kwargs.items():
if not _check_target_specific_cfgs(device, key):
continue
if hasattr(ctx, key):
setattr(ctx, key, value)
continue
if key in ctx.setters:
ctx.setters[key](ctx, value)
continue
if key in ms_ctx_param.__members__:
# enum variables begining with '_' are for internal use
if key in ms_ctx_param.__members__ and key[0] != '_':
ctx.set_param(ms_ctx_param.__members__[key], value)
continue
raise ValueError("Set context keyword %s is not recognized!" % key)
......@@ -540,9 +589,12 @@ def get_context(attr_key):
ValueError: If input key is not an attribute in context.
"""
ctx = _context()
device = ctx.get_param(ms_ctx_param.device_target)
_ = _check_target_specific_cfgs(device, attr_key)
if hasattr(ctx, attr_key):
return getattr(ctx, attr_key)
if attr_key in ms_ctx_param.__members__:
# enum variables begining with '_' are for internal use
if attr_key in ms_ctx_param.__members__ and attr_key[0] != '_':
return ctx.get_param(ms_ctx_param.__members__[attr_key])
raise ValueError("Get context keyword %s is not recognized!" % attr_key)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册