Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
0a858b38
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
0a858b38
编写于
8月 31, 2020
作者:
F
fary86
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Simplify ms_context implementation
上级
d5e02cf4
变更
6
显示空白变更内容
内联
并排
Showing
6 changed file
with
190 addition
and
265 deletion
+190
-265
mindspore/ccsrc/pipeline/jit/init.cc
mindspore/ccsrc/pipeline/jit/init.cc
+0
-92
mindspore/ccsrc/pybind_api/utils/ms_context_py.cc
mindspore/ccsrc/pybind_api/utils/ms_context_py.cc
+117
-0
mindspore/ccsrc/utils/context/context_extends.cc
mindspore/ccsrc/utils/context/context_extends.cc
+2
-2
mindspore/context.py
mindspore/context.py
+65
-165
mindspore/core/utils/ms_context.cc
mindspore/core/utils/ms_context.cc
+1
-1
mindspore/core/utils/ms_context.h
mindspore/core/utils/ms_context.h
+5
-5
未找到文件。
mindspore/ccsrc/pipeline/jit/init.cc
浏览文件 @
0a858b38
...
...
@@ -50,55 +50,6 @@ using ParallelContext = mindspore::parallel::ParallelContext;
using
CostModelContext
=
mindspore
::
parallel
::
CostModelContext
;
using
mindspore
::
MsCtxParam
;
namespace
mindspore
{
void
MsCtxSetParameter
(
std
::
shared_ptr
<
MsContext
>
ctx
,
MsCtxParam
param
,
const
py
::
object
&
value
)
{
MS_LOG
(
DEBUG
)
<<
"set param("
<<
param
<<
") with value '"
<<
py
::
str
(
value
)
<<
"' of type '"
<<
py
::
str
(
value
.
get_type
())
<<
"'."
;
if
(
param
>=
MS_CTX_TYPE_BOOL_BEGIN
&&
param
<
MS_CTX_TYPE_BOOL_END
&&
py
::
isinstance
<
py
::
bool_
>
(
value
))
{
ctx
->
set_param
<
bool
>
(
param
,
value
.
cast
<
bool
>
());
return
;
}
if
(
param
>=
MS_CTX_TYPE_INT_BEGIN
&&
param
<
MS_CTX_TYPE_INT_END
&&
py
::
isinstance
<
py
::
int_
>
(
value
))
{
ctx
->
set_param
<
int
>
(
param
,
value
.
cast
<
int
>
());
return
;
}
if
(
param
>=
MS_CTX_TYPE_UINT32_BEGIN
&&
param
<
MS_CTX_TYPE_UINT32_END
&&
py
::
isinstance
<
py
::
int_
>
(
value
))
{
ctx
->
set_param
<
uint32_t
>
(
param
,
value
.
cast
<
uint32_t
>
());
return
;
}
if
(
param
>=
MS_CTX_TYPE_FLOAT_BEGIN
&&
param
<
MS_CTX_TYPE_FLOAT_END
&&
py
::
isinstance
<
py
::
float_
>
(
value
))
{
ctx
->
set_param
<
float
>
(
param
,
value
.
cast
<
float
>
());
return
;
}
if
(
param
>=
MS_CTX_TYPE_STRING_BEGIN
&&
param
<
MS_CTX_TYPE_STRING_END
&&
py
::
isinstance
<
py
::
str
>
(
value
))
{
ctx
->
set_param
<
std
::
string
>
(
param
,
value
.
cast
<
std
::
string
>
());
return
;
}
MS_LOG
(
EXCEPTION
)
<<
"Got illegal param "
<<
param
<<
" and value with type "
<<
py
::
str
(
value
.
get_type
());
}
py
::
object
MsCtxGetParameter
(
const
std
::
shared_ptr
<
MsContext
>
&
ctx
,
MsCtxParam
param
)
{
if
(
param
>=
MS_CTX_TYPE_BOOL_BEGIN
&&
param
<
MS_CTX_TYPE_BOOL_END
)
{
return
py
::
bool_
(
ctx
->
get_param
<
bool
>
(
param
));
}
if
(
param
>=
MS_CTX_TYPE_INT_BEGIN
&&
param
<
MS_CTX_TYPE_INT_END
)
{
return
py
::
int_
(
ctx
->
get_param
<
int
>
(
param
));
}
if
(
param
>=
MS_CTX_TYPE_UINT32_BEGIN
&&
param
<
MS_CTX_TYPE_UINT32_END
)
{
return
py
::
int_
(
ctx
->
get_param
<
uint32_t
>
(
param
));
}
if
(
param
>=
MS_CTX_TYPE_FLOAT_BEGIN
&&
param
<
MS_CTX_TYPE_FLOAT_END
)
{
return
py
::
float_
(
ctx
->
get_param
<
float
>
(
param
));
}
if
(
param
>=
MS_CTX_TYPE_STRING_BEGIN
&&
param
<
MS_CTX_TYPE_STRING_END
)
{
return
py
::
str
(
ctx
->
get_param
<
std
::
string
>
(
param
));
}
MS_LOG
(
EXCEPTION
)
<<
"Got illegal param "
<<
param
<<
"."
;
}
}
// namespace mindspore
// Interface with python
PYBIND11_MODULE
(
_c_expression
,
m
)
{
m
.
doc
()
=
"MindSpore c plugin"
;
...
...
@@ -151,49 +102,6 @@ PYBIND11_MODULE(_c_expression, m) {
(
void
)
m
.
def
(
"export_graph"
,
&
mindspore
::
pipeline
::
ExportGraph
,
"Export Graph."
);
(
void
)
m
.
def
(
"ms_ctx_get_param"
,
&
mindspore
::
MsCtxGetParameter
,
"Get value of specified paramter."
);
(
void
)
m
.
def
(
"ms_ctx_set_param"
,
&
mindspore
::
MsCtxSetParameter
,
"Set value for specified paramter."
);
(
void
)
py
::
enum_
<
MsCtxParam
>
(
*
m
,
"ms_ctx_param"
,
py
::
arithmetic
())
.
value
(
"auto_mixed_precision_flag"
,
MsCtxParam
::
MS_CTX_AUTO_MIXED_PRECISION_FLAG
)
.
value
(
"check_bprop_flag"
,
MsCtxParam
::
MS_CTX_CHECK_BPROP_FLAG
)
.
value
(
"enable_dump"
,
MsCtxParam
::
MS_CTX_ENABLE_DUMP
)
.
value
(
"enable_dynamic_mem_pool"
,
MsCtxParam
::
MS_CTX_ENABLE_DYNAMIC_MEM_POOL
)
.
value
(
"enable_gpu_summary"
,
MsCtxParam
::
MS_CTX_ENABLE_GPU_SUMMARY
)
.
value
(
"enable_graph_kernel"
,
MsCtxParam
::
MS_CTX_ENABLE_GRAPH_KERNEL
)
.
value
(
"enable_hccl"
,
MsCtxParam
::
MS_CTX_ENABLE_HCCL
)
.
value
(
"enable_loop_sink"
,
MsCtxParam
::
MS_CTX_ENABLE_LOOP_SINK
)
.
value
(
"enable_mem_reuse"
,
MsCtxParam
::
MS_CTX_ENABLE_MEM_REUSE
)
.
value
(
"enable_pynative_hook"
,
MsCtxParam
::
MS_CTX_ENABLE_PYNATIVE_HOOK
)
.
value
(
"enable_pynative_infer"
,
MsCtxParam
::
MS_CTX_ENABLE_PYNATIVE_INFER
)
.
value
(
"enable_reduce_precision"
,
MsCtxParam
::
MS_CTX_ENABLE_REDUCE_PRECISION
)
.
value
(
"enable_sparse"
,
MsCtxParam
::
MS_CTX_ENABLE_SPARSE
)
.
value
(
"enable_task_sink"
,
MsCtxParam
::
MS_CTX_ENABLE_TASK_SINK
)
.
value
(
"ir_fusion_flag"
,
MsCtxParam
::
MS_CTX_IR_FUSION_FLAG
)
.
value
(
"is_multi_graph_sink"
,
MsCtxParam
::
MS_CTX_IS_MULTI_GRAPH_SINK
)
.
value
(
"is_pynative_ge_init"
,
MsCtxParam
::
MS_CTX_IS_PYNATIVE_GE_INIT
)
.
value
(
"precompile_only"
,
MsCtxParam
::
MS_CTX_PRECOMPILE_ONLY
)
.
value
(
"enable_profiling"
,
MsCtxParam
::
MS_CTX_ENABLE_PROFILING
)
.
value
(
"save_graphs_flag"
,
MsCtxParam
::
MS_CTX_SAVE_GRAPHS_FLAG
)
.
value
(
"max_device_memory"
,
MsCtxParam
::
MS_CTX_MAX_DEVICE_MEMORY
)
.
value
(
"execution_mode"
,
MsCtxParam
::
MS_CTX_EXECUTION_MODE
)
.
value
(
"device_target"
,
MsCtxParam
::
MS_CTX_DEVICE_TARGET
)
.
value
(
"graph_memory_max_size"
,
MsCtxParam
::
MS_CTX_GRAPH_MEMORY_MAX_SIZE
)
.
value
(
"print_file_path"
,
MsCtxParam
::
MS_CTX_PRINT_FILE_PATH
)
.
value
(
"profiling_options"
,
MsCtxParam
::
MS_CTX_PROFILING_OPTIONS
)
.
value
(
"save_dump_path"
,
MsCtxParam
::
MS_CTX_SAVE_DUMP_PATH
)
.
value
(
"save_graphs_path"
,
MsCtxParam
::
MS_CTX_SAVE_GRAPHS_PATH
)
.
value
(
"variable_memory_max_size"
,
MsCtxParam
::
MS_CTX_VARIABLE_MEMORY_MAX_SIZE
)
.
value
(
"device_id"
,
MsCtxParam
::
MS_CTX_DEVICE_ID
)
.
value
(
"ge_ref"
,
MsCtxParam
::
MS_CTX_GE_REF
)
.
value
(
"max_call_depth"
,
MsCtxParam
::
MS_CTX_MAX_CALL_DEPTH
)
.
value
(
"tsd_ref"
,
MsCtxParam
::
MS_CTX_TSD_REF
);
(
void
)
py
::
class_
<
mindspore
::
MsContext
,
std
::
shared_ptr
<
mindspore
::
MsContext
>>
(
m
,
"MSContext"
)
.
def_static
(
"get_instance"
,
&
mindspore
::
MsContext
::
GetInstance
,
"Get ms context instance."
)
.
def
(
"get_backend_policy"
,
&
mindspore
::
MsContext
::
backend_policy
,
"Get backend policy."
)
.
def
(
"set_backend_policy"
,
&
mindspore
::
MsContext
::
set_backend_policy
,
"Set backend policy."
);
(
void
)
py
::
class_
<
mindspore
::
MpiConfig
,
std
::
shared_ptr
<
mindspore
::
MpiConfig
>>
(
m
,
"MpiConfig"
)
.
def_static
(
"get_instance"
,
&
mindspore
::
MpiConfig
::
GetInstance
,
"Get mpi config instance."
)
.
def
(
"get_enable_mpi"
,
&
mindspore
::
MpiConfig
::
enable_mpi
,
"Get whether enable mpi."
)
...
...
mindspore/ccsrc/pybind_api/utils/ms_context_py.cc
0 → 100644
浏览文件 @
0a858b38
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <memory>
#include <string>
#include "utils/ms_context.h"
#include "utils/log_adapter.h"
#include "pybind_api/api_register.h"
namespace
mindspore
{
namespace
{
void
MsCtxSetParameter
(
std
::
shared_ptr
<
MsContext
>
ctx
,
MsCtxParam
param
,
const
py
::
object
&
value
)
{
MS_LOG
(
DEBUG
)
<<
"set param("
<<
param
<<
") with value '"
<<
py
::
str
(
value
).
cast
<
std
::
string
>
()
<<
"' of type '"
<<
py
::
str
(
value
.
get_type
()).
cast
<
std
::
string
>
()
<<
"'."
;
if
(
param
>=
MS_CTX_TYPE_BOOL_BEGIN
&&
param
<
MS_CTX_TYPE_BOOL_END
&&
py
::
isinstance
<
py
::
bool_
>
(
value
))
{
ctx
->
set_param
<
bool
>
(
param
,
value
.
cast
<
bool
>
());
return
;
}
if
(
param
>=
MS_CTX_TYPE_INT_BEGIN
&&
param
<
MS_CTX_TYPE_INT_END
&&
py
::
isinstance
<
py
::
int_
>
(
value
))
{
ctx
->
set_param
<
int
>
(
param
,
value
.
cast
<
int
>
());
return
;
}
if
(
param
>=
MS_CTX_TYPE_UINT32_BEGIN
&&
param
<
MS_CTX_TYPE_UINT32_END
&&
py
::
isinstance
<
py
::
int_
>
(
value
))
{
ctx
->
set_param
<
uint32_t
>
(
param
,
value
.
cast
<
uint32_t
>
());
return
;
}
if
(
param
>=
MS_CTX_TYPE_FLOAT_BEGIN
&&
param
<
MS_CTX_TYPE_FLOAT_END
&&
py
::
isinstance
<
py
::
float_
>
(
value
))
{
ctx
->
set_param
<
float
>
(
param
,
value
.
cast
<
float
>
());
return
;
}
if
(
param
>=
MS_CTX_TYPE_STRING_BEGIN
&&
param
<
MS_CTX_TYPE_STRING_END
&&
py
::
isinstance
<
py
::
str
>
(
value
))
{
ctx
->
set_param
<
std
::
string
>
(
param
,
value
.
cast
<
std
::
string
>
());
return
;
}
MS_LOG
(
EXCEPTION
)
<<
"Got illegal param "
<<
param
<<
" and value with type "
<<
py
::
str
(
value
.
get_type
()).
cast
<
std
::
string
>
();
}
py
::
object
MsCtxGetParameter
(
const
std
::
shared_ptr
<
MsContext
>
&
ctx
,
MsCtxParam
param
)
{
if
(
param
>=
MS_CTX_TYPE_BOOL_BEGIN
&&
param
<
MS_CTX_TYPE_BOOL_END
)
{
return
py
::
bool_
(
ctx
->
get_param
<
bool
>
(
param
));
}
if
(
param
>=
MS_CTX_TYPE_INT_BEGIN
&&
param
<
MS_CTX_TYPE_INT_END
)
{
return
py
::
int_
(
ctx
->
get_param
<
int
>
(
param
));
}
if
(
param
>=
MS_CTX_TYPE_UINT32_BEGIN
&&
param
<
MS_CTX_TYPE_UINT32_END
)
{
return
py
::
int_
(
ctx
->
get_param
<
uint32_t
>
(
param
));
}
if
(
param
>=
MS_CTX_TYPE_FLOAT_BEGIN
&&
param
<
MS_CTX_TYPE_FLOAT_END
)
{
return
py
::
float_
(
ctx
->
get_param
<
float
>
(
param
));
}
if
(
param
>=
MS_CTX_TYPE_STRING_BEGIN
&&
param
<
MS_CTX_TYPE_STRING_END
)
{
return
py
::
str
(
ctx
->
get_param
<
std
::
string
>
(
param
));
}
MS_LOG
(
EXCEPTION
)
<<
"Got illegal param "
<<
param
<<
"."
;
}
}
// namespace
REGISTER_PYBIND_DEFINE
(
MsContextPy
,
([](
const
py
::
module
*
m
)
{
(
void
)
py
::
enum_
<
MsCtxParam
>
(
*
m
,
"ms_ctx_param"
,
py
::
arithmetic
())
.
value
(
"enable_auto_mixed_precision"
,
MsCtxParam
::
MS_CTX_ENABLE_AUTO_MIXED_PRECISION
)
.
value
(
"check_bprop"
,
MsCtxParam
::
MS_CTX_CHECK_BPROP_FLAG
)
.
value
(
"enable_dump"
,
MsCtxParam
::
MS_CTX_ENABLE_DUMP
)
.
value
(
"enable_dynamic_mem_pool"
,
MsCtxParam
::
MS_CTX_ENABLE_DYNAMIC_MEM_POOL
)
.
value
(
"enable_gpu_summary"
,
MsCtxParam
::
MS_CTX_ENABLE_GPU_SUMMARY
)
.
value
(
"enable_graph_kernel"
,
MsCtxParam
::
MS_CTX_ENABLE_GRAPH_KERNEL
)
.
value
(
"enable_hccl"
,
MsCtxParam
::
MS_CTX_ENABLE_HCCL
)
.
value
(
"enable_loop_sink"
,
MsCtxParam
::
MS_CTX_ENABLE_LOOP_SINK
)
.
value
(
"enable_mem_reuse"
,
MsCtxParam
::
MS_CTX_ENABLE_MEM_REUSE
)
.
value
(
"enable_pynative_hook"
,
MsCtxParam
::
MS_CTX_ENABLE_PYNATIVE_HOOK
)
.
value
(
"enable_pynative_infer"
,
MsCtxParam
::
MS_CTX_ENABLE_PYNATIVE_INFER
)
.
value
(
"enable_reduce_precision"
,
MsCtxParam
::
MS_CTX_ENABLE_REDUCE_PRECISION
)
.
value
(
"enable_sparse"
,
MsCtxParam
::
MS_CTX_ENABLE_SPARSE
)
.
value
(
"enable_task_sink"
,
MsCtxParam
::
MS_CTX_ENABLE_TASK_SINK
)
.
value
(
"ir_fusion_flag"
,
MsCtxParam
::
MS_CTX_IR_FUSION_FLAG
)
.
value
(
"is_multi_graph_sink"
,
MsCtxParam
::
MS_CTX_IS_MULTI_GRAPH_SINK
)
.
value
(
"is_pynative_ge_init"
,
MsCtxParam
::
MS_CTX_IS_PYNATIVE_GE_INIT
)
.
value
(
"precompile_only"
,
MsCtxParam
::
MS_CTX_PRECOMPILE_ONLY
)
.
value
(
"enable_profiling"
,
MsCtxParam
::
MS_CTX_ENABLE_PROFILING
)
.
value
(
"save_graphs"
,
MsCtxParam
::
MS_CTX_SAVE_GRAPHS_FLAG
)
.
value
(
"max_device_memory"
,
MsCtxParam
::
MS_CTX_MAX_DEVICE_MEMORY
)
.
value
(
"mode"
,
MsCtxParam
::
MS_CTX_EXECUTION_MODE
)
.
value
(
"device_target"
,
MsCtxParam
::
MS_CTX_DEVICE_TARGET
)
.
value
(
"graph_memory_max_size"
,
MsCtxParam
::
MS_CTX_GRAPH_MEMORY_MAX_SIZE
)
.
value
(
"print_file_path"
,
MsCtxParam
::
MS_CTX_PRINT_FILE_PATH
)
.
value
(
"profiling_options"
,
MsCtxParam
::
MS_CTX_PROFILING_OPTIONS
)
.
value
(
"save_dump_path"
,
MsCtxParam
::
MS_CTX_SAVE_DUMP_PATH
)
.
value
(
"save_graphs_path"
,
MsCtxParam
::
MS_CTX_SAVE_GRAPHS_PATH
)
.
value
(
"variable_memory_max_size"
,
MsCtxParam
::
MS_CTX_VARIABLE_MEMORY_MAX_SIZE
)
.
value
(
"device_id"
,
MsCtxParam
::
MS_CTX_DEVICE_ID
)
.
value
(
"ge_ref"
,
MsCtxParam
::
MS_CTX_GE_REF
)
.
value
(
"max_call_depth"
,
MsCtxParam
::
MS_CTX_MAX_CALL_DEPTH
)
.
value
(
"tsd_ref"
,
MsCtxParam
::
MS_CTX_TSD_REF
);
(
void
)
py
::
class_
<
mindspore
::
MsContext
,
std
::
shared_ptr
<
mindspore
::
MsContext
>>
(
*
m
,
"MSContext"
)
.
def_static
(
"get_instance"
,
&
mindspore
::
MsContext
::
GetInstance
,
"Get ms context instance."
)
.
def
(
"get_param"
,
&
mindspore
::
MsCtxGetParameter
,
"Get value of specified paramter."
)
.
def
(
"set_param"
,
&
mindspore
::
MsCtxSetParameter
,
"Set value for specified paramter."
)
.
def
(
"get_backend_policy"
,
&
mindspore
::
MsContext
::
backend_policy
,
"Get backend policy."
)
.
def
(
"set_backend_policy"
,
&
mindspore
::
MsContext
::
set_backend_policy
,
"Set backend policy."
);
}));
}
// namespace mindspore
mindspore/ccsrc/utils/context/context_extends.cc
浏览文件 @
0a858b38
...
...
@@ -225,7 +225,7 @@ void GetGeOptions(const std::shared_ptr<MsContext> &ms_context_ptr, std::map<std
}
// Enable auto mixed precision according to the context options
if
(
ms_context_ptr
->
get_param
<
bool
>
(
MS_CTX_
AUTO_MIXED_PRECISION_FLAG
))
{
if
(
ms_context_ptr
->
get_param
<
bool
>
(
MS_CTX_
ENABLE_AUTO_MIXED_PRECISION
))
{
(
*
ge_options
)[
"ge.exec.precision_mode"
]
=
"allow_mix_precision"
;
}
else
{
(
*
ge_options
)[
"ge.exec.precision_mode"
]
=
"allow_fp32_to_fp16"
;
...
...
@@ -337,7 +337,7 @@ bool FinalizeGe(const std::shared_ptr<MsContext> &ms_context_ptr, bool force) {
if
(
ge
::
GEFinalize
()
!=
ge
::
GRAPH_SUCCESS
)
{
MS_LOG
(
WARNING
)
<<
"Finalize GE failed!"
;
}
ms_context_ptr
->
set_p
ynative_ge_init
(
false
);
ms_context_ptr
->
set_p
aram
<
bool
>
(
MS_CTX_IS_PYNATIVE_GE_INIT
,
false
);
}
else
{
MS_LOG
(
INFO
)
<<
"Ge is used, no need to finalize, tsd reference = "
<<
ms_context_ptr
->
get_param
<
uint32_t
>
(
MS_CTX_GE_REF
)
<<
"."
;
...
...
mindspore/context.py
浏览文件 @
0a858b38
...
...
@@ -22,7 +22,7 @@ import threading
from
collections
import
namedtuple
from
types
import
FunctionType
from
mindspore
import
log
as
logger
from
mindspore._c_expression
import
MSContext
,
ms_ctx_param
,
ms_ctx_get_param
,
ms_ctx_set_param
from
mindspore._c_expression
import
MSContext
,
ms_ctx_param
from
mindspore._checkparam
import
args_type_check
from
mindspore.parallel._auto_parallel_context
import
_set_auto_parallel_context
,
_get_auto_parallel_context
,
\
_reset_auto_parallel_context
...
...
@@ -158,17 +158,12 @@ class _Context:
return
value
def
get_param
(
self
,
param
):
return
ms_ctx_get_param
(
self
.
_context_handle
,
param
)
return
self
.
_context_handle
.
get_param
(
param
)
def
set_param
(
self
,
param
,
value
):
ms_ctx_set_param
(
self
.
_context_handle
,
param
,
value
)
self
.
_context_handle
.
set_param
(
param
,
value
)
@
property
def
mode
(
self
):
return
self
.
get_param
(
ms_ctx_param
.
execution_mode
)
@
mode
.
setter
def
mode
(
self
,
mode
):
def
set_mode
(
self
,
mode
):
"""
Switch between Graph mode and PyNative mode.
...
...
@@ -185,43 +180,17 @@ class _Context:
self
.
_context_switches
.
push
(
False
,
None
)
else
:
raise
ValueError
(
f
'The execution mode
{
mode
}
is invalid!'
)
self
.
set_param
(
ms_ctx_param
.
execution_
mode
,
mode
)
self
.
set_param
(
ms_ctx_param
.
mode
,
mode
)
def
set_backend_policy
(
self
,
policy
):
success
=
self
.
_context_handle
.
set_backend_policy
(
policy
)
if
not
success
:
raise
RuntimeError
(
"Backend policy must be one of ge, vm, ms."
)
@
property
def
precompile_only
(
self
):
return
self
.
get_param
(
ms_ctx_param
.
precompile_only
)
@
precompile_only
.
setter
def
precompile_only
(
self
,
precompile_only
):
self
.
set_param
(
ms_ctx_param
.
precompile_only
,
precompile_only
)
@
property
def
save_graphs
(
self
):
return
self
.
get_param
(
ms_ctx_param
.
save_graphs_flag
)
@
save_graphs
.
setter
def
save_graphs
(
self
,
save_graphs_flag
):
self
.
set_param
(
ms_ctx_param
.
save_graphs_flag
,
save_graphs_flag
)
@
property
def
save_graphs_path
(
self
):
return
self
.
get_param
(
ms_ctx_param
.
save_graphs_path
)
@
save_graphs_path
.
setter
def
save_graphs_path
(
self
,
save_graphs_path
):
def
set_save_graphs_path
(
self
,
save_graphs_path
):
self
.
set_param
(
ms_ctx_param
.
save_graphs_path
,
_make_directory
(
save_graphs_path
))
@
property
def
device_target
(
self
):
return
self
.
get_param
(
ms_ctx_param
.
device_target
)
@
device_target
.
setter
def
device_target
(
self
,
target
):
def
set_device_target
(
self
,
target
):
valid_targets
=
[
"CPU"
,
"GPU"
,
"Ascend"
,
"Davinci"
]
if
not
target
in
valid_targets
:
raise
ValueError
(
f
"Target device name
{
target
}
is invalid! It must be one of
{
valid_targets
}
"
)
...
...
@@ -231,72 +200,17 @@ class _Context:
if
self
.
enable_debug_runtime
and
target
==
"CPU"
:
self
.
set_backend_policy
(
"vm"
)
@
property
def
device_id
(
self
):
return
self
.
get_param
(
ms_ctx_param
.
device_id
)
@
device_id
.
setter
def
device_id
(
self
,
device_id
):
def
set_device_id
(
self
,
device_id
):
if
device_id
<
0
or
device_id
>
4095
:
raise
ValueError
(
f
"Device id must be in [0, 4095], but got
{
device_id
}
"
)
self
.
set_param
(
ms_ctx_param
.
device_id
,
device_id
)
@
property
def
max_call_depth
(
self
):
return
self
.
get_param
(
ms_ctx_param
.
max_call_depth
)
@
max_call_depth
.
setter
def
max_call_depth
(
self
,
max_call_depth
):
def
set_max_call_depth
(
self
,
max_call_depth
):
if
max_call_depth
<=
0
:
raise
ValueError
(
f
"Max call depth must be greater than 0, but got
{
max_call_depth
}
"
)
self
.
set_param
(
ms_ctx_param
.
max_call_depth
,
max_call_depth
)
@
property
def
enable_auto_mixed_precision
(
self
):
return
self
.
get_param
(
ms_ctx_param
.
auto_mixed_precision_flag
)
@
enable_auto_mixed_precision
.
setter
def
enable_auto_mixed_precision
(
self
,
enable_auto_mixed_precision
):
self
.
set_param
(
ms_ctx_param
.
auto_mixed_precision_flag
,
enable_auto_mixed_precision
)
@
property
def
enable_reduce_precision
(
self
):
return
self
.
get_param
(
ms_ctx_param
.
enable_reduce_precision_flag
)
@
enable_reduce_precision
.
setter
def
enable_reduce_precision
(
self
,
enable_reduce_precision
):
self
.
set_param
(
ms_ctx_param
.
enable_reduce_precision_flag
,
enable_reduce_precision
)
@
property
def
enable_dump
(
self
):
return
self
.
get_param
(
ms_ctx_param
.
enable_dump
)
@
enable_dump
.
setter
def
enable_dump
(
self
,
enable_dump
):
self
.
set_param
(
ms_ctx_param
.
enable_dump
,
enable_dump
)
@
property
def
save_dump_path
(
self
):
return
self
.
get_param
(
ms_ctx_param
.
save_dump_path
)
@
save_dump_path
.
setter
def
save_dump_path
(
self
,
save_dump_path
):
self
.
set_param
(
ms_ctx_param
.
save_dump_path
,
save_dump_path
)
@
property
def
enable_profiling
(
self
):
return
self
.
get_param
(
ms_ctx_param
.
enable_profiling
)
@
enable_profiling
.
setter
def
enable_profiling
(
self
,
flag
):
self
.
set_param
(
ms_ctx_param
.
enable_profiling
,
flag
)
@
property
def
profiling_options
(
self
):
return
self
.
get_param
(
ms_ctx_param
.
profiling_options
)
@
profiling_options
.
setter
def
profiling_options
(
self
,
option
):
def
set_profiling_options
(
self
,
option
):
options
=
[
"training_trace"
,
"task_trace"
,
"task_trace:training_trace"
,
"training_trace:task_trace"
,
"op_trace"
]
if
option
not
in
options
:
...
...
@@ -304,30 +218,7 @@ class _Context:
"'task_trace:training_trace' 'training_trace:task_trace' or 'op_trace'."
)
self
.
set_param
(
ms_ctx_param
.
profiling_options
,
option
)
@
property
def
enable_graph_kernel
(
self
):
return
self
.
get_param
(
ms_ctx_param
.
enable_graph_kernel
)
@
enable_graph_kernel
.
setter
def
enable_graph_kernel
(
self
,
graph_kernel_switch_
):
self
.
set_param
(
ms_ctx_param
.
enable_graph_kernel
,
graph_kernel_switch_
)
@
property
def
reserve_class_name_in_scope
(
self
):
"""Gets whether to save the network class name in the scope."""
return
self
.
_thread_local_info
.
reserve_class_name_in_scope
@
reserve_class_name_in_scope
.
setter
def
reserve_class_name_in_scope
(
self
,
reserve_class_name_in_scope
):
"""Sets whether to save the network class name in the scope."""
self
.
_thread_local_info
.
reserve_class_name_in_scope
=
reserve_class_name_in_scope
@
property
def
variable_memory_max_size
(
self
):
return
None
@
variable_memory_max_size
.
setter
def
variable_memory_max_size
(
self
,
variable_memory_max_size
):
def
set_variable_memory_max_size
(
self
,
variable_memory_max_size
):
if
not
check_input_format
(
variable_memory_max_size
):
raise
ValueError
(
"Context param variable_memory_max_size should be in correct format! Such as
\"
5GB
\"
"
)
if
int
(
variable_memory_max_size
[:
-
2
])
>=
_DEVICE_APP_MEMORY_SIZE
:
...
...
@@ -338,33 +229,7 @@ class _Context:
self
.
set_param
(
ms_ctx_param
.
variable_memory_max_size
,
variable_memory_max_size_
)
self
.
set_param
(
ms_ctx_param
.
graph_memory_max_size
,
graph_memory_max_size_
)
@
property
def
enable_ge
(
self
):
return
self
.
_context_handle
.
get_backend_policy
()
==
'ge'
@
property
def
enable_debug_runtime
(
self
):
return
self
.
_thread_local_info
.
debug_runtime
@
enable_debug_runtime
.
setter
def
enable_debug_runtime
(
self
,
enable
):
thread_info
=
self
.
_thread_local_info
thread_info
.
debug_runtime
=
enable
@
property
def
check_bprop
(
self
):
return
self
.
get_param
(
ms_ctx_param
.
check_bprop_flag
)
@
check_bprop
.
setter
def
check_bprop
(
self
,
check_bprop_flag
):
self
.
set_param
(
ms_ctx_param
.
check_bprop_flag
,
check_bprop_flag
)
@
property
def
max_device_memory
(
self
):
return
self
.
get_param
(
ms_ctx_param
.
max_device_memory
)
@
max_device_memory
.
setter
def
max_device_memory
(
self
,
max_device_memory
):
def
set_max_device_memory
(
self
,
max_device_memory
):
if
not
check_input_format
(
max_device_memory
):
raise
ValueError
(
"Context param max_device_memory should be in correct format! Such as
\"
3.5GB
\"
"
)
max_device_memory_value
=
float
(
max_device_memory
[:
-
2
])
...
...
@@ -372,12 +237,7 @@ class _Context:
raise
ValueError
(
"Context param max_device_memory should be in correct format! Such as
\"
3.5GB
\"
"
)
self
.
set_param
(
ms_ctx_param
.
max_device_memory
,
max_device_memory_value
)
@
property
def
print_file_path
(
self
):
return
None
@
print_file_path
.
setter
def
print_file_path
(
self
,
file_path
):
def
set_print_file_path
(
self
,
file_path
):
"""Add timestamp suffix to file name. Sets print file path."""
print_file_path
=
os
.
path
.
realpath
(
file_path
)
if
os
.
path
.
isdir
(
print_file_path
):
...
...
@@ -392,13 +252,42 @@ class _Context:
full_file_name
=
print_file_path
self
.
set_param
(
ms_ctx_param
.
print_file_path
,
full_file_name
)
setters
=
{
'mode'
:
set_mode
,
'backend_policy'
:
set_backend_policy
,
'save_graphs_path'
:
set_save_graphs_path
,
'device_target'
:
set_device_target
,
'device_id'
:
set_device_id
,
'max_call_depth'
:
set_max_call_depth
,
'profiling_options'
:
set_profiling_options
,
'variable_memory_max_size'
:
set_variable_memory_max_size
,
'max_device_memory'
:
set_max_device_memory
,
'print_file_path'
:
set_print_file_path
}
@
property
def
enable_sparse
(
self
):
return
self
.
get_param
(
ms_ctx_param
.
enable_sparse
)
def
reserve_class_name_in_scope
(
self
):
"""Gets whether to save the network class name in the scope."""
return
self
.
_thread_local_info
.
reserve_class_name_in_scope
@
reserve_class_name_in_scope
.
setter
def
reserve_class_name_in_scope
(
self
,
reserve_class_name_in_scope
):
"""Sets whether to save the network class name in the scope."""
self
.
_thread_local_info
.
reserve_class_name_in_scope
=
reserve_class_name_in_scope
@
property
def
enable_ge
(
self
):
return
self
.
_context_handle
.
get_backend_policy
()
==
'ge'
@
property
def
enable_debug_runtime
(
self
):
return
self
.
_thread_local_info
.
debug_runtime
@
enable_debug_runtime
.
setter
def
enable_debug_runtime
(
self
,
enable
):
thread_info
=
self
.
_thread_local_info
thread_info
.
debug_runtime
=
enable
@
enable_sparse
.
setter
def
enable_sparse
(
self
,
enable_sparse
):
self
.
set_param
(
ms_ctx_param
.
enable_sparse
,
enable_sparse
)
def
check_input_format
(
x
):
import
re
...
...
@@ -621,10 +510,18 @@ def set_context(**kwargs):
>>> context.set_context(print_file_path="print.pb")
>>> context.set_context(max_call_depth=80)
"""
ctx
=
_context
()
for
key
,
value
in
kwargs
.
items
():
if
not
hasattr
(
_context
(),
key
):
if
hasattr
(
ctx
,
key
):
setattr
(
ctx
,
key
,
value
)
continue
if
key
in
ctx
.
setters
:
ctx
.
setters
[
key
](
ctx
,
value
)
continue
if
key
in
ms_ctx_param
.
__members__
:
ctx
.
set_param
(
ms_ctx_param
.
__members__
[
key
],
value
)
continue
raise
ValueError
(
"Set context keyword %s is not recognized!"
%
key
)
setattr
(
_context
(),
key
,
value
)
def
get_context
(
attr_key
):
...
...
@@ -640,10 +537,13 @@ def get_context(attr_key):
Raises:
ValueError: If input key is not an attribute in context.
"""
if
not
hasattr
(
_context
(),
attr_key
):
raise
ValueError
(
"Get context keyword %s is not recognized!"
%
attr_key
)
return
getattr
(
_context
(),
attr_key
)
ctx
=
_context
()
if
hasattr
(
ctx
,
attr_key
):
return
getattr
(
ctx
,
attr_key
)
if
attr_key
in
ms_ctx_param
.
__members__
:
return
ctx
.
get_param
(
ms_ctx_param
.
__members__
[
attr_key
])
raise
ValueError
(
"Get context keyword %s is not recognized!"
%
attr_key
)
class
ParallelMode
:
"""
...
...
mindspore/core/utils/ms_context.cc
浏览文件 @
0a858b38
...
...
@@ -60,7 +60,7 @@ MsContext::MsContext(const std::string &policy, const std::string &target) {
#endif
set_param
<
bool
>
(
MS_CTX_ENABLE_GPU_SUMMARY
,
true
);
set_param
<
bool
>
(
MS_CTX_PRECOMPILE_ONLY
,
false
);
set_param
<
bool
>
(
MS_CTX_
AUTO_MIXED_PRECISION_FLAG
,
false
);
set_param
<
bool
>
(
MS_CTX_
ENABLE_AUTO_MIXED_PRECISION
,
false
);
set_param
<
bool
>
(
MS_CTX_ENABLE_PYNATIVE_INFER
,
false
);
set_param
<
bool
>
(
MS_CTX_ENABLE_PYNATIVE_HOOK
,
false
);
set_param
<
bool
>
(
MS_CTX_ENABLE_DYNAMIC_MEM_POOL
,
true
);
...
...
mindspore/core/utils/ms_context.h
浏览文件 @
0a858b38
...
...
@@ -53,7 +53,7 @@ const float kDefaultMaxDeviceMemory = 1024;
enum
MsCtxParam
:
unsigned
{
// paramater of type bool
MS_CTX_TYPE_BOOL_BEGIN
,
MS_CTX_
AUTO_MIXED_PRECISION_FLAG
=
MS_CTX_TYPE_BOOL_BEGIN
,
MS_CTX_
ENABLE_AUTO_MIXED_PRECISION
=
MS_CTX_TYPE_BOOL_BEGIN
,
MS_CTX_CHECK_BPROP_FLAG
,
MS_CTX_ENABLE_DUMP
,
MS_CTX_ENABLE_DYNAMIC_MEM_POOL
,
...
...
@@ -132,22 +132,22 @@ class MsContext {
template
<
typename
T
>
void
set_param
(
MsCtxParam
param
,
const
T
&
value
)
{
MS_LOG
(
EXCEPTION
)
<<
"Need
impleme
t "
<<
__FUNCTION__
<<
" for type "
<<
typeid
(
T
).
name
()
<<
"."
;
MS_LOG
(
EXCEPTION
)
<<
"Need
to implemen
t "
<<
__FUNCTION__
<<
" for type "
<<
typeid
(
T
).
name
()
<<
"."
;
}
template
<
typename
T
>
const
T
&
get_param
(
MsCtxParam
param
)
const
{
MS_LOG
(
EXCEPTION
)
<<
"Need
impleme
t "
<<
__FUNCTION__
<<
" for type "
<<
typeid
(
T
).
name
()
<<
"."
;
MS_LOG
(
EXCEPTION
)
<<
"Need
to implemen
t "
<<
__FUNCTION__
<<
" for type "
<<
typeid
(
T
).
name
()
<<
"."
;
}
template
<
typename
T
>
void
increase_param
(
MsCtxParam
param
)
{
MS_LOG
(
EXCEPTION
)
<<
"Need
impleme
t "
<<
__FUNCTION__
<<
" for type "
<<
typeid
(
T
).
name
()
<<
"."
;
MS_LOG
(
EXCEPTION
)
<<
"Need
to implemen
t "
<<
__FUNCTION__
<<
" for type "
<<
typeid
(
T
).
name
()
<<
"."
;
}
template
<
typename
T
>
void
decrease_param
(
MsCtxParam
param
)
{
MS_LOG
(
EXCEPTION
)
<<
"Need
impleme
t "
<<
__FUNCTION__
<<
" for type "
<<
typeid
(
T
).
name
()
<<
"."
;
MS_LOG
(
EXCEPTION
)
<<
"Need
to implemen
t "
<<
__FUNCTION__
<<
" for type "
<<
typeid
(
T
).
name
()
<<
"."
;
}
private:
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录