From 0a858b38780d3377a52335098f9691213fa39616 Mon Sep 17 00:00:00 2001 From: fary86 Date: Mon, 31 Aug 2020 15:07:38 +0800 Subject: [PATCH] Simplify ms_context implementation --- mindspore/ccsrc/pipeline/jit/init.cc | 92 ------- .../ccsrc/pybind_api/utils/ms_context_py.cc | 117 +++++++++ .../ccsrc/utils/context/context_extends.cc | 4 +- mindspore/context.py | 230 +++++------------- mindspore/core/utils/ms_context.cc | 2 +- mindspore/core/utils/ms_context.h | 10 +- 6 files changed, 190 insertions(+), 265 deletions(-) create mode 100644 mindspore/ccsrc/pybind_api/utils/ms_context_py.cc diff --git a/mindspore/ccsrc/pipeline/jit/init.cc b/mindspore/ccsrc/pipeline/jit/init.cc index 2ee67882f..97a7489aa 100644 --- a/mindspore/ccsrc/pipeline/jit/init.cc +++ b/mindspore/ccsrc/pipeline/jit/init.cc @@ -50,55 +50,6 @@ using ParallelContext = mindspore::parallel::ParallelContext; using CostModelContext = mindspore::parallel::CostModelContext; using mindspore::MsCtxParam; -namespace mindspore { -void MsCtxSetParameter(std::shared_ptr ctx, MsCtxParam param, const py::object &value) { - MS_LOG(DEBUG) << "set param(" << param << ") with value '" << py::str(value) << "' of type '" - << py::str(value.get_type()) << "'."; - if (param >= MS_CTX_TYPE_BOOL_BEGIN && param < MS_CTX_TYPE_BOOL_END && py::isinstance(value)) { - ctx->set_param(param, value.cast()); - return; - } - if (param >= MS_CTX_TYPE_INT_BEGIN && param < MS_CTX_TYPE_INT_END && py::isinstance(value)) { - ctx->set_param(param, value.cast()); - return; - } - if (param >= MS_CTX_TYPE_UINT32_BEGIN && param < MS_CTX_TYPE_UINT32_END && py::isinstance(value)) { - ctx->set_param(param, value.cast()); - return; - } - if (param >= MS_CTX_TYPE_FLOAT_BEGIN && param < MS_CTX_TYPE_FLOAT_END && py::isinstance(value)) { - ctx->set_param(param, value.cast()); - return; - } - if (param >= MS_CTX_TYPE_STRING_BEGIN && param < MS_CTX_TYPE_STRING_END && py::isinstance(value)) { - ctx->set_param(param, value.cast()); - return; - } - - MS_LOG(EXCEPTION) << "Got illegal param " << param << " and value with type " << py::str(value.get_type()); -} - -py::object MsCtxGetParameter(const std::shared_ptr &ctx, MsCtxParam param) { - if (param >= MS_CTX_TYPE_BOOL_BEGIN && param < MS_CTX_TYPE_BOOL_END) { - return py::bool_(ctx->get_param(param)); - } - if (param >= MS_CTX_TYPE_INT_BEGIN && param < MS_CTX_TYPE_INT_END) { - return py::int_(ctx->get_param(param)); - } - if (param >= MS_CTX_TYPE_UINT32_BEGIN && param < MS_CTX_TYPE_UINT32_END) { - return py::int_(ctx->get_param(param)); - } - if (param >= MS_CTX_TYPE_FLOAT_BEGIN && param < MS_CTX_TYPE_FLOAT_END) { - return py::float_(ctx->get_param(param)); - } - if (param >= MS_CTX_TYPE_STRING_BEGIN && param < MS_CTX_TYPE_STRING_END) { - return py::str(ctx->get_param(param)); - } - - MS_LOG(EXCEPTION) << "Got illegal param " << param << "."; -} -} // namespace mindspore - // Interface with python PYBIND11_MODULE(_c_expression, m) { m.doc() = "MindSpore c plugin"; @@ -151,49 +102,6 @@ PYBIND11_MODULE(_c_expression, m) { (void)m.def("export_graph", &mindspore::pipeline::ExportGraph, "Export Graph."); - (void)m.def("ms_ctx_get_param", &mindspore::MsCtxGetParameter, "Get value of specified paramter."); - (void)m.def("ms_ctx_set_param", &mindspore::MsCtxSetParameter, "Set value for specified paramter."); - - (void)py::enum_(*m, "ms_ctx_param", py::arithmetic()) - .value("auto_mixed_precision_flag", MsCtxParam::MS_CTX_AUTO_MIXED_PRECISION_FLAG) - .value("check_bprop_flag", MsCtxParam::MS_CTX_CHECK_BPROP_FLAG) - .value("enable_dump", MsCtxParam::MS_CTX_ENABLE_DUMP) - .value("enable_dynamic_mem_pool", MsCtxParam::MS_CTX_ENABLE_DYNAMIC_MEM_POOL) - .value("enable_gpu_summary", MsCtxParam::MS_CTX_ENABLE_GPU_SUMMARY) - .value("enable_graph_kernel", MsCtxParam::MS_CTX_ENABLE_GRAPH_KERNEL) - .value("enable_hccl", MsCtxParam::MS_CTX_ENABLE_HCCL) - .value("enable_loop_sink", MsCtxParam::MS_CTX_ENABLE_LOOP_SINK) - .value("enable_mem_reuse", MsCtxParam::MS_CTX_ENABLE_MEM_REUSE) - .value("enable_pynative_hook", MsCtxParam::MS_CTX_ENABLE_PYNATIVE_HOOK) - .value("enable_pynative_infer", MsCtxParam::MS_CTX_ENABLE_PYNATIVE_INFER) - .value("enable_reduce_precision", MsCtxParam::MS_CTX_ENABLE_REDUCE_PRECISION) - .value("enable_sparse", MsCtxParam::MS_CTX_ENABLE_SPARSE) - .value("enable_task_sink", MsCtxParam::MS_CTX_ENABLE_TASK_SINK) - .value("ir_fusion_flag", MsCtxParam::MS_CTX_IR_FUSION_FLAG) - .value("is_multi_graph_sink", MsCtxParam::MS_CTX_IS_MULTI_GRAPH_SINK) - .value("is_pynative_ge_init", MsCtxParam::MS_CTX_IS_PYNATIVE_GE_INIT) - .value("precompile_only", MsCtxParam::MS_CTX_PRECOMPILE_ONLY) - .value("enable_profiling", MsCtxParam::MS_CTX_ENABLE_PROFILING) - .value("save_graphs_flag", MsCtxParam::MS_CTX_SAVE_GRAPHS_FLAG) - .value("max_device_memory", MsCtxParam::MS_CTX_MAX_DEVICE_MEMORY) - .value("execution_mode", MsCtxParam::MS_CTX_EXECUTION_MODE) - .value("device_target", MsCtxParam::MS_CTX_DEVICE_TARGET) - .value("graph_memory_max_size", MsCtxParam::MS_CTX_GRAPH_MEMORY_MAX_SIZE) - .value("print_file_path", MsCtxParam::MS_CTX_PRINT_FILE_PATH) - .value("profiling_options", MsCtxParam::MS_CTX_PROFILING_OPTIONS) - .value("save_dump_path", MsCtxParam::MS_CTX_SAVE_DUMP_PATH) - .value("save_graphs_path", MsCtxParam::MS_CTX_SAVE_GRAPHS_PATH) - .value("variable_memory_max_size", MsCtxParam::MS_CTX_VARIABLE_MEMORY_MAX_SIZE) - .value("device_id", MsCtxParam::MS_CTX_DEVICE_ID) - .value("ge_ref", MsCtxParam::MS_CTX_GE_REF) - .value("max_call_depth", MsCtxParam::MS_CTX_MAX_CALL_DEPTH) - .value("tsd_ref", MsCtxParam::MS_CTX_TSD_REF); - - (void)py::class_>(m, "MSContext") - .def_static("get_instance", &mindspore::MsContext::GetInstance, "Get ms context instance.") - .def("get_backend_policy", &mindspore::MsContext::backend_policy, "Get backend policy.") - .def("set_backend_policy", &mindspore::MsContext::set_backend_policy, "Set backend policy."); - (void)py::class_>(m, "MpiConfig") .def_static("get_instance", &mindspore::MpiConfig::GetInstance, "Get mpi config instance.") .def("get_enable_mpi", &mindspore::MpiConfig::enable_mpi, "Get whether enable mpi.") diff --git a/mindspore/ccsrc/pybind_api/utils/ms_context_py.cc b/mindspore/ccsrc/pybind_api/utils/ms_context_py.cc new file mode 100644 index 000000000..4931069b5 --- /dev/null +++ b/mindspore/ccsrc/pybind_api/utils/ms_context_py.cc @@ -0,0 +1,117 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include "utils/ms_context.h" +#include "utils/log_adapter.h" +#include "pybind_api/api_register.h" + +namespace mindspore { +namespace { +void MsCtxSetParameter(std::shared_ptr ctx, MsCtxParam param, const py::object &value) { + MS_LOG(DEBUG) << "set param(" << param << ") with value '" << py::str(value).cast() << "' of type '" + << py::str(value.get_type()).cast() << "'."; + if (param >= MS_CTX_TYPE_BOOL_BEGIN && param < MS_CTX_TYPE_BOOL_END && py::isinstance(value)) { + ctx->set_param(param, value.cast()); + return; + } + if (param >= MS_CTX_TYPE_INT_BEGIN && param < MS_CTX_TYPE_INT_END && py::isinstance(value)) { + ctx->set_param(param, value.cast()); + return; + } + if (param >= MS_CTX_TYPE_UINT32_BEGIN && param < MS_CTX_TYPE_UINT32_END && py::isinstance(value)) { + ctx->set_param(param, value.cast()); + return; + } + if (param >= MS_CTX_TYPE_FLOAT_BEGIN && param < MS_CTX_TYPE_FLOAT_END && py::isinstance(value)) { + ctx->set_param(param, value.cast()); + return; + } + if (param >= MS_CTX_TYPE_STRING_BEGIN && param < MS_CTX_TYPE_STRING_END && py::isinstance(value)) { + ctx->set_param(param, value.cast()); + return; + } + + MS_LOG(EXCEPTION) << "Got illegal param " << param << " and value with type " + << py::str(value.get_type()).cast(); +} + +py::object MsCtxGetParameter(const std::shared_ptr &ctx, MsCtxParam param) { + if (param >= MS_CTX_TYPE_BOOL_BEGIN && param < MS_CTX_TYPE_BOOL_END) { + return py::bool_(ctx->get_param(param)); + } + if (param >= MS_CTX_TYPE_INT_BEGIN && param < MS_CTX_TYPE_INT_END) { + return py::int_(ctx->get_param(param)); + } + if (param >= MS_CTX_TYPE_UINT32_BEGIN && param < MS_CTX_TYPE_UINT32_END) { + return py::int_(ctx->get_param(param)); + } + if (param >= MS_CTX_TYPE_FLOAT_BEGIN && param < MS_CTX_TYPE_FLOAT_END) { + return py::float_(ctx->get_param(param)); + } + if (param >= MS_CTX_TYPE_STRING_BEGIN && param < MS_CTX_TYPE_STRING_END) { + return py::str(ctx->get_param(param)); + } + + MS_LOG(EXCEPTION) << "Got illegal param " << param << "."; +} +} // namespace + +REGISTER_PYBIND_DEFINE(MsContextPy, ([](const py::module *m) { + (void)py::enum_(*m, "ms_ctx_param", py::arithmetic()) + .value("enable_auto_mixed_precision", MsCtxParam::MS_CTX_ENABLE_AUTO_MIXED_PRECISION) + .value("check_bprop", MsCtxParam::MS_CTX_CHECK_BPROP_FLAG) + .value("enable_dump", MsCtxParam::MS_CTX_ENABLE_DUMP) + .value("enable_dynamic_mem_pool", MsCtxParam::MS_CTX_ENABLE_DYNAMIC_MEM_POOL) + .value("enable_gpu_summary", MsCtxParam::MS_CTX_ENABLE_GPU_SUMMARY) + .value("enable_graph_kernel", MsCtxParam::MS_CTX_ENABLE_GRAPH_KERNEL) + .value("enable_hccl", MsCtxParam::MS_CTX_ENABLE_HCCL) + .value("enable_loop_sink", MsCtxParam::MS_CTX_ENABLE_LOOP_SINK) + .value("enable_mem_reuse", MsCtxParam::MS_CTX_ENABLE_MEM_REUSE) + .value("enable_pynative_hook", MsCtxParam::MS_CTX_ENABLE_PYNATIVE_HOOK) + .value("enable_pynative_infer", MsCtxParam::MS_CTX_ENABLE_PYNATIVE_INFER) + .value("enable_reduce_precision", MsCtxParam::MS_CTX_ENABLE_REDUCE_PRECISION) + .value("enable_sparse", MsCtxParam::MS_CTX_ENABLE_SPARSE) + .value("enable_task_sink", MsCtxParam::MS_CTX_ENABLE_TASK_SINK) + .value("ir_fusion_flag", MsCtxParam::MS_CTX_IR_FUSION_FLAG) + .value("is_multi_graph_sink", MsCtxParam::MS_CTX_IS_MULTI_GRAPH_SINK) + .value("is_pynative_ge_init", MsCtxParam::MS_CTX_IS_PYNATIVE_GE_INIT) + .value("precompile_only", MsCtxParam::MS_CTX_PRECOMPILE_ONLY) + .value("enable_profiling", MsCtxParam::MS_CTX_ENABLE_PROFILING) + .value("save_graphs", MsCtxParam::MS_CTX_SAVE_GRAPHS_FLAG) + .value("max_device_memory", MsCtxParam::MS_CTX_MAX_DEVICE_MEMORY) + .value("mode", MsCtxParam::MS_CTX_EXECUTION_MODE) + .value("device_target", MsCtxParam::MS_CTX_DEVICE_TARGET) + .value("graph_memory_max_size", MsCtxParam::MS_CTX_GRAPH_MEMORY_MAX_SIZE) + .value("print_file_path", MsCtxParam::MS_CTX_PRINT_FILE_PATH) + .value("profiling_options", MsCtxParam::MS_CTX_PROFILING_OPTIONS) + .value("save_dump_path", MsCtxParam::MS_CTX_SAVE_DUMP_PATH) + .value("save_graphs_path", MsCtxParam::MS_CTX_SAVE_GRAPHS_PATH) + .value("variable_memory_max_size", MsCtxParam::MS_CTX_VARIABLE_MEMORY_MAX_SIZE) + .value("device_id", MsCtxParam::MS_CTX_DEVICE_ID) + .value("ge_ref", MsCtxParam::MS_CTX_GE_REF) + .value("max_call_depth", MsCtxParam::MS_CTX_MAX_CALL_DEPTH) + .value("tsd_ref", MsCtxParam::MS_CTX_TSD_REF); + + (void)py::class_>(*m, "MSContext") + .def_static("get_instance", &mindspore::MsContext::GetInstance, "Get ms context instance.") + .def("get_param", &mindspore::MsCtxGetParameter, "Get value of specified paramter.") + .def("set_param", &mindspore::MsCtxSetParameter, "Set value for specified paramter.") + .def("get_backend_policy", &mindspore::MsContext::backend_policy, "Get backend policy.") + .def("set_backend_policy", &mindspore::MsContext::set_backend_policy, "Set backend policy."); + })); +} // namespace mindspore diff --git a/mindspore/ccsrc/utils/context/context_extends.cc b/mindspore/ccsrc/utils/context/context_extends.cc index efc1d3d85..ec5fd5da5 100644 --- a/mindspore/ccsrc/utils/context/context_extends.cc +++ b/mindspore/ccsrc/utils/context/context_extends.cc @@ -225,7 +225,7 @@ void GetGeOptions(const std::shared_ptr &ms_context_ptr, std::mapget_param(MS_CTX_AUTO_MIXED_PRECISION_FLAG)) { + if (ms_context_ptr->get_param(MS_CTX_ENABLE_AUTO_MIXED_PRECISION)) { (*ge_options)["ge.exec.precision_mode"] = "allow_mix_precision"; } else { (*ge_options)["ge.exec.precision_mode"] = "allow_fp32_to_fp16"; @@ -337,7 +337,7 @@ bool FinalizeGe(const std::shared_ptr &ms_context_ptr, bool force) { if (ge::GEFinalize() != ge::GRAPH_SUCCESS) { MS_LOG(WARNING) << "Finalize GE failed!"; } - ms_context_ptr->set_pynative_ge_init(false); + ms_context_ptr->set_param(MS_CTX_IS_PYNATIVE_GE_INIT, false); } else { MS_LOG(INFO) << "Ge is used, no need to finalize, tsd reference = " << ms_context_ptr->get_param(MS_CTX_GE_REF) << "."; diff --git a/mindspore/context.py b/mindspore/context.py index 7788a4f36..98607b863 100644 --- a/mindspore/context.py +++ b/mindspore/context.py @@ -22,7 +22,7 @@ import threading from collections import namedtuple from types import FunctionType from mindspore import log as logger -from mindspore._c_expression import MSContext, ms_ctx_param, ms_ctx_get_param, ms_ctx_set_param +from mindspore._c_expression import MSContext, ms_ctx_param from mindspore._checkparam import args_type_check from mindspore.parallel._auto_parallel_context import _set_auto_parallel_context, _get_auto_parallel_context, \ _reset_auto_parallel_context @@ -158,17 +158,12 @@ class _Context: return value def get_param(self, param): - return ms_ctx_get_param(self._context_handle, param) + return self._context_handle.get_param(param) def set_param(self, param, value): - ms_ctx_set_param(self._context_handle, param, value) + self._context_handle.set_param(param, value) - @property - def mode(self): - return self.get_param(ms_ctx_param.execution_mode) - - @mode.setter - def mode(self, mode): + def set_mode(self, mode): """ Switch between Graph mode and PyNative mode. @@ -185,43 +180,17 @@ class _Context: self._context_switches.push(False, None) else: raise ValueError(f'The execution mode {mode} is invalid!') - self.set_param(ms_ctx_param.execution_mode, mode) + self.set_param(ms_ctx_param.mode, mode) def set_backend_policy(self, policy): success = self._context_handle.set_backend_policy(policy) if not success: raise RuntimeError("Backend policy must be one of ge, vm, ms.") - @property - def precompile_only(self): - return self.get_param(ms_ctx_param.precompile_only) - - @precompile_only.setter - def precompile_only(self, precompile_only): - self.set_param(ms_ctx_param.precompile_only, precompile_only) - - @property - def save_graphs(self): - return self.get_param(ms_ctx_param.save_graphs_flag) - - @save_graphs.setter - def save_graphs(self, save_graphs_flag): - self.set_param(ms_ctx_param.save_graphs_flag, save_graphs_flag) - - @property - def save_graphs_path(self): - return self.get_param(ms_ctx_param.save_graphs_path) - - @save_graphs_path.setter - def save_graphs_path(self, save_graphs_path): + def set_save_graphs_path(self, save_graphs_path): self.set_param(ms_ctx_param.save_graphs_path, _make_directory(save_graphs_path)) - @property - def device_target(self): - return self.get_param(ms_ctx_param.device_target) - - @device_target.setter - def device_target(self, target): + def set_device_target(self, target): valid_targets = ["CPU", "GPU", "Ascend", "Davinci"] if not target in valid_targets: raise ValueError(f"Target device name {target} is invalid! It must be one of {valid_targets}") @@ -231,72 +200,17 @@ class _Context: if self.enable_debug_runtime and target == "CPU": self.set_backend_policy("vm") - @property - def device_id(self): - return self.get_param(ms_ctx_param.device_id) - - @device_id.setter - def device_id(self, device_id): + def set_device_id(self, device_id): if device_id < 0 or device_id > 4095: raise ValueError(f"Device id must be in [0, 4095], but got {device_id}") self.set_param(ms_ctx_param.device_id, device_id) - @property - def max_call_depth(self): - return self.get_param(ms_ctx_param.max_call_depth) - - @max_call_depth.setter - def max_call_depth(self, max_call_depth): + def set_max_call_depth(self, max_call_depth): if max_call_depth <= 0: raise ValueError(f"Max call depth must be greater than 0, but got {max_call_depth}") self.set_param(ms_ctx_param.max_call_depth, max_call_depth) - @property - def enable_auto_mixed_precision(self): - return self.get_param(ms_ctx_param.auto_mixed_precision_flag) - - @enable_auto_mixed_precision.setter - def enable_auto_mixed_precision(self, enable_auto_mixed_precision): - self.set_param(ms_ctx_param.auto_mixed_precision_flag, enable_auto_mixed_precision) - - @property - def enable_reduce_precision(self): - return self.get_param(ms_ctx_param.enable_reduce_precision_flag) - - @enable_reduce_precision.setter - def enable_reduce_precision(self, enable_reduce_precision): - self.set_param(ms_ctx_param.enable_reduce_precision_flag, enable_reduce_precision) - - @property - def enable_dump(self): - return self.get_param(ms_ctx_param.enable_dump) - - @enable_dump.setter - def enable_dump(self, enable_dump): - self.set_param(ms_ctx_param.enable_dump, enable_dump) - - @property - def save_dump_path(self): - return self.get_param(ms_ctx_param.save_dump_path) - - @save_dump_path.setter - def save_dump_path(self, save_dump_path): - self.set_param(ms_ctx_param.save_dump_path, save_dump_path) - - @property - def enable_profiling(self): - return self.get_param(ms_ctx_param.enable_profiling) - - @enable_profiling.setter - def enable_profiling(self, flag): - self.set_param(ms_ctx_param.enable_profiling, flag) - - @property - def profiling_options(self): - return self.get_param(ms_ctx_param.profiling_options) - - @profiling_options.setter - def profiling_options(self, option): + def set_profiling_options(self, option): options = ["training_trace", "task_trace", "task_trace:training_trace", "training_trace:task_trace", "op_trace"] if option not in options: @@ -304,30 +218,7 @@ class _Context: "'task_trace:training_trace' 'training_trace:task_trace' or 'op_trace'.") self.set_param(ms_ctx_param.profiling_options, option) - @property - def enable_graph_kernel(self): - return self.get_param(ms_ctx_param.enable_graph_kernel) - - @enable_graph_kernel.setter - def enable_graph_kernel(self, graph_kernel_switch_): - self.set_param(ms_ctx_param.enable_graph_kernel, graph_kernel_switch_) - - @property - def reserve_class_name_in_scope(self): - """Gets whether to save the network class name in the scope.""" - return self._thread_local_info.reserve_class_name_in_scope - - @reserve_class_name_in_scope.setter - def reserve_class_name_in_scope(self, reserve_class_name_in_scope): - """Sets whether to save the network class name in the scope.""" - self._thread_local_info.reserve_class_name_in_scope = reserve_class_name_in_scope - - @property - def variable_memory_max_size(self): - return None - - @variable_memory_max_size.setter - def variable_memory_max_size(self, variable_memory_max_size): + def set_variable_memory_max_size(self, variable_memory_max_size): if not check_input_format(variable_memory_max_size): raise ValueError("Context param variable_memory_max_size should be in correct format! Such as \"5GB\"") if int(variable_memory_max_size[:-2]) >= _DEVICE_APP_MEMORY_SIZE: @@ -338,33 +229,7 @@ class _Context: self.set_param(ms_ctx_param.variable_memory_max_size, variable_memory_max_size_) self.set_param(ms_ctx_param.graph_memory_max_size, graph_memory_max_size_) - @property - def enable_ge(self): - return self._context_handle.get_backend_policy() == 'ge' - - @property - def enable_debug_runtime(self): - return self._thread_local_info.debug_runtime - - @enable_debug_runtime.setter - def enable_debug_runtime(self, enable): - thread_info = self._thread_local_info - thread_info.debug_runtime = enable - - @property - def check_bprop(self): - return self.get_param(ms_ctx_param.check_bprop_flag) - - @check_bprop.setter - def check_bprop(self, check_bprop_flag): - self.set_param(ms_ctx_param.check_bprop_flag, check_bprop_flag) - - @property - def max_device_memory(self): - return self.get_param(ms_ctx_param.max_device_memory) - - @max_device_memory.setter - def max_device_memory(self, max_device_memory): + def set_max_device_memory(self, max_device_memory): if not check_input_format(max_device_memory): raise ValueError("Context param max_device_memory should be in correct format! Such as \"3.5GB\"") max_device_memory_value = float(max_device_memory[:-2]) @@ -372,12 +237,7 @@ class _Context: raise ValueError("Context param max_device_memory should be in correct format! Such as \"3.5GB\"") self.set_param(ms_ctx_param.max_device_memory, max_device_memory_value) - @property - def print_file_path(self): - return None - - @print_file_path.setter - def print_file_path(self, file_path): + def set_print_file_path(self, file_path): """Add timestamp suffix to file name. Sets print file path.""" print_file_path = os.path.realpath(file_path) if os.path.isdir(print_file_path): @@ -392,13 +252,42 @@ class _Context: full_file_name = print_file_path self.set_param(ms_ctx_param.print_file_path, full_file_name) + setters = { + 'mode': set_mode, + 'backend_policy': set_backend_policy, + 'save_graphs_path': set_save_graphs_path, + 'device_target': set_device_target, + 'device_id': set_device_id, + 'max_call_depth': set_max_call_depth, + 'profiling_options': set_profiling_options, + 'variable_memory_max_size': set_variable_memory_max_size, + 'max_device_memory': set_max_device_memory, + 'print_file_path': set_print_file_path + } + @property - def enable_sparse(self): - return self.get_param(ms_ctx_param.enable_sparse) + def reserve_class_name_in_scope(self): + """Gets whether to save the network class name in the scope.""" + return self._thread_local_info.reserve_class_name_in_scope + + @reserve_class_name_in_scope.setter + def reserve_class_name_in_scope(self, reserve_class_name_in_scope): + """Sets whether to save the network class name in the scope.""" + self._thread_local_info.reserve_class_name_in_scope = reserve_class_name_in_scope + + @property + def enable_ge(self): + return self._context_handle.get_backend_policy() == 'ge' + + @property + def enable_debug_runtime(self): + return self._thread_local_info.debug_runtime + + @enable_debug_runtime.setter + def enable_debug_runtime(self, enable): + thread_info = self._thread_local_info + thread_info.debug_runtime = enable - @enable_sparse.setter - def enable_sparse(self, enable_sparse): - self.set_param(ms_ctx_param.enable_sparse, enable_sparse) def check_input_format(x): import re @@ -621,10 +510,18 @@ def set_context(**kwargs): >>> context.set_context(print_file_path="print.pb") >>> context.set_context(max_call_depth=80) """ + ctx = _context() for key, value in kwargs.items(): - if not hasattr(_context(), key): - raise ValueError("Set context keyword %s is not recognized!" % key) - setattr(_context(), key, value) + if hasattr(ctx, key): + setattr(ctx, key, value) + continue + if key in ctx.setters: + ctx.setters[key](ctx, value) + continue + if key in ms_ctx_param.__members__: + ctx.set_param(ms_ctx_param.__members__[key], value) + continue + raise ValueError("Set context keyword %s is not recognized!" % key) def get_context(attr_key): @@ -640,10 +537,13 @@ def get_context(attr_key): Raises: ValueError: If input key is not an attribute in context. """ - if not hasattr(_context(), attr_key): - raise ValueError( - "Get context keyword %s is not recognized!" % attr_key) - return getattr(_context(), attr_key) + ctx = _context() + if hasattr(ctx, attr_key): + return getattr(ctx, attr_key) + if attr_key in ms_ctx_param.__members__: + return ctx.get_param(ms_ctx_param.__members__[attr_key]) + raise ValueError("Get context keyword %s is not recognized!" % attr_key) + class ParallelMode: """ diff --git a/mindspore/core/utils/ms_context.cc b/mindspore/core/utils/ms_context.cc index 064a76b06..4e73fd824 100644 --- a/mindspore/core/utils/ms_context.cc +++ b/mindspore/core/utils/ms_context.cc @@ -60,7 +60,7 @@ MsContext::MsContext(const std::string &policy, const std::string &target) { #endif set_param(MS_CTX_ENABLE_GPU_SUMMARY, true); set_param(MS_CTX_PRECOMPILE_ONLY, false); - set_param(MS_CTX_AUTO_MIXED_PRECISION_FLAG, false); + set_param(MS_CTX_ENABLE_AUTO_MIXED_PRECISION, false); set_param(MS_CTX_ENABLE_PYNATIVE_INFER, false); set_param(MS_CTX_ENABLE_PYNATIVE_HOOK, false); set_param(MS_CTX_ENABLE_DYNAMIC_MEM_POOL, true); diff --git a/mindspore/core/utils/ms_context.h b/mindspore/core/utils/ms_context.h index 9d461eaa4..38eb2d4d3 100644 --- a/mindspore/core/utils/ms_context.h +++ b/mindspore/core/utils/ms_context.h @@ -53,7 +53,7 @@ const float kDefaultMaxDeviceMemory = 1024; enum MsCtxParam : unsigned { // paramater of type bool MS_CTX_TYPE_BOOL_BEGIN, - MS_CTX_AUTO_MIXED_PRECISION_FLAG = MS_CTX_TYPE_BOOL_BEGIN, + MS_CTX_ENABLE_AUTO_MIXED_PRECISION = MS_CTX_TYPE_BOOL_BEGIN, MS_CTX_CHECK_BPROP_FLAG, MS_CTX_ENABLE_DUMP, MS_CTX_ENABLE_DYNAMIC_MEM_POOL, @@ -132,22 +132,22 @@ class MsContext { template void set_param(MsCtxParam param, const T &value) { - MS_LOG(EXCEPTION) << "Need implemet " << __FUNCTION__ << " for type " << typeid(T).name() << "."; + MS_LOG(EXCEPTION) << "Need to implement " << __FUNCTION__ << " for type " << typeid(T).name() << "."; } template const T &get_param(MsCtxParam param) const { - MS_LOG(EXCEPTION) << "Need implemet " << __FUNCTION__ << " for type " << typeid(T).name() << "."; + MS_LOG(EXCEPTION) << "Need to implement " << __FUNCTION__ << " for type " << typeid(T).name() << "."; } template void increase_param(MsCtxParam param) { - MS_LOG(EXCEPTION) << "Need implemet " << __FUNCTION__ << " for type " << typeid(T).name() << "."; + MS_LOG(EXCEPTION) << "Need to implement " << __FUNCTION__ << " for type " << typeid(T).name() << "."; } template void decrease_param(MsCtxParam param) { - MS_LOG(EXCEPTION) << "Need implemet " << __FUNCTION__ << " for type " << typeid(T).name() << "."; + MS_LOG(EXCEPTION) << "Need to implement " << __FUNCTION__ << " for type " << typeid(T).name() << "."; } private: -- GitLab