Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
b0f89685
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
b0f89685
编写于
9月 04, 2020
作者:
F
fary86
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add device specific config check
上级
5a63dac0
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
60 addition
and
18 deletion
+60
-18
mindspore/ccsrc/pipeline/jit/static_analysis/evaluator.cc
mindspore/ccsrc/pipeline/jit/static_analysis/evaluator.cc
+2
-1
mindspore/ccsrc/pybind_api/utils/ms_context_py.cc
mindspore/ccsrc/pybind_api/utils/ms_context_py.cc
+3
-14
mindspore/context.py
mindspore/context.py
+55
-3
未找到文件。
mindspore/ccsrc/pipeline/jit/static_analysis/evaluator.cc
浏览文件 @
b0f89685
...
@@ -120,7 +120,8 @@ EvalResultPtr BaseFuncGraphEvaluator::Eval(AnalysisEnginePtr engine, const Abstr
...
@@ -120,7 +120,8 @@ EvalResultPtr BaseFuncGraphEvaluator::Eval(AnalysisEnginePtr engine, const Abstr
engine
->
IncreaseFunctionCallDepth
();
engine
->
IncreaseFunctionCallDepth
();
if
(
engine
->
function_call_depth
()
>
MsContext
::
GetInstance
()
->
get_param
<
uint32_t
>
(
MS_CTX_MAX_CALL_DEPTH
))
{
if
(
engine
->
function_call_depth
()
>
MsContext
::
GetInstance
()
->
get_param
<
uint32_t
>
(
MS_CTX_MAX_CALL_DEPTH
))
{
MS_LOG
(
EXCEPTION
)
<<
"Exceed function call depth limit "
MS_LOG
(
EXCEPTION
)
<<
"Exceed function call depth limit "
<<
MsContext
::
GetInstance
()
->
get_param
<
uint32_t
>
(
MS_CTX_MAX_CALL_DEPTH
)
<<
"."
;
<<
MsContext
::
GetInstance
()
->
get_param
<
uint32_t
>
(
MS_CTX_MAX_CALL_DEPTH
)
<<
", please call 'context.set_context(max_call_depth=value)' to adjust this value."
;
}
}
std
::
vector
<
AnfNodePtr
>
nodes
=
FastShadowSort
(
func_node
);
std
::
vector
<
AnfNodePtr
>
nodes
=
FastShadowSort
(
func_node
);
for
(
auto
it
=
nodes
.
crbegin
();
it
!=
nodes
.
crend
();
it
++
)
{
for
(
auto
it
=
nodes
.
crbegin
();
it
!=
nodes
.
crend
();
it
++
)
{
...
...
mindspore/ccsrc/pybind_api/utils/ms_context_py.cc
浏览文件 @
b0f89685
...
@@ -71,40 +71,29 @@ py::object MsCtxGetParameter(const std::shared_ptr<MsContext> &ctx, MsCtxParam p
...
@@ -71,40 +71,29 @@ py::object MsCtxGetParameter(const std::shared_ptr<MsContext> &ctx, MsCtxParam p
}
}
}
// namespace
}
// namespace
// Note: exported python enum variables begining with '_' are for internal use
REGISTER_PYBIND_DEFINE
(
MsContextPy
,
([](
const
py
::
module
*
m
)
{
REGISTER_PYBIND_DEFINE
(
MsContextPy
,
([](
const
py
::
module
*
m
)
{
(
void
)
py
::
enum_
<
MsCtxParam
>
(
*
m
,
"ms_ctx_param"
,
py
::
arithmetic
())
(
void
)
py
::
enum_
<
MsCtxParam
>
(
*
m
,
"ms_ctx_param"
,
py
::
arithmetic
())
.
value
(
"enable_auto_mixed_precision"
,
MsCtxParam
::
MS_CTX_ENABLE_AUTO_MIXED_PRECISION
)
.
value
(
"enable_auto_mixed_precision"
,
MsCtxParam
::
MS_CTX_ENABLE_AUTO_MIXED_PRECISION
)
.
value
(
"check_bprop"
,
MsCtxParam
::
MS_CTX_CHECK_BPROP_FLAG
)
.
value
(
"check_bprop"
,
MsCtxParam
::
MS_CTX_CHECK_BPROP_FLAG
)
.
value
(
"enable_dump"
,
MsCtxParam
::
MS_CTX_ENABLE_DUMP
)
.
value
(
"enable_dump"
,
MsCtxParam
::
MS_CTX_ENABLE_DUMP
)
.
value
(
"enable_dynamic_mem_pool"
,
MsCtxParam
::
MS_CTX_ENABLE_DYNAMIC_MEM_POOL
)
.
value
(
"enable_gpu_summary"
,
MsCtxParam
::
MS_CTX_ENABLE_GPU_SUMMARY
)
.
value
(
"enable_graph_kernel"
,
MsCtxParam
::
MS_CTX_ENABLE_GRAPH_KERNEL
)
.
value
(
"enable_graph_kernel"
,
MsCtxParam
::
MS_CTX_ENABLE_GRAPH_KERNEL
)
.
value
(
"enable_hccl"
,
MsCtxParam
::
MS_CTX_ENABLE_HCCL
)
.
value
(
"enable_mem_reuse"
,
MsCtxParam
::
MS_CTX_ENABLE_MEM_REUSE
)
.
value
(
"enable_pynative_hook"
,
MsCtxParam
::
MS_CTX_ENABLE_PYNATIVE_HOOK
)
.
value
(
"enable_pynative_infer"
,
MsCtxParam
::
MS_CTX_ENABLE_PYNATIVE_INFER
)
.
value
(
"enable_reduce_precision"
,
MsCtxParam
::
MS_CTX_ENABLE_REDUCE_PRECISION
)
.
value
(
"enable_reduce_precision"
,
MsCtxParam
::
MS_CTX_ENABLE_REDUCE_PRECISION
)
.
value
(
"enable_sparse"
,
MsCtxParam
::
MS_CTX_ENABLE_SPARSE
)
.
value
(
"enable_sparse"
,
MsCtxParam
::
MS_CTX_ENABLE_SPARSE
)
.
value
(
"enable_task_sink"
,
MsCtxParam
::
MS_CTX_ENABLE_TASK_SINK
)
.
value
(
"ir_fusion_flag"
,
MsCtxParam
::
MS_CTX_IR_FUSION_FLAG
)
.
value
(
"is_multi_graph_sink"
,
MsCtxParam
::
MS_CTX_IS_MULTI_GRAPH_SINK
)
.
value
(
"is_pynative_ge_init"
,
MsCtxParam
::
MS_CTX_IS_PYNATIVE_GE_INIT
)
.
value
(
"precompile_only"
,
MsCtxParam
::
MS_CTX_PRECOMPILE_ONLY
)
.
value
(
"precompile_only"
,
MsCtxParam
::
MS_CTX_PRECOMPILE_ONLY
)
.
value
(
"enable_profiling"
,
MsCtxParam
::
MS_CTX_ENABLE_PROFILING
)
.
value
(
"enable_profiling"
,
MsCtxParam
::
MS_CTX_ENABLE_PROFILING
)
.
value
(
"save_graphs"
,
MsCtxParam
::
MS_CTX_SAVE_GRAPHS_FLAG
)
.
value
(
"save_graphs"
,
MsCtxParam
::
MS_CTX_SAVE_GRAPHS_FLAG
)
.
value
(
"max_device_memory"
,
MsCtxParam
::
MS_CTX_MAX_DEVICE_MEMORY
)
.
value
(
"max_device_memory"
,
MsCtxParam
::
MS_CTX_MAX_DEVICE_MEMORY
)
.
value
(
"mode"
,
MsCtxParam
::
MS_CTX_EXECUTION_MODE
)
.
value
(
"mode"
,
MsCtxParam
::
MS_CTX_EXECUTION_MODE
)
.
value
(
"device_target"
,
MsCtxParam
::
MS_CTX_DEVICE_TARGET
)
.
value
(
"device_target"
,
MsCtxParam
::
MS_CTX_DEVICE_TARGET
)
.
value
(
"graph_memory_max_size"
,
MsCtxParam
::
MS_CTX_GRAPH_MEMORY_MAX_SIZE
)
.
value
(
"
_
graph_memory_max_size"
,
MsCtxParam
::
MS_CTX_GRAPH_MEMORY_MAX_SIZE
)
.
value
(
"print_file_path"
,
MsCtxParam
::
MS_CTX_PRINT_FILE_PATH
)
.
value
(
"print_file_path"
,
MsCtxParam
::
MS_CTX_PRINT_FILE_PATH
)
.
value
(
"profiling_options"
,
MsCtxParam
::
MS_CTX_PROFILING_OPTIONS
)
.
value
(
"profiling_options"
,
MsCtxParam
::
MS_CTX_PROFILING_OPTIONS
)
.
value
(
"save_dump_path"
,
MsCtxParam
::
MS_CTX_SAVE_DUMP_PATH
)
.
value
(
"save_dump_path"
,
MsCtxParam
::
MS_CTX_SAVE_DUMP_PATH
)
.
value
(
"save_graphs_path"
,
MsCtxParam
::
MS_CTX_SAVE_GRAPHS_PATH
)
.
value
(
"save_graphs_path"
,
MsCtxParam
::
MS_CTX_SAVE_GRAPHS_PATH
)
.
value
(
"variable_memory_max_size"
,
MsCtxParam
::
MS_CTX_VARIABLE_MEMORY_MAX_SIZE
)
.
value
(
"variable_memory_max_size"
,
MsCtxParam
::
MS_CTX_VARIABLE_MEMORY_MAX_SIZE
)
.
value
(
"device_id"
,
MsCtxParam
::
MS_CTX_DEVICE_ID
)
.
value
(
"device_id"
,
MsCtxParam
::
MS_CTX_DEVICE_ID
)
.
value
(
"ge_ref"
,
MsCtxParam
::
MS_CTX_GE_REF
)
.
value
(
"max_call_depth"
,
MsCtxParam
::
MS_CTX_MAX_CALL_DEPTH
);
.
value
(
"max_call_depth"
,
MsCtxParam
::
MS_CTX_MAX_CALL_DEPTH
)
.
value
(
"tsd_ref"
,
MsCtxParam
::
MS_CTX_TSD_REF
);
(
void
)
py
::
class_
<
mindspore
::
MsContext
,
std
::
shared_ptr
<
mindspore
::
MsContext
>>
(
*
m
,
"MSContext"
)
(
void
)
py
::
class_
<
mindspore
::
MsContext
,
std
::
shared_ptr
<
mindspore
::
MsContext
>>
(
*
m
,
"MSContext"
)
.
def_static
(
"get_instance"
,
&
mindspore
::
MsContext
::
GetInstance
,
"Get ms context instance."
)
.
def_static
(
"get_instance"
,
&
mindspore
::
MsContext
::
GetInstance
,
"Get ms context instance."
)
...
...
mindspore/context.py
浏览文件 @
b0f89685
...
@@ -219,6 +219,7 @@ class _Context:
...
@@ -219,6 +219,7 @@ class _Context:
self
.
set_param
(
ms_ctx_param
.
profiling_options
,
option
)
self
.
set_param
(
ms_ctx_param
.
profiling_options
,
option
)
def
set_variable_memory_max_size
(
self
,
variable_memory_max_size
):
def
set_variable_memory_max_size
(
self
,
variable_memory_max_size
):
"""set values of variable_memory_max_size and graph_memory_max_size"""
if
not
_check_input_format
(
variable_memory_max_size
):
if
not
_check_input_format
(
variable_memory_max_size
):
raise
ValueError
(
"Context param variable_memory_max_size should be in correct format! Such as
\"
5GB
\"
"
)
raise
ValueError
(
"Context param variable_memory_max_size should be in correct format! Such as
\"
5GB
\"
"
)
if
int
(
variable_memory_max_size
[:
-
2
])
>=
_DEVICE_APP_MEMORY_SIZE
:
if
int
(
variable_memory_max_size
[:
-
2
])
>=
_DEVICE_APP_MEMORY_SIZE
:
...
@@ -227,7 +228,8 @@ class _Context:
...
@@ -227,7 +228,8 @@ class _Context:
graph_memory_max_size
=
_DEVICE_APP_MEMORY_SIZE
-
int
(
variable_memory_max_size
[:
-
2
])
graph_memory_max_size
=
_DEVICE_APP_MEMORY_SIZE
-
int
(
variable_memory_max_size
[:
-
2
])
graph_memory_max_size_
=
str
(
graph_memory_max_size
)
+
" * 1024 * 1024 * 1024"
graph_memory_max_size_
=
str
(
graph_memory_max_size
)
+
" * 1024 * 1024 * 1024"
self
.
set_param
(
ms_ctx_param
.
variable_memory_max_size
,
variable_memory_max_size_
)
self
.
set_param
(
ms_ctx_param
.
variable_memory_max_size
,
variable_memory_max_size_
)
self
.
set_param
(
ms_ctx_param
.
graph_memory_max_size
,
graph_memory_max_size_
)
# pylint: disable=protected-access
self
.
set_param
(
ms_ctx_param
.
_graph_memory_max_size
,
graph_memory_max_size_
)
def
set_max_device_memory
(
self
,
max_device_memory
):
def
set_max_device_memory
(
self
,
max_device_memory
):
if
not
_check_input_format
(
max_device_memory
):
if
not
_check_input_format
(
max_device_memory
):
...
@@ -425,6 +427,26 @@ def reset_auto_parallel_context():
...
@@ -425,6 +427,26 @@ def reset_auto_parallel_context():
_reset_auto_parallel_context
()
_reset_auto_parallel_context
()
def
_check_target_specific_cfgs
(
device
,
arg_key
):
"""Checking whether a config is sutable for a specified device"""
device_cfgs
=
{
'enable_auto_mixed_precision'
:
[
'Ascend'
],
'enable_dump'
:
[
'Ascend'
],
'enable_profiling'
:
[
'Ascend'
],
'variable_memory_max_size'
:
[
'Ascend'
],
'max_device_memory'
:
[
'GPU'
]
}
# configs not in map device_cfgs are supposed to be suitable for all devices
if
not
arg_key
in
device_cfgs
:
return
True
supported_devices
=
device_cfgs
[
arg_key
]
if
device
in
supported_devices
:
return
True
logger
.
warning
(
f
"Config '
{
arg_key
}
' only supports devices in
{
supported_devices
}
, current device is '
{
device
}
'"
", ignore it."
)
return
False
@
args_type_check
(
mode
=
int
,
precompile_only
=
bool
,
device_target
=
str
,
device_id
=
int
,
save_graphs
=
bool
,
@
args_type_check
(
mode
=
int
,
precompile_only
=
bool
,
device_target
=
str
,
device_id
=
int
,
save_graphs
=
bool
,
save_graphs_path
=
str
,
enable_dump
=
bool
,
save_graphs_path
=
str
,
enable_dump
=
bool
,
save_dump_path
=
str
,
enable_reduce_precision
=
bool
,
variable_memory_max_size
=
str
,
save_dump_path
=
str
,
enable_reduce_precision
=
bool
,
variable_memory_max_size
=
str
,
...
@@ -450,6 +472,26 @@ def set_context(**kwargs):
...
@@ -450,6 +472,26 @@ def set_context(**kwargs):
The mode is not recommended to be changed after net was initilized because the implementations of some
The mode is not recommended to be changed after net was initilized because the implementations of some
operations are different in graph mode and pynative mode. Default: PYNATIVE_MODE.
operations are different in graph mode and pynative mode. Default: PYNATIVE_MODE.
Some configurations are device specific, see the bellow table for details:
=========================== =========================== =================
Common(CPU/GPU/Asecend) Ascend GPU
=========================== =========================== =================
check_bprop enable_auto_mixed_precision max_device_memory
device_id enable_dump
device_target enable_profiling
enable_graph_kernel variable_memory_max_size
enable_reduce_precision
enable_sparse
mode
print_file_path
profiling_options
reserve_class_name_in_scope
save_dump_path
save_graphs
save_graphs_path
=========================== =========================== =================
Args:
Args:
mode (int): Running in GRAPH_MODE(0) or PYNATIVE_MODE(1).
mode (int): Running in GRAPH_MODE(0) or PYNATIVE_MODE(1).
device_target (str): The target device to run, support "Ascend", "GPU", "CPU". Default: "Ascend".
device_target (str): The target device to run, support "Ascend", "GPU", "CPU". Default: "Ascend".
...
@@ -513,14 +555,21 @@ def set_context(**kwargs):
...
@@ -513,14 +555,21 @@ def set_context(**kwargs):
>>> context.set_context(max_call_depth=80)
>>> context.set_context(max_call_depth=80)
"""
"""
ctx
=
_context
()
ctx
=
_context
()
# set device target first
if
'device_target'
in
kwargs
:
ctx
.
set_device_target
(
kwargs
[
'device_target'
])
device
=
ctx
.
get_param
(
ms_ctx_param
.
device_target
)
for
key
,
value
in
kwargs
.
items
():
for
key
,
value
in
kwargs
.
items
():
if
not
_check_target_specific_cfgs
(
device
,
key
):
continue
if
hasattr
(
ctx
,
key
):
if
hasattr
(
ctx
,
key
):
setattr
(
ctx
,
key
,
value
)
setattr
(
ctx
,
key
,
value
)
continue
continue
if
key
in
ctx
.
setters
:
if
key
in
ctx
.
setters
:
ctx
.
setters
[
key
](
ctx
,
value
)
ctx
.
setters
[
key
](
ctx
,
value
)
continue
continue
if
key
in
ms_ctx_param
.
__members__
:
# enum variables begining with '_' are for internal use
if
key
in
ms_ctx_param
.
__members__
and
key
[
0
]
!=
'_'
:
ctx
.
set_param
(
ms_ctx_param
.
__members__
[
key
],
value
)
ctx
.
set_param
(
ms_ctx_param
.
__members__
[
key
],
value
)
continue
continue
raise
ValueError
(
"Set context keyword %s is not recognized!"
%
key
)
raise
ValueError
(
"Set context keyword %s is not recognized!"
%
key
)
...
@@ -540,9 +589,12 @@ def get_context(attr_key):
...
@@ -540,9 +589,12 @@ def get_context(attr_key):
ValueError: If input key is not an attribute in context.
ValueError: If input key is not an attribute in context.
"""
"""
ctx
=
_context
()
ctx
=
_context
()
device
=
ctx
.
get_param
(
ms_ctx_param
.
device_target
)
_
=
_check_target_specific_cfgs
(
device
,
attr_key
)
if
hasattr
(
ctx
,
attr_key
):
if
hasattr
(
ctx
,
attr_key
):
return
getattr
(
ctx
,
attr_key
)
return
getattr
(
ctx
,
attr_key
)
if
attr_key
in
ms_ctx_param
.
__members__
:
# enum variables begining with '_' are for internal use
if
attr_key
in
ms_ctx_param
.
__members__
and
attr_key
[
0
]
!=
'_'
:
return
ctx
.
get_param
(
ms_ctx_param
.
__members__
[
attr_key
])
return
ctx
.
get_param
(
ms_ctx_param
.
__members__
[
attr_key
])
raise
ValueError
(
"Get context keyword %s is not recognized!"
%
attr_key
)
raise
ValueError
(
"Get context keyword %s is not recognized!"
%
attr_key
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录