Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
0a858b38
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
0a858b38
编写于
8月 31, 2020
作者:
F
fary86
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Simplify ms_context implementation
上级
d5e02cf4
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
190 addition
and
265 deletion
+190
-265
mindspore/ccsrc/pipeline/jit/init.cc
mindspore/ccsrc/pipeline/jit/init.cc
+0
-92
mindspore/ccsrc/pybind_api/utils/ms_context_py.cc
mindspore/ccsrc/pybind_api/utils/ms_context_py.cc
+117
-0
mindspore/ccsrc/utils/context/context_extends.cc
mindspore/ccsrc/utils/context/context_extends.cc
+2
-2
mindspore/context.py
mindspore/context.py
+65
-165
mindspore/core/utils/ms_context.cc
mindspore/core/utils/ms_context.cc
+1
-1
mindspore/core/utils/ms_context.h
mindspore/core/utils/ms_context.h
+5
-5
未找到文件。
mindspore/ccsrc/pipeline/jit/init.cc
浏览文件 @
0a858b38
...
@@ -50,55 +50,6 @@ using ParallelContext = mindspore::parallel::ParallelContext;
...
@@ -50,55 +50,6 @@ using ParallelContext = mindspore::parallel::ParallelContext;
using
CostModelContext
=
mindspore
::
parallel
::
CostModelContext
;
using
CostModelContext
=
mindspore
::
parallel
::
CostModelContext
;
using
mindspore
::
MsCtxParam
;
using
mindspore
::
MsCtxParam
;
namespace
mindspore
{
void
MsCtxSetParameter
(
std
::
shared_ptr
<
MsContext
>
ctx
,
MsCtxParam
param
,
const
py
::
object
&
value
)
{
MS_LOG
(
DEBUG
)
<<
"set param("
<<
param
<<
") with value '"
<<
py
::
str
(
value
)
<<
"' of type '"
<<
py
::
str
(
value
.
get_type
())
<<
"'."
;
if
(
param
>=
MS_CTX_TYPE_BOOL_BEGIN
&&
param
<
MS_CTX_TYPE_BOOL_END
&&
py
::
isinstance
<
py
::
bool_
>
(
value
))
{
ctx
->
set_param
<
bool
>
(
param
,
value
.
cast
<
bool
>
());
return
;
}
if
(
param
>=
MS_CTX_TYPE_INT_BEGIN
&&
param
<
MS_CTX_TYPE_INT_END
&&
py
::
isinstance
<
py
::
int_
>
(
value
))
{
ctx
->
set_param
<
int
>
(
param
,
value
.
cast
<
int
>
());
return
;
}
if
(
param
>=
MS_CTX_TYPE_UINT32_BEGIN
&&
param
<
MS_CTX_TYPE_UINT32_END
&&
py
::
isinstance
<
py
::
int_
>
(
value
))
{
ctx
->
set_param
<
uint32_t
>
(
param
,
value
.
cast
<
uint32_t
>
());
return
;
}
if
(
param
>=
MS_CTX_TYPE_FLOAT_BEGIN
&&
param
<
MS_CTX_TYPE_FLOAT_END
&&
py
::
isinstance
<
py
::
float_
>
(
value
))
{
ctx
->
set_param
<
float
>
(
param
,
value
.
cast
<
float
>
());
return
;
}
if
(
param
>=
MS_CTX_TYPE_STRING_BEGIN
&&
param
<
MS_CTX_TYPE_STRING_END
&&
py
::
isinstance
<
py
::
str
>
(
value
))
{
ctx
->
set_param
<
std
::
string
>
(
param
,
value
.
cast
<
std
::
string
>
());
return
;
}
MS_LOG
(
EXCEPTION
)
<<
"Got illegal param "
<<
param
<<
" and value with type "
<<
py
::
str
(
value
.
get_type
());
}
py
::
object
MsCtxGetParameter
(
const
std
::
shared_ptr
<
MsContext
>
&
ctx
,
MsCtxParam
param
)
{
if
(
param
>=
MS_CTX_TYPE_BOOL_BEGIN
&&
param
<
MS_CTX_TYPE_BOOL_END
)
{
return
py
::
bool_
(
ctx
->
get_param
<
bool
>
(
param
));
}
if
(
param
>=
MS_CTX_TYPE_INT_BEGIN
&&
param
<
MS_CTX_TYPE_INT_END
)
{
return
py
::
int_
(
ctx
->
get_param
<
int
>
(
param
));
}
if
(
param
>=
MS_CTX_TYPE_UINT32_BEGIN
&&
param
<
MS_CTX_TYPE_UINT32_END
)
{
return
py
::
int_
(
ctx
->
get_param
<
uint32_t
>
(
param
));
}
if
(
param
>=
MS_CTX_TYPE_FLOAT_BEGIN
&&
param
<
MS_CTX_TYPE_FLOAT_END
)
{
return
py
::
float_
(
ctx
->
get_param
<
float
>
(
param
));
}
if
(
param
>=
MS_CTX_TYPE_STRING_BEGIN
&&
param
<
MS_CTX_TYPE_STRING_END
)
{
return
py
::
str
(
ctx
->
get_param
<
std
::
string
>
(
param
));
}
MS_LOG
(
EXCEPTION
)
<<
"Got illegal param "
<<
param
<<
"."
;
}
}
// namespace mindspore
// Interface with python
// Interface with python
PYBIND11_MODULE
(
_c_expression
,
m
)
{
PYBIND11_MODULE
(
_c_expression
,
m
)
{
m
.
doc
()
=
"MindSpore c plugin"
;
m
.
doc
()
=
"MindSpore c plugin"
;
...
@@ -151,49 +102,6 @@ PYBIND11_MODULE(_c_expression, m) {
...
@@ -151,49 +102,6 @@ PYBIND11_MODULE(_c_expression, m) {
(
void
)
m
.
def
(
"export_graph"
,
&
mindspore
::
pipeline
::
ExportGraph
,
"Export Graph."
);
(
void
)
m
.
def
(
"export_graph"
,
&
mindspore
::
pipeline
::
ExportGraph
,
"Export Graph."
);
(
void
)
m
.
def
(
"ms_ctx_get_param"
,
&
mindspore
::
MsCtxGetParameter
,
"Get value of specified paramter."
);
(
void
)
m
.
def
(
"ms_ctx_set_param"
,
&
mindspore
::
MsCtxSetParameter
,
"Set value for specified paramter."
);
(
void
)
py
::
enum_
<
MsCtxParam
>
(
*
m
,
"ms_ctx_param"
,
py
::
arithmetic
())
.
value
(
"auto_mixed_precision_flag"
,
MsCtxParam
::
MS_CTX_AUTO_MIXED_PRECISION_FLAG
)
.
value
(
"check_bprop_flag"
,
MsCtxParam
::
MS_CTX_CHECK_BPROP_FLAG
)
.
value
(
"enable_dump"
,
MsCtxParam
::
MS_CTX_ENABLE_DUMP
)
.
value
(
"enable_dynamic_mem_pool"
,
MsCtxParam
::
MS_CTX_ENABLE_DYNAMIC_MEM_POOL
)
.
value
(
"enable_gpu_summary"
,
MsCtxParam
::
MS_CTX_ENABLE_GPU_SUMMARY
)
.
value
(
"enable_graph_kernel"
,
MsCtxParam
::
MS_CTX_ENABLE_GRAPH_KERNEL
)
.
value
(
"enable_hccl"
,
MsCtxParam
::
MS_CTX_ENABLE_HCCL
)
.
value
(
"enable_loop_sink"
,
MsCtxParam
::
MS_CTX_ENABLE_LOOP_SINK
)
.
value
(
"enable_mem_reuse"
,
MsCtxParam
::
MS_CTX_ENABLE_MEM_REUSE
)
.
value
(
"enable_pynative_hook"
,
MsCtxParam
::
MS_CTX_ENABLE_PYNATIVE_HOOK
)
.
value
(
"enable_pynative_infer"
,
MsCtxParam
::
MS_CTX_ENABLE_PYNATIVE_INFER
)
.
value
(
"enable_reduce_precision"
,
MsCtxParam
::
MS_CTX_ENABLE_REDUCE_PRECISION
)
.
value
(
"enable_sparse"
,
MsCtxParam
::
MS_CTX_ENABLE_SPARSE
)
.
value
(
"enable_task_sink"
,
MsCtxParam
::
MS_CTX_ENABLE_TASK_SINK
)
.
value
(
"ir_fusion_flag"
,
MsCtxParam
::
MS_CTX_IR_FUSION_FLAG
)
.
value
(
"is_multi_graph_sink"
,
MsCtxParam
::
MS_CTX_IS_MULTI_GRAPH_SINK
)
.
value
(
"is_pynative_ge_init"
,
MsCtxParam
::
MS_CTX_IS_PYNATIVE_GE_INIT
)
.
value
(
"precompile_only"
,
MsCtxParam
::
MS_CTX_PRECOMPILE_ONLY
)
.
value
(
"enable_profiling"
,
MsCtxParam
::
MS_CTX_ENABLE_PROFILING
)
.
value
(
"save_graphs_flag"
,
MsCtxParam
::
MS_CTX_SAVE_GRAPHS_FLAG
)
.
value
(
"max_device_memory"
,
MsCtxParam
::
MS_CTX_MAX_DEVICE_MEMORY
)
.
value
(
"execution_mode"
,
MsCtxParam
::
MS_CTX_EXECUTION_MODE
)
.
value
(
"device_target"
,
MsCtxParam
::
MS_CTX_DEVICE_TARGET
)
.
value
(
"graph_memory_max_size"
,
MsCtxParam
::
MS_CTX_GRAPH_MEMORY_MAX_SIZE
)
.
value
(
"print_file_path"
,
MsCtxParam
::
MS_CTX_PRINT_FILE_PATH
)
.
value
(
"profiling_options"
,
MsCtxParam
::
MS_CTX_PROFILING_OPTIONS
)
.
value
(
"save_dump_path"
,
MsCtxParam
::
MS_CTX_SAVE_DUMP_PATH
)
.
value
(
"save_graphs_path"
,
MsCtxParam
::
MS_CTX_SAVE_GRAPHS_PATH
)
.
value
(
"variable_memory_max_size"
,
MsCtxParam
::
MS_CTX_VARIABLE_MEMORY_MAX_SIZE
)
.
value
(
"device_id"
,
MsCtxParam
::
MS_CTX_DEVICE_ID
)
.
value
(
"ge_ref"
,
MsCtxParam
::
MS_CTX_GE_REF
)
.
value
(
"max_call_depth"
,
MsCtxParam
::
MS_CTX_MAX_CALL_DEPTH
)
.
value
(
"tsd_ref"
,
MsCtxParam
::
MS_CTX_TSD_REF
);
(
void
)
py
::
class_
<
mindspore
::
MsContext
,
std
::
shared_ptr
<
mindspore
::
MsContext
>>
(
m
,
"MSContext"
)
.
def_static
(
"get_instance"
,
&
mindspore
::
MsContext
::
GetInstance
,
"Get ms context instance."
)
.
def
(
"get_backend_policy"
,
&
mindspore
::
MsContext
::
backend_policy
,
"Get backend policy."
)
.
def
(
"set_backend_policy"
,
&
mindspore
::
MsContext
::
set_backend_policy
,
"Set backend policy."
);
(
void
)
py
::
class_
<
mindspore
::
MpiConfig
,
std
::
shared_ptr
<
mindspore
::
MpiConfig
>>
(
m
,
"MpiConfig"
)
(
void
)
py
::
class_
<
mindspore
::
MpiConfig
,
std
::
shared_ptr
<
mindspore
::
MpiConfig
>>
(
m
,
"MpiConfig"
)
.
def_static
(
"get_instance"
,
&
mindspore
::
MpiConfig
::
GetInstance
,
"Get mpi config instance."
)
.
def_static
(
"get_instance"
,
&
mindspore
::
MpiConfig
::
GetInstance
,
"Get mpi config instance."
)
.
def
(
"get_enable_mpi"
,
&
mindspore
::
MpiConfig
::
enable_mpi
,
"Get whether enable mpi."
)
.
def
(
"get_enable_mpi"
,
&
mindspore
::
MpiConfig
::
enable_mpi
,
"Get whether enable mpi."
)
...
...
mindspore/ccsrc/pybind_api/utils/ms_context_py.cc
0 → 100644
浏览文件 @
0a858b38
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <memory>
#include <string>
#include "utils/ms_context.h"
#include "utils/log_adapter.h"
#include "pybind_api/api_register.h"
namespace
mindspore
{
namespace
{
void
MsCtxSetParameter
(
std
::
shared_ptr
<
MsContext
>
ctx
,
MsCtxParam
param
,
const
py
::
object
&
value
)
{
MS_LOG
(
DEBUG
)
<<
"set param("
<<
param
<<
") with value '"
<<
py
::
str
(
value
).
cast
<
std
::
string
>
()
<<
"' of type '"
<<
py
::
str
(
value
.
get_type
()).
cast
<
std
::
string
>
()
<<
"'."
;
if
(
param
>=
MS_CTX_TYPE_BOOL_BEGIN
&&
param
<
MS_CTX_TYPE_BOOL_END
&&
py
::
isinstance
<
py
::
bool_
>
(
value
))
{
ctx
->
set_param
<
bool
>
(
param
,
value
.
cast
<
bool
>
());
return
;
}
if
(
param
>=
MS_CTX_TYPE_INT_BEGIN
&&
param
<
MS_CTX_TYPE_INT_END
&&
py
::
isinstance
<
py
::
int_
>
(
value
))
{
ctx
->
set_param
<
int
>
(
param
,
value
.
cast
<
int
>
());
return
;
}
if
(
param
>=
MS_CTX_TYPE_UINT32_BEGIN
&&
param
<
MS_CTX_TYPE_UINT32_END
&&
py
::
isinstance
<
py
::
int_
>
(
value
))
{
ctx
->
set_param
<
uint32_t
>
(
param
,
value
.
cast
<
uint32_t
>
());
return
;
}
if
(
param
>=
MS_CTX_TYPE_FLOAT_BEGIN
&&
param
<
MS_CTX_TYPE_FLOAT_END
&&
py
::
isinstance
<
py
::
float_
>
(
value
))
{
ctx
->
set_param
<
float
>
(
param
,
value
.
cast
<
float
>
());
return
;
}
if
(
param
>=
MS_CTX_TYPE_STRING_BEGIN
&&
param
<
MS_CTX_TYPE_STRING_END
&&
py
::
isinstance
<
py
::
str
>
(
value
))
{
ctx
->
set_param
<
std
::
string
>
(
param
,
value
.
cast
<
std
::
string
>
());
return
;
}
MS_LOG
(
EXCEPTION
)
<<
"Got illegal param "
<<
param
<<
" and value with type "
<<
py
::
str
(
value
.
get_type
()).
cast
<
std
::
string
>
();
}
py
::
object
MsCtxGetParameter
(
const
std
::
shared_ptr
<
MsContext
>
&
ctx
,
MsCtxParam
param
)
{
if
(
param
>=
MS_CTX_TYPE_BOOL_BEGIN
&&
param
<
MS_CTX_TYPE_BOOL_END
)
{
return
py
::
bool_
(
ctx
->
get_param
<
bool
>
(
param
));
}
if
(
param
>=
MS_CTX_TYPE_INT_BEGIN
&&
param
<
MS_CTX_TYPE_INT_END
)
{
return
py
::
int_
(
ctx
->
get_param
<
int
>
(
param
));
}
if
(
param
>=
MS_CTX_TYPE_UINT32_BEGIN
&&
param
<
MS_CTX_TYPE_UINT32_END
)
{
return
py
::
int_
(
ctx
->
get_param
<
uint32_t
>
(
param
));
}
if
(
param
>=
MS_CTX_TYPE_FLOAT_BEGIN
&&
param
<
MS_CTX_TYPE_FLOAT_END
)
{
return
py
::
float_
(
ctx
->
get_param
<
float
>
(
param
));
}
if
(
param
>=
MS_CTX_TYPE_STRING_BEGIN
&&
param
<
MS_CTX_TYPE_STRING_END
)
{
return
py
::
str
(
ctx
->
get_param
<
std
::
string
>
(
param
));
}
MS_LOG
(
EXCEPTION
)
<<
"Got illegal param "
<<
param
<<
"."
;
}
}
// namespace
REGISTER_PYBIND_DEFINE
(
MsContextPy
,
([](
const
py
::
module
*
m
)
{
(
void
)
py
::
enum_
<
MsCtxParam
>
(
*
m
,
"ms_ctx_param"
,
py
::
arithmetic
())
.
value
(
"enable_auto_mixed_precision"
,
MsCtxParam
::
MS_CTX_ENABLE_AUTO_MIXED_PRECISION
)
.
value
(
"check_bprop"
,
MsCtxParam
::
MS_CTX_CHECK_BPROP_FLAG
)
.
value
(
"enable_dump"
,
MsCtxParam
::
MS_CTX_ENABLE_DUMP
)
.
value
(
"enable_dynamic_mem_pool"
,
MsCtxParam
::
MS_CTX_ENABLE_DYNAMIC_MEM_POOL
)
.
value
(
"enable_gpu_summary"
,
MsCtxParam
::
MS_CTX_ENABLE_GPU_SUMMARY
)
.
value
(
"enable_graph_kernel"
,
MsCtxParam
::
MS_CTX_ENABLE_GRAPH_KERNEL
)
.
value
(
"enable_hccl"
,
MsCtxParam
::
MS_CTX_ENABLE_HCCL
)
.
value
(
"enable_loop_sink"
,
MsCtxParam
::
MS_CTX_ENABLE_LOOP_SINK
)
.
value
(
"enable_mem_reuse"
,
MsCtxParam
::
MS_CTX_ENABLE_MEM_REUSE
)
.
value
(
"enable_pynative_hook"
,
MsCtxParam
::
MS_CTX_ENABLE_PYNATIVE_HOOK
)
.
value
(
"enable_pynative_infer"
,
MsCtxParam
::
MS_CTX_ENABLE_PYNATIVE_INFER
)
.
value
(
"enable_reduce_precision"
,
MsCtxParam
::
MS_CTX_ENABLE_REDUCE_PRECISION
)
.
value
(
"enable_sparse"
,
MsCtxParam
::
MS_CTX_ENABLE_SPARSE
)
.
value
(
"enable_task_sink"
,
MsCtxParam
::
MS_CTX_ENABLE_TASK_SINK
)
.
value
(
"ir_fusion_flag"
,
MsCtxParam
::
MS_CTX_IR_FUSION_FLAG
)
.
value
(
"is_multi_graph_sink"
,
MsCtxParam
::
MS_CTX_IS_MULTI_GRAPH_SINK
)
.
value
(
"is_pynative_ge_init"
,
MsCtxParam
::
MS_CTX_IS_PYNATIVE_GE_INIT
)
.
value
(
"precompile_only"
,
MsCtxParam
::
MS_CTX_PRECOMPILE_ONLY
)
.
value
(
"enable_profiling"
,
MsCtxParam
::
MS_CTX_ENABLE_PROFILING
)
.
value
(
"save_graphs"
,
MsCtxParam
::
MS_CTX_SAVE_GRAPHS_FLAG
)
.
value
(
"max_device_memory"
,
MsCtxParam
::
MS_CTX_MAX_DEVICE_MEMORY
)
.
value
(
"mode"
,
MsCtxParam
::
MS_CTX_EXECUTION_MODE
)
.
value
(
"device_target"
,
MsCtxParam
::
MS_CTX_DEVICE_TARGET
)
.
value
(
"graph_memory_max_size"
,
MsCtxParam
::
MS_CTX_GRAPH_MEMORY_MAX_SIZE
)
.
value
(
"print_file_path"
,
MsCtxParam
::
MS_CTX_PRINT_FILE_PATH
)
.
value
(
"profiling_options"
,
MsCtxParam
::
MS_CTX_PROFILING_OPTIONS
)
.
value
(
"save_dump_path"
,
MsCtxParam
::
MS_CTX_SAVE_DUMP_PATH
)
.
value
(
"save_graphs_path"
,
MsCtxParam
::
MS_CTX_SAVE_GRAPHS_PATH
)
.
value
(
"variable_memory_max_size"
,
MsCtxParam
::
MS_CTX_VARIABLE_MEMORY_MAX_SIZE
)
.
value
(
"device_id"
,
MsCtxParam
::
MS_CTX_DEVICE_ID
)
.
value
(
"ge_ref"
,
MsCtxParam
::
MS_CTX_GE_REF
)
.
value
(
"max_call_depth"
,
MsCtxParam
::
MS_CTX_MAX_CALL_DEPTH
)
.
value
(
"tsd_ref"
,
MsCtxParam
::
MS_CTX_TSD_REF
);
(
void
)
py
::
class_
<
mindspore
::
MsContext
,
std
::
shared_ptr
<
mindspore
::
MsContext
>>
(
*
m
,
"MSContext"
)
.
def_static
(
"get_instance"
,
&
mindspore
::
MsContext
::
GetInstance
,
"Get ms context instance."
)
.
def
(
"get_param"
,
&
mindspore
::
MsCtxGetParameter
,
"Get value of specified paramter."
)
.
def
(
"set_param"
,
&
mindspore
::
MsCtxSetParameter
,
"Set value for specified paramter."
)
.
def
(
"get_backend_policy"
,
&
mindspore
::
MsContext
::
backend_policy
,
"Get backend policy."
)
.
def
(
"set_backend_policy"
,
&
mindspore
::
MsContext
::
set_backend_policy
,
"Set backend policy."
);
}));
}
// namespace mindspore
mindspore/ccsrc/utils/context/context_extends.cc
浏览文件 @
0a858b38
...
@@ -225,7 +225,7 @@ void GetGeOptions(const std::shared_ptr<MsContext> &ms_context_ptr, std::map<std
...
@@ -225,7 +225,7 @@ void GetGeOptions(const std::shared_ptr<MsContext> &ms_context_ptr, std::map<std
}
}
// Enable auto mixed precision according to the context options
// Enable auto mixed precision according to the context options
if
(
ms_context_ptr
->
get_param
<
bool
>
(
MS_CTX_
AUTO_MIXED_PRECISION_FLAG
))
{
if
(
ms_context_ptr
->
get_param
<
bool
>
(
MS_CTX_
ENABLE_AUTO_MIXED_PRECISION
))
{
(
*
ge_options
)[
"ge.exec.precision_mode"
]
=
"allow_mix_precision"
;
(
*
ge_options
)[
"ge.exec.precision_mode"
]
=
"allow_mix_precision"
;
}
else
{
}
else
{
(
*
ge_options
)[
"ge.exec.precision_mode"
]
=
"allow_fp32_to_fp16"
;
(
*
ge_options
)[
"ge.exec.precision_mode"
]
=
"allow_fp32_to_fp16"
;
...
@@ -337,7 +337,7 @@ bool FinalizeGe(const std::shared_ptr<MsContext> &ms_context_ptr, bool force) {
...
@@ -337,7 +337,7 @@ bool FinalizeGe(const std::shared_ptr<MsContext> &ms_context_ptr, bool force) {
if
(
ge
::
GEFinalize
()
!=
ge
::
GRAPH_SUCCESS
)
{
if
(
ge
::
GEFinalize
()
!=
ge
::
GRAPH_SUCCESS
)
{
MS_LOG
(
WARNING
)
<<
"Finalize GE failed!"
;
MS_LOG
(
WARNING
)
<<
"Finalize GE failed!"
;
}
}
ms_context_ptr
->
set_p
ynative_ge_init
(
false
);
ms_context_ptr
->
set_p
aram
<
bool
>
(
MS_CTX_IS_PYNATIVE_GE_INIT
,
false
);
}
else
{
}
else
{
MS_LOG
(
INFO
)
<<
"Ge is used, no need to finalize, tsd reference = "
MS_LOG
(
INFO
)
<<
"Ge is used, no need to finalize, tsd reference = "
<<
ms_context_ptr
->
get_param
<
uint32_t
>
(
MS_CTX_GE_REF
)
<<
"."
;
<<
ms_context_ptr
->
get_param
<
uint32_t
>
(
MS_CTX_GE_REF
)
<<
"."
;
...
...
mindspore/context.py
浏览文件 @
0a858b38
...
@@ -22,7 +22,7 @@ import threading
...
@@ -22,7 +22,7 @@ import threading
from
collections
import
namedtuple
from
collections
import
namedtuple
from
types
import
FunctionType
from
types
import
FunctionType
from
mindspore
import
log
as
logger
from
mindspore
import
log
as
logger
from
mindspore._c_expression
import
MSContext
,
ms_ctx_param
,
ms_ctx_get_param
,
ms_ctx_set_param
from
mindspore._c_expression
import
MSContext
,
ms_ctx_param
from
mindspore._checkparam
import
args_type_check
from
mindspore._checkparam
import
args_type_check
from
mindspore.parallel._auto_parallel_context
import
_set_auto_parallel_context
,
_get_auto_parallel_context
,
\
from
mindspore.parallel._auto_parallel_context
import
_set_auto_parallel_context
,
_get_auto_parallel_context
,
\
_reset_auto_parallel_context
_reset_auto_parallel_context
...
@@ -158,17 +158,12 @@ class _Context:
...
@@ -158,17 +158,12 @@ class _Context:
return
value
return
value
def
get_param
(
self
,
param
):
def
get_param
(
self
,
param
):
return
ms_ctx_get_param
(
self
.
_context_handle
,
param
)
return
self
.
_context_handle
.
get_param
(
param
)
def
set_param
(
self
,
param
,
value
):
def
set_param
(
self
,
param
,
value
):
ms_ctx_set_param
(
self
.
_context_handle
,
param
,
value
)
self
.
_context_handle
.
set_param
(
param
,
value
)
@
property
def
set_mode
(
self
,
mode
):
def
mode
(
self
):
return
self
.
get_param
(
ms_ctx_param
.
execution_mode
)
@
mode
.
setter
def
mode
(
self
,
mode
):
"""
"""
Switch between Graph mode and PyNative mode.
Switch between Graph mode and PyNative mode.
...
@@ -185,43 +180,17 @@ class _Context:
...
@@ -185,43 +180,17 @@ class _Context:
self
.
_context_switches
.
push
(
False
,
None
)
self
.
_context_switches
.
push
(
False
,
None
)
else
:
else
:
raise
ValueError
(
f
'The execution mode
{
mode
}
is invalid!'
)
raise
ValueError
(
f
'The execution mode
{
mode
}
is invalid!'
)
self
.
set_param
(
ms_ctx_param
.
execution_
mode
,
mode
)
self
.
set_param
(
ms_ctx_param
.
mode
,
mode
)
def
set_backend_policy
(
self
,
policy
):
def
set_backend_policy
(
self
,
policy
):
success
=
self
.
_context_handle
.
set_backend_policy
(
policy
)
success
=
self
.
_context_handle
.
set_backend_policy
(
policy
)
if
not
success
:
if
not
success
:
raise
RuntimeError
(
"Backend policy must be one of ge, vm, ms."
)
raise
RuntimeError
(
"Backend policy must be one of ge, vm, ms."
)
@
property
def
set_save_graphs_path
(
self
,
save_graphs_path
):
def
precompile_only
(
self
):
return
self
.
get_param
(
ms_ctx_param
.
precompile_only
)
@
precompile_only
.
setter
def
precompile_only
(
self
,
precompile_only
):
self
.
set_param
(
ms_ctx_param
.
precompile_only
,
precompile_only
)
@
property
def
save_graphs
(
self
):
return
self
.
get_param
(
ms_ctx_param
.
save_graphs_flag
)
@
save_graphs
.
setter
def
save_graphs
(
self
,
save_graphs_flag
):
self
.
set_param
(
ms_ctx_param
.
save_graphs_flag
,
save_graphs_flag
)
@
property
def
save_graphs_path
(
self
):
return
self
.
get_param
(
ms_ctx_param
.
save_graphs_path
)
@
save_graphs_path
.
setter
def
save_graphs_path
(
self
,
save_graphs_path
):
self
.
set_param
(
ms_ctx_param
.
save_graphs_path
,
_make_directory
(
save_graphs_path
))
self
.
set_param
(
ms_ctx_param
.
save_graphs_path
,
_make_directory
(
save_graphs_path
))
@
property
def
set_device_target
(
self
,
target
):
def
device_target
(
self
):
return
self
.
get_param
(
ms_ctx_param
.
device_target
)
@
device_target
.
setter
def
device_target
(
self
,
target
):
valid_targets
=
[
"CPU"
,
"GPU"
,
"Ascend"
,
"Davinci"
]
valid_targets
=
[
"CPU"
,
"GPU"
,
"Ascend"
,
"Davinci"
]
if
not
target
in
valid_targets
:
if
not
target
in
valid_targets
:
raise
ValueError
(
f
"Target device name
{
target
}
is invalid! It must be one of
{
valid_targets
}
"
)
raise
ValueError
(
f
"Target device name
{
target
}
is invalid! It must be one of
{
valid_targets
}
"
)
...
@@ -231,72 +200,17 @@ class _Context:
...
@@ -231,72 +200,17 @@ class _Context:
if
self
.
enable_debug_runtime
and
target
==
"CPU"
:
if
self
.
enable_debug_runtime
and
target
==
"CPU"
:
self
.
set_backend_policy
(
"vm"
)
self
.
set_backend_policy
(
"vm"
)
@
property
def
set_device_id
(
self
,
device_id
):
def
device_id
(
self
):
return
self
.
get_param
(
ms_ctx_param
.
device_id
)
@
device_id
.
setter
def
device_id
(
self
,
device_id
):
if
device_id
<
0
or
device_id
>
4095
:
if
device_id
<
0
or
device_id
>
4095
:
raise
ValueError
(
f
"Device id must be in [0, 4095], but got
{
device_id
}
"
)
raise
ValueError
(
f
"Device id must be in [0, 4095], but got
{
device_id
}
"
)
self
.
set_param
(
ms_ctx_param
.
device_id
,
device_id
)
self
.
set_param
(
ms_ctx_param
.
device_id
,
device_id
)
@
property
def
set_max_call_depth
(
self
,
max_call_depth
):
def
max_call_depth
(
self
):
return
self
.
get_param
(
ms_ctx_param
.
max_call_depth
)
@
max_call_depth
.
setter
def
max_call_depth
(
self
,
max_call_depth
):
if
max_call_depth
<=
0
:
if
max_call_depth
<=
0
:
raise
ValueError
(
f
"Max call depth must be greater than 0, but got
{
max_call_depth
}
"
)
raise
ValueError
(
f
"Max call depth must be greater than 0, but got
{
max_call_depth
}
"
)
self
.
set_param
(
ms_ctx_param
.
max_call_depth
,
max_call_depth
)
self
.
set_param
(
ms_ctx_param
.
max_call_depth
,
max_call_depth
)
@
property
def
set_profiling_options
(
self
,
option
):
def
enable_auto_mixed_precision
(
self
):
return
self
.
get_param
(
ms_ctx_param
.
auto_mixed_precision_flag
)
@
enable_auto_mixed_precision
.
setter
def
enable_auto_mixed_precision
(
self
,
enable_auto_mixed_precision
):
self
.
set_param
(
ms_ctx_param
.
auto_mixed_precision_flag
,
enable_auto_mixed_precision
)
@
property
def
enable_reduce_precision
(
self
):
return
self
.
get_param
(
ms_ctx_param
.
enable_reduce_precision_flag
)
@
enable_reduce_precision
.
setter
def
enable_reduce_precision
(
self
,
enable_reduce_precision
):
self
.
set_param
(
ms_ctx_param
.
enable_reduce_precision_flag
,
enable_reduce_precision
)
@
property
def
enable_dump
(
self
):
return
self
.
get_param
(
ms_ctx_param
.
enable_dump
)
@
enable_dump
.
setter
def
enable_dump
(
self
,
enable_dump
):
self
.
set_param
(
ms_ctx_param
.
enable_dump
,
enable_dump
)
@
property
def
save_dump_path
(
self
):
return
self
.
get_param
(
ms_ctx_param
.
save_dump_path
)
@
save_dump_path
.
setter
def
save_dump_path
(
self
,
save_dump_path
):
self
.
set_param
(
ms_ctx_param
.
save_dump_path
,
save_dump_path
)
@
property
def
enable_profiling
(
self
):
return
self
.
get_param
(
ms_ctx_param
.
enable_profiling
)
@
enable_profiling
.
setter
def
enable_profiling
(
self
,
flag
):
self
.
set_param
(
ms_ctx_param
.
enable_profiling
,
flag
)
@
property
def
profiling_options
(
self
):
return
self
.
get_param
(
ms_ctx_param
.
profiling_options
)
@
profiling_options
.
setter
def
profiling_options
(
self
,
option
):
options
=
[
"training_trace"
,
"task_trace"
,
options
=
[
"training_trace"
,
"task_trace"
,
"task_trace:training_trace"
,
"training_trace:task_trace"
,
"op_trace"
]
"task_trace:training_trace"
,
"training_trace:task_trace"
,
"op_trace"
]
if
option
not
in
options
:
if
option
not
in
options
:
...
@@ -304,30 +218,7 @@ class _Context:
...
@@ -304,30 +218,7 @@ class _Context:
"'task_trace:training_trace' 'training_trace:task_trace' or 'op_trace'."
)
"'task_trace:training_trace' 'training_trace:task_trace' or 'op_trace'."
)
self
.
set_param
(
ms_ctx_param
.
profiling_options
,
option
)
self
.
set_param
(
ms_ctx_param
.
profiling_options
,
option
)
@
property
def
set_variable_memory_max_size
(
self
,
variable_memory_max_size
):
def
enable_graph_kernel
(
self
):
return
self
.
get_param
(
ms_ctx_param
.
enable_graph_kernel
)
@
enable_graph_kernel
.
setter
def
enable_graph_kernel
(
self
,
graph_kernel_switch_
):
self
.
set_param
(
ms_ctx_param
.
enable_graph_kernel
,
graph_kernel_switch_
)
@
property
def
reserve_class_name_in_scope
(
self
):
"""Gets whether to save the network class name in the scope."""
return
self
.
_thread_local_info
.
reserve_class_name_in_scope
@
reserve_class_name_in_scope
.
setter
def
reserve_class_name_in_scope
(
self
,
reserve_class_name_in_scope
):
"""Sets whether to save the network class name in the scope."""
self
.
_thread_local_info
.
reserve_class_name_in_scope
=
reserve_class_name_in_scope
@
property
def
variable_memory_max_size
(
self
):
return
None
@
variable_memory_max_size
.
setter
def
variable_memory_max_size
(
self
,
variable_memory_max_size
):
if
not
check_input_format
(
variable_memory_max_size
):
if
not
check_input_format
(
variable_memory_max_size
):
raise
ValueError
(
"Context param variable_memory_max_size should be in correct format! Such as
\"
5GB
\"
"
)
raise
ValueError
(
"Context param variable_memory_max_size should be in correct format! Such as
\"
5GB
\"
"
)
if
int
(
variable_memory_max_size
[:
-
2
])
>=
_DEVICE_APP_MEMORY_SIZE
:
if
int
(
variable_memory_max_size
[:
-
2
])
>=
_DEVICE_APP_MEMORY_SIZE
:
...
@@ -338,33 +229,7 @@ class _Context:
...
@@ -338,33 +229,7 @@ class _Context:
self
.
set_param
(
ms_ctx_param
.
variable_memory_max_size
,
variable_memory_max_size_
)
self
.
set_param
(
ms_ctx_param
.
variable_memory_max_size
,
variable_memory_max_size_
)
self
.
set_param
(
ms_ctx_param
.
graph_memory_max_size
,
graph_memory_max_size_
)
self
.
set_param
(
ms_ctx_param
.
graph_memory_max_size
,
graph_memory_max_size_
)
@
property
def
set_max_device_memory
(
self
,
max_device_memory
):
def
enable_ge
(
self
):
return
self
.
_context_handle
.
get_backend_policy
()
==
'ge'
@
property
def
enable_debug_runtime
(
self
):
return
self
.
_thread_local_info
.
debug_runtime
@
enable_debug_runtime
.
setter
def
enable_debug_runtime
(
self
,
enable
):
thread_info
=
self
.
_thread_local_info
thread_info
.
debug_runtime
=
enable
@
property
def
check_bprop
(
self
):
return
self
.
get_param
(
ms_ctx_param
.
check_bprop_flag
)
@
check_bprop
.
setter
def
check_bprop
(
self
,
check_bprop_flag
):
self
.
set_param
(
ms_ctx_param
.
check_bprop_flag
,
check_bprop_flag
)
@
property
def
max_device_memory
(
self
):
return
self
.
get_param
(
ms_ctx_param
.
max_device_memory
)
@
max_device_memory
.
setter
def
max_device_memory
(
self
,
max_device_memory
):
if
not
check_input_format
(
max_device_memory
):
if
not
check_input_format
(
max_device_memory
):
raise
ValueError
(
"Context param max_device_memory should be in correct format! Such as
\"
3.5GB
\"
"
)
raise
ValueError
(
"Context param max_device_memory should be in correct format! Such as
\"
3.5GB
\"
"
)
max_device_memory_value
=
float
(
max_device_memory
[:
-
2
])
max_device_memory_value
=
float
(
max_device_memory
[:
-
2
])
...
@@ -372,12 +237,7 @@ class _Context:
...
@@ -372,12 +237,7 @@ class _Context:
raise
ValueError
(
"Context param max_device_memory should be in correct format! Such as
\"
3.5GB
\"
"
)
raise
ValueError
(
"Context param max_device_memory should be in correct format! Such as
\"
3.5GB
\"
"
)
self
.
set_param
(
ms_ctx_param
.
max_device_memory
,
max_device_memory_value
)
self
.
set_param
(
ms_ctx_param
.
max_device_memory
,
max_device_memory_value
)
@
property
def
set_print_file_path
(
self
,
file_path
):
def
print_file_path
(
self
):
return
None
@
print_file_path
.
setter
def
print_file_path
(
self
,
file_path
):
"""Add timestamp suffix to file name. Sets print file path."""
"""Add timestamp suffix to file name. Sets print file path."""
print_file_path
=
os
.
path
.
realpath
(
file_path
)
print_file_path
=
os
.
path
.
realpath
(
file_path
)
if
os
.
path
.
isdir
(
print_file_path
):
if
os
.
path
.
isdir
(
print_file_path
):
...
@@ -392,13 +252,42 @@ class _Context:
...
@@ -392,13 +252,42 @@ class _Context:
full_file_name
=
print_file_path
full_file_name
=
print_file_path
self
.
set_param
(
ms_ctx_param
.
print_file_path
,
full_file_name
)
self
.
set_param
(
ms_ctx_param
.
print_file_path
,
full_file_name
)
setters
=
{
'mode'
:
set_mode
,
'backend_policy'
:
set_backend_policy
,
'save_graphs_path'
:
set_save_graphs_path
,
'device_target'
:
set_device_target
,
'device_id'
:
set_device_id
,
'max_call_depth'
:
set_max_call_depth
,
'profiling_options'
:
set_profiling_options
,
'variable_memory_max_size'
:
set_variable_memory_max_size
,
'max_device_memory'
:
set_max_device_memory
,
'print_file_path'
:
set_print_file_path
}
@
property
@
property
def
enable_sparse
(
self
):
def
reserve_class_name_in_scope
(
self
):
return
self
.
get_param
(
ms_ctx_param
.
enable_sparse
)
"""Gets whether to save the network class name in the scope."""
return
self
.
_thread_local_info
.
reserve_class_name_in_scope
@
reserve_class_name_in_scope
.
setter
def
reserve_class_name_in_scope
(
self
,
reserve_class_name_in_scope
):
"""Sets whether to save the network class name in the scope."""
self
.
_thread_local_info
.
reserve_class_name_in_scope
=
reserve_class_name_in_scope
@
property
def
enable_ge
(
self
):
return
self
.
_context_handle
.
get_backend_policy
()
==
'ge'
@
property
def
enable_debug_runtime
(
self
):
return
self
.
_thread_local_info
.
debug_runtime
@
enable_debug_runtime
.
setter
def
enable_debug_runtime
(
self
,
enable
):
thread_info
=
self
.
_thread_local_info
thread_info
.
debug_runtime
=
enable
@
enable_sparse
.
setter
def
enable_sparse
(
self
,
enable_sparse
):
self
.
set_param
(
ms_ctx_param
.
enable_sparse
,
enable_sparse
)
def
check_input_format
(
x
):
def
check_input_format
(
x
):
import
re
import
re
...
@@ -621,10 +510,18 @@ def set_context(**kwargs):
...
@@ -621,10 +510,18 @@ def set_context(**kwargs):
>>> context.set_context(print_file_path="print.pb")
>>> context.set_context(print_file_path="print.pb")
>>> context.set_context(max_call_depth=80)
>>> context.set_context(max_call_depth=80)
"""
"""
ctx
=
_context
()
for
key
,
value
in
kwargs
.
items
():
for
key
,
value
in
kwargs
.
items
():
if
not
hasattr
(
_context
(),
key
):
if
hasattr
(
ctx
,
key
):
raise
ValueError
(
"Set context keyword %s is not recognized!"
%
key
)
setattr
(
ctx
,
key
,
value
)
setattr
(
_context
(),
key
,
value
)
continue
if
key
in
ctx
.
setters
:
ctx
.
setters
[
key
](
ctx
,
value
)
continue
if
key
in
ms_ctx_param
.
__members__
:
ctx
.
set_param
(
ms_ctx_param
.
__members__
[
key
],
value
)
continue
raise
ValueError
(
"Set context keyword %s is not recognized!"
%
key
)
def
get_context
(
attr_key
):
def
get_context
(
attr_key
):
...
@@ -640,10 +537,13 @@ def get_context(attr_key):
...
@@ -640,10 +537,13 @@ def get_context(attr_key):
Raises:
Raises:
ValueError: If input key is not an attribute in context.
ValueError: If input key is not an attribute in context.
"""
"""
if
not
hasattr
(
_context
(),
attr_key
):
ctx
=
_context
()
raise
ValueError
(
if
hasattr
(
ctx
,
attr_key
):
"Get context keyword %s is not recognized!"
%
attr_key
)
return
getattr
(
ctx
,
attr_key
)
return
getattr
(
_context
(),
attr_key
)
if
attr_key
in
ms_ctx_param
.
__members__
:
return
ctx
.
get_param
(
ms_ctx_param
.
__members__
[
attr_key
])
raise
ValueError
(
"Get context keyword %s is not recognized!"
%
attr_key
)
class
ParallelMode
:
class
ParallelMode
:
"""
"""
...
...
mindspore/core/utils/ms_context.cc
浏览文件 @
0a858b38
...
@@ -60,7 +60,7 @@ MsContext::MsContext(const std::string &policy, const std::string &target) {
...
@@ -60,7 +60,7 @@ MsContext::MsContext(const std::string &policy, const std::string &target) {
#endif
#endif
set_param
<
bool
>
(
MS_CTX_ENABLE_GPU_SUMMARY
,
true
);
set_param
<
bool
>
(
MS_CTX_ENABLE_GPU_SUMMARY
,
true
);
set_param
<
bool
>
(
MS_CTX_PRECOMPILE_ONLY
,
false
);
set_param
<
bool
>
(
MS_CTX_PRECOMPILE_ONLY
,
false
);
set_param
<
bool
>
(
MS_CTX_
AUTO_MIXED_PRECISION_FLAG
,
false
);
set_param
<
bool
>
(
MS_CTX_
ENABLE_AUTO_MIXED_PRECISION
,
false
);
set_param
<
bool
>
(
MS_CTX_ENABLE_PYNATIVE_INFER
,
false
);
set_param
<
bool
>
(
MS_CTX_ENABLE_PYNATIVE_INFER
,
false
);
set_param
<
bool
>
(
MS_CTX_ENABLE_PYNATIVE_HOOK
,
false
);
set_param
<
bool
>
(
MS_CTX_ENABLE_PYNATIVE_HOOK
,
false
);
set_param
<
bool
>
(
MS_CTX_ENABLE_DYNAMIC_MEM_POOL
,
true
);
set_param
<
bool
>
(
MS_CTX_ENABLE_DYNAMIC_MEM_POOL
,
true
);
...
...
mindspore/core/utils/ms_context.h
浏览文件 @
0a858b38
...
@@ -53,7 +53,7 @@ const float kDefaultMaxDeviceMemory = 1024;
...
@@ -53,7 +53,7 @@ const float kDefaultMaxDeviceMemory = 1024;
enum
MsCtxParam
:
unsigned
{
enum
MsCtxParam
:
unsigned
{
// paramater of type bool
// paramater of type bool
MS_CTX_TYPE_BOOL_BEGIN
,
MS_CTX_TYPE_BOOL_BEGIN
,
MS_CTX_
AUTO_MIXED_PRECISION_FLAG
=
MS_CTX_TYPE_BOOL_BEGIN
,
MS_CTX_
ENABLE_AUTO_MIXED_PRECISION
=
MS_CTX_TYPE_BOOL_BEGIN
,
MS_CTX_CHECK_BPROP_FLAG
,
MS_CTX_CHECK_BPROP_FLAG
,
MS_CTX_ENABLE_DUMP
,
MS_CTX_ENABLE_DUMP
,
MS_CTX_ENABLE_DYNAMIC_MEM_POOL
,
MS_CTX_ENABLE_DYNAMIC_MEM_POOL
,
...
@@ -132,22 +132,22 @@ class MsContext {
...
@@ -132,22 +132,22 @@ class MsContext {
template
<
typename
T
>
template
<
typename
T
>
void
set_param
(
MsCtxParam
param
,
const
T
&
value
)
{
void
set_param
(
MsCtxParam
param
,
const
T
&
value
)
{
MS_LOG
(
EXCEPTION
)
<<
"Need
impleme
t "
<<
__FUNCTION__
<<
" for type "
<<
typeid
(
T
).
name
()
<<
"."
;
MS_LOG
(
EXCEPTION
)
<<
"Need
to implemen
t "
<<
__FUNCTION__
<<
" for type "
<<
typeid
(
T
).
name
()
<<
"."
;
}
}
template
<
typename
T
>
template
<
typename
T
>
const
T
&
get_param
(
MsCtxParam
param
)
const
{
const
T
&
get_param
(
MsCtxParam
param
)
const
{
MS_LOG
(
EXCEPTION
)
<<
"Need
impleme
t "
<<
__FUNCTION__
<<
" for type "
<<
typeid
(
T
).
name
()
<<
"."
;
MS_LOG
(
EXCEPTION
)
<<
"Need
to implemen
t "
<<
__FUNCTION__
<<
" for type "
<<
typeid
(
T
).
name
()
<<
"."
;
}
}
template
<
typename
T
>
template
<
typename
T
>
void
increase_param
(
MsCtxParam
param
)
{
void
increase_param
(
MsCtxParam
param
)
{
MS_LOG
(
EXCEPTION
)
<<
"Need
impleme
t "
<<
__FUNCTION__
<<
" for type "
<<
typeid
(
T
).
name
()
<<
"."
;
MS_LOG
(
EXCEPTION
)
<<
"Need
to implemen
t "
<<
__FUNCTION__
<<
" for type "
<<
typeid
(
T
).
name
()
<<
"."
;
}
}
template
<
typename
T
>
template
<
typename
T
>
void
decrease_param
(
MsCtxParam
param
)
{
void
decrease_param
(
MsCtxParam
param
)
{
MS_LOG
(
EXCEPTION
)
<<
"Need
impleme
t "
<<
__FUNCTION__
<<
" for type "
<<
typeid
(
T
).
name
()
<<
"."
;
MS_LOG
(
EXCEPTION
)
<<
"Need
to implemen
t "
<<
__FUNCTION__
<<
" for type "
<<
typeid
(
T
).
name
()
<<
"."
;
}
}
private:
private:
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录