提交 0a858b38 编写于 作者: F fary86

Simplify ms_context implementation

上级 d5e02cf4
......@@ -50,55 +50,6 @@ using ParallelContext = mindspore::parallel::ParallelContext;
using CostModelContext = mindspore::parallel::CostModelContext;
using mindspore::MsCtxParam;
namespace mindspore {
void MsCtxSetParameter(std::shared_ptr<MsContext> ctx, MsCtxParam param, const py::object &value) {
MS_LOG(DEBUG) << "set param(" << param << ") with value '" << py::str(value) << "' of type '"
<< py::str(value.get_type()) << "'.";
if (param >= MS_CTX_TYPE_BOOL_BEGIN && param < MS_CTX_TYPE_BOOL_END && py::isinstance<py::bool_>(value)) {
ctx->set_param<bool>(param, value.cast<bool>());
return;
}
if (param >= MS_CTX_TYPE_INT_BEGIN && param < MS_CTX_TYPE_INT_END && py::isinstance<py::int_>(value)) {
ctx->set_param<int>(param, value.cast<int>());
return;
}
if (param >= MS_CTX_TYPE_UINT32_BEGIN && param < MS_CTX_TYPE_UINT32_END && py::isinstance<py::int_>(value)) {
ctx->set_param<uint32_t>(param, value.cast<uint32_t>());
return;
}
if (param >= MS_CTX_TYPE_FLOAT_BEGIN && param < MS_CTX_TYPE_FLOAT_END && py::isinstance<py::float_>(value)) {
ctx->set_param<float>(param, value.cast<float>());
return;
}
if (param >= MS_CTX_TYPE_STRING_BEGIN && param < MS_CTX_TYPE_STRING_END && py::isinstance<py::str>(value)) {
ctx->set_param<std::string>(param, value.cast<std::string>());
return;
}
MS_LOG(EXCEPTION) << "Got illegal param " << param << " and value with type " << py::str(value.get_type());
}
py::object MsCtxGetParameter(const std::shared_ptr<MsContext> &ctx, MsCtxParam param) {
if (param >= MS_CTX_TYPE_BOOL_BEGIN && param < MS_CTX_TYPE_BOOL_END) {
return py::bool_(ctx->get_param<bool>(param));
}
if (param >= MS_CTX_TYPE_INT_BEGIN && param < MS_CTX_TYPE_INT_END) {
return py::int_(ctx->get_param<int>(param));
}
if (param >= MS_CTX_TYPE_UINT32_BEGIN && param < MS_CTX_TYPE_UINT32_END) {
return py::int_(ctx->get_param<uint32_t>(param));
}
if (param >= MS_CTX_TYPE_FLOAT_BEGIN && param < MS_CTX_TYPE_FLOAT_END) {
return py::float_(ctx->get_param<float>(param));
}
if (param >= MS_CTX_TYPE_STRING_BEGIN && param < MS_CTX_TYPE_STRING_END) {
return py::str(ctx->get_param<std::string>(param));
}
MS_LOG(EXCEPTION) << "Got illegal param " << param << ".";
}
} // namespace mindspore
// Interface with python
PYBIND11_MODULE(_c_expression, m) {
m.doc() = "MindSpore c plugin";
......@@ -151,49 +102,6 @@ PYBIND11_MODULE(_c_expression, m) {
(void)m.def("export_graph", &mindspore::pipeline::ExportGraph, "Export Graph.");
(void)m.def("ms_ctx_get_param", &mindspore::MsCtxGetParameter, "Get value of specified paramter.");
(void)m.def("ms_ctx_set_param", &mindspore::MsCtxSetParameter, "Set value for specified paramter.");
(void)py::enum_<MsCtxParam>(*m, "ms_ctx_param", py::arithmetic())
.value("auto_mixed_precision_flag", MsCtxParam::MS_CTX_AUTO_MIXED_PRECISION_FLAG)
.value("check_bprop_flag", MsCtxParam::MS_CTX_CHECK_BPROP_FLAG)
.value("enable_dump", MsCtxParam::MS_CTX_ENABLE_DUMP)
.value("enable_dynamic_mem_pool", MsCtxParam::MS_CTX_ENABLE_DYNAMIC_MEM_POOL)
.value("enable_gpu_summary", MsCtxParam::MS_CTX_ENABLE_GPU_SUMMARY)
.value("enable_graph_kernel", MsCtxParam::MS_CTX_ENABLE_GRAPH_KERNEL)
.value("enable_hccl", MsCtxParam::MS_CTX_ENABLE_HCCL)
.value("enable_loop_sink", MsCtxParam::MS_CTX_ENABLE_LOOP_SINK)
.value("enable_mem_reuse", MsCtxParam::MS_CTX_ENABLE_MEM_REUSE)
.value("enable_pynative_hook", MsCtxParam::MS_CTX_ENABLE_PYNATIVE_HOOK)
.value("enable_pynative_infer", MsCtxParam::MS_CTX_ENABLE_PYNATIVE_INFER)
.value("enable_reduce_precision", MsCtxParam::MS_CTX_ENABLE_REDUCE_PRECISION)
.value("enable_sparse", MsCtxParam::MS_CTX_ENABLE_SPARSE)
.value("enable_task_sink", MsCtxParam::MS_CTX_ENABLE_TASK_SINK)
.value("ir_fusion_flag", MsCtxParam::MS_CTX_IR_FUSION_FLAG)
.value("is_multi_graph_sink", MsCtxParam::MS_CTX_IS_MULTI_GRAPH_SINK)
.value("is_pynative_ge_init", MsCtxParam::MS_CTX_IS_PYNATIVE_GE_INIT)
.value("precompile_only", MsCtxParam::MS_CTX_PRECOMPILE_ONLY)
.value("enable_profiling", MsCtxParam::MS_CTX_ENABLE_PROFILING)
.value("save_graphs_flag", MsCtxParam::MS_CTX_SAVE_GRAPHS_FLAG)
.value("max_device_memory", MsCtxParam::MS_CTX_MAX_DEVICE_MEMORY)
.value("execution_mode", MsCtxParam::MS_CTX_EXECUTION_MODE)
.value("device_target", MsCtxParam::MS_CTX_DEVICE_TARGET)
.value("graph_memory_max_size", MsCtxParam::MS_CTX_GRAPH_MEMORY_MAX_SIZE)
.value("print_file_path", MsCtxParam::MS_CTX_PRINT_FILE_PATH)
.value("profiling_options", MsCtxParam::MS_CTX_PROFILING_OPTIONS)
.value("save_dump_path", MsCtxParam::MS_CTX_SAVE_DUMP_PATH)
.value("save_graphs_path", MsCtxParam::MS_CTX_SAVE_GRAPHS_PATH)
.value("variable_memory_max_size", MsCtxParam::MS_CTX_VARIABLE_MEMORY_MAX_SIZE)
.value("device_id", MsCtxParam::MS_CTX_DEVICE_ID)
.value("ge_ref", MsCtxParam::MS_CTX_GE_REF)
.value("max_call_depth", MsCtxParam::MS_CTX_MAX_CALL_DEPTH)
.value("tsd_ref", MsCtxParam::MS_CTX_TSD_REF);
(void)py::class_<mindspore::MsContext, std::shared_ptr<mindspore::MsContext>>(m, "MSContext")
.def_static("get_instance", &mindspore::MsContext::GetInstance, "Get ms context instance.")
.def("get_backend_policy", &mindspore::MsContext::backend_policy, "Get backend policy.")
.def("set_backend_policy", &mindspore::MsContext::set_backend_policy, "Set backend policy.");
(void)py::class_<mindspore::MpiConfig, std::shared_ptr<mindspore::MpiConfig>>(m, "MpiConfig")
.def_static("get_instance", &mindspore::MpiConfig::GetInstance, "Get mpi config instance.")
.def("get_enable_mpi", &mindspore::MpiConfig::enable_mpi, "Get whether enable mpi.")
......
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <memory>
#include <string>
#include "utils/ms_context.h"
#include "utils/log_adapter.h"
#include "pybind_api/api_register.h"
namespace mindspore {
namespace {
void MsCtxSetParameter(std::shared_ptr<MsContext> ctx, MsCtxParam param, const py::object &value) {
MS_LOG(DEBUG) << "set param(" << param << ") with value '" << py::str(value).cast<std::string>() << "' of type '"
<< py::str(value.get_type()).cast<std::string>() << "'.";
if (param >= MS_CTX_TYPE_BOOL_BEGIN && param < MS_CTX_TYPE_BOOL_END && py::isinstance<py::bool_>(value)) {
ctx->set_param<bool>(param, value.cast<bool>());
return;
}
if (param >= MS_CTX_TYPE_INT_BEGIN && param < MS_CTX_TYPE_INT_END && py::isinstance<py::int_>(value)) {
ctx->set_param<int>(param, value.cast<int>());
return;
}
if (param >= MS_CTX_TYPE_UINT32_BEGIN && param < MS_CTX_TYPE_UINT32_END && py::isinstance<py::int_>(value)) {
ctx->set_param<uint32_t>(param, value.cast<uint32_t>());
return;
}
if (param >= MS_CTX_TYPE_FLOAT_BEGIN && param < MS_CTX_TYPE_FLOAT_END && py::isinstance<py::float_>(value)) {
ctx->set_param<float>(param, value.cast<float>());
return;
}
if (param >= MS_CTX_TYPE_STRING_BEGIN && param < MS_CTX_TYPE_STRING_END && py::isinstance<py::str>(value)) {
ctx->set_param<std::string>(param, value.cast<std::string>());
return;
}
MS_LOG(EXCEPTION) << "Got illegal param " << param << " and value with type "
<< py::str(value.get_type()).cast<std::string>();
}
py::object MsCtxGetParameter(const std::shared_ptr<MsContext> &ctx, MsCtxParam param) {
if (param >= MS_CTX_TYPE_BOOL_BEGIN && param < MS_CTX_TYPE_BOOL_END) {
return py::bool_(ctx->get_param<bool>(param));
}
if (param >= MS_CTX_TYPE_INT_BEGIN && param < MS_CTX_TYPE_INT_END) {
return py::int_(ctx->get_param<int>(param));
}
if (param >= MS_CTX_TYPE_UINT32_BEGIN && param < MS_CTX_TYPE_UINT32_END) {
return py::int_(ctx->get_param<uint32_t>(param));
}
if (param >= MS_CTX_TYPE_FLOAT_BEGIN && param < MS_CTX_TYPE_FLOAT_END) {
return py::float_(ctx->get_param<float>(param));
}
if (param >= MS_CTX_TYPE_STRING_BEGIN && param < MS_CTX_TYPE_STRING_END) {
return py::str(ctx->get_param<std::string>(param));
}
MS_LOG(EXCEPTION) << "Got illegal param " << param << ".";
}
} // namespace
REGISTER_PYBIND_DEFINE(MsContextPy, ([](const py::module *m) {
(void)py::enum_<MsCtxParam>(*m, "ms_ctx_param", py::arithmetic())
.value("enable_auto_mixed_precision", MsCtxParam::MS_CTX_ENABLE_AUTO_MIXED_PRECISION)
.value("check_bprop", MsCtxParam::MS_CTX_CHECK_BPROP_FLAG)
.value("enable_dump", MsCtxParam::MS_CTX_ENABLE_DUMP)
.value("enable_dynamic_mem_pool", MsCtxParam::MS_CTX_ENABLE_DYNAMIC_MEM_POOL)
.value("enable_gpu_summary", MsCtxParam::MS_CTX_ENABLE_GPU_SUMMARY)
.value("enable_graph_kernel", MsCtxParam::MS_CTX_ENABLE_GRAPH_KERNEL)
.value("enable_hccl", MsCtxParam::MS_CTX_ENABLE_HCCL)
.value("enable_loop_sink", MsCtxParam::MS_CTX_ENABLE_LOOP_SINK)
.value("enable_mem_reuse", MsCtxParam::MS_CTX_ENABLE_MEM_REUSE)
.value("enable_pynative_hook", MsCtxParam::MS_CTX_ENABLE_PYNATIVE_HOOK)
.value("enable_pynative_infer", MsCtxParam::MS_CTX_ENABLE_PYNATIVE_INFER)
.value("enable_reduce_precision", MsCtxParam::MS_CTX_ENABLE_REDUCE_PRECISION)
.value("enable_sparse", MsCtxParam::MS_CTX_ENABLE_SPARSE)
.value("enable_task_sink", MsCtxParam::MS_CTX_ENABLE_TASK_SINK)
.value("ir_fusion_flag", MsCtxParam::MS_CTX_IR_FUSION_FLAG)
.value("is_multi_graph_sink", MsCtxParam::MS_CTX_IS_MULTI_GRAPH_SINK)
.value("is_pynative_ge_init", MsCtxParam::MS_CTX_IS_PYNATIVE_GE_INIT)
.value("precompile_only", MsCtxParam::MS_CTX_PRECOMPILE_ONLY)
.value("enable_profiling", MsCtxParam::MS_CTX_ENABLE_PROFILING)
.value("save_graphs", MsCtxParam::MS_CTX_SAVE_GRAPHS_FLAG)
.value("max_device_memory", MsCtxParam::MS_CTX_MAX_DEVICE_MEMORY)
.value("mode", MsCtxParam::MS_CTX_EXECUTION_MODE)
.value("device_target", MsCtxParam::MS_CTX_DEVICE_TARGET)
.value("graph_memory_max_size", MsCtxParam::MS_CTX_GRAPH_MEMORY_MAX_SIZE)
.value("print_file_path", MsCtxParam::MS_CTX_PRINT_FILE_PATH)
.value("profiling_options", MsCtxParam::MS_CTX_PROFILING_OPTIONS)
.value("save_dump_path", MsCtxParam::MS_CTX_SAVE_DUMP_PATH)
.value("save_graphs_path", MsCtxParam::MS_CTX_SAVE_GRAPHS_PATH)
.value("variable_memory_max_size", MsCtxParam::MS_CTX_VARIABLE_MEMORY_MAX_SIZE)
.value("device_id", MsCtxParam::MS_CTX_DEVICE_ID)
.value("ge_ref", MsCtxParam::MS_CTX_GE_REF)
.value("max_call_depth", MsCtxParam::MS_CTX_MAX_CALL_DEPTH)
.value("tsd_ref", MsCtxParam::MS_CTX_TSD_REF);
(void)py::class_<mindspore::MsContext, std::shared_ptr<mindspore::MsContext>>(*m, "MSContext")
.def_static("get_instance", &mindspore::MsContext::GetInstance, "Get ms context instance.")
.def("get_param", &mindspore::MsCtxGetParameter, "Get value of specified paramter.")
.def("set_param", &mindspore::MsCtxSetParameter, "Set value for specified paramter.")
.def("get_backend_policy", &mindspore::MsContext::backend_policy, "Get backend policy.")
.def("set_backend_policy", &mindspore::MsContext::set_backend_policy, "Set backend policy.");
}));
} // namespace mindspore
......@@ -225,7 +225,7 @@ void GetGeOptions(const std::shared_ptr<MsContext> &ms_context_ptr, std::map<std
}
// Enable auto mixed precision according to the context options
if (ms_context_ptr->get_param<bool>(MS_CTX_AUTO_MIXED_PRECISION_FLAG)) {
if (ms_context_ptr->get_param<bool>(MS_CTX_ENABLE_AUTO_MIXED_PRECISION)) {
(*ge_options)["ge.exec.precision_mode"] = "allow_mix_precision";
} else {
(*ge_options)["ge.exec.precision_mode"] = "allow_fp32_to_fp16";
......@@ -337,7 +337,7 @@ bool FinalizeGe(const std::shared_ptr<MsContext> &ms_context_ptr, bool force) {
if (ge::GEFinalize() != ge::GRAPH_SUCCESS) {
MS_LOG(WARNING) << "Finalize GE failed!";
}
ms_context_ptr->set_pynative_ge_init(false);
ms_context_ptr->set_param<bool>(MS_CTX_IS_PYNATIVE_GE_INIT, false);
} else {
MS_LOG(INFO) << "Ge is used, no need to finalize, tsd reference = "
<< ms_context_ptr->get_param<uint32_t>(MS_CTX_GE_REF) << ".";
......
......@@ -22,7 +22,7 @@ import threading
from collections import namedtuple
from types import FunctionType
from mindspore import log as logger
from mindspore._c_expression import MSContext, ms_ctx_param, ms_ctx_get_param, ms_ctx_set_param
from mindspore._c_expression import MSContext, ms_ctx_param
from mindspore._checkparam import args_type_check
from mindspore.parallel._auto_parallel_context import _set_auto_parallel_context, _get_auto_parallel_context, \
_reset_auto_parallel_context
......@@ -158,17 +158,12 @@ class _Context:
return value
def get_param(self, param):
return ms_ctx_get_param(self._context_handle, param)
return self._context_handle.get_param(param)
def set_param(self, param, value):
ms_ctx_set_param(self._context_handle, param, value)
self._context_handle.set_param(param, value)
@property
def mode(self):
return self.get_param(ms_ctx_param.execution_mode)
@mode.setter
def mode(self, mode):
def set_mode(self, mode):
"""
Switch between Graph mode and PyNative mode.
......@@ -185,43 +180,17 @@ class _Context:
self._context_switches.push(False, None)
else:
raise ValueError(f'The execution mode {mode} is invalid!')
self.set_param(ms_ctx_param.execution_mode, mode)
self.set_param(ms_ctx_param.mode, mode)
def set_backend_policy(self, policy):
success = self._context_handle.set_backend_policy(policy)
if not success:
raise RuntimeError("Backend policy must be one of ge, vm, ms.")
@property
def precompile_only(self):
return self.get_param(ms_ctx_param.precompile_only)
@precompile_only.setter
def precompile_only(self, precompile_only):
self.set_param(ms_ctx_param.precompile_only, precompile_only)
@property
def save_graphs(self):
return self.get_param(ms_ctx_param.save_graphs_flag)
@save_graphs.setter
def save_graphs(self, save_graphs_flag):
self.set_param(ms_ctx_param.save_graphs_flag, save_graphs_flag)
@property
def save_graphs_path(self):
return self.get_param(ms_ctx_param.save_graphs_path)
@save_graphs_path.setter
def save_graphs_path(self, save_graphs_path):
def set_save_graphs_path(self, save_graphs_path):
self.set_param(ms_ctx_param.save_graphs_path, _make_directory(save_graphs_path))
@property
def device_target(self):
return self.get_param(ms_ctx_param.device_target)
@device_target.setter
def device_target(self, target):
def set_device_target(self, target):
valid_targets = ["CPU", "GPU", "Ascend", "Davinci"]
if not target in valid_targets:
raise ValueError(f"Target device name {target} is invalid! It must be one of {valid_targets}")
......@@ -231,72 +200,17 @@ class _Context:
if self.enable_debug_runtime and target == "CPU":
self.set_backend_policy("vm")
@property
def device_id(self):
return self.get_param(ms_ctx_param.device_id)
@device_id.setter
def device_id(self, device_id):
def set_device_id(self, device_id):
if device_id < 0 or device_id > 4095:
raise ValueError(f"Device id must be in [0, 4095], but got {device_id}")
self.set_param(ms_ctx_param.device_id, device_id)
@property
def max_call_depth(self):
return self.get_param(ms_ctx_param.max_call_depth)
@max_call_depth.setter
def max_call_depth(self, max_call_depth):
def set_max_call_depth(self, max_call_depth):
if max_call_depth <= 0:
raise ValueError(f"Max call depth must be greater than 0, but got {max_call_depth}")
self.set_param(ms_ctx_param.max_call_depth, max_call_depth)
@property
def enable_auto_mixed_precision(self):
return self.get_param(ms_ctx_param.auto_mixed_precision_flag)
@enable_auto_mixed_precision.setter
def enable_auto_mixed_precision(self, enable_auto_mixed_precision):
self.set_param(ms_ctx_param.auto_mixed_precision_flag, enable_auto_mixed_precision)
@property
def enable_reduce_precision(self):
return self.get_param(ms_ctx_param.enable_reduce_precision_flag)
@enable_reduce_precision.setter
def enable_reduce_precision(self, enable_reduce_precision):
self.set_param(ms_ctx_param.enable_reduce_precision_flag, enable_reduce_precision)
@property
def enable_dump(self):
return self.get_param(ms_ctx_param.enable_dump)
@enable_dump.setter
def enable_dump(self, enable_dump):
self.set_param(ms_ctx_param.enable_dump, enable_dump)
@property
def save_dump_path(self):
return self.get_param(ms_ctx_param.save_dump_path)
@save_dump_path.setter
def save_dump_path(self, save_dump_path):
self.set_param(ms_ctx_param.save_dump_path, save_dump_path)
@property
def enable_profiling(self):
return self.get_param(ms_ctx_param.enable_profiling)
@enable_profiling.setter
def enable_profiling(self, flag):
self.set_param(ms_ctx_param.enable_profiling, flag)
@property
def profiling_options(self):
return self.get_param(ms_ctx_param.profiling_options)
@profiling_options.setter
def profiling_options(self, option):
def set_profiling_options(self, option):
options = ["training_trace", "task_trace",
"task_trace:training_trace", "training_trace:task_trace", "op_trace"]
if option not in options:
......@@ -304,30 +218,7 @@ class _Context:
"'task_trace:training_trace' 'training_trace:task_trace' or 'op_trace'.")
self.set_param(ms_ctx_param.profiling_options, option)
@property
def enable_graph_kernel(self):
return self.get_param(ms_ctx_param.enable_graph_kernel)
@enable_graph_kernel.setter
def enable_graph_kernel(self, graph_kernel_switch_):
self.set_param(ms_ctx_param.enable_graph_kernel, graph_kernel_switch_)
@property
def reserve_class_name_in_scope(self):
"""Gets whether to save the network class name in the scope."""
return self._thread_local_info.reserve_class_name_in_scope
@reserve_class_name_in_scope.setter
def reserve_class_name_in_scope(self, reserve_class_name_in_scope):
"""Sets whether to save the network class name in the scope."""
self._thread_local_info.reserve_class_name_in_scope = reserve_class_name_in_scope
@property
def variable_memory_max_size(self):
return None
@variable_memory_max_size.setter
def variable_memory_max_size(self, variable_memory_max_size):
def set_variable_memory_max_size(self, variable_memory_max_size):
if not check_input_format(variable_memory_max_size):
raise ValueError("Context param variable_memory_max_size should be in correct format! Such as \"5GB\"")
if int(variable_memory_max_size[:-2]) >= _DEVICE_APP_MEMORY_SIZE:
......@@ -338,33 +229,7 @@ class _Context:
self.set_param(ms_ctx_param.variable_memory_max_size, variable_memory_max_size_)
self.set_param(ms_ctx_param.graph_memory_max_size, graph_memory_max_size_)
@property
def enable_ge(self):
return self._context_handle.get_backend_policy() == 'ge'
@property
def enable_debug_runtime(self):
return self._thread_local_info.debug_runtime
@enable_debug_runtime.setter
def enable_debug_runtime(self, enable):
thread_info = self._thread_local_info
thread_info.debug_runtime = enable
@property
def check_bprop(self):
return self.get_param(ms_ctx_param.check_bprop_flag)
@check_bprop.setter
def check_bprop(self, check_bprop_flag):
self.set_param(ms_ctx_param.check_bprop_flag, check_bprop_flag)
@property
def max_device_memory(self):
return self.get_param(ms_ctx_param.max_device_memory)
@max_device_memory.setter
def max_device_memory(self, max_device_memory):
def set_max_device_memory(self, max_device_memory):
if not check_input_format(max_device_memory):
raise ValueError("Context param max_device_memory should be in correct format! Such as \"3.5GB\"")
max_device_memory_value = float(max_device_memory[:-2])
......@@ -372,12 +237,7 @@ class _Context:
raise ValueError("Context param max_device_memory should be in correct format! Such as \"3.5GB\"")
self.set_param(ms_ctx_param.max_device_memory, max_device_memory_value)
@property
def print_file_path(self):
return None
@print_file_path.setter
def print_file_path(self, file_path):
def set_print_file_path(self, file_path):
"""Add timestamp suffix to file name. Sets print file path."""
print_file_path = os.path.realpath(file_path)
if os.path.isdir(print_file_path):
......@@ -392,13 +252,42 @@ class _Context:
full_file_name = print_file_path
self.set_param(ms_ctx_param.print_file_path, full_file_name)
setters = {
'mode': set_mode,
'backend_policy': set_backend_policy,
'save_graphs_path': set_save_graphs_path,
'device_target': set_device_target,
'device_id': set_device_id,
'max_call_depth': set_max_call_depth,
'profiling_options': set_profiling_options,
'variable_memory_max_size': set_variable_memory_max_size,
'max_device_memory': set_max_device_memory,
'print_file_path': set_print_file_path
}
@property
def enable_sparse(self):
return self.get_param(ms_ctx_param.enable_sparse)
def reserve_class_name_in_scope(self):
"""Gets whether to save the network class name in the scope."""
return self._thread_local_info.reserve_class_name_in_scope
@reserve_class_name_in_scope.setter
def reserve_class_name_in_scope(self, reserve_class_name_in_scope):
"""Sets whether to save the network class name in the scope."""
self._thread_local_info.reserve_class_name_in_scope = reserve_class_name_in_scope
@property
def enable_ge(self):
return self._context_handle.get_backend_policy() == 'ge'
@property
def enable_debug_runtime(self):
return self._thread_local_info.debug_runtime
@enable_debug_runtime.setter
def enable_debug_runtime(self, enable):
thread_info = self._thread_local_info
thread_info.debug_runtime = enable
@enable_sparse.setter
def enable_sparse(self, enable_sparse):
self.set_param(ms_ctx_param.enable_sparse, enable_sparse)
def check_input_format(x):
import re
......@@ -621,10 +510,18 @@ def set_context(**kwargs):
>>> context.set_context(print_file_path="print.pb")
>>> context.set_context(max_call_depth=80)
"""
ctx = _context()
for key, value in kwargs.items():
if not hasattr(_context(), key):
if hasattr(ctx, key):
setattr(ctx, key, value)
continue
if key in ctx.setters:
ctx.setters[key](ctx, value)
continue
if key in ms_ctx_param.__members__:
ctx.set_param(ms_ctx_param.__members__[key], value)
continue
raise ValueError("Set context keyword %s is not recognized!" % key)
setattr(_context(), key, value)
def get_context(attr_key):
......@@ -640,10 +537,13 @@ def get_context(attr_key):
Raises:
ValueError: If input key is not an attribute in context.
"""
if not hasattr(_context(), attr_key):
raise ValueError(
"Get context keyword %s is not recognized!" % attr_key)
return getattr(_context(), attr_key)
ctx = _context()
if hasattr(ctx, attr_key):
return getattr(ctx, attr_key)
if attr_key in ms_ctx_param.__members__:
return ctx.get_param(ms_ctx_param.__members__[attr_key])
raise ValueError("Get context keyword %s is not recognized!" % attr_key)
class ParallelMode:
"""
......
......@@ -60,7 +60,7 @@ MsContext::MsContext(const std::string &policy, const std::string &target) {
#endif
set_param<bool>(MS_CTX_ENABLE_GPU_SUMMARY, true);
set_param<bool>(MS_CTX_PRECOMPILE_ONLY, false);
set_param<bool>(MS_CTX_AUTO_MIXED_PRECISION_FLAG, false);
set_param<bool>(MS_CTX_ENABLE_AUTO_MIXED_PRECISION, false);
set_param<bool>(MS_CTX_ENABLE_PYNATIVE_INFER, false);
set_param<bool>(MS_CTX_ENABLE_PYNATIVE_HOOK, false);
set_param<bool>(MS_CTX_ENABLE_DYNAMIC_MEM_POOL, true);
......
......@@ -53,7 +53,7 @@ const float kDefaultMaxDeviceMemory = 1024;
enum MsCtxParam : unsigned {
// paramater of type bool
MS_CTX_TYPE_BOOL_BEGIN,
MS_CTX_AUTO_MIXED_PRECISION_FLAG = MS_CTX_TYPE_BOOL_BEGIN,
MS_CTX_ENABLE_AUTO_MIXED_PRECISION = MS_CTX_TYPE_BOOL_BEGIN,
MS_CTX_CHECK_BPROP_FLAG,
MS_CTX_ENABLE_DUMP,
MS_CTX_ENABLE_DYNAMIC_MEM_POOL,
......@@ -132,22 +132,22 @@ class MsContext {
template <typename T>
void set_param(MsCtxParam param, const T &value) {
MS_LOG(EXCEPTION) << "Need implemet " << __FUNCTION__ << " for type " << typeid(T).name() << ".";
MS_LOG(EXCEPTION) << "Need to implement " << __FUNCTION__ << " for type " << typeid(T).name() << ".";
}
template <typename T>
const T &get_param(MsCtxParam param) const {
MS_LOG(EXCEPTION) << "Need implemet " << __FUNCTION__ << " for type " << typeid(T).name() << ".";
MS_LOG(EXCEPTION) << "Need to implement " << __FUNCTION__ << " for type " << typeid(T).name() << ".";
}
template <typename T>
void increase_param(MsCtxParam param) {
MS_LOG(EXCEPTION) << "Need implemet " << __FUNCTION__ << " for type " << typeid(T).name() << ".";
MS_LOG(EXCEPTION) << "Need to implement " << __FUNCTION__ << " for type " << typeid(T).name() << ".";
}
template <typename T>
void decrease_param(MsCtxParam param) {
MS_LOG(EXCEPTION) << "Need implemet " << __FUNCTION__ << " for type " << typeid(T).name() << ".";
MS_LOG(EXCEPTION) << "Need to implement " << __FUNCTION__ << " for type " << typeid(T).name() << ".";
}
private:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册