提交 e17c7d2c 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!1207 add enable_auto_mixed_precision

Merge pull request !1207 from jinyaohui/add_context
......@@ -117,6 +117,10 @@ PYBIND11_MODULE(_c_expression, m) {
.def("close_tsd", &mindspore::MsContext::CloseTsd, "Close tdt dataset client.")
.def("get_save_graphs_flag", &mindspore::MsContext::save_graphs_flag, "Get whether to save graphs.")
.def("set_save_graphs_flag", &mindspore::MsContext::set_save_graphs_flag, "Set whether to save graphs.")
.def("get_auto_mixed_precision_flag", &mindspore::MsContext::auto_mixed_precision_flag,
"Get whether to enable auto mixed precision.")
.def("set_auto_mixed_precision_flag", &mindspore::MsContext::set_auto_mixed_precision_flag,
"Set whether to enable auto mixed precision.")
.def("get_enable_reduce_precision_flag", &mindspore::MsContext::enable_reduce_precision,
"Get whether to enable reduce precision.")
.def("set_enable_reduce_precision_flag", &mindspore::MsContext::set_enable_reduce_precision,
......
......@@ -105,6 +105,11 @@ class MsContext {
void set_enable_gpu_summary(bool enable_gpu_summary) { enable_gpu_summary_ = enable_gpu_summary; }
bool enable_gpu_summary() const { return enable_gpu_summary_; }
void set_auto_mixed_precision_flag(bool auto_mixed_precision_flag) {
auto_mixed_precision_flag_ = auto_mixed_precision_flag;
}
bool auto_mixed_precision_flag() const { return auto_mixed_precision_flag_; }
void set_enable_reduce_precision(bool flag) { enable_reduce_precision_ = flag; }
bool enable_reduce_precision() const { return enable_reduce_precision_; }
......
......@@ -233,6 +233,14 @@ class _Context:
def save_ms_model_path(self, save_ms_model_path):
self._context_handle.set_save_ms_model_path(save_ms_model_path)
@property
def enable_auto_mixed_precision(self):
return self._context_handle.get_auto_mixed_precision_flag()
@enable_auto_mixed_precision.setter
def enable_auto_mixed_precision(self, enable_auto_mixed_precision):
self._context_handle.set_auto_mixed_precision_flag(enable_auto_mixed_precision)
@property
def enable_reduce_precision(self):
return self._context_handle.get_enable_reduce_precision_flag()
......@@ -441,7 +449,7 @@ def reset_auto_parallel_context():
@args_type_check(mode=int, precompile_only=bool, device_target=str, device_id=int, save_graphs=bool,
save_graphs_path=str, save_ms_model=bool, save_ms_model_path=str, enable_dump=bool,
save_dump_path=str, enable_reduce_precision=bool, variable_memory_max_size=str,
enable_profiling=bool, profiling_options=str)
enable_profiling=bool, profiling_options=str, enable_auto_mixed_precision=bool)
def set_context(**kwargs):
"""
Sets context for running environment.
......@@ -469,6 +477,7 @@ def set_context(**kwargs):
save_ms_model (bool): Whether to save lite model converted by graph. Default: False.
save_ms_model_path (str): Path to save converted lite model. Default: "."
save_graphs_path (str): Path to save graphs. Default: "."
enable_auto_mixed_precision (bool): Whether to enable auto mixed precision. Default: True.
reserve_class_name_in_scope (bool) : Whether to save the network class name in the scope. Default: True.
enable_reduce_precision (bool): Whether to enable precision reduction. Default: True.
enable_dump (bool): Whether to enable dump. Default: False.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册