提交 0a858b38 编写于 作者: F fary86

Simplify ms_context implementation

上级 d5e02cf4
...@@ -50,55 +50,6 @@ using ParallelContext = mindspore::parallel::ParallelContext; ...@@ -50,55 +50,6 @@ using ParallelContext = mindspore::parallel::ParallelContext;
using CostModelContext = mindspore::parallel::CostModelContext; using CostModelContext = mindspore::parallel::CostModelContext;
using mindspore::MsCtxParam; using mindspore::MsCtxParam;
namespace mindspore {
void MsCtxSetParameter(std::shared_ptr<MsContext> ctx, MsCtxParam param, const py::object &value) {
MS_LOG(DEBUG) << "set param(" << param << ") with value '" << py::str(value) << "' of type '"
<< py::str(value.get_type()) << "'.";
if (param >= MS_CTX_TYPE_BOOL_BEGIN && param < MS_CTX_TYPE_BOOL_END && py::isinstance<py::bool_>(value)) {
ctx->set_param<bool>(param, value.cast<bool>());
return;
}
if (param >= MS_CTX_TYPE_INT_BEGIN && param < MS_CTX_TYPE_INT_END && py::isinstance<py::int_>(value)) {
ctx->set_param<int>(param, value.cast<int>());
return;
}
if (param >= MS_CTX_TYPE_UINT32_BEGIN && param < MS_CTX_TYPE_UINT32_END && py::isinstance<py::int_>(value)) {
ctx->set_param<uint32_t>(param, value.cast<uint32_t>());
return;
}
if (param >= MS_CTX_TYPE_FLOAT_BEGIN && param < MS_CTX_TYPE_FLOAT_END && py::isinstance<py::float_>(value)) {
ctx->set_param<float>(param, value.cast<float>());
return;
}
if (param >= MS_CTX_TYPE_STRING_BEGIN && param < MS_CTX_TYPE_STRING_END && py::isinstance<py::str>(value)) {
ctx->set_param<std::string>(param, value.cast<std::string>());
return;
}
MS_LOG(EXCEPTION) << "Got illegal param " << param << " and value with type " << py::str(value.get_type());
}
py::object MsCtxGetParameter(const std::shared_ptr<MsContext> &ctx, MsCtxParam param) {
if (param >= MS_CTX_TYPE_BOOL_BEGIN && param < MS_CTX_TYPE_BOOL_END) {
return py::bool_(ctx->get_param<bool>(param));
}
if (param >= MS_CTX_TYPE_INT_BEGIN && param < MS_CTX_TYPE_INT_END) {
return py::int_(ctx->get_param<int>(param));
}
if (param >= MS_CTX_TYPE_UINT32_BEGIN && param < MS_CTX_TYPE_UINT32_END) {
return py::int_(ctx->get_param<uint32_t>(param));
}
if (param >= MS_CTX_TYPE_FLOAT_BEGIN && param < MS_CTX_TYPE_FLOAT_END) {
return py::float_(ctx->get_param<float>(param));
}
if (param >= MS_CTX_TYPE_STRING_BEGIN && param < MS_CTX_TYPE_STRING_END) {
return py::str(ctx->get_param<std::string>(param));
}
MS_LOG(EXCEPTION) << "Got illegal param " << param << ".";
}
} // namespace mindspore
// Interface with python // Interface with python
PYBIND11_MODULE(_c_expression, m) { PYBIND11_MODULE(_c_expression, m) {
m.doc() = "MindSpore c plugin"; m.doc() = "MindSpore c plugin";
...@@ -151,49 +102,6 @@ PYBIND11_MODULE(_c_expression, m) { ...@@ -151,49 +102,6 @@ PYBIND11_MODULE(_c_expression, m) {
(void)m.def("export_graph", &mindspore::pipeline::ExportGraph, "Export Graph."); (void)m.def("export_graph", &mindspore::pipeline::ExportGraph, "Export Graph.");
(void)m.def("ms_ctx_get_param", &mindspore::MsCtxGetParameter, "Get value of specified paramter.");
(void)m.def("ms_ctx_set_param", &mindspore::MsCtxSetParameter, "Set value for specified paramter.");
(void)py::enum_<MsCtxParam>(*m, "ms_ctx_param", py::arithmetic())
.value("auto_mixed_precision_flag", MsCtxParam::MS_CTX_AUTO_MIXED_PRECISION_FLAG)
.value("check_bprop_flag", MsCtxParam::MS_CTX_CHECK_BPROP_FLAG)
.value("enable_dump", MsCtxParam::MS_CTX_ENABLE_DUMP)
.value("enable_dynamic_mem_pool", MsCtxParam::MS_CTX_ENABLE_DYNAMIC_MEM_POOL)
.value("enable_gpu_summary", MsCtxParam::MS_CTX_ENABLE_GPU_SUMMARY)
.value("enable_graph_kernel", MsCtxParam::MS_CTX_ENABLE_GRAPH_KERNEL)
.value("enable_hccl", MsCtxParam::MS_CTX_ENABLE_HCCL)
.value("enable_loop_sink", MsCtxParam::MS_CTX_ENABLE_LOOP_SINK)
.value("enable_mem_reuse", MsCtxParam::MS_CTX_ENABLE_MEM_REUSE)
.value("enable_pynative_hook", MsCtxParam::MS_CTX_ENABLE_PYNATIVE_HOOK)
.value("enable_pynative_infer", MsCtxParam::MS_CTX_ENABLE_PYNATIVE_INFER)
.value("enable_reduce_precision", MsCtxParam::MS_CTX_ENABLE_REDUCE_PRECISION)
.value("enable_sparse", MsCtxParam::MS_CTX_ENABLE_SPARSE)
.value("enable_task_sink", MsCtxParam::MS_CTX_ENABLE_TASK_SINK)
.value("ir_fusion_flag", MsCtxParam::MS_CTX_IR_FUSION_FLAG)
.value("is_multi_graph_sink", MsCtxParam::MS_CTX_IS_MULTI_GRAPH_SINK)
.value("is_pynative_ge_init", MsCtxParam::MS_CTX_IS_PYNATIVE_GE_INIT)
.value("precompile_only", MsCtxParam::MS_CTX_PRECOMPILE_ONLY)
.value("enable_profiling", MsCtxParam::MS_CTX_ENABLE_PROFILING)
.value("save_graphs_flag", MsCtxParam::MS_CTX_SAVE_GRAPHS_FLAG)
.value("max_device_memory", MsCtxParam::MS_CTX_MAX_DEVICE_MEMORY)
.value("execution_mode", MsCtxParam::MS_CTX_EXECUTION_MODE)
.value("device_target", MsCtxParam::MS_CTX_DEVICE_TARGET)
.value("graph_memory_max_size", MsCtxParam::MS_CTX_GRAPH_MEMORY_MAX_SIZE)
.value("print_file_path", MsCtxParam::MS_CTX_PRINT_FILE_PATH)
.value("profiling_options", MsCtxParam::MS_CTX_PROFILING_OPTIONS)
.value("save_dump_path", MsCtxParam::MS_CTX_SAVE_DUMP_PATH)
.value("save_graphs_path", MsCtxParam::MS_CTX_SAVE_GRAPHS_PATH)
.value("variable_memory_max_size", MsCtxParam::MS_CTX_VARIABLE_MEMORY_MAX_SIZE)
.value("device_id", MsCtxParam::MS_CTX_DEVICE_ID)
.value("ge_ref", MsCtxParam::MS_CTX_GE_REF)
.value("max_call_depth", MsCtxParam::MS_CTX_MAX_CALL_DEPTH)
.value("tsd_ref", MsCtxParam::MS_CTX_TSD_REF);
(void)py::class_<mindspore::MsContext, std::shared_ptr<mindspore::MsContext>>(m, "MSContext")
.def_static("get_instance", &mindspore::MsContext::GetInstance, "Get ms context instance.")
.def("get_backend_policy", &mindspore::MsContext::backend_policy, "Get backend policy.")
.def("set_backend_policy", &mindspore::MsContext::set_backend_policy, "Set backend policy.");
(void)py::class_<mindspore::MpiConfig, std::shared_ptr<mindspore::MpiConfig>>(m, "MpiConfig") (void)py::class_<mindspore::MpiConfig, std::shared_ptr<mindspore::MpiConfig>>(m, "MpiConfig")
.def_static("get_instance", &mindspore::MpiConfig::GetInstance, "Get mpi config instance.") .def_static("get_instance", &mindspore::MpiConfig::GetInstance, "Get mpi config instance.")
.def("get_enable_mpi", &mindspore::MpiConfig::enable_mpi, "Get whether enable mpi.") .def("get_enable_mpi", &mindspore::MpiConfig::enable_mpi, "Get whether enable mpi.")
......
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <memory>
#include <string>
#include "utils/ms_context.h"
#include "utils/log_adapter.h"
#include "pybind_api/api_register.h"
namespace mindspore {
namespace {
void MsCtxSetParameter(std::shared_ptr<MsContext> ctx, MsCtxParam param, const py::object &value) {
MS_LOG(DEBUG) << "set param(" << param << ") with value '" << py::str(value).cast<std::string>() << "' of type '"
<< py::str(value.get_type()).cast<std::string>() << "'.";
if (param >= MS_CTX_TYPE_BOOL_BEGIN && param < MS_CTX_TYPE_BOOL_END && py::isinstance<py::bool_>(value)) {
ctx->set_param<bool>(param, value.cast<bool>());
return;
}
if (param >= MS_CTX_TYPE_INT_BEGIN && param < MS_CTX_TYPE_INT_END && py::isinstance<py::int_>(value)) {
ctx->set_param<int>(param, value.cast<int>());
return;
}
if (param >= MS_CTX_TYPE_UINT32_BEGIN && param < MS_CTX_TYPE_UINT32_END && py::isinstance<py::int_>(value)) {
ctx->set_param<uint32_t>(param, value.cast<uint32_t>());
return;
}
if (param >= MS_CTX_TYPE_FLOAT_BEGIN && param < MS_CTX_TYPE_FLOAT_END && py::isinstance<py::float_>(value)) {
ctx->set_param<float>(param, value.cast<float>());
return;
}
if (param >= MS_CTX_TYPE_STRING_BEGIN && param < MS_CTX_TYPE_STRING_END && py::isinstance<py::str>(value)) {
ctx->set_param<std::string>(param, value.cast<std::string>());
return;
}
MS_LOG(EXCEPTION) << "Got illegal param " << param << " and value with type "
<< py::str(value.get_type()).cast<std::string>();
}
py::object MsCtxGetParameter(const std::shared_ptr<MsContext> &ctx, MsCtxParam param) {
if (param >= MS_CTX_TYPE_BOOL_BEGIN && param < MS_CTX_TYPE_BOOL_END) {
return py::bool_(ctx->get_param<bool>(param));
}
if (param >= MS_CTX_TYPE_INT_BEGIN && param < MS_CTX_TYPE_INT_END) {
return py::int_(ctx->get_param<int>(param));
}
if (param >= MS_CTX_TYPE_UINT32_BEGIN && param < MS_CTX_TYPE_UINT32_END) {
return py::int_(ctx->get_param<uint32_t>(param));
}
if (param >= MS_CTX_TYPE_FLOAT_BEGIN && param < MS_CTX_TYPE_FLOAT_END) {
return py::float_(ctx->get_param<float>(param));
}
if (param >= MS_CTX_TYPE_STRING_BEGIN && param < MS_CTX_TYPE_STRING_END) {
return py::str(ctx->get_param<std::string>(param));
}
MS_LOG(EXCEPTION) << "Got illegal param " << param << ".";
}
} // namespace
REGISTER_PYBIND_DEFINE(MsContextPy, ([](const py::module *m) {
(void)py::enum_<MsCtxParam>(*m, "ms_ctx_param", py::arithmetic())
.value("enable_auto_mixed_precision", MsCtxParam::MS_CTX_ENABLE_AUTO_MIXED_PRECISION)
.value("check_bprop", MsCtxParam::MS_CTX_CHECK_BPROP_FLAG)
.value("enable_dump", MsCtxParam::MS_CTX_ENABLE_DUMP)
.value("enable_dynamic_mem_pool", MsCtxParam::MS_CTX_ENABLE_DYNAMIC_MEM_POOL)
.value("enable_gpu_summary", MsCtxParam::MS_CTX_ENABLE_GPU_SUMMARY)
.value("enable_graph_kernel", MsCtxParam::MS_CTX_ENABLE_GRAPH_KERNEL)
.value("enable_hccl", MsCtxParam::MS_CTX_ENABLE_HCCL)
.value("enable_loop_sink", MsCtxParam::MS_CTX_ENABLE_LOOP_SINK)
.value("enable_mem_reuse", MsCtxParam::MS_CTX_ENABLE_MEM_REUSE)
.value("enable_pynative_hook", MsCtxParam::MS_CTX_ENABLE_PYNATIVE_HOOK)
.value("enable_pynative_infer", MsCtxParam::MS_CTX_ENABLE_PYNATIVE_INFER)
.value("enable_reduce_precision", MsCtxParam::MS_CTX_ENABLE_REDUCE_PRECISION)
.value("enable_sparse", MsCtxParam::MS_CTX_ENABLE_SPARSE)
.value("enable_task_sink", MsCtxParam::MS_CTX_ENABLE_TASK_SINK)
.value("ir_fusion_flag", MsCtxParam::MS_CTX_IR_FUSION_FLAG)
.value("is_multi_graph_sink", MsCtxParam::MS_CTX_IS_MULTI_GRAPH_SINK)
.value("is_pynative_ge_init", MsCtxParam::MS_CTX_IS_PYNATIVE_GE_INIT)
.value("precompile_only", MsCtxParam::MS_CTX_PRECOMPILE_ONLY)
.value("enable_profiling", MsCtxParam::MS_CTX_ENABLE_PROFILING)
.value("save_graphs", MsCtxParam::MS_CTX_SAVE_GRAPHS_FLAG)
.value("max_device_memory", MsCtxParam::MS_CTX_MAX_DEVICE_MEMORY)
.value("mode", MsCtxParam::MS_CTX_EXECUTION_MODE)
.value("device_target", MsCtxParam::MS_CTX_DEVICE_TARGET)
.value("graph_memory_max_size", MsCtxParam::MS_CTX_GRAPH_MEMORY_MAX_SIZE)
.value("print_file_path", MsCtxParam::MS_CTX_PRINT_FILE_PATH)
.value("profiling_options", MsCtxParam::MS_CTX_PROFILING_OPTIONS)
.value("save_dump_path", MsCtxParam::MS_CTX_SAVE_DUMP_PATH)
.value("save_graphs_path", MsCtxParam::MS_CTX_SAVE_GRAPHS_PATH)
.value("variable_memory_max_size", MsCtxParam::MS_CTX_VARIABLE_MEMORY_MAX_SIZE)
.value("device_id", MsCtxParam::MS_CTX_DEVICE_ID)
.value("ge_ref", MsCtxParam::MS_CTX_GE_REF)
.value("max_call_depth", MsCtxParam::MS_CTX_MAX_CALL_DEPTH)
.value("tsd_ref", MsCtxParam::MS_CTX_TSD_REF);
(void)py::class_<mindspore::MsContext, std::shared_ptr<mindspore::MsContext>>(*m, "MSContext")
.def_static("get_instance", &mindspore::MsContext::GetInstance, "Get ms context instance.")
.def("get_param", &mindspore::MsCtxGetParameter, "Get value of specified paramter.")
.def("set_param", &mindspore::MsCtxSetParameter, "Set value for specified paramter.")
.def("get_backend_policy", &mindspore::MsContext::backend_policy, "Get backend policy.")
.def("set_backend_policy", &mindspore::MsContext::set_backend_policy, "Set backend policy.");
}));
} // namespace mindspore
...@@ -225,7 +225,7 @@ void GetGeOptions(const std::shared_ptr<MsContext> &ms_context_ptr, std::map<std ...@@ -225,7 +225,7 @@ void GetGeOptions(const std::shared_ptr<MsContext> &ms_context_ptr, std::map<std
} }
// Enable auto mixed precision according to the context options // Enable auto mixed precision according to the context options
if (ms_context_ptr->get_param<bool>(MS_CTX_AUTO_MIXED_PRECISION_FLAG)) { if (ms_context_ptr->get_param<bool>(MS_CTX_ENABLE_AUTO_MIXED_PRECISION)) {
(*ge_options)["ge.exec.precision_mode"] = "allow_mix_precision"; (*ge_options)["ge.exec.precision_mode"] = "allow_mix_precision";
} else { } else {
(*ge_options)["ge.exec.precision_mode"] = "allow_fp32_to_fp16"; (*ge_options)["ge.exec.precision_mode"] = "allow_fp32_to_fp16";
...@@ -337,7 +337,7 @@ bool FinalizeGe(const std::shared_ptr<MsContext> &ms_context_ptr, bool force) { ...@@ -337,7 +337,7 @@ bool FinalizeGe(const std::shared_ptr<MsContext> &ms_context_ptr, bool force) {
if (ge::GEFinalize() != ge::GRAPH_SUCCESS) { if (ge::GEFinalize() != ge::GRAPH_SUCCESS) {
MS_LOG(WARNING) << "Finalize GE failed!"; MS_LOG(WARNING) << "Finalize GE failed!";
} }
ms_context_ptr->set_pynative_ge_init(false); ms_context_ptr->set_param<bool>(MS_CTX_IS_PYNATIVE_GE_INIT, false);
} else { } else {
MS_LOG(INFO) << "Ge is used, no need to finalize, tsd reference = " MS_LOG(INFO) << "Ge is used, no need to finalize, tsd reference = "
<< ms_context_ptr->get_param<uint32_t>(MS_CTX_GE_REF) << "."; << ms_context_ptr->get_param<uint32_t>(MS_CTX_GE_REF) << ".";
......
...@@ -22,7 +22,7 @@ import threading ...@@ -22,7 +22,7 @@ import threading
from collections import namedtuple from collections import namedtuple
from types import FunctionType from types import FunctionType
from mindspore import log as logger from mindspore import log as logger
from mindspore._c_expression import MSContext, ms_ctx_param, ms_ctx_get_param, ms_ctx_set_param from mindspore._c_expression import MSContext, ms_ctx_param
from mindspore._checkparam import args_type_check from mindspore._checkparam import args_type_check
from mindspore.parallel._auto_parallel_context import _set_auto_parallel_context, _get_auto_parallel_context, \ from mindspore.parallel._auto_parallel_context import _set_auto_parallel_context, _get_auto_parallel_context, \
_reset_auto_parallel_context _reset_auto_parallel_context
...@@ -158,17 +158,12 @@ class _Context: ...@@ -158,17 +158,12 @@ class _Context:
return value return value
def get_param(self, param): def get_param(self, param):
return ms_ctx_get_param(self._context_handle, param) return self._context_handle.get_param(param)
def set_param(self, param, value): def set_param(self, param, value):
ms_ctx_set_param(self._context_handle, param, value) self._context_handle.set_param(param, value)
@property def set_mode(self, mode):
def mode(self):
return self.get_param(ms_ctx_param.execution_mode)
@mode.setter
def mode(self, mode):
""" """
Switch between Graph mode and PyNative mode. Switch between Graph mode and PyNative mode.
...@@ -185,43 +180,17 @@ class _Context: ...@@ -185,43 +180,17 @@ class _Context:
self._context_switches.push(False, None) self._context_switches.push(False, None)
else: else:
raise ValueError(f'The execution mode {mode} is invalid!') raise ValueError(f'The execution mode {mode} is invalid!')
self.set_param(ms_ctx_param.execution_mode, mode) self.set_param(ms_ctx_param.mode, mode)
def set_backend_policy(self, policy): def set_backend_policy(self, policy):
success = self._context_handle.set_backend_policy(policy) success = self._context_handle.set_backend_policy(policy)
if not success: if not success:
raise RuntimeError("Backend policy must be one of ge, vm, ms.") raise RuntimeError("Backend policy must be one of ge, vm, ms.")
@property def set_save_graphs_path(self, save_graphs_path):
def precompile_only(self):
return self.get_param(ms_ctx_param.precompile_only)
@precompile_only.setter
def precompile_only(self, precompile_only):
self.set_param(ms_ctx_param.precompile_only, precompile_only)
@property
def save_graphs(self):
return self.get_param(ms_ctx_param.save_graphs_flag)
@save_graphs.setter
def save_graphs(self, save_graphs_flag):
self.set_param(ms_ctx_param.save_graphs_flag, save_graphs_flag)
@property
def save_graphs_path(self):
return self.get_param(ms_ctx_param.save_graphs_path)
@save_graphs_path.setter
def save_graphs_path(self, save_graphs_path):
self.set_param(ms_ctx_param.save_graphs_path, _make_directory(save_graphs_path)) self.set_param(ms_ctx_param.save_graphs_path, _make_directory(save_graphs_path))
@property def set_device_target(self, target):
def device_target(self):
return self.get_param(ms_ctx_param.device_target)
@device_target.setter
def device_target(self, target):
valid_targets = ["CPU", "GPU", "Ascend", "Davinci"] valid_targets = ["CPU", "GPU", "Ascend", "Davinci"]
if not target in valid_targets: if not target in valid_targets:
raise ValueError(f"Target device name {target} is invalid! It must be one of {valid_targets}") raise ValueError(f"Target device name {target} is invalid! It must be one of {valid_targets}")
...@@ -231,72 +200,17 @@ class _Context: ...@@ -231,72 +200,17 @@ class _Context:
if self.enable_debug_runtime and target == "CPU": if self.enable_debug_runtime and target == "CPU":
self.set_backend_policy("vm") self.set_backend_policy("vm")
@property def set_device_id(self, device_id):
def device_id(self):
return self.get_param(ms_ctx_param.device_id)
@device_id.setter
def device_id(self, device_id):
if device_id < 0 or device_id > 4095: if device_id < 0 or device_id > 4095:
raise ValueError(f"Device id must be in [0, 4095], but got {device_id}") raise ValueError(f"Device id must be in [0, 4095], but got {device_id}")
self.set_param(ms_ctx_param.device_id, device_id) self.set_param(ms_ctx_param.device_id, device_id)
@property def set_max_call_depth(self, max_call_depth):
def max_call_depth(self):
return self.get_param(ms_ctx_param.max_call_depth)
@max_call_depth.setter
def max_call_depth(self, max_call_depth):
if max_call_depth <= 0: if max_call_depth <= 0:
raise ValueError(f"Max call depth must be greater than 0, but got {max_call_depth}") raise ValueError(f"Max call depth must be greater than 0, but got {max_call_depth}")
self.set_param(ms_ctx_param.max_call_depth, max_call_depth) self.set_param(ms_ctx_param.max_call_depth, max_call_depth)
@property def set_profiling_options(self, option):
def enable_auto_mixed_precision(self):
return self.get_param(ms_ctx_param.auto_mixed_precision_flag)
@enable_auto_mixed_precision.setter
def enable_auto_mixed_precision(self, enable_auto_mixed_precision):
self.set_param(ms_ctx_param.auto_mixed_precision_flag, enable_auto_mixed_precision)
@property
def enable_reduce_precision(self):
return self.get_param(ms_ctx_param.enable_reduce_precision_flag)
@enable_reduce_precision.setter
def enable_reduce_precision(self, enable_reduce_precision):
self.set_param(ms_ctx_param.enable_reduce_precision_flag, enable_reduce_precision)
@property
def enable_dump(self):
return self.get_param(ms_ctx_param.enable_dump)
@enable_dump.setter
def enable_dump(self, enable_dump):
self.set_param(ms_ctx_param.enable_dump, enable_dump)
@property
def save_dump_path(self):
return self.get_param(ms_ctx_param.save_dump_path)
@save_dump_path.setter
def save_dump_path(self, save_dump_path):
self.set_param(ms_ctx_param.save_dump_path, save_dump_path)
@property
def enable_profiling(self):
return self.get_param(ms_ctx_param.enable_profiling)
@enable_profiling.setter
def enable_profiling(self, flag):
self.set_param(ms_ctx_param.enable_profiling, flag)
@property
def profiling_options(self):
return self.get_param(ms_ctx_param.profiling_options)
@profiling_options.setter
def profiling_options(self, option):
options = ["training_trace", "task_trace", options = ["training_trace", "task_trace",
"task_trace:training_trace", "training_trace:task_trace", "op_trace"] "task_trace:training_trace", "training_trace:task_trace", "op_trace"]
if option not in options: if option not in options:
...@@ -304,30 +218,7 @@ class _Context: ...@@ -304,30 +218,7 @@ class _Context:
"'task_trace:training_trace' 'training_trace:task_trace' or 'op_trace'.") "'task_trace:training_trace' 'training_trace:task_trace' or 'op_trace'.")
self.set_param(ms_ctx_param.profiling_options, option) self.set_param(ms_ctx_param.profiling_options, option)
@property def set_variable_memory_max_size(self, variable_memory_max_size):
def enable_graph_kernel(self):
return self.get_param(ms_ctx_param.enable_graph_kernel)
@enable_graph_kernel.setter
def enable_graph_kernel(self, graph_kernel_switch_):
self.set_param(ms_ctx_param.enable_graph_kernel, graph_kernel_switch_)
@property
def reserve_class_name_in_scope(self):
"""Gets whether to save the network class name in the scope."""
return self._thread_local_info.reserve_class_name_in_scope
@reserve_class_name_in_scope.setter
def reserve_class_name_in_scope(self, reserve_class_name_in_scope):
"""Sets whether to save the network class name in the scope."""
self._thread_local_info.reserve_class_name_in_scope = reserve_class_name_in_scope
@property
def variable_memory_max_size(self):
return None
@variable_memory_max_size.setter
def variable_memory_max_size(self, variable_memory_max_size):
if not check_input_format(variable_memory_max_size): if not check_input_format(variable_memory_max_size):
raise ValueError("Context param variable_memory_max_size should be in correct format! Such as \"5GB\"") raise ValueError("Context param variable_memory_max_size should be in correct format! Such as \"5GB\"")
if int(variable_memory_max_size[:-2]) >= _DEVICE_APP_MEMORY_SIZE: if int(variable_memory_max_size[:-2]) >= _DEVICE_APP_MEMORY_SIZE:
...@@ -338,33 +229,7 @@ class _Context: ...@@ -338,33 +229,7 @@ class _Context:
self.set_param(ms_ctx_param.variable_memory_max_size, variable_memory_max_size_) self.set_param(ms_ctx_param.variable_memory_max_size, variable_memory_max_size_)
self.set_param(ms_ctx_param.graph_memory_max_size, graph_memory_max_size_) self.set_param(ms_ctx_param.graph_memory_max_size, graph_memory_max_size_)
@property def set_max_device_memory(self, max_device_memory):
def enable_ge(self):
return self._context_handle.get_backend_policy() == 'ge'
@property
def enable_debug_runtime(self):
return self._thread_local_info.debug_runtime
@enable_debug_runtime.setter
def enable_debug_runtime(self, enable):
thread_info = self._thread_local_info
thread_info.debug_runtime = enable
@property
def check_bprop(self):
return self.get_param(ms_ctx_param.check_bprop_flag)
@check_bprop.setter
def check_bprop(self, check_bprop_flag):
self.set_param(ms_ctx_param.check_bprop_flag, check_bprop_flag)
@property
def max_device_memory(self):
return self.get_param(ms_ctx_param.max_device_memory)
@max_device_memory.setter
def max_device_memory(self, max_device_memory):
if not check_input_format(max_device_memory): if not check_input_format(max_device_memory):
raise ValueError("Context param max_device_memory should be in correct format! Such as \"3.5GB\"") raise ValueError("Context param max_device_memory should be in correct format! Such as \"3.5GB\"")
max_device_memory_value = float(max_device_memory[:-2]) max_device_memory_value = float(max_device_memory[:-2])
...@@ -372,12 +237,7 @@ class _Context: ...@@ -372,12 +237,7 @@ class _Context:
raise ValueError("Context param max_device_memory should be in correct format! Such as \"3.5GB\"") raise ValueError("Context param max_device_memory should be in correct format! Such as \"3.5GB\"")
self.set_param(ms_ctx_param.max_device_memory, max_device_memory_value) self.set_param(ms_ctx_param.max_device_memory, max_device_memory_value)
@property def set_print_file_path(self, file_path):
def print_file_path(self):
return None
@print_file_path.setter
def print_file_path(self, file_path):
"""Add timestamp suffix to file name. Sets print file path.""" """Add timestamp suffix to file name. Sets print file path."""
print_file_path = os.path.realpath(file_path) print_file_path = os.path.realpath(file_path)
if os.path.isdir(print_file_path): if os.path.isdir(print_file_path):
...@@ -392,13 +252,42 @@ class _Context: ...@@ -392,13 +252,42 @@ class _Context:
full_file_name = print_file_path full_file_name = print_file_path
self.set_param(ms_ctx_param.print_file_path, full_file_name) self.set_param(ms_ctx_param.print_file_path, full_file_name)
setters = {
'mode': set_mode,
'backend_policy': set_backend_policy,
'save_graphs_path': set_save_graphs_path,
'device_target': set_device_target,
'device_id': set_device_id,
'max_call_depth': set_max_call_depth,
'profiling_options': set_profiling_options,
'variable_memory_max_size': set_variable_memory_max_size,
'max_device_memory': set_max_device_memory,
'print_file_path': set_print_file_path
}
@property @property
def enable_sparse(self): def reserve_class_name_in_scope(self):
return self.get_param(ms_ctx_param.enable_sparse) """Gets whether to save the network class name in the scope."""
return self._thread_local_info.reserve_class_name_in_scope
@reserve_class_name_in_scope.setter
def reserve_class_name_in_scope(self, reserve_class_name_in_scope):
"""Sets whether to save the network class name in the scope."""
self._thread_local_info.reserve_class_name_in_scope = reserve_class_name_in_scope
@property
def enable_ge(self):
return self._context_handle.get_backend_policy() == 'ge'
@property
def enable_debug_runtime(self):
return self._thread_local_info.debug_runtime
@enable_debug_runtime.setter
def enable_debug_runtime(self, enable):
thread_info = self._thread_local_info
thread_info.debug_runtime = enable
@enable_sparse.setter
def enable_sparse(self, enable_sparse):
self.set_param(ms_ctx_param.enable_sparse, enable_sparse)
def check_input_format(x): def check_input_format(x):
import re import re
...@@ -621,10 +510,18 @@ def set_context(**kwargs): ...@@ -621,10 +510,18 @@ def set_context(**kwargs):
>>> context.set_context(print_file_path="print.pb") >>> context.set_context(print_file_path="print.pb")
>>> context.set_context(max_call_depth=80) >>> context.set_context(max_call_depth=80)
""" """
ctx = _context()
for key, value in kwargs.items(): for key, value in kwargs.items():
if not hasattr(_context(), key): if hasattr(ctx, key):
raise ValueError("Set context keyword %s is not recognized!" % key) setattr(ctx, key, value)
setattr(_context(), key, value) continue
if key in ctx.setters:
ctx.setters[key](ctx, value)
continue
if key in ms_ctx_param.__members__:
ctx.set_param(ms_ctx_param.__members__[key], value)
continue
raise ValueError("Set context keyword %s is not recognized!" % key)
def get_context(attr_key): def get_context(attr_key):
...@@ -640,10 +537,13 @@ def get_context(attr_key): ...@@ -640,10 +537,13 @@ def get_context(attr_key):
Raises: Raises:
ValueError: If input key is not an attribute in context. ValueError: If input key is not an attribute in context.
""" """
if not hasattr(_context(), attr_key): ctx = _context()
raise ValueError( if hasattr(ctx, attr_key):
"Get context keyword %s is not recognized!" % attr_key) return getattr(ctx, attr_key)
return getattr(_context(), attr_key) if attr_key in ms_ctx_param.__members__:
return ctx.get_param(ms_ctx_param.__members__[attr_key])
raise ValueError("Get context keyword %s is not recognized!" % attr_key)
class ParallelMode: class ParallelMode:
""" """
......
...@@ -60,7 +60,7 @@ MsContext::MsContext(const std::string &policy, const std::string &target) { ...@@ -60,7 +60,7 @@ MsContext::MsContext(const std::string &policy, const std::string &target) {
#endif #endif
set_param<bool>(MS_CTX_ENABLE_GPU_SUMMARY, true); set_param<bool>(MS_CTX_ENABLE_GPU_SUMMARY, true);
set_param<bool>(MS_CTX_PRECOMPILE_ONLY, false); set_param<bool>(MS_CTX_PRECOMPILE_ONLY, false);
set_param<bool>(MS_CTX_AUTO_MIXED_PRECISION_FLAG, false); set_param<bool>(MS_CTX_ENABLE_AUTO_MIXED_PRECISION, false);
set_param<bool>(MS_CTX_ENABLE_PYNATIVE_INFER, false); set_param<bool>(MS_CTX_ENABLE_PYNATIVE_INFER, false);
set_param<bool>(MS_CTX_ENABLE_PYNATIVE_HOOK, false); set_param<bool>(MS_CTX_ENABLE_PYNATIVE_HOOK, false);
set_param<bool>(MS_CTX_ENABLE_DYNAMIC_MEM_POOL, true); set_param<bool>(MS_CTX_ENABLE_DYNAMIC_MEM_POOL, true);
......
...@@ -53,7 +53,7 @@ const float kDefaultMaxDeviceMemory = 1024; ...@@ -53,7 +53,7 @@ const float kDefaultMaxDeviceMemory = 1024;
enum MsCtxParam : unsigned { enum MsCtxParam : unsigned {
// paramater of type bool // paramater of type bool
MS_CTX_TYPE_BOOL_BEGIN, MS_CTX_TYPE_BOOL_BEGIN,
MS_CTX_AUTO_MIXED_PRECISION_FLAG = MS_CTX_TYPE_BOOL_BEGIN, MS_CTX_ENABLE_AUTO_MIXED_PRECISION = MS_CTX_TYPE_BOOL_BEGIN,
MS_CTX_CHECK_BPROP_FLAG, MS_CTX_CHECK_BPROP_FLAG,
MS_CTX_ENABLE_DUMP, MS_CTX_ENABLE_DUMP,
MS_CTX_ENABLE_DYNAMIC_MEM_POOL, MS_CTX_ENABLE_DYNAMIC_MEM_POOL,
...@@ -132,22 +132,22 @@ class MsContext { ...@@ -132,22 +132,22 @@ class MsContext {
template <typename T> template <typename T>
void set_param(MsCtxParam param, const T &value) { void set_param(MsCtxParam param, const T &value) {
MS_LOG(EXCEPTION) << "Need implemet " << __FUNCTION__ << " for type " << typeid(T).name() << "."; MS_LOG(EXCEPTION) << "Need to implement " << __FUNCTION__ << " for type " << typeid(T).name() << ".";
} }
template <typename T> template <typename T>
const T &get_param(MsCtxParam param) const { const T &get_param(MsCtxParam param) const {
MS_LOG(EXCEPTION) << "Need implemet " << __FUNCTION__ << " for type " << typeid(T).name() << "."; MS_LOG(EXCEPTION) << "Need to implement " << __FUNCTION__ << " for type " << typeid(T).name() << ".";
} }
template <typename T> template <typename T>
void increase_param(MsCtxParam param) { void increase_param(MsCtxParam param) {
MS_LOG(EXCEPTION) << "Need implemet " << __FUNCTION__ << " for type " << typeid(T).name() << "."; MS_LOG(EXCEPTION) << "Need to implement " << __FUNCTION__ << " for type " << typeid(T).name() << ".";
} }
template <typename T> template <typename T>
void decrease_param(MsCtxParam param) { void decrease_param(MsCtxParam param) {
MS_LOG(EXCEPTION) << "Need implemet " << __FUNCTION__ << " for type " << typeid(T).name() << "."; MS_LOG(EXCEPTION) << "Need to implement " << __FUNCTION__ << " for type " << typeid(T).name() << ".";
} }
private: private:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册