Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
b0f89685
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
b0f89685
编写于
9月 04, 2020
作者:
F
fary86
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add device specific config check
上级
5a63dac0
变更
3
显示空白变更内容
内联
并排
Showing
3 changed file
with
60 addition
and
18 deletion
+60
-18
mindspore/ccsrc/pipeline/jit/static_analysis/evaluator.cc
mindspore/ccsrc/pipeline/jit/static_analysis/evaluator.cc
+2
-1
mindspore/ccsrc/pybind_api/utils/ms_context_py.cc
mindspore/ccsrc/pybind_api/utils/ms_context_py.cc
+3
-14
mindspore/context.py
mindspore/context.py
+55
-3
未找到文件。
mindspore/ccsrc/pipeline/jit/static_analysis/evaluator.cc
浏览文件 @
b0f89685
...
...
@@ -120,7 +120,8 @@ EvalResultPtr BaseFuncGraphEvaluator::Eval(AnalysisEnginePtr engine, const Abstr
engine
->
IncreaseFunctionCallDepth
();
if
(
engine
->
function_call_depth
()
>
MsContext
::
GetInstance
()
->
get_param
<
uint32_t
>
(
MS_CTX_MAX_CALL_DEPTH
))
{
MS_LOG
(
EXCEPTION
)
<<
"Exceed function call depth limit "
<<
MsContext
::
GetInstance
()
->
get_param
<
uint32_t
>
(
MS_CTX_MAX_CALL_DEPTH
)
<<
"."
;
<<
MsContext
::
GetInstance
()
->
get_param
<
uint32_t
>
(
MS_CTX_MAX_CALL_DEPTH
)
<<
", please call 'context.set_context(max_call_depth=value)' to adjust this value."
;
}
std
::
vector
<
AnfNodePtr
>
nodes
=
FastShadowSort
(
func_node
);
for
(
auto
it
=
nodes
.
crbegin
();
it
!=
nodes
.
crend
();
it
++
)
{
...
...
mindspore/ccsrc/pybind_api/utils/ms_context_py.cc
浏览文件 @
b0f89685
...
...
@@ -71,40 +71,29 @@ py::object MsCtxGetParameter(const std::shared_ptr<MsContext> &ctx, MsCtxParam p
}
}
// namespace
// Note: exported python enum variables begining with '_' are for internal use
REGISTER_PYBIND_DEFINE
(
MsContextPy
,
([](
const
py
::
module
*
m
)
{
(
void
)
py
::
enum_
<
MsCtxParam
>
(
*
m
,
"ms_ctx_param"
,
py
::
arithmetic
())
.
value
(
"enable_auto_mixed_precision"
,
MsCtxParam
::
MS_CTX_ENABLE_AUTO_MIXED_PRECISION
)
.
value
(
"check_bprop"
,
MsCtxParam
::
MS_CTX_CHECK_BPROP_FLAG
)
.
value
(
"enable_dump"
,
MsCtxParam
::
MS_CTX_ENABLE_DUMP
)
.
value
(
"enable_dynamic_mem_pool"
,
MsCtxParam
::
MS_CTX_ENABLE_DYNAMIC_MEM_POOL
)
.
value
(
"enable_gpu_summary"
,
MsCtxParam
::
MS_CTX_ENABLE_GPU_SUMMARY
)
.
value
(
"enable_graph_kernel"
,
MsCtxParam
::
MS_CTX_ENABLE_GRAPH_KERNEL
)
.
value
(
"enable_hccl"
,
MsCtxParam
::
MS_CTX_ENABLE_HCCL
)
.
value
(
"enable_mem_reuse"
,
MsCtxParam
::
MS_CTX_ENABLE_MEM_REUSE
)
.
value
(
"enable_pynative_hook"
,
MsCtxParam
::
MS_CTX_ENABLE_PYNATIVE_HOOK
)
.
value
(
"enable_pynative_infer"
,
MsCtxParam
::
MS_CTX_ENABLE_PYNATIVE_INFER
)
.
value
(
"enable_reduce_precision"
,
MsCtxParam
::
MS_CTX_ENABLE_REDUCE_PRECISION
)
.
value
(
"enable_sparse"
,
MsCtxParam
::
MS_CTX_ENABLE_SPARSE
)
.
value
(
"enable_task_sink"
,
MsCtxParam
::
MS_CTX_ENABLE_TASK_SINK
)
.
value
(
"ir_fusion_flag"
,
MsCtxParam
::
MS_CTX_IR_FUSION_FLAG
)
.
value
(
"is_multi_graph_sink"
,
MsCtxParam
::
MS_CTX_IS_MULTI_GRAPH_SINK
)
.
value
(
"is_pynative_ge_init"
,
MsCtxParam
::
MS_CTX_IS_PYNATIVE_GE_INIT
)
.
value
(
"precompile_only"
,
MsCtxParam
::
MS_CTX_PRECOMPILE_ONLY
)
.
value
(
"enable_profiling"
,
MsCtxParam
::
MS_CTX_ENABLE_PROFILING
)
.
value
(
"save_graphs"
,
MsCtxParam
::
MS_CTX_SAVE_GRAPHS_FLAG
)
.
value
(
"max_device_memory"
,
MsCtxParam
::
MS_CTX_MAX_DEVICE_MEMORY
)
.
value
(
"mode"
,
MsCtxParam
::
MS_CTX_EXECUTION_MODE
)
.
value
(
"device_target"
,
MsCtxParam
::
MS_CTX_DEVICE_TARGET
)
.
value
(
"graph_memory_max_size"
,
MsCtxParam
::
MS_CTX_GRAPH_MEMORY_MAX_SIZE
)
.
value
(
"
_
graph_memory_max_size"
,
MsCtxParam
::
MS_CTX_GRAPH_MEMORY_MAX_SIZE
)
.
value
(
"print_file_path"
,
MsCtxParam
::
MS_CTX_PRINT_FILE_PATH
)
.
value
(
"profiling_options"
,
MsCtxParam
::
MS_CTX_PROFILING_OPTIONS
)
.
value
(
"save_dump_path"
,
MsCtxParam
::
MS_CTX_SAVE_DUMP_PATH
)
.
value
(
"save_graphs_path"
,
MsCtxParam
::
MS_CTX_SAVE_GRAPHS_PATH
)
.
value
(
"variable_memory_max_size"
,
MsCtxParam
::
MS_CTX_VARIABLE_MEMORY_MAX_SIZE
)
.
value
(
"device_id"
,
MsCtxParam
::
MS_CTX_DEVICE_ID
)
.
value
(
"ge_ref"
,
MsCtxParam
::
MS_CTX_GE_REF
)
.
value
(
"max_call_depth"
,
MsCtxParam
::
MS_CTX_MAX_CALL_DEPTH
)
.
value
(
"tsd_ref"
,
MsCtxParam
::
MS_CTX_TSD_REF
);
.
value
(
"max_call_depth"
,
MsCtxParam
::
MS_CTX_MAX_CALL_DEPTH
);
(
void
)
py
::
class_
<
mindspore
::
MsContext
,
std
::
shared_ptr
<
mindspore
::
MsContext
>>
(
*
m
,
"MSContext"
)
.
def_static
(
"get_instance"
,
&
mindspore
::
MsContext
::
GetInstance
,
"Get ms context instance."
)
...
...
mindspore/context.py
浏览文件 @
b0f89685
...
...
@@ -219,6 +219,7 @@ class _Context:
self
.
set_param
(
ms_ctx_param
.
profiling_options
,
option
)
def
set_variable_memory_max_size
(
self
,
variable_memory_max_size
):
"""set values of variable_memory_max_size and graph_memory_max_size"""
if
not
_check_input_format
(
variable_memory_max_size
):
raise
ValueError
(
"Context param variable_memory_max_size should be in correct format! Such as
\"
5GB
\"
"
)
if
int
(
variable_memory_max_size
[:
-
2
])
>=
_DEVICE_APP_MEMORY_SIZE
:
...
...
@@ -227,7 +228,8 @@ class _Context:
graph_memory_max_size
=
_DEVICE_APP_MEMORY_SIZE
-
int
(
variable_memory_max_size
[:
-
2
])
graph_memory_max_size_
=
str
(
graph_memory_max_size
)
+
" * 1024 * 1024 * 1024"
self
.
set_param
(
ms_ctx_param
.
variable_memory_max_size
,
variable_memory_max_size_
)
self
.
set_param
(
ms_ctx_param
.
graph_memory_max_size
,
graph_memory_max_size_
)
# pylint: disable=protected-access
self
.
set_param
(
ms_ctx_param
.
_graph_memory_max_size
,
graph_memory_max_size_
)
def
set_max_device_memory
(
self
,
max_device_memory
):
if
not
_check_input_format
(
max_device_memory
):
...
...
@@ -425,6 +427,26 @@ def reset_auto_parallel_context():
_reset_auto_parallel_context
()
def
_check_target_specific_cfgs
(
device
,
arg_key
):
"""Checking whether a config is sutable for a specified device"""
device_cfgs
=
{
'enable_auto_mixed_precision'
:
[
'Ascend'
],
'enable_dump'
:
[
'Ascend'
],
'enable_profiling'
:
[
'Ascend'
],
'variable_memory_max_size'
:
[
'Ascend'
],
'max_device_memory'
:
[
'GPU'
]
}
# configs not in map device_cfgs are supposed to be suitable for all devices
if
not
arg_key
in
device_cfgs
:
return
True
supported_devices
=
device_cfgs
[
arg_key
]
if
device
in
supported_devices
:
return
True
logger
.
warning
(
f
"Config '
{
arg_key
}
' only supports devices in
{
supported_devices
}
, current device is '
{
device
}
'"
", ignore it."
)
return
False
@
args_type_check
(
mode
=
int
,
precompile_only
=
bool
,
device_target
=
str
,
device_id
=
int
,
save_graphs
=
bool
,
save_graphs_path
=
str
,
enable_dump
=
bool
,
save_dump_path
=
str
,
enable_reduce_precision
=
bool
,
variable_memory_max_size
=
str
,
...
...
@@ -450,6 +472,26 @@ def set_context(**kwargs):
The mode is not recommended to be changed after net was initilized because the implementations of some
operations are different in graph mode and pynative mode. Default: PYNATIVE_MODE.
Some configurations are device specific, see the bellow table for details:
=========================== =========================== =================
Common(CPU/GPU/Asecend) Ascend GPU
=========================== =========================== =================
check_bprop enable_auto_mixed_precision max_device_memory
device_id enable_dump
device_target enable_profiling
enable_graph_kernel variable_memory_max_size
enable_reduce_precision
enable_sparse
mode
print_file_path
profiling_options
reserve_class_name_in_scope
save_dump_path
save_graphs
save_graphs_path
=========================== =========================== =================
Args:
mode (int): Running in GRAPH_MODE(0) or PYNATIVE_MODE(1).
device_target (str): The target device to run, support "Ascend", "GPU", "CPU". Default: "Ascend".
...
...
@@ -513,14 +555,21 @@ def set_context(**kwargs):
>>> context.set_context(max_call_depth=80)
"""
ctx
=
_context
()
# set device target first
if
'device_target'
in
kwargs
:
ctx
.
set_device_target
(
kwargs
[
'device_target'
])
device
=
ctx
.
get_param
(
ms_ctx_param
.
device_target
)
for
key
,
value
in
kwargs
.
items
():
if
not
_check_target_specific_cfgs
(
device
,
key
):
continue
if
hasattr
(
ctx
,
key
):
setattr
(
ctx
,
key
,
value
)
continue
if
key
in
ctx
.
setters
:
ctx
.
setters
[
key
](
ctx
,
value
)
continue
if
key
in
ms_ctx_param
.
__members__
:
# enum variables begining with '_' are for internal use
if
key
in
ms_ctx_param
.
__members__
and
key
[
0
]
!=
'_'
:
ctx
.
set_param
(
ms_ctx_param
.
__members__
[
key
],
value
)
continue
raise
ValueError
(
"Set context keyword %s is not recognized!"
%
key
)
...
...
@@ -540,9 +589,12 @@ def get_context(attr_key):
ValueError: If input key is not an attribute in context.
"""
ctx
=
_context
()
device
=
ctx
.
get_param
(
ms_ctx_param
.
device_target
)
_
=
_check_target_specific_cfgs
(
device
,
attr_key
)
if
hasattr
(
ctx
,
attr_key
):
return
getattr
(
ctx
,
attr_key
)
if
attr_key
in
ms_ctx_param
.
__members__
:
# enum variables begining with '_' are for internal use
if
attr_key
in
ms_ctx_param
.
__members__
and
attr_key
[
0
]
!=
'_'
:
return
ctx
.
get_param
(
ms_ctx_param
.
__members__
[
attr_key
])
raise
ValueError
(
"Get context keyword %s is not recognized!"
%
attr_key
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录