diff --git a/mindspore/ccsrc/pipeline/init.cc b/mindspore/ccsrc/pipeline/init.cc index 7c663291c043dafe8bd9a4d11fc63a5b5404304d..61ea7eb6aa3d3c60ed1644a0486b5d66875c4ba5 100644 --- a/mindspore/ccsrc/pipeline/init.cc +++ b/mindspore/ccsrc/pipeline/init.cc @@ -117,6 +117,10 @@ PYBIND11_MODULE(_c_expression, m) { .def("close_tsd", &mindspore::MsContext::CloseTsd, "Close tdt dataset client.") .def("get_save_graphs_flag", &mindspore::MsContext::save_graphs_flag, "Get whether to save graphs.") .def("set_save_graphs_flag", &mindspore::MsContext::set_save_graphs_flag, "Set whether to save graphs.") + .def("get_auto_mixed_precision_flag", &mindspore::MsContext::auto_mixed_precision_flag, + "Get whether to enable auto mixed precision.") + .def("set_auto_mixed_precision_flag", &mindspore::MsContext::set_auto_mixed_precision_flag, + "Set whether to enable auto mixed precision.") .def("get_enable_reduce_precision_flag", &mindspore::MsContext::enable_reduce_precision, "Get whether to enable reduce precision.") .def("set_enable_reduce_precision_flag", &mindspore::MsContext::set_enable_reduce_precision, diff --git a/mindspore/ccsrc/utils/context/ms_context.h b/mindspore/ccsrc/utils/context/ms_context.h index f4e52950cd6cc7401cfb971e35c8c7f8506eccf4..36d684e020bd47cef02e3d7186399e351d6a93a9 100644 --- a/mindspore/ccsrc/utils/context/ms_context.h +++ b/mindspore/ccsrc/utils/context/ms_context.h @@ -105,6 +105,11 @@ class MsContext { void set_enable_gpu_summary(bool enable_gpu_summary) { enable_gpu_summary_ = enable_gpu_summary; } bool enable_gpu_summary() const { return enable_gpu_summary_; } + void set_auto_mixed_precision_flag(bool auto_mixed_precision_flag) { + auto_mixed_precision_flag_ = auto_mixed_precision_flag; + } + bool auto_mixed_precision_flag() const { return auto_mixed_precision_flag_; } + void set_enable_reduce_precision(bool flag) { enable_reduce_precision_ = flag; } bool enable_reduce_precision() const { return enable_reduce_precision_; } diff --git a/mindspore/context.py b/mindspore/context.py index 0c56a069416ec437292384b20c75fb8072a7a14e..7b6b881968e432d690b3b9e488262bf98664b5d0 100644 --- a/mindspore/context.py +++ b/mindspore/context.py @@ -233,6 +233,14 @@ class _Context: def save_ms_model_path(self, save_ms_model_path): self._context_handle.set_save_ms_model_path(save_ms_model_path) + @property + def enable_auto_mixed_precision(self): + return self._context_handle.get_auto_mixed_precision_flag() + + @enable_auto_mixed_precision.setter + def enable_auto_mixed_precision(self, enable_auto_mixed_precision): + self._context_handle.set_auto_mixed_precision_flag(enable_auto_mixed_precision) + @property def enable_reduce_precision(self): return self._context_handle.get_enable_reduce_precision_flag() @@ -441,7 +449,7 @@ def reset_auto_parallel_context(): @args_type_check(mode=int, precompile_only=bool, device_target=str, device_id=int, save_graphs=bool, save_graphs_path=str, save_ms_model=bool, save_ms_model_path=str, enable_dump=bool, save_dump_path=str, enable_reduce_precision=bool, variable_memory_max_size=str, - enable_profiling=bool, profiling_options=str) + enable_profiling=bool, profiling_options=str, enable_auto_mixed_precision=bool) def set_context(**kwargs): """ Sets context for running environment. @@ -469,6 +477,7 @@ def set_context(**kwargs): save_ms_model (bool): Whether to save lite model converted by graph. Default: False. save_ms_model_path (str): Path to save converted lite model. Default: "." save_graphs_path (str): Path to save graphs. Default: "." + enable_auto_mixed_precision (bool): Whether to enable auto mixed precision. Default: True. reserve_class_name_in_scope (bool) : Whether to save the network class name in the scope. Default: True. enable_reduce_precision (bool): Whether to enable precision reduction. Default: True. enable_dump (bool): Whether to enable dump. Default: False.