Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
8d419314
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
8d419314
编写于
8月 31, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
8月 31, 2020
浏览文件
操作
浏览文件
下载
差异文件
!5352 refactor ms_context implementation
Merge pull request !5352 from fary86/refactor_context_interface
上级
4d963d96
fcbb3e0e
变更
77
隐藏空白更改
内联
并排
Showing
77 changed file
with
577 addition
and
549 deletion
+577
-549
mindspore/ccsrc/backend/kernel_compiler/hccl/hcom_all_broadcast.cc
.../ccsrc/backend/kernel_compiler/hccl/hcom_all_broadcast.cc
+1
-1
mindspore/ccsrc/backend/kernel_compiler/hccl/hcom_all_gather.cc
...ore/ccsrc/backend/kernel_compiler/hccl/hcom_all_gather.cc
+1
-1
mindspore/ccsrc/backend/kernel_compiler/hccl/hcom_all_reduce.cc
...ore/ccsrc/backend/kernel_compiler/hccl/hcom_all_reduce.cc
+1
-1
mindspore/ccsrc/backend/kernel_compiler/hccl/hcom_all_reduce_scatter.cc
...c/backend/kernel_compiler/hccl/hcom_all_reduce_scatter.cc
+1
-1
mindspore/ccsrc/backend/kernel_compiler/kernel_query.cc
mindspore/ccsrc/backend/kernel_compiler/kernel_query.cc
+2
-1
mindspore/ccsrc/backend/kernel_compiler/oplib/oplib.cc
mindspore/ccsrc/backend/kernel_compiler/oplib/oplib.cc
+1
-1
mindspore/ccsrc/backend/optimizer/ascend/ascend_backend_optimization.cc
...c/backend/optimizer/ascend/ascend_backend_optimization.cc
+24
-22
mindspore/ccsrc/backend/optimizer/ascend/format_type/insert_trans_op.cc
...c/backend/optimizer/ascend/format_type/insert_trans_op.cc
+2
-1
mindspore/ccsrc/backend/optimizer/ascend/format_type/rectify_do_mask_kernel_info.cc
...timizer/ascend/format_type/rectify_do_mask_kernel_info.cc
+1
-1
mindspore/ccsrc/backend/optimizer/common/common_backend_optimization.cc
...c/backend/optimizer/common/common_backend_optimization.cc
+2
-2
mindspore/ccsrc/backend/optimizer/common/helper.cc
mindspore/ccsrc/backend/optimizer/common/helper.cc
+2
-1
mindspore/ccsrc/backend/optimizer/common/pass_manager.cc
mindspore/ccsrc/backend/optimizer/common/pass_manager.cc
+2
-2
mindspore/ccsrc/backend/optimizer/pass/const_to_attr_strided_slice_grad.cc
...ackend/optimizer/pass/const_to_attr_strided_slice_grad.cc
+1
-1
mindspore/ccsrc/backend/session/ascend_control_parser.cc
mindspore/ccsrc/backend/session/ascend_control_parser.cc
+2
-2
mindspore/ccsrc/backend/session/ascend_session.cc
mindspore/ccsrc/backend/session/ascend_session.cc
+14
-14
mindspore/ccsrc/backend/session/gpu_session.cc
mindspore/ccsrc/backend/session/gpu_session.cc
+4
-4
mindspore/ccsrc/backend/session/infer_session.cc
mindspore/ccsrc/backend/session/infer_session.cc
+4
-4
mindspore/ccsrc/backend/session/session_basic.cc
mindspore/ccsrc/backend/session/session_basic.cc
+6
-5
mindspore/ccsrc/backend/session/session_basic.h
mindspore/ccsrc/backend/session/session_basic.h
+1
-1
mindspore/ccsrc/debug/data_dump_parser.cc
mindspore/ccsrc/debug/data_dump_parser.cc
+1
-1
mindspore/ccsrc/debug/debugger/debugger.cc
mindspore/ccsrc/debug/debugger/debugger.cc
+1
-1
mindspore/ccsrc/debug/dump_proto.cc
mindspore/ccsrc/debug/dump_proto.cc
+1
-1
mindspore/ccsrc/debug/e2e_dump.cc
mindspore/ccsrc/debug/e2e_dump.cc
+2
-2
mindspore/ccsrc/frontend/optimizer/ad/grad.cc
mindspore/ccsrc/frontend/optimizer/ad/grad.cc
+1
-1
mindspore/ccsrc/frontend/optimizer/ad/kprim.cc
mindspore/ccsrc/frontend/optimizer/ad/kprim.cc
+1
-1
mindspore/ccsrc/frontend/optimizer/irpass/arithmetic_simplify.cc
...re/ccsrc/frontend/optimizer/irpass/arithmetic_simplify.cc
+3
-3
mindspore/ccsrc/frontend/optimizer/optimizer.h
mindspore/ccsrc/frontend/optimizer/optimizer.h
+1
-1
mindspore/ccsrc/frontend/optimizer/py_pass.cc
mindspore/ccsrc/frontend/optimizer/py_pass.cc
+2
-2
mindspore/ccsrc/frontend/parallel/graph_util/graph_info.cc
mindspore/ccsrc/frontend/parallel/graph_util/graph_info.cc
+1
-1
mindspore/ccsrc/frontend/parallel/step_auto_parallel.cc
mindspore/ccsrc/frontend/parallel/step_auto_parallel.cc
+1
-1
mindspore/ccsrc/pipeline/jit/action.cc
mindspore/ccsrc/pipeline/jit/action.cc
+10
-10
mindspore/ccsrc/pipeline/jit/base.h
mindspore/ccsrc/pipeline/jit/base.h
+1
-1
mindspore/ccsrc/pipeline/jit/init.cc
mindspore/ccsrc/pipeline/jit/init.cc
+89
-44
mindspore/ccsrc/pipeline/jit/pass.cc
mindspore/ccsrc/pipeline/jit/pass.cc
+1
-1
mindspore/ccsrc/pipeline/jit/pipeline.cc
mindspore/ccsrc/pipeline/jit/pipeline.cc
+10
-9
mindspore/ccsrc/pipeline/jit/pipeline_ge.cc
mindspore/ccsrc/pipeline/jit/pipeline_ge.cc
+2
-2
mindspore/ccsrc/pipeline/jit/static_analysis/evaluator.cc
mindspore/ccsrc/pipeline/jit/static_analysis/evaluator.cc
+4
-3
mindspore/ccsrc/pipeline/jit/static_analysis/evaluator.h
mindspore/ccsrc/pipeline/jit/static_analysis/evaluator.h
+1
-1
mindspore/ccsrc/pipeline/jit/static_analysis/prim.cc
mindspore/ccsrc/pipeline/jit/static_analysis/prim.cc
+1
-1
mindspore/ccsrc/pipeline/pynative/pynative_execute.cc
mindspore/ccsrc/pipeline/pynative/pynative_execute.cc
+5
-5
mindspore/ccsrc/pybind_api/ir/primitive_py.cc
mindspore/ccsrc/pybind_api/ir/primitive_py.cc
+1
-1
mindspore/ccsrc/runtime/device/ascend/ascend_device_address.cc
...pore/ccsrc/runtime/device/ascend/ascend_device_address.cc
+7
-5
mindspore/ccsrc/runtime/device/ascend/ascend_kernel_runtime.cc
...pore/ccsrc/runtime/device/ascend/ascend_kernel_runtime.cc
+7
-7
mindspore/ccsrc/runtime/device/ascend/ascend_memory_manager.cc
...pore/ccsrc/runtime/device/ascend/ascend_memory_manager.cc
+1
-1
mindspore/ccsrc/runtime/device/ascend/ascend_stream_assign.cc
...spore/ccsrc/runtime/device/ascend/ascend_stream_assign.cc
+1
-1
mindspore/ccsrc/runtime/device/ascend/dump/data_dumper.cc
mindspore/ccsrc/runtime/device/ascend/dump/data_dumper.cc
+1
-1
mindspore/ccsrc/runtime/device/ascend/kernel_select_ascend.cc
...spore/ccsrc/runtime/device/ascend/kernel_select_ascend.cc
+1
-1
mindspore/ccsrc/runtime/device/ascend/profiling/profiling_manager.cc
...csrc/runtime/device/ascend/profiling/profiling_manager.cc
+1
-1
mindspore/ccsrc/runtime/device/ascend/profiling/profiling_manager.h
...ccsrc/runtime/device/ascend/profiling/profiling_manager.h
+1
-1
mindspore/ccsrc/runtime/device/ascend/profiling/profiling_utils.cc
.../ccsrc/runtime/device/ascend/profiling/profiling_utils.cc
+3
-3
mindspore/ccsrc/runtime/device/gpu/gpu_kernel_runtime.cc
mindspore/ccsrc/runtime/device/gpu/gpu_kernel_runtime.cc
+3
-3
mindspore/ccsrc/runtime/device/gpu/gpu_memory_allocator.cc
mindspore/ccsrc/runtime/device/gpu/gpu_memory_allocator.cc
+2
-2
mindspore/ccsrc/runtime/device/gpu/gpu_memory_manager.cc
mindspore/ccsrc/runtime/device/gpu/gpu_memory_manager.cc
+2
-2
mindspore/ccsrc/runtime/device/gpu/kernel_info_setter.cc
mindspore/ccsrc/runtime/device/gpu/kernel_info_setter.cc
+1
-1
mindspore/ccsrc/runtime/device/kernel_adjust.cc
mindspore/ccsrc/runtime/device/kernel_adjust.cc
+2
-2
mindspore/ccsrc/runtime/device/kernel_runtime.cc
mindspore/ccsrc/runtime/device/kernel_runtime.cc
+7
-5
mindspore/ccsrc/runtime/device/memory_manager.cc
mindspore/ccsrc/runtime/device/memory_manager.cc
+1
-1
mindspore/ccsrc/transform/graph_ir/convert.cc
mindspore/ccsrc/transform/graph_ir/convert.cc
+1
-1
mindspore/ccsrc/utils/context/context_extends.cc
mindspore/ccsrc/utils/context/context_extends.cc
+44
-41
mindspore/ccsrc/utils/convert_utils_py.cc
mindspore/ccsrc/utils/convert_utils_py.cc
+1
-1
mindspore/ccsrc/utils/tensorprint_utils.cc
mindspore/ccsrc/utils/tensorprint_utils.cc
+1
-1
mindspore/ccsrc/vm/backend.cc
mindspore/ccsrc/vm/backend.cc
+2
-2
mindspore/ccsrc/vm/transform.cc
mindspore/ccsrc/vm/transform.cc
+9
-9
mindspore/context.py
mindspore/context.py
+62
-65
mindspore/core/abstract/prim_others.cc
mindspore/core/abstract/prim_others.cc
+1
-1
mindspore/core/ir/anf.cc
mindspore/core/ir/anf.cc
+2
-2
mindspore/core/ir/func_graph_cloner.cc
mindspore/core/ir/func_graph_cloner.cc
+1
-1
mindspore/core/ir/meta_func_graph.cc
mindspore/core/ir/meta_func_graph.cc
+1
-1
mindspore/core/utils/ms_context.cc
mindspore/core/utils/ms_context.cc
+37
-86
mindspore/core/utils/ms_context.h
mindspore/core/utils/ms_context.h
+153
-133
tests/ut/cpp/optimizer/lib_test.cc
tests/ut/cpp/optimizer/lib_test.cc
+1
-1
tests/ut/cpp/pre_activate/ascend/format_type/insert_trans_op_test.cc
...p/pre_activate/ascend/format_type/insert_trans_op_test.cc
+1
-1
tests/ut/cpp/pre_activate/ascend/format_type/remove_internal_output_test.cc
...ctivate/ascend/format_type/remove_internal_output_test.cc
+1
-1
tests/ut/cpp/pre_activate/ascend/ir_fission/transdata_split_test.cc
...pp/pre_activate/ascend/ir_fission/transdata_split_test.cc
+1
-1
tests/ut/cpp/pre_activate/ascend/ir_fusion/transpose_transdata_fusion_test.cc
...ivate/ascend/ir_fusion/transpose_transdata_fusion_test.cc
+1
-1
tests/ut/cpp/pre_activate/pass/eliminate_redundant_op_test.cc
...s/ut/cpp/pre_activate/pass/eliminate_redundant_op_test.cc
+1
-1
tests/ut/cpp/pynative/pynative_execute_test.cc
tests/ut/cpp/pynative/pynative_execute_test.cc
+3
-3
未找到文件。
mindspore/ccsrc/backend/kernel_compiler/hccl/hcom_all_broadcast.cc
浏览文件 @
8d419314
...
...
@@ -25,7 +25,7 @@ bool HcomAllBroadCastKernel::Launch(const std::vector<AddressPtr> &inputs,
const
std
::
vector
<
AddressPtr
>
&
/*outputs*/
,
void
*
stream_ptr
)
{
auto
context_ptr
=
MsContext
::
GetInstance
();
MS_EXCEPTION_IF_NULL
(
context_ptr
);
if
(
context_ptr
->
enable_task_sink
(
))
{
if
(
context_ptr
->
get_param
<
bool
>
(
MS_CTX_ENABLE_TASK_SINK
))
{
return
true
;
}
if
(
inputs
.
empty
()
||
hccl_data_type_list_
.
empty
())
{
...
...
mindspore/ccsrc/backend/kernel_compiler/hccl/hcom_all_gather.cc
浏览文件 @
8d419314
...
...
@@ -24,7 +24,7 @@ bool HcomAllGatherKernel::Launch(const std::vector<AddressPtr> &inputs, const st
const
std
::
vector
<
AddressPtr
>
&
outputs
,
void
*
stream_ptr
)
{
auto
context_ptr
=
MsContext
::
GetInstance
();
MS_EXCEPTION_IF_NULL
(
context_ptr
);
if
(
context_ptr
->
enable_task_sink
(
))
{
if
(
context_ptr
->
get_param
<
bool
>
(
MS_CTX_ENABLE_TASK_SINK
))
{
return
true
;
}
if
(
inputs
.
empty
()
||
hccl_data_type_list_
.
empty
())
{
...
...
mindspore/ccsrc/backend/kernel_compiler/hccl/hcom_all_reduce.cc
浏览文件 @
8d419314
...
...
@@ -24,7 +24,7 @@ bool HcomAllReduceKernel::Launch(const std::vector<AddressPtr> &inputs, const st
const
std
::
vector
<
AddressPtr
>
&
outputs
,
void
*
stream_ptr
)
{
auto
context_ptr
=
MsContext
::
GetInstance
();
MS_EXCEPTION_IF_NULL
(
context_ptr
);
if
(
context_ptr
->
enable_task_sink
(
))
{
if
(
context_ptr
->
get_param
<
bool
>
(
MS_CTX_ENABLE_TASK_SINK
))
{
return
true
;
}
if
(
inputs
.
empty
()
||
outputs
.
empty
()
||
hccl_data_type_list_
.
empty
())
{
...
...
mindspore/ccsrc/backend/kernel_compiler/hccl/hcom_all_reduce_scatter.cc
浏览文件 @
8d419314
...
...
@@ -25,7 +25,7 @@ bool HcomAllReduceScatterKernel::Launch(const std::vector<AddressPtr> &inputs,
const
std
::
vector
<
AddressPtr
>
&
outputs
,
void
*
stream_ptr
)
{
auto
context_ptr
=
MsContext
::
GetInstance
();
MS_EXCEPTION_IF_NULL
(
context_ptr
);
if
(
context_ptr
->
enable_task_sink
(
))
{
if
(
context_ptr
->
get_param
<
bool
>
(
MS_CTX_ENABLE_TASK_SINK
))
{
return
true
;
}
if
(
inputs
.
empty
()
||
outputs
.
empty
()
||
hccl_data_type_list_
.
empty
())
{
...
...
mindspore/ccsrc/backend/kernel_compiler/kernel_query.cc
浏览文件 @
8d419314
...
...
@@ -101,7 +101,8 @@ void KernelQuery(const CNodePtr &kernel_node, std::vector<std::shared_ptr<kernel
auto
context_ptr
=
MsContext
::
GetInstance
();
MS_EXCEPTION_IF_NULL
(
context_ptr
);
if
(
context_ptr
->
enable_graph_kernel
()
&&
IsPrimitiveCNode
(
kernel_node
,
prim
::
kPrimBatchMatMul
))
{
if
(
context_ptr
->
get_param
<
bool
>
(
MS_CTX_ENABLE_GRAPH_KERNEL
)
&&
IsPrimitiveCNode
(
kernel_node
,
prim
::
kPrimBatchMatMul
))
{
kernel_type
=
KernelType
::
AKG_KERNEL
;
}
...
...
mindspore/ccsrc/backend/kernel_compiler/oplib/oplib.cc
浏览文件 @
8d419314
...
...
@@ -328,7 +328,7 @@ std::shared_ptr<OpInfo> OpLib::FindOp(const std::string &op_name, OpImplyType im
}
auto
context
=
MsContext
::
GetInstance
();
MS_EXCEPTION_IF_NULL
(
context
);
bool
is_gpu
=
(
context
->
device_target
(
)
==
kGPUDevice
);
bool
is_gpu
=
(
context
->
get_param
<
std
::
string
>
(
MS_CTX_DEVICE_TARGET
)
==
kGPUDevice
);
if
(
is_gpu
&&
(
imply_type
==
kTBE
||
imply_type
==
kAICPU
))
{
MS_LOG
(
ERROR
)
<<
"FindOp failed: opname: "
<<
op_name
<<
", imply_type: "
<<
ImplTypeToStr
(
imply_type
)
<<
", current op num: "
<<
op_info_
.
size
();
...
...
mindspore/ccsrc/backend/optimizer/ascend/ascend_backend_optimization.cc
浏览文件 @
8d419314
...
...
@@ -249,8 +249,8 @@ void AscendMixPrecision(const std::shared_ptr<session::KernelGraph> &kernel_grap
void
AscendBackendIRFusionOptimization
(
const
std
::
shared_ptr
<
session
::
KernelGraph
>
&
kernel_graph
)
{
auto
context_ptr
=
MsContext
::
GetInstance
();
MS_EXCEPTION_IF_NULL
(
context_ptr
);
bool
save_graphs
=
context_ptr
->
save_graphs_flag
(
);
auto
save_graphs_path
=
context_ptr
->
save_graphs_path
(
);
bool
save_graphs
=
context_ptr
->
get_param
<
bool
>
(
MS_CTX_SAVE_GRAPHS_FLAG
);
auto
save_graphs_path
=
context_ptr
->
get_param
<
std
::
string
>
(
MS_CTX_SAVE_GRAPHS_PATH
);
if
(
save_graphs_path
.
empty
())
{
save_graphs_path
=
"."
;
}
...
...
@@ -262,7 +262,7 @@ void AscendBackendIRFusionOptimization(const std::shared_ptr<session::KernelGrap
}
auto
optimizer
=
std
::
make_shared
<
GraphOptimizer
>
();
auto
ir_fusion_pm
=
std
::
make_shared
<
PassManager
>
(
"ir_fusion_pm"
);
if
(
context_ptr
->
execution_mode
(
)
==
kPynativeMode
)
{
if
(
context_ptr
->
get_param
<
int
>
(
MS_CTX_EXECUTION_MODE
)
==
kPynativeMode
)
{
ir_fusion_pm
->
AddPass
(
std
::
make_shared
<
BnSplit
>
());
ir_fusion_pm
->
AddPass
(
std
::
make_shared
<
BnGradSplit
>
());
}
else
{
...
...
@@ -276,7 +276,8 @@ void AscendBackendIRFusionOptimization(const std::shared_ptr<session::KernelGrap
AddAscendIRFusionRulesPass
(
ir_fusion_pm
.
get
());
AddAscendIRFusionPass
(
ir_fusion_pm
.
get
());
if
(
context_ptr
->
enable_task_sink
()
&&
context_ptr
->
loop_sink_flag
()
&&
ConfigManager
::
GetInstance
().
iter_num
()
>
1
)
{
if
(
context_ptr
->
get_param
<
bool
>
(
MS_CTX_ENABLE_TASK_SINK
)
&&
context_ptr
->
get_param
<
bool
>
(
MS_CTX_ENABLE_LOOP_SINK
)
&&
ConfigManager
::
GetInstance
().
iter_num
()
>
1
)
{
ir_fusion_pm
->
AddPass
(
std
::
make_shared
<
InsertMemcpyAsyncForGetNext
>
());
ir_fusion_pm
->
AddPass
(
std
::
make_shared
<
GetitemTuple
>
());
ir_fusion_pm
->
AddPass
(
std
::
make_shared
<
EraseVisitAttr
>
());
...
...
@@ -296,12 +297,12 @@ void AscendBackendIRFusionOptimization(const std::shared_ptr<session::KernelGrap
void
RunOpAscendBackendIRFusionOptimization
(
const
std
::
shared_ptr
<
session
::
KernelGraph
>
&
kernel_graph
)
{
auto
context_ptr
=
MsContext
::
GetInstance
();
MS_EXCEPTION_IF_NULL
(
context_ptr
);
if
(
!
context_ptr
->
ir_fusion_flag
(
))
{
if
(
!
context_ptr
->
get_param
<
bool
>
(
MS_CTX_IR_FUSION_FLAG
))
{
MS_LOG
(
INFO
)
<<
"IRFusion is not enable, skip"
;
return
;
}
bool
save_graphs
=
context_ptr
->
save_graphs_flag
(
);
auto
save_graphs_path
=
context_ptr
->
save_graphs_path
(
);
bool
save_graphs
=
context_ptr
->
get_param
<
bool
>
(
MS_CTX_SAVE_GRAPHS_FLAG
);
auto
save_graphs_path
=
context_ptr
->
get_param
<
std
::
string
>
(
MS_CTX_SAVE_GRAPHS_PATH
);
if
(
save_graphs_path
.
empty
())
{
save_graphs_path
=
"."
;
}
...
...
@@ -331,8 +332,8 @@ void RunOpAscendBackendIRFusionOptimization(const std::shared_ptr<session::Kerne
void
AscendBackendOptimization
(
const
std
::
shared_ptr
<
session
::
KernelGraph
>
&
kernel_graph
)
{
auto
context_ptr
=
MsContext
::
GetInstance
();
MS_EXCEPTION_IF_NULL
(
context_ptr
);
bool
save_graphs
=
context_ptr
->
save_graphs_flag
(
);
auto
save_graphs_path
=
context_ptr
->
save_graphs_path
(
);
bool
save_graphs
=
context_ptr
->
get_param
<
bool
>
(
MS_CTX_SAVE_GRAPHS_FLAG
);
auto
save_graphs_path
=
context_ptr
->
get_param
<
std
::
string
>
(
MS_CTX_SAVE_GRAPHS_PATH
);
if
(
save_graphs_path
.
empty
())
{
save_graphs_path
=
"."
;
}
...
...
@@ -367,7 +368,8 @@ void AscendBackendOptimization(const std::shared_ptr<session::KernelGraph> &kern
auto
other2_pm
=
std
::
make_shared
<
PassManager
>
(
"other2_pm"
);
other2_pm
->
AddPass
(
std
::
make_shared
<
GetitemTuple
>
());
other2_pm
->
AddPass
(
std
::
make_shared
<
CommonSubexpressionElimination
>
());
if
(
context_ptr
->
enable_task_sink
()
&&
context_ptr
->
loop_sink_flag
()
&&
ConfigManager
::
GetInstance
().
iter_num
()
>
1
)
{
if
(
context_ptr
->
get_param
<
bool
>
(
MS_CTX_ENABLE_TASK_SINK
)
&&
context_ptr
->
get_param
<
bool
>
(
MS_CTX_ENABLE_LOOP_SINK
)
&&
ConfigManager
::
GetInstance
().
iter_num
()
>
1
)
{
other2_pm
->
AddPass
(
std
::
make_shared
<
GetnextMemcpyElimination
>
());
}
other2_pm
->
AddPass
(
std
::
make_shared
<
CheckConsistency
>
());
...
...
@@ -388,11 +390,11 @@ void AscendBackendGraphKernelOpt(const std::shared_ptr<session::KernelGraph> &ke
bool
is_before_kernel_select
)
{
auto
context_ptr
=
MsContext
::
GetInstance
();
MS_EXCEPTION_IF_NULL
(
context_ptr
);
if
(
!
(
context_ptr
->
enable_graph_kernel
(
)))
{
if
(
!
(
context_ptr
->
get_param
<
bool
>
(
MS_CTX_ENABLE_GRAPH_KERNEL
)))
{
return
;
}
bool
save_graphs
=
context_ptr
->
save_graphs_flag
(
);
auto
save_graphs_path
=
context_ptr
->
save_graphs_path
(
);
bool
save_graphs
=
context_ptr
->
get_param
<
bool
>
(
MS_CTX_SAVE_GRAPHS_FLAG
);
auto
save_graphs_path
=
context_ptr
->
get_param
<
std
::
string
>
(
MS_CTX_SAVE_GRAPHS_PATH
);
if
(
save_graphs_path
.
empty
())
{
save_graphs_path
=
"."
;
}
...
...
@@ -418,11 +420,11 @@ void AscendBackendFuseBasicOpt(const std::shared_ptr<session::KernelGraph> &kern
bool
is_before_kernel_select
)
{
auto
context_ptr
=
MsContext
::
GetInstance
();
MS_EXCEPTION_IF_NULL
(
context_ptr
);
if
(
!
(
context_ptr
->
enable_graph_kernel
(
)))
{
if
(
!
(
context_ptr
->
get_param
<
bool
>
(
MS_CTX_ENABLE_GRAPH_KERNEL
)))
{
return
;
}
bool
save_graphs
=
context_ptr
->
save_graphs_flag
(
);
auto
save_graphs_path
=
context_ptr
->
save_graphs_path
(
);
bool
save_graphs
=
context_ptr
->
get_param
<
bool
>
(
MS_CTX_SAVE_GRAPHS_FLAG
);
auto
save_graphs_path
=
context_ptr
->
get_param
<
std
::
string
>
(
MS_CTX_SAVE_GRAPHS_PATH
);
if
(
save_graphs_path
.
empty
())
{
save_graphs_path
=
"."
;
}
...
...
@@ -447,11 +449,11 @@ void AscendBackendFuseBasicOpt(const std::shared_ptr<session::KernelGraph> &kern
void
AscendBackendAddAtomicClean
(
const
std
::
shared_ptr
<
session
::
KernelGraph
>
&
kernel_graph
)
{
auto
context_ptr
=
MsContext
::
GetInstance
();
MS_EXCEPTION_IF_NULL
(
context_ptr
);
if
(
!
(
context_ptr
->
enable_graph_kernel
(
)))
{
if
(
!
(
context_ptr
->
get_param
<
bool
>
(
MS_CTX_ENABLE_GRAPH_KERNEL
)))
{
return
;
}
bool
save_graphs
=
context_ptr
->
save_graphs_flag
(
);
auto
save_graphs_path
=
context_ptr
->
save_graphs_path
(
);
bool
save_graphs
=
context_ptr
->
get_param
<
bool
>
(
MS_CTX_SAVE_GRAPHS_FLAG
);
auto
save_graphs_path
=
context_ptr
->
get_param
<
std
::
string
>
(
MS_CTX_SAVE_GRAPHS_PATH
);
if
(
save_graphs_path
.
empty
())
{
save_graphs_path
=
"."
;
}
...
...
@@ -473,12 +475,12 @@ void AscendBackendAddAtomicClean(const std::shared_ptr<session::KernelGraph> &ke
void
AscendBackendUBFusionOptimization
(
const
std
::
shared_ptr
<
session
::
KernelGraph
>
&
kernel_graph
)
{
auto
context_ptr
=
MsContext
::
GetInstance
();
MS_EXCEPTION_IF_NULL
(
context_ptr
);
if
(
!
context_ptr
->
ir_fusion_flag
(
))
{
if
(
!
context_ptr
->
get_param
<
bool
>
(
MS_CTX_IR_FUSION_FLAG
))
{
MS_LOG
(
INFO
)
<<
"UBFusion is not enable, skip"
;
return
;
}
bool
save_graphs
=
context_ptr
->
save_graphs_flag
(
);
auto
save_graphs_path
=
context_ptr
->
save_graphs_path
(
);
bool
save_graphs
=
context_ptr
->
get_param
<
bool
>
(
MS_CTX_SAVE_GRAPHS_FLAG
);
auto
save_graphs_path
=
context_ptr
->
get_param
<
std
::
string
>
(
MS_CTX_SAVE_GRAPHS_PATH
);
if
(
save_graphs_path
.
empty
())
{
save_graphs_path
=
"."
;
}
...
...
mindspore/ccsrc/backend/optimizer/ascend/format_type/insert_trans_op.cc
浏览文件 @
8d419314
...
...
@@ -53,7 +53,8 @@ const AnfNodePtr InsertTransOp::Process(const FuncGraphPtr &func_graph, const An
}
auto
ms_context
=
MsContext
::
GetInstance
();
MS_EXCEPTION_IF_NULL
(
ms_context
);
if
(
ms_context
->
execution_mode
()
==
kPynativeMode
&&
!
ms_context
->
enable_pynative_hook
())
{
if
(
ms_context
->
get_param
<
int
>
(
MS_CTX_EXECUTION_MODE
)
==
kPynativeMode
&&
!
ms_context
->
get_param
<
bool
>
(
MS_CTX_ENABLE_PYNATIVE_HOOK
))
{
if
(
IsGraphOutput
(
node
,
AnfAlgo
::
GetAllOutput
(
func_graph
->
output
(),
{
prim
::
kPrimTupleGetItem
})))
{
return
new_node
;
}
...
...
mindspore/ccsrc/backend/optimizer/ascend/format_type/rectify_do_mask_kernel_info.cc
浏览文件 @
8d419314
...
...
@@ -44,7 +44,7 @@ const AnfNodePtr RectifyDoMaskKernelInfo::Process(const FuncGraphPtr &graph, con
auto
cnode
=
node
->
cast
<
CNodePtr
>
();
auto
ms_context
=
MsContext
::
GetInstance
();
MS_EXCEPTION_IF_NULL
(
ms_context
);
if
(
ms_context
->
execution_mode
(
)
==
kPynativeMode
)
{
if
(
ms_context
->
get_param
<
int
>
(
MS_CTX_EXECUTION_MODE
)
==
kPynativeMode
)
{
return
RectifyKernelInfoInPynativeProcess
(
node
);
}
if
(
AnfAlgo
::
GetCNodeName
(
cnode
)
!=
prim
::
kPrimDropoutGenMask
->
name
())
{
...
...
mindspore/ccsrc/backend/optimizer/common/common_backend_optimization.cc
浏览文件 @
8d419314
...
...
@@ -33,8 +33,8 @@ void BackendCommonOptimization(const std::shared_ptr<session::KernelGraph> &kern
MS_LOG
(
INFO
)
<<
"start common opt graph:"
<<
kernel_graph
->
graph_id
();
auto
context_ptr
=
MsContext
::
GetInstance
();
MS_EXCEPTION_IF_NULL
(
context_ptr
);
bool
save_graphs
=
context_ptr
->
save_graphs_flag
(
);
auto
save_graphs_path
=
context_ptr
->
save_graphs_path
(
);
bool
save_graphs
=
context_ptr
->
get_param
<
bool
>
(
MS_CTX_SAVE_GRAPHS_FLAG
);
auto
save_graphs_path
=
context_ptr
->
get_param
<
std
::
string
>
(
MS_CTX_SAVE_GRAPHS_PATH
);
if
(
save_graphs_path
.
empty
())
{
save_graphs_path
=
"."
;
}
...
...
mindspore/ccsrc/backend/optimizer/common/helper.cc
浏览文件 @
8d419314
...
...
@@ -392,7 +392,8 @@ tensor::TensorPtr CreateTupleTensor(const ValueTuplePtr &value_tuple) {
bool
IsNopNode
(
const
AnfNodePtr
&
node
)
{
auto
context_ptr
=
MsContext
::
GetInstance
();
MS_EXCEPTION_IF_NULL
(
context_ptr
);
if
(
context_ptr
->
device_target
()
!=
kAscendDevice
&&
context_ptr
->
device_target
()
!=
kGPUDevice
)
{
if
(
context_ptr
->
get_param
<
std
::
string
>
(
MS_CTX_DEVICE_TARGET
)
!=
kAscendDevice
&&
context_ptr
->
get_param
<
std
::
string
>
(
MS_CTX_DEVICE_TARGET
)
!=
kGPUDevice
)
{
return
false
;
}
static
std
::
unordered_set
<
std
::
string
>
nop_nodes
=
{
prim
::
kPrimReshape
->
name
(),
kExpandDimsOpName
,
...
...
mindspore/ccsrc/backend/optimizer/common/pass_manager.cc
浏览文件 @
8d419314
...
...
@@ -40,8 +40,8 @@ bool PassManager::Run(const FuncGraphPtr &func_graph, const std::vector<PassPtr>
}
auto
context_ptr
=
MsContext
::
GetInstance
();
MS_EXCEPTION_IF_NULL
(
context_ptr
);
bool
save_graphs
=
context_ptr
->
save_graphs_flag
(
);
auto
save_graphs_path
=
context_ptr
->
save_graphs_path
(
);
bool
save_graphs
=
context_ptr
->
get_param
<
bool
>
(
MS_CTX_SAVE_GRAPHS_FLAG
);
auto
save_graphs_path
=
context_ptr
->
get_param
<
std
::
string
>
(
MS_CTX_SAVE_GRAPHS_PATH
);
if
(
save_graphs_path
.
empty
())
{
save_graphs_path
=
"."
;
}
...
...
mindspore/ccsrc/backend/optimizer/pass/const_to_attr_strided_slice_grad.cc
浏览文件 @
8d419314
...
...
@@ -114,7 +114,7 @@ const AnfNodePtr ConstToAttrStridedSliceGradPass::Process(const FuncGraphPtr &gr
auto
ms_context
=
MsContext
::
GetInstance
();
MS_EXCEPTION_IF_NULL
(
ms_context
);
if
(
ms_context
->
device_target
(
)
==
kAscendDevice
)
{
if
(
ms_context
->
get_param
<
std
::
string
>
(
MS_CTX_DEVICE_TARGET
)
==
kAscendDevice
)
{
if
(
!
CheckAttrs
(
strided_slice_grad
))
{
MS_LOG
(
INFO
)
<<
"Check strided_slice_grad's attrs failed, graph not changed"
;
return
nullptr
;
...
...
mindspore/ccsrc/backend/session/ascend_control_parser.cc
浏览文件 @
8d419314
...
...
@@ -359,11 +359,11 @@ void AscendControlParser::ExecutorValidate(NotNull<KernelGraphPtr> root_graph) {
auto
context_ptr
=
MsContext
::
GetInstance
();
MS_EXCEPTION_IF_NULL
(
context_ptr
);
auto
save_graphs_path
=
context_ptr
->
save_graphs_path
(
);
auto
save_graphs_path
=
context_ptr
->
get_param
<
std
::
string
>
(
MS_CTX_SAVE_GRAPHS_PATH
);
if
(
save_graphs_path
.
empty
())
{
save_graphs_path
=
"."
;
}
if
(
context_ptr
->
save_graphs_flag
(
))
{
if
(
context_ptr
->
get_param
<
bool
>
(
MS_CTX_SAVE_GRAPHS_FLAG
))
{
std
::
string
file_path
=
save_graphs_path
+
"/after_erase_label_and_parameter.ir"
;
DumpIR
(
file_path
,
root_graph
.
get
());
}
...
...
mindspore/ccsrc/backend/session/ascend_session.cc
浏览文件 @
8d419314
...
...
@@ -253,7 +253,7 @@ void AscendSession::BuildGraph(GraphId graph_id) {
debugger_
->
PreExecute
(
graph
);
}
#endif
if
(
ms_context
->
precompile_only
(
))
{
if
(
ms_context
->
get_param
<
bool
>
(
MS_CTX_PRECOMPILE_ONLY
))
{
MS_LOG
(
INFO
)
<<
"Precompile only, stop in build kernel step"
;
}
else
{
// alloc memory, including static memory and dynamic memory
...
...
@@ -278,8 +278,8 @@ void AscendSession::CompileChildGraph(const KernelGraphPtr &child_graph) {
child_graph
->
SetExecOrderByDefault
();
auto
context_ptr
=
MsContext
::
GetInstance
();
MS_EXCEPTION_IF_NULL
(
context_ptr
);
bool
save_graphs
=
context_ptr
->
save_graphs_flag
(
);
auto
save_graphs_path
=
context_ptr
->
save_graphs_path
(
);
bool
save_graphs
=
context_ptr
->
get_param
<
bool
>
(
MS_CTX_SAVE_GRAPHS_FLAG
);
auto
save_graphs_path
=
context_ptr
->
get_param
<
std
::
string
>
(
MS_CTX_SAVE_GRAPHS_PATH
);
if
(
save_graphs_path
.
empty
())
{
save_graphs_path
=
"."
;
}
...
...
@@ -436,7 +436,7 @@ void AscendSession::SelectKernel(const KernelGraph &kernel_graph) const {
}
auto
ms_context
=
MsContext
::
GetInstance
();
MS_EXCEPTION_IF_NULL
(
ms_context
);
if
(
ms_context
->
execution_mode
(
)
==
kGraphMode
)
{
if
(
ms_context
->
get_param
<
int
>
(
MS_CTX_EXECUTION_MODE
)
==
kGraphMode
)
{
if
(
raise_precision_count
>
0
)
{
MS_LOG
(
WARNING
)
<<
"There has "
<<
raise_precision_count
<<
" node/nodes used raise precision to selected the kernel!"
;
...
...
@@ -481,8 +481,8 @@ void AscendSession::AdjustKernel(const std::shared_ptr<KernelGraph> &kernel_grap
device
::
KernelAdjust
::
GetInstance
().
InsertSwitchLoop
(
kernel_graph
);
auto
context_ptr
=
MsContext
::
GetInstance
();
MS_EXCEPTION_IF_NULL
(
context_ptr
);
bool
save_graphs
=
context_ptr
->
save_graphs_flag
(
);
auto
save_graphs_path
=
context_ptr
->
save_graphs_path
(
);
bool
save_graphs
=
context_ptr
->
get_param
<
bool
>
(
MS_CTX_SAVE_GRAPHS_FLAG
);
auto
save_graphs_path
=
context_ptr
->
get_param
<
std
::
string
>
(
MS_CTX_SAVE_GRAPHS_PATH
);
if
(
save_graphs_path
.
empty
())
{
save_graphs_path
=
"."
;
}
...
...
@@ -601,11 +601,11 @@ void AscendSession::DumpAllGraphs(const std::vector<KernelGraphPtr> &all_graphs)
#ifdef ENABLE_DUMP_IR
auto
context_ptr
=
MsContext
::
GetInstance
();
MS_EXCEPTION_IF_NULL
(
context_ptr
);
bool
save_graphs
=
context_ptr
->
save_graphs_flag
(
);
bool
save_graphs
=
context_ptr
->
get_param
<
bool
>
(
MS_CTX_SAVE_GRAPHS_FLAG
);
if
(
!
save_graphs
)
{
return
;
}
auto
save_graphs_path
=
context_ptr
->
save_graphs_path
(
);
auto
save_graphs_path
=
context_ptr
->
get_param
<
std
::
string
>
(
MS_CTX_SAVE_GRAPHS_PATH
);
if
(
save_graphs_path
.
empty
())
{
save_graphs_path
=
"."
;
}
...
...
@@ -733,7 +733,7 @@ void AscendSession::MergeGraphExecOrder() {
if
(
graph_order
.
size
()
>
1
)
{
auto
context_ptr
=
MsContext
::
GetInstance
();
MS_EXCEPTION_IF_NULL
(
context_ptr
);
if
(
!
context_ptr
->
enable_task_sink
(
))
{
if
(
!
context_ptr
->
get_param
<
bool
>
(
MS_CTX_ENABLE_TASK_SINK
))
{
MS_LOG
(
EXCEPTION
)
<<
"Control sink network should run with task-sink mode!"
;
}
}
...
...
@@ -920,8 +920,8 @@ void AscendSession::IrFusionPass(const NotNull<KernelGraphPtr> graph, NotNull<st
auto
context_ptr
=
MsContext
::
GetInstance
();
MS_EXCEPTION_IF_NULL
(
context_ptr
);
bool
save_graphs
=
context_ptr
->
save_graphs_flag
(
);
auto
save_graphs_path
=
context_ptr
->
save_graphs_path
(
);
bool
save_graphs
=
context_ptr
->
get_param
<
bool
>
(
MS_CTX_SAVE_GRAPHS_FLAG
);
auto
save_graphs_path
=
context_ptr
->
get_param
<
std
::
string
>
(
MS_CTX_SAVE_GRAPHS_PATH
);
if
(
save_graphs
)
{
if
(
save_graphs_path
.
empty
())
{
save_graphs_path
=
"."
;
...
...
@@ -947,7 +947,7 @@ void AscendSession::SelectKernel(NotNull<KernelGraphPtr> root_graph) {
auto
ms_context
=
MsContext
::
GetInstance
();
MS_EXCEPTION_IF_NULL
(
ms_context
);
if
(
ms_context
->
execution_mode
(
)
==
kGraphMode
)
{
if
(
ms_context
->
get_param
<
int
>
(
MS_CTX_EXECUTION_MODE
)
==
kGraphMode
)
{
if
(
raise_precision_count
>
0
)
{
MS_LOG
(
WARNING
)
<<
"There are "
<<
raise_precision_count
<<
" node/nodes used raise precision to selected the kernel!"
;
...
...
@@ -992,8 +992,8 @@ void AscendSession::RecurseSelectKernelInfo(NotNull<KernelGraphPtr> graph,
auto
context_ptr
=
MsContext
::
GetInstance
();
MS_EXCEPTION_IF_NULL
(
context_ptr
);
bool
save_graphs
=
context_ptr
->
save_graphs_flag
(
);
auto
save_graphs_path
=
context_ptr
->
save_graphs_path
(
);
bool
save_graphs
=
context_ptr
->
get_param
<
bool
>
(
MS_CTX_SAVE_GRAPHS_FLAG
);
auto
save_graphs_path
=
context_ptr
->
get_param
<
std
::
string
>
(
MS_CTX_SAVE_GRAPHS_PATH
);
if
(
save_graphs
)
{
if
(
save_graphs_path
.
empty
())
{
save_graphs_path
=
"."
;
...
...
mindspore/ccsrc/backend/session/gpu_session.cc
浏览文件 @
8d419314
...
...
@@ -76,7 +76,7 @@ void GPUSession::Optimize(const std::shared_ptr<KernelGraph> &kernel_graph) {
pm
->
AddPass
(
std
::
make_shared
<
opt
::
ReplaceBNGradCastFusion
>
());
pm
->
AddPass
(
std
::
make_shared
<
opt
::
ReplaceMomentumCastFusion
>
());
pm
->
AddPass
(
std
::
make_shared
<
opt
::
ReplaceAddNFusion
>
());
if
(
!
CheckInModeBlackList
(
kernel_graph
)
&&
context_ptr
->
execution_mode
(
)
!=
kPynativeMode
)
{
if
(
!
CheckInModeBlackList
(
kernel_graph
)
&&
context_ptr
->
get_param
<
int
>
(
MS_CTX_EXECUTION_MODE
)
!=
kPynativeMode
)
{
pm
->
AddPass
(
std
::
make_shared
<
opt
::
BatchNormReluFusion
>
());
pm
->
AddPass
(
std
::
make_shared
<
opt
::
BatchNormReluGradFusion
>
());
pm
->
AddPass
(
std
::
make_shared
<
opt
::
BatchNormAddReluFusion
>
());
...
...
@@ -154,7 +154,7 @@ void GPUSession::LoadInputData(const std::shared_ptr<KernelGraph> &kernel_graph,
auto
device_address
=
AnfAlgo
::
GetMutableOutputAddr
(
pk_node
,
0
);
auto
tensor_address
=
std
::
dynamic_pointer_cast
<
device
::
DeviceAddress
>
(
tensor
->
device_address
());
bool
need_sync
=
false
;
if
(
ms_context
->
enable_pynative_infer
(
))
{
if
(
ms_context
->
get_param
<
bool
>
(
MS_CTX_ENABLE_PYNATIVE_INFER
))
{
if
(
tensor_address
==
nullptr
||
tensor_address
!=
device_address
)
{
need_sync
=
true
;
}
...
...
@@ -223,7 +223,7 @@ GraphId GPUSession::CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrList
// Prepare ms context info for dump .pb graph
auto
context_ptr
=
MsContext
::
GetInstance
();
MS_EXCEPTION_IF_NULL
(
context_ptr
);
bool
save_graphs
=
context_ptr
->
save_graphs_flag
(
);
bool
save_graphs
=
context_ptr
->
get_param
<
bool
>
(
MS_CTX_SAVE_GRAPHS_FLAG
);
// Optimize
Optimize
(
graph
);
// Select kernel build info
...
...
@@ -290,7 +290,7 @@ void GPUSession::RunGraph(const GraphId &graph_id, const std::vector<tensor::Ten
// Summary
auto
context_ptr
=
MsContext
::
GetInstance
();
MS_EXCEPTION_IF_NULL
(
context_ptr
);
if
(
context_ptr
->
enable_gpu_summary
(
))
{
if
(
context_ptr
->
get_param
<
bool
>
(
MS_CTX_ENABLE_GPU_SUMMARY
))
{
Summary
(
kernel_graph
.
get
());
}
#ifdef ENABLE_DEBUGGER
...
...
mindspore/ccsrc/backend/session/infer_session.cc
浏览文件 @
8d419314
...
...
@@ -268,7 +268,7 @@ void MSInferSession::RegAllOp() {
return
;
}
Initialized
=
true
;
MsContext
::
GetInstance
()
->
set_
execution_mode
(
kGraphMode
);
MsContext
::
GetInstance
()
->
set_
param
<
int
>
(
MS_CTX_EXECUTION_MODE
,
kGraphMode
);
Py_Initialize
();
auto
c_expression
=
PyImport_ImportModule
(
"mindspore._c_expression"
);
if
(
c_expression
==
nullptr
)
{
...
...
@@ -357,13 +357,13 @@ Status MSInferSession::InitEnv(const std::string &device, uint32_t device_id) {
MS_LOG
(
ERROR
)
<<
"Get Context failed!"
;
return
FAILED
;
}
ms_context
->
set_
execution_mode
(
kGraphMode
);
ms_context
->
set_
device_id
(
device_id
);
ms_context
->
set_
param
<
int
>
(
MS_CTX_EXECUTION_MODE
,
kGraphMode
);
ms_context
->
set_
param
<
uint32_t
>
(
MS_CTX_DEVICE_ID
,
device_id
);
auto
ajust_device
=
AjustTargetName
(
device
);
if
(
ajust_device
==
""
)
{
return
FAILED
;
}
ms_context
->
set_
device_target
(
device
);
ms_context
->
set_
param
<
std
::
string
>
(
MS_CTX_DEVICE_TARGET
,
device
);
if
(
!
context
::
OpenTsd
(
ms_context
))
{
MS_LOG
(
ERROR
)
<<
"Session init OpenTsd failed!"
;
return
FAILED
;
...
...
mindspore/ccsrc/backend/session/session_basic.cc
浏览文件 @
8d419314
...
...
@@ -93,10 +93,11 @@ tensor::TensorPtr CreateCNodeOutputTensor(const session::KernelWithIndex &node_o
// if in paynative mode,data only copyed to host when user want to print data
auto
ms_context
=
MsContext
::
GetInstance
();
MS_EXCEPTION_IF_NULL
(
ms_context
);
if
(
ms_context
->
execution_mode
()
!=
kPynativeMode
&&
ms_context
->
device_target
()
!=
kGPUDevice
)
{
if
(
ms_context
->
get_param
<
int
>
(
MS_CTX_EXECUTION_MODE
)
!=
kPynativeMode
&&
ms_context
->
get_param
<
std
::
string
>
(
MS_CTX_DEVICE_TARGET
)
!=
kGPUDevice
)
{
tensor
->
set_need_sync
(
true
);
}
if
(
ms_context
->
execution_mode
(
)
!=
kPynativeMode
)
{
if
(
ms_context
->
get_param
<
int
>
(
MS_CTX_EXECUTION_MODE
)
!=
kPynativeMode
)
{
tensor
->
SetNeedWait
(
true
);
}
tensor
->
set_dirty
(
false
);
...
...
@@ -938,7 +939,7 @@ bool TensorNeedSync(const AnfNodePtr ¶meter, const tensor::TensorPtr &tensor
auto
ms_context
=
MsContext
::
GetInstance
();
MS_EXCEPTION_IF_NULL
(
ms_context
);
auto
device_address
=
AnfAlgo
::
GetMutableOutputAddr
(
parameter
,
0
);
if
(
ms_context
->
enable_pynative_infer
(
))
{
if
(
ms_context
->
get_param
<
bool
>
(
MS_CTX_ENABLE_PYNATIVE_INFER
))
{
return
tensor
->
device_address
().
get
()
==
nullptr
||
tensor
->
device_address
()
!=
device_address
;
}
if
(
tensor
->
is_dirty
())
{
...
...
@@ -979,7 +980,7 @@ void SessionBasic::LoadInputData(const std::shared_ptr<KernelGraph> &kernel_grap
MS_EXCEPTION_IF_NULL
(
input_node
);
if
(
input_node
->
isa
<
Parameter
>
()
&&
AnfAlgo
::
OutputAddrExist
(
input_node
,
0
)
&&
TensorNeedSync
(
input_node
,
tensor
))
{
auto
device_address
=
AnfAlgo
::
GetMutableOutputAddr
(
input_node
,
0
);
if
(
ms_context
->
execution_mode
(
)
==
kPynativeMode
||
if
(
ms_context
->
get_param
<
int
>
(
MS_CTX_EXECUTION_MODE
)
==
kPynativeMode
||
AnfAlgo
::
IsParameterWeight
(
input_node
->
cast
<
ParameterPtr
>
()))
{
tensor
->
set_device_address
(
device_address
);
}
...
...
@@ -1177,7 +1178,7 @@ CNodePtr SessionBasic::ConstructOutput(const AnfNodePtrList &outputs, const std:
if
(
backend_anf
!=
nullptr
)
{
auto
context_ptr
=
MsContext
::
GetInstance
();
MS_EXCEPTION_IF_NULL
(
context_ptr
);
if
(
context_ptr
->
execution_mode
(
)
==
kPynativeMode
)
{
if
(
context_ptr
->
get_param
<
int
>
(
MS_CTX_EXECUTION_MODE
)
==
kPynativeMode
)
{
return
backend_anf
;
}
...
...
mindspore/ccsrc/backend/session/session_basic.h
浏览文件 @
8d419314
...
...
@@ -118,7 +118,7 @@ class SessionBasic : public std::enable_shared_from_this<SessionBasic> {
debugger_
=
Debugger
::
GetInstance
();
auto
ms_context
=
MsContext
::
GetInstance
();
MS_EXCEPTION_IF_NULL
(
ms_context
);
debugger_
->
Init
(
device_id_
,
ms_context
->
device_target
(
));
debugger_
->
Init
(
device_id_
,
ms_context
->
get_param
<
std
::
string
>
(
MS_CTX_DEVICE_TARGET
));
}
#endif
...
...
mindspore/ccsrc/debug/data_dump_parser.cc
浏览文件 @
8d419314
...
...
@@ -53,7 +53,7 @@ bool DataDumpParser::DumpEnabled() const {
auto
context
=
MsContext
::
GetInstance
();
MS_EXCEPTION_IF_NULL
(
context
);
if
(
context
->
execution_mode
(
)
==
kPynativeMode
)
{
if
(
context
->
get_param
<
int
>
(
MS_CTX_EXECUTION_MODE
)
==
kPynativeMode
)
{
MS_LOG
(
EXCEPTION
)
<<
"[DataDump] PyNative mode not support data dump"
;
}
return
true
;
...
...
mindspore/ccsrc/debug/debugger/debugger.cc
浏览文件 @
8d419314
...
...
@@ -142,7 +142,7 @@ void Debugger::EnableDebugger() {
// switch memory reuse on or off
auto
context_ptr
=
MsContext
::
GetInstance
();
MS_EXCEPTION_IF_NULL
(
context_ptr
);
context_ptr
->
set_
enable_mem_reuse
(
partial_memory_
);
context_ptr
->
set_
param
<
bool
>
(
MS_CTX_ENABLE_MEM_REUSE
,
partial_memory_
);
// print some message about memory reuse to user
if
(
partial_memory_
)
{
MS_LOG
(
WARNING
)
<<
"Partial Memory Reuse is enabled. Note: 1. Please only set watchpoints before running the first "
...
...
mindspore/ccsrc/debug/dump_proto.cc
浏览文件 @
8d419314
...
...
@@ -530,7 +530,7 @@ void DumpIRProto(const FuncGraphPtr &func_graph, const std::string &suffix) {
MS_LOG
(
ERROR
)
<<
"ms_context is nullptr"
;
return
;
}
auto
save_graphs_path
=
ms_context
->
save_graphs_path
(
);
auto
save_graphs_path
=
ms_context
->
get_param
<
std
::
string
>
(
MS_CTX_SAVE_GRAPHS_PATH
);
if
(
save_graphs_path
.
empty
())
{
save_graphs_path
=
"."
;
}
...
...
mindspore/ccsrc/debug/e2e_dump.cc
浏览文件 @
8d419314
...
...
@@ -112,7 +112,7 @@ bool Dump::IsConfigValid(const nlohmann::json &dumpSettings) {
auto
context_ptr
=
MsContext
::
GetInstance
();
MS_EXCEPTION_IF_NULL
(
context_ptr
);
// dump_enable_ is true, close mem reuse
context_ptr
->
set_
enable_mem_reuse
(
!
dump_enable_
);
context_ptr
->
set_
param
<
bool
>
(
MS_CTX_ENABLE_MEM_REUSE
,
!
dump_enable_
);
trans_flag_
=
trans_flag
;
dump_mode_
=
mode
;
dump_path_
=
path
;
...
...
@@ -135,7 +135,7 @@ bool Dump::SetDumpConfFromJsonFile() {
}
auto
context_ptr
=
MsContext
::
GetInstance
();
MS_EXCEPTION_IF_NULL
(
context_ptr
);
auto
id
=
context_ptr
->
device_id
(
);
auto
id
=
context_ptr
->
get_param
<
uint32_t
>
(
MS_CTX_DEVICE_ID
);
char
real_path
[
PATH_MAX
]
=
{
0
};
if
(
nullptr
==
realpath
(
config_path_str
,
real_path
))
{
MS_LOG
(
ERROR
)
<<
"Env e2e dump path error, "
<<
config_path_str
;
...
...
mindspore/ccsrc/frontend/optimizer/ad/grad.cc
浏览文件 @
8d419314
...
...
@@ -34,7 +34,7 @@ FuncGraphPtr Grad(const FuncGraphPtr &func_graph, const pipeline::ResourceBasePt
manager_ptr
->
AddFuncGraph
(
func_graph
);
auto
multi_graph_sink
=
[
&
func_graph
](
const
FuncGraphPtr
&
f
)
{
if
(
MsContext
::
GetInstance
()
->
is_multi_graph_sink
(
))
{
if
(
MsContext
::
GetInstance
()
->
get_param
<
bool
>
(
MS_CTX_IS_MULTI_GRAPH_SINK
))
{
if
(
func_graph
->
has_flag
(
FUNC_GRAPH_FLAG_IGNORE_VALUES
))
{
f
->
set_flag
(
FUNC_GRAPH_FLAG_IGNORE_VALUES
,
true
);
}
...
...
mindspore/ccsrc/frontend/optimizer/ad/kprim.cc
浏览文件 @
8d419314
...
...
@@ -182,7 +182,7 @@ void KPrim::TransformArgs(const FuncGraphManagerPtr &mng, const FuncGraphPtr &bp
void
KPrim
::
CheckBprop
(
const
FuncGraphPtr
&
bprop_fg
,
const
string
&
prim_to_check
)
{
auto
context
=
MsContext
::
GetInstance
();
MS_EXCEPTION_IF_NULL
(
context
);
bool
check_bprop_flag
=
context
->
check_bprop_flag
(
);
bool
check_bprop_flag
=
context
->
get_param
<
bool
>
(
MS_CTX_CHECK_BPROP_FLAG
);
// Skip checking if check_bprop not set
if
(
!
check_bprop_flag
)
{
return
;
...
...
mindspore/ccsrc/frontend/optimizer/irpass/arithmetic_simplify.cc
浏览文件 @
8d419314
...
...
@@ -29,7 +29,7 @@ AnfNodePtr ArithmeticSimplify::operator()(const OptimizerPtr &, const AnfNodePtr
PConstant
const_2
(
node
);
PConstant
any_const
(
node
);
if
(
MsContext
::
GetInstance
()
->
execution_mode
(
)
!=
kPynativeMode
)
{
if
(
MsContext
::
GetInstance
()
->
get_param
<
int
>
(
MS_CTX_EXECUTION_MODE
)
!=
kPynativeMode
)
{
MATCH_REPLACE
(
node
,
x
+
zero_
,
x
);
// Add by zero
MATCH_REPLACE
(
node
,
x
+
zero_scalar_
,
x
);
// Add by zero
MATCH_REPLACE
(
node
,
PBinOperation
(
prim
::
kPrimScalarAdd
,
x
,
zero_scalar_
,
true
),
x
);
// Scalar Add by zero
...
...
@@ -41,7 +41,7 @@ AnfNodePtr ArithmeticSimplify::operator()(const OptimizerPtr &, const AnfNodePtr
}
// Prim Eliminate (identity)
MATCH_REPLACE
(
node
,
PPrimitive
(
prim
::
kPrimIdentity
,
x
),
x
);
if
(
MsContext
::
GetInstance
()
->
execution_mode
(
)
==
kPynativeMode
)
{
if
(
MsContext
::
GetInstance
()
->
get_param
<
int
>
(
MS_CTX_EXECUTION_MODE
)
==
kPynativeMode
)
{
return
nullptr
;
}
...
...
@@ -75,7 +75,7 @@ AnfNodePtr ArithmeticSimplify::operator()(const OptimizerPtr &, const AnfNodePtr
}
AnfNodePtr
ArithmeticSimplify2
::
operator
()(
const
OptimizerPtr
&
,
const
AnfNodePtr
&
node
)
{
if
(
MsContext
::
GetInstance
()
->
execution_mode
(
)
==
kPynativeMode
)
{
if
(
MsContext
::
GetInstance
()
->
get_param
<
int
>
(
MS_CTX_EXECUTION_MODE
)
==
kPynativeMode
)
{
return
nullptr
;
}
PatternNode
x
,
y
;
...
...
mindspore/ccsrc/frontend/optimizer/optimizer.h
浏览文件 @
8d419314
...
...
@@ -181,7 +181,7 @@ class Optimizer : public std::enable_shared_from_this<Optimizer> {
}
};
use_profile
?
(
WITH
(
MsProfile
::
GetProfile
()
->
Step
(
pass_names_
[
i
]))
opt_func
)
:
opt_func
();
if
(
is_on_debug_
&&
MsContext
::
GetInstance
()
->
save_graphs_flag
(
))
{
if
(
is_on_debug_
&&
MsContext
::
GetInstance
()
->
get_param
<
bool
>
(
MS_CTX_SAVE_GRAPHS_FLAG
))
{
MS_LOG
(
DEBUG
)
<<
"The opt "
<<
name_
<<
" round "
<<
counter
<<
" OptPass "
<<
pass_names_
[
i
]
<<
" end."
;
auto
fg_name
=
"opt_substep_"
+
name_
+
"_r"
+
std
::
to_string
(
counter
)
+
"_"
+
std
::
to_string
(
i
)
+
"_"
+
pass_names_
[
i
];
...
...
mindspore/ccsrc/frontend/optimizer/py_pass.cc
浏览文件 @
8d419314
...
...
@@ -217,8 +217,8 @@ AnfNodePtr BuildTarget(const PatternPtr &pattern, const FuncGraphPtr &func_graph
void
DrawNode
(
string
name
,
AnfNodePtr
node
)
{
auto
context_ptr
=
MsContext
::
GetInstance
();
bool
save_graphs
=
context_ptr
->
save_graphs_flag
(
);
auto
save_graphs_path
=
context_ptr
->
save_graphs_path
(
);
bool
save_graphs
=
context_ptr
->
get_param
<
bool
>
(
MS_CTX_SAVE_GRAPHS_FLAG
);
auto
save_graphs_path
=
context_ptr
->
get_param
<
std
::
string
>
(
MS_CTX_SAVE_GRAPHS_PATH
);
if
(
save_graphs_path
.
empty
())
{
save_graphs_path
=
"."
;
}
...
...
mindspore/ccsrc/frontend/parallel/graph_util/graph_info.cc
浏览文件 @
8d419314
...
...
@@ -44,7 +44,7 @@ std::vector<PrimitivePtr> FindPrimtive(const FuncGraphPtr &graph, const std::str
}
void
DumpGraph
(
const
FuncGraphPtr
&
root
,
const
std
::
string
&
name
)
{
if
(
MsContext
::
GetInstance
()
->
save_graphs_flag
(
))
{
if
(
MsContext
::
GetInstance
()
->
get_param
<
bool
>
(
MS_CTX_SAVE_GRAPHS_FLAG
))
{
draw
::
Draw
(
name
+
".dot"
,
root
);
DumpIR
(
name
+
".ir"
,
root
);
ExportIR
(
name
+
".dat"
,
"0"
,
root
);
...
...
mindspore/ccsrc/frontend/parallel/step_auto_parallel.cc
浏览文件 @
8d419314
...
...
@@ -69,7 +69,7 @@ bool StepAutoParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &) {
struct
timeval
start_time
,
end_time
;
(
void
)
gettimeofday
(
&
start_time
,
nullptr
);
if
(
MsContext
::
GetInstance
()
->
save_graphs_flag
(
))
{
if
(
MsContext
::
GetInstance
()
->
get_param
<
bool
>
(
MS_CTX_SAVE_GRAPHS_FLAG
))
{
draw
::
Draw
(
STEP_AUTO_PARALLEL_BEGIN
,
root
);
}
MS_LOG
(
INFO
)
<<
"Now entering step auto parallel"
;
...
...
mindspore/ccsrc/pipeline/jit/action.cc
浏览文件 @
8d419314
...
...
@@ -271,7 +271,7 @@ bool OptimizeAction(const ResourcePtr &res, const std::vector<PassItem> &passes)
if
(
!
result
)
{
MS_LOG
(
EXCEPTION
)
<<
"Pass running to end, failed in pass:"
<<
pass
.
first
;
}
if
(
MsContext
::
GetInstance
()
->
save_graphs_flag
(
)
&&
res
->
func_graph
()
!=
nullptr
)
{
if
(
MsContext
::
GetInstance
()
->
get_param
<
bool
>
(
MS_CTX_SAVE_GRAPHS_FLAG
)
&&
res
->
func_graph
()
!=
nullptr
)
{
auto
fg_name
=
"opt_pass_"
+
std
::
to_string
(
counter
)
+
"_"
+
pass
.
first
;
auto
func_graph
=
res
->
func_graph
();
MS_EXCEPTION_IF_NULL
(
func_graph
);
...
...
@@ -295,20 +295,20 @@ bool PynativeOptimizeAction(const ResourcePtr &res) { return OptimizeAction(res,
static
bool
IsCtrlSink
()
{
auto
ms_ctx
=
MsContext
::
GetInstance
();
if
(
ms_ctx
->
execution_mode
(
)
!=
kGraphMode
)
{
if
(
ms_ctx
->
get_param
<
int
>
(
MS_CTX_EXECUTION_MODE
)
!=
kGraphMode
)
{
return
false
;
}
std
::
string
device_target
=
ms_ctx
->
device_target
(
);
std
::
string
device_target
=
ms_ctx
->
get_param
<
std
::
string
>
(
MS_CTX_DEVICE_TARGET
);
if
(
device_target
!=
kAscendDevice
)
{
return
false
;
}
if
(
!
ms_ctx
->
enable_task_sink
(
))
{
if
(
!
ms_ctx
->
get_param
<
bool
>
(
MS_CTX_ENABLE_TASK_SINK
))
{
return
false
;
}
if
(
!
ms_ctx
->
is_multi_graph_sink
(
))
{
if
(
!
ms_ctx
->
get_param
<
bool
>
(
MS_CTX_IS_MULTI_GRAPH_SINK
))
{
return
false
;
}
return
true
;
...
...
@@ -325,13 +325,13 @@ bool TaskEmitAction(const ResourcePtr &res) {
MS_EXCEPTION_IF_NULL
(
context_ptr
);
if
(
CompileGraphs
::
ContainMixedTarget
(
func_graph
))
{
bc_ptr
->
set_is_multi_graph_sink
(
false
);
context_ptr
->
set_
is_multi_graph_sink
(
false
);
context_ptr
->
set_
loop_sink_flag
(
false
);
}
else
if
(
context_ptr
->
execution_mode
(
)
!=
kPynativeMode
)
{
std
::
string
device_target
=
context_ptr
->
device_target
(
);
context_ptr
->
set_
param
<
bool
>
(
MS_CTX_IS_MULTI_GRAPH_SINK
,
false
);
context_ptr
->
set_
param
<
bool
>
(
MS_CTX_ENABLE_LOOP_SINK
,
false
);
}
else
if
(
context_ptr
->
get_param
<
int
>
(
MS_CTX_EXECUTION_MODE
)
!=
kPynativeMode
)
{
std
::
string
device_target
=
context_ptr
->
get_param
<
std
::
string
>
(
MS_CTX_DEVICE_TARGET
);
if
(
device_target
==
kAscendDevice
&&
backend
!=
kMsVm
)
{
bc_ptr
->
set_is_multi_graph_sink
(
true
);
context_ptr
->
set_
is_multi_graph_sink
(
true
);
context_ptr
->
set_
param
<
bool
>
(
MS_CTX_IS_MULTI_GRAPH_SINK
,
true
);
}
}
...
...
mindspore/ccsrc/pipeline/jit/base.h
浏览文件 @
8d419314
...
...
@@ -49,7 +49,7 @@ inline std::string GetFilePathName(const std::string &file_name) {
if
(
ms_context
==
nullptr
)
{
MS_LOG
(
EXCEPTION
)
<<
"ms_context is nullptr"
;
}
auto
save_graphs_path
=
ms_context
->
save_graphs_path
(
);
auto
save_graphs_path
=
ms_context
->
get_param
<
std
::
string
>
(
MS_CTX_SAVE_GRAPHS_PATH
);
if
(
save_graphs_path
.
empty
())
{
save_graphs_path
=
"."
;
}
...
...
mindspore/ccsrc/pipeline/jit/init.cc
浏览文件 @
8d419314
...
...
@@ -48,6 +48,56 @@ using OpLib = mindspore::kernel::OpLib;
using
OpInfoLoaderPy
=
mindspore
::
kernel
::
OpInfoLoaderPy
;
using
ParallelContext
=
mindspore
::
parallel
::
ParallelContext
;
using
CostModelContext
=
mindspore
::
parallel
::
CostModelContext
;
using
mindspore
::
MsCtxParam
;
namespace
mindspore
{
void
MsCtxSetParameter
(
std
::
shared_ptr
<
MsContext
>
ctx
,
MsCtxParam
param
,
const
py
::
object
&
value
)
{
MS_LOG
(
DEBUG
)
<<
"set param("
<<
param
<<
") with value '"
<<
py
::
str
(
value
)
<<
"' of type '"
<<
py
::
str
(
value
.
get_type
())
<<
"'."
;
if
(
param
>=
MS_CTX_TYPE_BOOL_BEGIN
&&
param
<
MS_CTX_TYPE_BOOL_END
&&
py
::
isinstance
<
py
::
bool_
>
(
value
))
{
ctx
->
set_param
<
bool
>
(
param
,
value
.
cast
<
bool
>
());
return
;
}
if
(
param
>=
MS_CTX_TYPE_INT_BEGIN
&&
param
<
MS_CTX_TYPE_INT_END
&&
py
::
isinstance
<
py
::
int_
>
(
value
))
{
ctx
->
set_param
<
int
>
(
param
,
value
.
cast
<
int
>
());
return
;
}
if
(
param
>=
MS_CTX_TYPE_UINT32_BEGIN
&&
param
<
MS_CTX_TYPE_UINT32_END
&&
py
::
isinstance
<
py
::
int_
>
(
value
))
{
ctx
->
set_param
<
uint32_t
>
(
param
,
value
.
cast
<
uint32_t
>
());
return
;
}
if
(
param
>=
MS_CTX_TYPE_FLOAT_BEGIN
&&
param
<
MS_CTX_TYPE_FLOAT_END
&&
py
::
isinstance
<
py
::
float_
>
(
value
))
{
ctx
->
set_param
<
float
>
(
param
,
value
.
cast
<
float
>
());
return
;
}
if
(
param
>=
MS_CTX_TYPE_STRING_BEGIN
&&
param
<
MS_CTX_TYPE_STRING_END
&&
py
::
isinstance
<
py
::
str
>
(
value
))
{
ctx
->
set_param
<
std
::
string
>
(
param
,
value
.
cast
<
std
::
string
>
());
return
;
}
MS_LOG
(
EXCEPTION
)
<<
"Got illegal param "
<<
param
<<
" and value with type "
<<
py
::
str
(
value
.
get_type
());
}
py
::
object
MsCtxGetParameter
(
const
std
::
shared_ptr
<
MsContext
>
&
ctx
,
MsCtxParam
param
)
{
if
(
param
>=
MS_CTX_TYPE_BOOL_BEGIN
&&
param
<
MS_CTX_TYPE_BOOL_END
)
{
return
py
::
bool_
(
ctx
->
get_param
<
bool
>
(
param
));
}
if
(
param
>=
MS_CTX_TYPE_INT_BEGIN
&&
param
<
MS_CTX_TYPE_INT_END
)
{
return
py
::
int_
(
ctx
->
get_param
<
int
>
(
param
));
}
if
(
param
>=
MS_CTX_TYPE_UINT32_BEGIN
&&
param
<
MS_CTX_TYPE_UINT32_END
)
{
return
py
::
int_
(
ctx
->
get_param
<
uint32_t
>
(
param
));
}
if
(
param
>=
MS_CTX_TYPE_FLOAT_BEGIN
&&
param
<
MS_CTX_TYPE_FLOAT_END
)
{
return
py
::
float_
(
ctx
->
get_param
<
float
>
(
param
));
}
if
(
param
>=
MS_CTX_TYPE_STRING_BEGIN
&&
param
<
MS_CTX_TYPE_STRING_END
)
{
return
py
::
str
(
ctx
->
get_param
<
std
::
string
>
(
param
));
}
MS_LOG
(
EXCEPTION
)
<<
"Got illegal param "
<<
param
<<
"."
;
}
}
// namespace mindspore
// Interface with python
PYBIND11_MODULE
(
_c_expression
,
m
)
{
...
...
@@ -101,53 +151,48 @@ PYBIND11_MODULE(_c_expression, m) {
(
void
)
m
.
def
(
"export_graph"
,
&
mindspore
::
pipeline
::
ExportGraph
,
"Export Graph."
);
(
void
)
m
.
def
(
"ms_ctx_get_param"
,
&
mindspore
::
MsCtxGetParameter
,
"Get value of specified paramter."
);
(
void
)
m
.
def
(
"ms_ctx_set_param"
,
&
mindspore
::
MsCtxSetParameter
,
"Set value for specified paramter."
);
(
void
)
py
::
enum_
<
MsCtxParam
>
(
*
m
,
"ms_ctx_param"
,
py
::
arithmetic
())
.
value
(
"auto_mixed_precision_flag"
,
MsCtxParam
::
MS_CTX_AUTO_MIXED_PRECISION_FLAG
)
.
value
(
"check_bprop_flag"
,
MsCtxParam
::
MS_CTX_CHECK_BPROP_FLAG
)
.
value
(
"enable_dump"
,
MsCtxParam
::
MS_CTX_ENABLE_DUMP
)
.
value
(
"enable_dynamic_mem_pool"
,
MsCtxParam
::
MS_CTX_ENABLE_DYNAMIC_MEM_POOL
)
.
value
(
"enable_gpu_summary"
,
MsCtxParam
::
MS_CTX_ENABLE_GPU_SUMMARY
)
.
value
(
"enable_graph_kernel"
,
MsCtxParam
::
MS_CTX_ENABLE_GRAPH_KERNEL
)
.
value
(
"enable_hccl"
,
MsCtxParam
::
MS_CTX_ENABLE_HCCL
)
.
value
(
"enable_loop_sink"
,
MsCtxParam
::
MS_CTX_ENABLE_LOOP_SINK
)
.
value
(
"enable_mem_reuse"
,
MsCtxParam
::
MS_CTX_ENABLE_MEM_REUSE
)
.
value
(
"enable_pynative_hook"
,
MsCtxParam
::
MS_CTX_ENABLE_PYNATIVE_HOOK
)
.
value
(
"enable_pynative_infer"
,
MsCtxParam
::
MS_CTX_ENABLE_PYNATIVE_INFER
)
.
value
(
"enable_reduce_precision"
,
MsCtxParam
::
MS_CTX_ENABLE_REDUCE_PRECISION
)
.
value
(
"enable_sparse"
,
MsCtxParam
::
MS_CTX_ENABLE_SPARSE
)
.
value
(
"enable_task_sink"
,
MsCtxParam
::
MS_CTX_ENABLE_TASK_SINK
)
.
value
(
"ir_fusion_flag"
,
MsCtxParam
::
MS_CTX_IR_FUSION_FLAG
)
.
value
(
"is_multi_graph_sink"
,
MsCtxParam
::
MS_CTX_IS_MULTI_GRAPH_SINK
)
.
value
(
"is_pynative_ge_init"
,
MsCtxParam
::
MS_CTX_IS_PYNATIVE_GE_INIT
)
.
value
(
"precompile_only"
,
MsCtxParam
::
MS_CTX_PRECOMPILE_ONLY
)
.
value
(
"enable_profiling"
,
MsCtxParam
::
MS_CTX_ENABLE_PROFILING
)
.
value
(
"save_graphs_flag"
,
MsCtxParam
::
MS_CTX_SAVE_GRAPHS_FLAG
)
.
value
(
"max_device_memory"
,
MsCtxParam
::
MS_CTX_MAX_DEVICE_MEMORY
)
.
value
(
"execution_mode"
,
MsCtxParam
::
MS_CTX_EXECUTION_MODE
)
.
value
(
"device_target"
,
MsCtxParam
::
MS_CTX_DEVICE_TARGET
)
.
value
(
"graph_memory_max_size"
,
MsCtxParam
::
MS_CTX_GRAPH_MEMORY_MAX_SIZE
)
.
value
(
"print_file_path"
,
MsCtxParam
::
MS_CTX_PRINT_FILE_PATH
)
.
value
(
"profiling_options"
,
MsCtxParam
::
MS_CTX_PROFILING_OPTIONS
)
.
value
(
"save_dump_path"
,
MsCtxParam
::
MS_CTX_SAVE_DUMP_PATH
)
.
value
(
"save_graphs_path"
,
MsCtxParam
::
MS_CTX_SAVE_GRAPHS_PATH
)
.
value
(
"variable_memory_max_size"
,
MsCtxParam
::
MS_CTX_VARIABLE_MEMORY_MAX_SIZE
)
.
value
(
"device_id"
,
MsCtxParam
::
MS_CTX_DEVICE_ID
)
.
value
(
"ge_ref"
,
MsCtxParam
::
MS_CTX_GE_REF
)
.
value
(
"max_call_depth"
,
MsCtxParam
::
MS_CTX_MAX_CALL_DEPTH
)
.
value
(
"tsd_ref"
,
MsCtxParam
::
MS_CTX_TSD_REF
);
(
void
)
py
::
class_
<
mindspore
::
MsContext
,
std
::
shared_ptr
<
mindspore
::
MsContext
>>
(
m
,
"MSContext"
)
.
def_static
(
"get_instance"
,
&
mindspore
::
MsContext
::
GetInstance
,
"Get ms context instance."
)
.
def
(
"get_backend_policy"
,
&
mindspore
::
MsContext
::
backend_policy
,
"Get backend policy."
)
.
def
(
"set_backend_policy"
,
&
mindspore
::
MsContext
::
set_backend_policy
,
"Set backend policy."
)
.
def
(
"get_execution_mode"
,
&
mindspore
::
MsContext
::
execution_mode
,
"Get execution mode."
)
.
def
(
"set_execution_mode"
,
&
mindspore
::
MsContext
::
set_execution_mode
,
"Set execution mode."
)
.
def
(
"set_precompile_only"
,
&
mindspore
::
MsContext
::
set_precompile_only
,
"Set enable precompile only."
)
.
def
(
"get_precompile_only"
,
&
mindspore
::
MsContext
::
precompile_only
,
"Get enable precompile only."
)
.
def
(
"get_device_target"
,
&
mindspore
::
MsContext
::
device_target
,
"Get device target."
)
.
def
(
"set_device_target"
,
&
mindspore
::
MsContext
::
set_device_target
,
"Set device target."
)
.
def
(
"get_device_id"
,
&
mindspore
::
MsContext
::
device_id
,
"Get device id."
)
.
def
(
"set_device_id"
,
&
mindspore
::
MsContext
::
set_device_id
,
"Set device id."
)
.
def
(
"get_max_call_depth"
,
&
mindspore
::
MsContext
::
max_call_depth
,
"Get max call depth."
)
.
def
(
"set_max_call_depth"
,
&
mindspore
::
MsContext
::
set_max_call_depth
,
"Set max call depth."
)
.
def
(
"get_save_graphs_flag"
,
&
mindspore
::
MsContext
::
save_graphs_flag
,
"Get whether to save graphs."
)
.
def
(
"set_save_graphs_flag"
,
&
mindspore
::
MsContext
::
set_save_graphs_flag
,
"Set whether to save graphs."
)
.
def
(
"get_auto_mixed_precision_flag"
,
&
mindspore
::
MsContext
::
auto_mixed_precision_flag
,
"Get whether to enable auto mixed precision."
)
.
def
(
"set_auto_mixed_precision_flag"
,
&
mindspore
::
MsContext
::
set_auto_mixed_precision_flag
,
"Set whether to enable auto mixed precision."
)
.
def
(
"get_enable_reduce_precision_flag"
,
&
mindspore
::
MsContext
::
enable_reduce_precision
,
"Get whether to enable reduce precision."
)
.
def
(
"set_enable_reduce_precision_flag"
,
&
mindspore
::
MsContext
::
set_enable_reduce_precision
,
"Set whether to enable reduce precision."
)
.
def
(
"get_save_graphs_path"
,
&
mindspore
::
MsContext
::
save_graphs_path
,
"Get save graphs path."
)
.
def
(
"set_save_graphs_path"
,
&
mindspore
::
MsContext
::
set_save_graphs_path
,
"Set save graphs path."
)
.
def
(
"get_enable_dump"
,
&
mindspore
::
MsContext
::
enable_dump
,
"Get whether to enable dump."
)
.
def
(
"set_enable_dump"
,
&
mindspore
::
MsContext
::
set_enable_dump
,
"Set whether to enable dump."
)
.
def
(
"get_save_dump_path"
,
&
mindspore
::
MsContext
::
save_dump_path
,
"Get path to dump."
)
.
def
(
"set_save_dump_path"
,
&
mindspore
::
MsContext
::
set_save_dump_path
,
"Set path to dump."
)
.
def
(
"set_graph_memory_max_size"
,
&
mindspore
::
MsContext
::
set_graph_memory_max_size
,
"set graph memory max size."
)
.
def
(
"set_variable_memory_max_size"
,
&
mindspore
::
MsContext
::
set_variable_memory_max_size
,
"set variable memory max size"
)
.
def
(
"get_enable_profiling"
,
&
mindspore
::
MsContext
::
enable_profiling
,
"Get whether to open profiling."
)
.
def
(
"set_enable_profiling"
,
&
mindspore
::
MsContext
::
set_enable_profiling
,
"Set whether to open profiling."
)
.
def
(
"get_profiling_options"
,
&
mindspore
::
MsContext
::
profiling_options
,
"Get options to profiling."
)
.
def
(
"set_profiling_options"
,
&
mindspore
::
MsContext
::
set_profiling_options
,
"Set options to profiling."
)
.
def
(
"get_check_bprop_flag"
,
&
mindspore
::
MsContext
::
check_bprop_flag
,
"Get whether to check bprop."
)
.
def
(
"set_check_bprop_flag"
,
&
mindspore
::
MsContext
::
set_check_bprop_flag
,
"Set whether to check bprop."
)
.
def
(
"get_max_device_memory"
,
&
mindspore
::
MsContext
::
max_device_memory
,
"Get deivce memory max size."
)
.
def
(
"set_max_device_memory"
,
&
mindspore
::
MsContext
::
set_max_device_memory
,
"Set deivce memory max size."
)
.
def
(
"set_print_file_path"
,
&
mindspore
::
MsContext
::
set_print_file_path
,
"Set path to print."
)
.
def
(
"set_enable_graph_kernel"
,
&
mindspore
::
MsContext
::
set_enable_graph_kernel
,
"Set the GraphKernel switch to on or off."
)
.
def
(
"get_enable_graph_kernel"
,
&
mindspore
::
MsContext
::
enable_graph_kernel
,
"Get the value of GraphKernel switch."
)
.
def
(
"get_enable_sparse"
,
&
mindspore
::
MsContext
::
enable_sparse
,
"Get whether to enable sparsity."
)
.
def
(
"set_enable_sparse"
,
&
mindspore
::
MsContext
::
set_enable_sparse
,
"Set whether to enable sparsity."
);
.
def
(
"set_backend_policy"
,
&
mindspore
::
MsContext
::
set_backend_policy
,
"Set backend policy."
);
(
void
)
py
::
class_
<
mindspore
::
MpiConfig
,
std
::
shared_ptr
<
mindspore
::
MpiConfig
>>
(
m
,
"MpiConfig"
)
.
def_static
(
"get_instance"
,
&
mindspore
::
MpiConfig
::
GetInstance
,
"Get mpi config instance."
)
...
...
mindspore/ccsrc/pipeline/jit/pass.cc
浏览文件 @
8d419314
...
...
@@ -271,7 +271,7 @@ void InitOpt(const ResourcePtr &res) {
g_pass_opts
[
"opt_prepare"
]
=
Optimizer
::
MakeOptimizer
(
"opt_prepare"
,
res
,
GetPreparePhases
(
irpass
));
auto
context_ptr
=
MsContext
::
GetInstance
();
MS_EXCEPTION_IF_NULL
(
context_ptr
);
if
(
!
(
context_ptr
->
enable_graph_kernel
(
)))
{
if
(
!
(
context_ptr
->
get_param
<
bool
>
(
MS_CTX_ENABLE_GRAPH_KERNEL
)))
{
g_pass_opts
[
"opt_graph_kernel_a"
]
->
set_enable
(
false
);
g_pass_opts
[
"opt_graph_kernel_b"
]
->
set_enable
(
false
);
}
...
...
mindspore/ccsrc/pipeline/jit/pipeline.cc
浏览文件 @
8d419314
...
...
@@ -88,7 +88,7 @@ std::string GetBaseNameForIR(int stage_idx, const std::string &action_name) {
if
(
ms_context
==
nullptr
)
{
MS_LOG
(
EXCEPTION
)
<<
"ms_context is nullptr"
;
}
auto
save_graphs_path
=
ms_context
->
save_graphs_path
(
);
auto
save_graphs_path
=
ms_context
->
get_param
<
std
::
string
>
(
MS_CTX_SAVE_GRAPHS_PATH
);
if
(
save_graphs_path
.
empty
())
{
save_graphs_path
=
"."
;
}
...
...
@@ -646,7 +646,7 @@ void Pipeline::Run() {
if
(
!
result
)
{
MS_LOG
(
EXCEPTION
)
<<
"Pipeline running to end, failed in step:"
<<
action
.
first
;
}
if
(
MsContext
::
GetInstance
()
->
save_graphs_flag
(
)
&&
resource_
->
func_graph
()
!=
nullptr
)
{
if
(
MsContext
::
GetInstance
()
->
get_param
<
bool
>
(
MS_CTX_SAVE_GRAPHS_FLAG
)
&&
resource_
->
func_graph
()
!=
nullptr
)
{
auto
graph
=
resource_
->
func_graph
();
if
(
graph
!=
nullptr
)
{
user_graph
=
graph
;
...
...
@@ -688,7 +688,7 @@ void Pipeline::Run() {
MsProfile
::
Reset
();
#endif
if
(
MsContext
::
GetInstance
()
->
save_graphs_flag
(
)
&&
(
user_graph
!=
nullptr
))
{
if
(
MsContext
::
GetInstance
()
->
get_param
<
bool
>
(
MS_CTX_SAVE_GRAPHS_FLAG
)
&&
(
user_graph
!=
nullptr
))
{
std
::
string
user_graph_file
=
GetFilePathName
(
"ModelDigraph.dot"
);
MS_LOG
(
DEBUG
)
<<
"Save user graph to: "
<<
user_graph_file
;
draw
::
DrawUserFuncGraph
(
user_graph_file
,
user_graph
);
...
...
@@ -710,7 +710,7 @@ void ProcessVmArgInner(const py::tuple &args, const ResourcePtr &res, VectorRef
if
(
!
succ
)
{
MS_LOG
(
EXCEPTION
)
<<
"The "
<<
i
<<
"th arg convert failed."
;
}
if
(
MsContext
::
GetInstance
()
->
execution_mode
(
)
==
0
&&
!
converted
->
isa
<
tensor
::
Tensor
>
())
{
if
(
MsContext
::
GetInstance
()
->
get_param
<
int
>
(
MS_CTX_EXECUTION_MODE
)
==
0
&&
!
converted
->
isa
<
tensor
::
Tensor
>
())
{
MS_EXCEPTION
(
TypeError
)
<<
"For 'graph mode', the "
<<
i
<<
"th arg: "
<<
converted
->
ToString
()
<<
" is not tensor."
;
}
...
...
@@ -891,7 +891,7 @@ bool InitExecDatasetVm(const std::string &queue_name, int64_t size, int64_t batc
// Convert CNodeList to LinConvertResult.
ConfigManager
::
GetInstance
().
set_iter_num
(
1
);
auto
runner
=
convert_fn
({
app_init
},
""
);
if
(
MsContext
::
GetInstance
()
->
execution_mode
(
)
!=
kPynativeMode
)
{
if
(
MsContext
::
GetInstance
()
->
get_param
<
int
>
(
MS_CTX_EXECUTION_MODE
)
!=
kPynativeMode
)
{
backend
->
Link
(
runner
.
graph_id
);
}
ConfigManager
::
GetInstance
().
set_iter_num
(
size
);
...
...
@@ -965,10 +965,11 @@ void InitHccl() {
auto
ms_context
=
MsContext
::
GetInstance
();
MS_EXCEPTION_IF_NULL
(
ms_context
);
(
void
)
context
::
OpenTsd
(
ms_context
);
uint32_t
device_id
=
ms_context
->
device_id
();
std
::
string
device_name
=
ms_context
->
device_target
();
ms_context
->
set_enable_hccl
(
true
);
if
(
ms_context
->
backend_policy
()
==
"ms"
&&
ms_context
->
device_target
()
==
kAscendDevice
)
{
uint32_t
device_id
=
ms_context
->
get_param
<
uint32_t
>
(
MS_CTX_DEVICE_ID
);
std
::
string
device_name
=
ms_context
->
get_param
<
std
::
string
>
(
MS_CTX_DEVICE_TARGET
);
ms_context
->
set_param
<
bool
>
(
MS_CTX_ENABLE_HCCL
,
true
);
if
(
ms_context
->
backend_policy
()
==
"ms"
&&
ms_context
->
get_param
<
std
::
string
>
(
MS_CTX_DEVICE_TARGET
)
==
kAscendDevice
)
{
auto
runtime_instance
=
device
::
KernelRuntimeManager
::
Instance
().
GetKernelRuntime
(
device_name
,
device_id
);
MS_EXCEPTION_IF_NULL
(
runtime_instance
);
if
(
!
runtime_instance
->
Init
())
{
...
...
mindspore/ccsrc/pipeline/jit/pipeline_ge.cc
浏览文件 @
8d419314
...
...
@@ -214,7 +214,7 @@ bool AddDFGraph(const std::map<std::string, ExecutorInfoPtr> &info, const py::di
return
false
;
}
if
(
MsContext
::
GetInstance
()
->
save_graphs_flag
(
))
{
if
(
MsContext
::
GetInstance
()
->
get_param
<
bool
>
(
MS_CTX_SAVE_GRAPHS_FLAG
))
{
convertor
.
DrawComputeGraph
(
GetFilePathName
(
"ge_graph.dot"
));
// for debug
convertor
.
DrawInitGraph
(
GetFilePathName
(
"init_graph.dot"
));
// for debug
convertor
.
DrawSaveCheckpointGraph
(
GetFilePathName
(
"save_checkpoint_graph.dot"
));
// for debug
...
...
@@ -244,7 +244,7 @@ FuncGraphPtr BuildDFGraph(const std::map<std::string, ExecutorInfoPtr> &info, co
}
FuncGraphPtr
anf_graph
=
info
.
at
(
phase
)
->
func_graph
;
if
(
MsContext
::
GetInstance
()
->
save_graphs_flag
(
))
{
if
(
MsContext
::
GetInstance
()
->
get_param
<
bool
>
(
MS_CTX_SAVE_GRAPHS_FLAG
))
{
draw
::
Draw
(
GetFilePathName
(
"anf_graph.dot"
),
anf_graph
);
// for debug
DumpIR
(
GetFilePathName
(
"anf_graph.ir"
),
anf_graph
,
true
);
}
...
...
mindspore/ccsrc/pipeline/jit/static_analysis/evaluator.cc
浏览文件 @
8d419314
...
...
@@ -118,8 +118,9 @@ EvalResultPtr BaseFuncGraphEvaluator::Eval(AnalysisEnginePtr engine, const Abstr
<<
", current function call depth: "
<<
engine
->
function_call_depth
();
AbstractBasePtr
ret_base
=
nullptr
;
engine
->
IncreaseFunctionCallDepth
();
if
(
engine
->
function_call_depth
()
>
MsContext
::
GetInstance
()
->
max_call_depth
())
{
MS_LOG
(
EXCEPTION
)
<<
"Exceed function call depth limit "
<<
MsContext
::
GetInstance
()
->
max_call_depth
()
<<
"."
;
if
(
engine
->
function_call_depth
()
>
MsContext
::
GetInstance
()
->
get_param
<
uint32_t
>
(
MS_CTX_MAX_CALL_DEPTH
))
{
MS_LOG
(
EXCEPTION
)
<<
"Exceed function call depth limit "
<<
MsContext
::
GetInstance
()
->
get_param
<
uint32_t
>
(
MS_CTX_MAX_CALL_DEPTH
)
<<
"."
;
}
std
::
vector
<
AnfNodePtr
>
nodes
=
FastShadowSort
(
func_node
);
for
(
auto
it
=
nodes
.
crbegin
();
it
!=
nodes
.
crend
();
it
++
)
{
...
...
@@ -409,7 +410,7 @@ EvalResultPtr JEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &arg
bparams
.
push_back
(
SensitivityTransform
(
orig_func_
));
auto
context
=
MsContext
::
GetInstance
();
MS_EXCEPTION_IF_NULL
(
context
);
bool
enable_sparse
=
context
->
enable_sparse
(
);
bool
enable_sparse
=
context
->
get_param
<
bool
>
(
MS_CTX_ENABLE_SPARSE
);
(
void
)
std
::
transform
(
args_spec_list
.
begin
(),
args_spec_list
.
end
(),
std
::
back_inserter
(
bparams
),
[
&
enable_sparse
](
const
AbstractBasePtr
&
arg_spec
)
->
AbstractBasePtr
{
if
(
enable_sparse
&&
arg_spec
->
isa
<
AbstractTensor
>
())
{
...
...
mindspore/ccsrc/pipeline/jit/static_analysis/evaluator.h
浏览文件 @
8d419314
...
...
@@ -62,7 +62,7 @@ class Evaluator : public Base {
virtual
EvalResultPtr
AbstractEval
(
const
AbstractBasePtrList
&
args_spec_list
)
{
auto
context
=
MsContext
::
GetInstance
();
MS_EXCEPTION_IF_NULL
(
context
);
bool
enable_sparse
=
context
->
enable_sparse
(
);
bool
enable_sparse
=
context
->
get_param
<
bool
>
(
MS_CTX_ENABLE_SPARSE
);
if
(
!
enable_sparse
)
{
return
nullptr
;
}
...
...
mindspore/ccsrc/pipeline/jit/static_analysis/prim.cc
浏览文件 @
8d419314
...
...
@@ -290,7 +290,7 @@ py::dict ConvertAbstractToPython(const AbstractBasePtr &abs_base) {
if
(
abs_base
->
isa
<
AbstractTensor
>
())
{
auto
arg_tensor
=
dyn_cast
<
AbstractTensor
>
(
abs_base
);
dic
[
"shape"
]
=
arg_tensor
->
shape
()
->
shape
();
if
(
MsContext
::
GetInstance
()
->
execution_mode
(
)
==
kGraphMode
)
{
if
(
MsContext
::
GetInstance
()
->
get_param
<
int
>
(
MS_CTX_EXECUTION_MODE
)
==
kGraphMode
)
{
const
auto
&
min_shape
=
arg_tensor
->
shape
()
->
min_shape
();
const
auto
&
max_shape
=
arg_tensor
->
shape
()
->
max_shape
();
if
(
!
min_shape
.
empty
()
&&
!
max_shape
.
empty
())
{
...
...
mindspore/ccsrc/pipeline/pynative/pynative_execute.cc
浏览文件 @
8d419314
...
...
@@ -558,8 +558,8 @@ py::object RunOpInMs(const OpExecInfoPtr &op_exec_info, PynativeStatusCode *stat
MS_EXCEPTION_IF_NULL
(
op_exec_info
);
MS_LOG
(
INFO
)
<<
"Start run op["
<<
op_exec_info
->
op_name
<<
"] with backend policy ms"
;
auto
ms_context
=
MsContext
::
GetInstance
();
ms_context
->
set_
enable_pynative_infer
(
true
);
std
::
string
device_target
=
ms_context
->
device_target
(
);
ms_context
->
set_
param
<
bool
>
(
MS_CTX_ENABLE_PYNATIVE_INFER
,
true
);
std
::
string
device_target
=
ms_context
->
get_param
<
std
::
string
>
(
MS_CTX_DEVICE_TARGET
);
if
(
device_target
!=
kAscendDevice
&&
device_target
!=
kGPUDevice
)
{
MS_EXCEPTION
(
ArgumentError
)
<<
"Device target ["
<<
device_target
<<
"] is not supported in Pynative mode"
;
}
...
...
@@ -567,7 +567,7 @@ py::object RunOpInMs(const OpExecInfoPtr &op_exec_info, PynativeStatusCode *stat
if
(
session
==
nullptr
)
{
session
=
session
::
SessionFactory
::
Get
().
Create
(
device_target
);
MS_EXCEPTION_IF_NULL
(
session
);
session
->
Init
(
ms_context
->
device_id
(
));
session
->
Init
(
ms_context
->
get_param
<
uint32_t
>
(
MS_CTX_DEVICE_ID
));
}
std
::
vector
<
tensor
::
TensorPtr
>
input_tensors
;
...
...
@@ -578,7 +578,7 @@ py::object RunOpInMs(const OpExecInfoPtr &op_exec_info, PynativeStatusCode *stat
session
->
BuildOpAsync
(
op_exec_info
.
get
(),
graph_info
,
input_tensors
,
tensors_mask
);
EraseValueNodeTensor
(
tensors_mask
,
&
input_tensors
);
py
::
tuple
result
=
session
->
RunOpAsync
(
op_exec_info
.
get
(),
graph_info
,
input_tensors
);
ms_context
->
set_
enable_pynative_infer
(
false
);
ms_context
->
set_
param
<
bool
>
(
MS_CTX_ENABLE_PYNATIVE_INFER
,
false
);
*
status
=
PYNATIVE_SUCCESS
;
MS_LOG
(
INFO
)
<<
"End run op["
<<
op_exec_info
->
op_name
<<
"] with backend policy ms"
;
return
result
;
...
...
@@ -1308,7 +1308,7 @@ void PynativeExecutor::Clear(const std::string &flag) {
// Maybe exit in the pynative runing op, so need reset pynative flag.
auto
ms_context
=
MsContext
::
GetInstance
();
if
(
ms_context
!=
nullptr
)
{
ms_context
->
set_
enable_pynative_infer
(
false
);
ms_context
->
set_
param
<
bool
>
(
MS_CTX_ENABLE_PYNATIVE_INFER
,
false
);
}
ConfigManager
::
GetInstance
().
ResetIterNum
();
return
;
...
...
mindspore/ccsrc/pybind_api/ir/primitive_py.cc
浏览文件 @
8d419314
...
...
@@ -89,7 +89,7 @@ py::tuple check_bprop_out(const py::object &grads_obj, const py::tuple &py_args)
MS_EXCEPTION
(
ValueError
)
<<
"For user define net bprop, the gradients number: "
<<
grads
.
size
()
<<
" is not equal to the args number: "
<<
py_args
.
size
()
-
2
<<
"."
;
}
if
(
MsContext
::
GetInstance
()
->
check_bprop_flag
(
))
{
if
(
MsContext
::
GetInstance
()
->
get_param
<
bool
>
(
MS_CTX_CHECK_BPROP_FLAG
))
{
for
(
size_t
i
=
0
;
i
<
grads
.
size
();
i
++
)
{
if
(
py
::
isinstance
<
tensor
::
Tensor
>
(
py_args
[
i
]))
{
if
(
!
py
::
isinstance
<
tensor
::
Tensor
>
(
grads
[
i
]))
{
...
...
mindspore/ccsrc/runtime/device/ascend/ascend_device_address.cc
浏览文件 @
8d419314
...
...
@@ -154,7 +154,7 @@ bool SyncDeviceToHostAndFloatToFloat64(void *dst, size_t dst_size, const void *s
DeviceAddressPtr
AssignLaunchMemory
(
size_t
size
,
const
std
::
string
&
format
,
TypeId
type
)
{
auto
ms_context
=
MsContext
::
GetInstance
();
MS_EXCEPTION_IF_NULL
(
ms_context
);
auto
device_id
=
ms_context
->
device_id
(
);
auto
device_id
=
ms_context
->
get_param
<
uint32_t
>
(
MS_CTX_DEVICE_ID
);
auto
runtime_instance
=
device
::
KernelRuntimeManager
::
Instance
().
GetKernelRuntime
(
kAscendDevice
,
device_id
);
MS_EXCEPTION_IF_NULL
(
runtime_instance
);
auto
address_ptr
=
runtime_instance
->
AssignSingleOpLaunchMemory
(
size
,
format
,
type
);
...
...
@@ -261,11 +261,12 @@ void AscendDeviceAddress::SyncStream() const {
MS_LOG
(
INFO
)
<<
"Start!"
;
auto
ms_context
=
MsContext
::
GetInstance
();
MS_EXCEPTION_IF_NULL
(
ms_context
);
if
(
ms_context
->
execution_mode
()
!=
kPynativeMode
&&
!
ms_context
->
enable_pynative_infer
())
{
if
(
ms_context
->
get_param
<
int
>
(
MS_CTX_EXECUTION_MODE
)
!=
kPynativeMode
&&
!
ms_context
->
get_param
<
bool
>
(
MS_CTX_ENABLE_PYNATIVE_INFER
))
{
MS_LOG
(
INFO
)
<<
"Finish!"
;
return
;
}
auto
device_id
=
ms_context
->
device_id
(
);
auto
device_id
=
ms_context
->
get_param
<
uint32_t
>
(
MS_CTX_DEVICE_ID
);
auto
runtime_instance
=
device
::
KernelRuntimeManager
::
Instance
().
GetKernelRuntime
(
kAscendDevice
,
device_id
);
MS_EXCEPTION_IF_NULL
(
runtime_instance
);
auto
ret
=
runtime_instance
->
SyncStream
();
...
...
@@ -348,7 +349,7 @@ void AscendDeviceAddress::LaunchTransData(kernel::KernelModPtr kernel_mod_ptr, v
}
auto
ms_context
=
MsContext
::
GetInstance
();
MS_EXCEPTION_IF_NULL
(
ms_context
);
auto
device_id
=
ms_context
->
device_id
(
);
auto
device_id
=
ms_context
->
get_param
<
uint32_t
>
(
MS_CTX_DEVICE_ID
);
auto
runtime_instance
=
device
::
KernelRuntimeManager
::
Instance
().
GetKernelRuntime
(
kAscendDevice
,
device_id
);
MS_EXCEPTION_IF_NULL
(
runtime_instance
);
auto
ret
=
...
...
@@ -475,7 +476,8 @@ bool AscendDeviceAddress::SyncDeviceToHostAndConvertFormat(const ShapeVector &sh
std
::
vector
<
size_t
>
device_shape
=
GetDeviceShape
(
&
host_shape
);
auto
ms_context
=
MsContext
::
GetInstance
();
MS_EXCEPTION_IF_NULL
(
ms_context
);
if
(
ms_context
->
execution_mode
()
!=
kGraphMode
&&
ms_context
->
execution_mode
()
!=
kPynativeMode
&&
if
(
ms_context
->
get_param
<
int
>
(
MS_CTX_EXECUTION_MODE
)
!=
kGraphMode
&&
ms_context
->
get_param
<
int
>
(
MS_CTX_EXECUTION_MODE
)
!=
kPynativeMode
&&
type_id_name_map
.
find
(
type_id_
)
!=
type_id_name_map
.
end
())
{
std
::
pair
<
std
::
string
,
std
::
string
>
type_format
=
std
::
make_pair
(
type_id_name_map
.
at
(
type_id_
),
format_
);
if
(
use_trans_data
.
find
(
type_format
)
!=
use_trans_data
.
end
())
{
...
...
mindspore/ccsrc/runtime/device/ascend/ascend_kernel_runtime.cc
浏览文件 @
8d419314
...
...
@@ -158,7 +158,7 @@ void AscendKernelRuntime::ClearGraphRuntimeResource(uint32_t graph_id, const std
bool
AscendKernelRuntime
::
NeedDestroyHccl
()
{
auto
context_ptr
=
MsContext
::
GetInstance
();
MS_EXCEPTION_IF_NULL
(
context_ptr
);
if
(
!
context_ptr
->
enable_hccl
(
))
{
if
(
!
context_ptr
->
get_param
<
bool
>
(
MS_CTX_ENABLE_HCCL
))
{
MS_LOG
(
INFO
)
<<
"Hccl is not enabled"
;
return
false
;
}
...
...
@@ -177,7 +177,7 @@ void AscendKernelRuntime::ReleaseDeviceRes() {
auto
context_ptr
=
MsContext
::
GetInstance
();
MS_EXCEPTION_IF_NULL
(
context_ptr
);
auto
ret
=
rtSetDevice
(
context_ptr
->
device_id
(
));
auto
ret
=
rtSetDevice
(
context_ptr
->
get_param
<
uint32_t
>
(
MS_CTX_DEVICE_ID
));
if
(
ret
!=
RT_ERROR_NONE
)
{
MS_EXCEPTION
(
DeviceProcessError
)
<<
"Call rtSetDevice, ret["
<<
static_cast
<
int
>
(
ret
)
<<
"]"
;
}
...
...
@@ -461,12 +461,12 @@ bool AscendKernelRuntime::GenTask(const session::KernelGraph *graph) {
MS_LOG
(
INFO
)
<<
"GenTask start. GraphId:"
<<
graph
->
graph_id
();
auto
context_ptr
=
MsContext
::
GetInstance
();
MS_EXCEPTION_IF_NULL
(
context_ptr
);
bool
is_task_sink
=
context_ptr
->
enable_task_sink
(
);
bool
is_task_sink
=
context_ptr
->
get_param
<
bool
>
(
MS_CTX_ENABLE_TASK_SINK
);
if
(
!
is_task_sink
)
{
return
true
;
}
#ifdef MEM_REUSE_DEBUG
if
(
!
context_ptr
->
enable_mem_reuse
(
))
{
if
(
!
context_ptr
->
get_param
<
bool
>
(
MS_CTX_ENABLE_MEM_REUSE
))
{
// Get normal graph ir for memreuse
mindspore
::
memreuse
::
MemReuseChecker
::
GetInstance
().
CheckNormalIR
(
graph
);
}
...
...
@@ -518,7 +518,7 @@ bool AscendKernelRuntime::LoadTask(const session::KernelGraph *graph) {
MS_LOG
(
INFO
)
<<
"LoadTask start. GraphId:"
<<
graph
->
graph_id
();
auto
context_ptr
=
MsContext
::
GetInstance
();
MS_EXCEPTION_IF_NULL
(
context_ptr
);
bool
is_task_sink
=
context_ptr
->
enable_task_sink
(
);
bool
is_task_sink
=
context_ptr
->
get_param
<
bool
>
(
MS_CTX_ENABLE_TASK_SINK
);
if
(
!
is_task_sink
)
{
return
true
;
}
...
...
@@ -658,7 +658,7 @@ bool AscendKernelRuntime::InitDevice() {
MS_LOG
(
ERROR
)
<<
"Get MsContext instance failed"
;
return
false
;
}
if
(
context_ptr
->
enable_hccl
(
))
{
if
(
context_ptr
->
get_param
<
bool
>
(
MS_CTX_ENABLE_HCCL
))
{
if
(
!
HcclInit
())
{
MS_LOG
(
ERROR
)
<<
"HcclInit init failed"
;
return
false
;
...
...
@@ -746,7 +746,7 @@ bool AscendKernelRuntime::DestroyHccl() {
return
false
;
}
MS_LOG
(
INFO
)
<<
"Hccl destroy successful, status = "
<<
res
<<
"."
;
context_ptr
->
set_
enable_hccl
(
false
);
context_ptr
->
set_
param
<
bool
>
(
MS_CTX_ENABLE_HCCL
,
false
);
return
true
;
}
...
...
mindspore/ccsrc/runtime/device/ascend/ascend_memory_manager.cc
浏览文件 @
8d419314
...
...
@@ -43,7 +43,7 @@ void AscendMemoryManager::MallocDeviceMemory() {
uint64_t
AscendMemoryManager
::
GetDeviceMemSizeFromContext
()
{
auto
context
=
MsContext
::
GetInstance
();
MS_EXCEPTION_IF_NULL
(
context
);
auto
variable_memory_max_size
=
context
->
variable_memory_max_size
(
);
auto
variable_memory_max_size
=
context
->
get_param
<
std
::
string
>
(
MS_CTX_VARIABLE_MEMORY_MAX_SIZE
);
if
(
variable_memory_max_size
==
"0"
)
{
return
0
;
}
...
...
mindspore/ccsrc/runtime/device/ascend/ascend_stream_assign.cc
浏览文件 @
8d419314
...
...
@@ -1373,7 +1373,7 @@ vector<CNodePtr>::iterator AscendStreamAssign::FindTargetOp(vector<CNodePtr>::it
bool
AscendStreamAssign
::
IsTaskSink
()
{
auto
ms_context
=
MsContext
::
GetInstance
();
MS_EXCEPTION_IF_NULL
(
ms_context
);
if
(
!
ms_context
->
enable_task_sink
(
))
{
if
(
!
ms_context
->
get_param
<
bool
>
(
MS_CTX_ENABLE_TASK_SINK
))
{
MS_LOG
(
INFO
)
<<
"Task sink mode is not enable"
;
return
false
;
}
else
{
...
...
mindspore/ccsrc/runtime/device/ascend/dump/data_dumper.cc
浏览文件 @
8d419314
...
...
@@ -117,7 +117,7 @@ void DataDumper::SetOpMappingInfo(NotNull<aicpu::dump::OpMappingInfo *> dump_inf
if
(
!
dump_path
.
has_value
())
{
MS_LOG
(
EXCEPTION
)
<<
"Dump path invalid"
;
}
auto
device_id
=
context_ptr
->
device_id
(
);
auto
device_id
=
context_ptr
->
get_param
<
uint32_t
>
(
MS_CTX_DEVICE_ID
);
dump_info
->
set_dump_path
(
"/"
+
dump_path
.
value
()
+
"_"
+
std
::
to_string
(
device_id
)
+
"/"
);
MS_LOG
(
INFO
)
<<
"[DataDump] dump_path:"
<<
dump_path
.
value
();
...
...
mindspore/ccsrc/runtime/device/ascend/kernel_select_ascend.cc
浏览文件 @
8d419314
...
...
@@ -363,7 +363,7 @@ void PrecisionReduce(const std::vector<int> &node_mix_precision_datatype_index,
*
precision_reduce
=
false
;
return
;
}
if
(
context_ptr
->
enable_reduce_precision
(
))
{
if
(
context_ptr
->
get_param
<
bool
>
(
MS_CTX_ENABLE_REDUCE_PRECISION
))
{
selected_ret
=
RaiseOrReduceDataTypePrecisionSelect
(
node_mix_precision_datatype_index
,
node_mix_precision_datatype
,
kernel_support_datatype
,
&
kernel_match_datatype_idx_copy
);
}
...
...
mindspore/ccsrc/runtime/device/ascend/profiling/profiling_manager.cc
浏览文件 @
8d419314
...
...
@@ -117,7 +117,7 @@ bool ProfilingManager::StartupProfiling(uint32_t device_id) {
}
auto
context
=
MsContext
::
GetInstance
();
MS_EXCEPTION_IF_NULL
(
context
);
const
string
prof_options_str
=
context
->
profiling_options
(
);
const
string
prof_options_str
=
context
->
get_param
<
std
::
string
>
(
MS_CTX_PROFILING_OPTIONS
);
std
::
vector
<
string
>
opts
=
Split
(
prof_options_str
,
':'
);
if
(
opts
.
empty
())
{
MS_LOG
(
WARNING
)
<<
"Profiling is enabled, but profiling option is not set!"
;
...
...
mindspore/ccsrc/runtime/device/ascend/profiling/profiling_manager.h
浏览文件 @
8d419314
...
...
@@ -41,7 +41,7 @@ class ProfilingManager {
inline
bool
IsProfiling
()
const
{
auto
context
=
MsContext
::
GetInstance
();
MS_EXCEPTION_IF_NULL
(
context
);
return
context
->
enable_profiling
(
);
return
context
->
get_param
<
bool
>
(
MS_CTX_ENABLE_PROFILING
);
}
protected:
...
...
mindspore/ccsrc/runtime/device/ascend/profiling/profiling_utils.cc
浏览文件 @
8d419314
...
...
@@ -342,12 +342,12 @@ void ProfilingUtils::ReportProfilingData(const std::vector<uint32_t> &task_ids,
auto
context
=
MsContext
::
GetInstance
();
MS_EXCEPTION_IF_NULL
(
context
);
TaskDescReporter
task_reporter
(
context
->
device_id
(
),
"vm.task_desc_info"
,
ret
->
second
);
TaskDescReporter
task_reporter
(
context
->
get_param
<
uint32_t
>
(
MS_CTX_DEVICE_ID
),
"vm.task_desc_info"
,
ret
->
second
);
task_reporter
.
set_task_ids
(
task_ids
);
task_reporter
.
set_stream_ids
(
stream_ids
);
task_reporter
.
ReportData
();
GraphDescReporter
graph_reporter
(
context
->
device_id
(
),
"vm.graph_desc_info"
,
ret
->
second
);
GraphDescReporter
graph_reporter
(
context
->
get_param
<
uint32_t
>
(
MS_CTX_DEVICE_ID
),
"vm.graph_desc_info"
,
ret
->
second
);
graph_profiling_cnode_
.
erase
(
ret
);
graph_reporter
.
ReportData
();
...
...
@@ -357,7 +357,7 @@ void ProfilingUtils::ReportProfilingData(const std::vector<uint32_t> &task_ids,
MS_LOG
(
ERROR
)
<<
"Graph id not found in graph_point"
;
return
;
}
PointReporter
point_reporter
(
context
->
device_id
(
),
"vm.point"
);
PointReporter
point_reporter
(
context
->
get_param
<
uint32_t
>
(
MS_CTX_DEVICE_ID
),
"vm.point"
);
for
(
const
auto
&
point
:
point_iter
->
second
)
{
point_reporter
.
AddReportData
(
point
);
}
...
...
mindspore/ccsrc/runtime/device/gpu/gpu_kernel_runtime.cc
浏览文件 @
8d419314
...
...
@@ -416,7 +416,7 @@ void GPUKernelRuntime::AssignMemory(session::KernelGraph *graph) {
mem_manager_
->
ResetDynamicMemory
();
AssignStaticMemoryInput
(
graph
);
AssignStaticMemoryValueNode
(
graph
);
bool
is_enable_dynamic_mem
=
context_ptr
->
enable_dynamic_mem_pool
(
);
bool
is_enable_dynamic_mem
=
context_ptr
->
get_param
<
bool
>
(
MS_CTX_ENABLE_DYNAMIC_MEM_POOL
);
if
(
is_enable_dynamic_mem
)
{
// Use the dynamic memory pool.
InitKernelRefCount
(
graph
);
...
...
@@ -435,8 +435,8 @@ bool GPUKernelRuntime::Run(session::KernelGraph *graph, Debugger *debugger) {
bool
ret
=
true
;
auto
context_ptr
=
MsContext
::
GetInstance
();
MS_EXCEPTION_IF_NULL
(
context_ptr
);
bool
is_enable_dynamic_mem
=
context_ptr
->
enable_dynamic_mem_pool
(
);
bool
is_enable_pynative_infer
=
context_ptr
->
enable_pynative_infer
(
);
bool
is_enable_dynamic_mem
=
context_ptr
->
get_param
<
bool
>
(
MS_CTX_ENABLE_DYNAMIC_MEM_POOL
);
bool
is_enable_pynative_infer
=
context_ptr
->
get_param
<
bool
>
(
MS_CTX_ENABLE_PYNATIVE_INFER
);
if
(
is_enable_dynamic_mem
&&
!
is_enable_pynative_infer
)
{
auto
graph_id
=
graph
->
graph_id
();
auto
iter
=
mem_swap_map_
.
find
(
graph_id
);
...
...
mindspore/ccsrc/runtime/device/gpu/gpu_memory_allocator.cc
浏览文件 @
8d419314
...
...
@@ -29,7 +29,7 @@ bool GPUMemoryAllocator::Init() {
size_t
free_size
=
CudaDriver
::
free_mem_size
();
auto
context_ptr
=
MsContext
::
GetInstance
();
MS_EXCEPTION_IF_NULL
(
context_ptr
);
limited_device_memory_
=
context_ptr
->
max_device_memory
(
);
limited_device_memory_
=
context_ptr
->
get_param
<
float
>
(
MS_CTX_MAX_DEVICE_MEMORY
);
available_device_memory_
=
FloatToSize
(
limited_device_memory_
*
1024
*
1024
*
1024
);
if
(
total_size
>
0
&&
free_size
>
0
&&
available_device_memory_
>
0
)
{
MS_LOG
(
INFO
)
<<
"GPU device total memory size "
<<
total_size
<<
", current free memory size "
<<
free_size
...
...
@@ -44,7 +44,7 @@ bool GPUMemoryAllocator::Init() {
void
GPUMemoryAllocator
::
CheckMaxDeviceMemory
()
const
{
auto
context_ptr
=
MsContext
::
GetInstance
();
MS_EXCEPTION_IF_NULL
(
context_ptr
);
auto
max_device_memory
=
context_ptr
->
max_device_memory
(
);
auto
max_device_memory
=
context_ptr
->
get_param
<
float
>
(
MS_CTX_MAX_DEVICE_MEMORY
);
// Currently not support modifying the max device memory.
if
(
limited_device_memory_
!=
max_device_memory
)
{
MS_LOG
(
EXCEPTION
)
...
...
mindspore/ccsrc/runtime/device/gpu/gpu_memory_manager.cc
浏览文件 @
8d419314
...
...
@@ -37,7 +37,7 @@ void GPUMemoryManager::MallocDeviceMemory() {
auto
context_ptr
=
MsContext
::
GetInstance
();
MS_EXCEPTION_IF_NULL
(
context_ptr
);
// If use the dynamic memory pool, then alloc the first memory block to init.
if
(
context_ptr
->
enable_dynamic_mem_pool
(
))
{
if
(
context_ptr
->
get_param
<
bool
>
(
MS_CTX_ENABLE_DYNAMIC_MEM_POOL
))
{
auto
device_addr
=
MallocMemFromMemPool
(
1
);
if
(
!
device_addr
)
{
MS_LOG
(
EXCEPTION
)
<<
"Dynamic memory pool init error."
;
...
...
@@ -65,7 +65,7 @@ void GPUMemoryManager::FreeDeviceMemory() {
uint8_t
*
GPUMemoryManager
::
MallocStaticMem
(
size_t
size
,
bool
)
{
auto
context_ptr
=
MsContext
::
GetInstance
();
MS_EXCEPTION_IF_NULL
(
context_ptr
);
if
(
context_ptr
->
enable_dynamic_mem_pool
(
))
{
if
(
context_ptr
->
get_param
<
bool
>
(
MS_CTX_ENABLE_DYNAMIC_MEM_POOL
))
{
auto
device_ptr
=
MallocMemFromMemPool
(
size
);
MS_EXCEPTION_IF_NULL
(
device_ptr
);
return
AddressOffset
(
device_ptr
,
0
);
...
...
mindspore/ccsrc/runtime/device/gpu/kernel_info_setter.cc
浏览文件 @
8d419314
...
...
@@ -162,7 +162,7 @@ void SetTensorDeviceInfo(const kernel::KernelBuildInfo &selected_kernel_info, co
bool
IsNeedProcessFormatInfo
(
const
CNodePtr
&
kernel_node
,
const
std
::
vector
<
TypeId
>
&
inputs_type
)
{
auto
ms_context
=
MsContext
::
GetInstance
();
MS_EXCEPTION_IF_NULL
(
ms_context
);
if
(
ms_context
->
execution_mode
(
)
==
kPynativeMode
)
{
if
(
ms_context
->
get_param
<
int
>
(
MS_CTX_EXECUTION_MODE
)
==
kPynativeMode
)
{
return
false
;
}
if
(
!
AnfAlgo
::
IsRealCNodeKernel
(
kernel_node
))
{
...
...
mindspore/ccsrc/runtime/device/kernel_adjust.cc
浏览文件 @
8d419314
...
...
@@ -60,8 +60,8 @@ void KernelAdjust::ReorderGetNext(const std::shared_ptr<session::KernelGraph> &k
bool
KernelAdjust
::
NeedInsertSwitch
()
{
auto
context_ptr
=
MsContext
::
GetInstance
();
MS_EXCEPTION_IF_NULL
(
context_ptr
);
return
(
context_ptr
->
enable_task_sink
()
&&
context_ptr
->
loop_sink_flag
(
)
&&
ConfigManager
::
GetInstance
().
iter_num
()
>
1
);
return
(
context_ptr
->
get_param
<
bool
>
(
MS_CTX_ENABLE_TASK_SINK
)
&&
context_ptr
->
get_param
<
bool
>
(
MS_CTX_ENABLE_LOOP_SINK
)
&&
ConfigManager
::
GetInstance
().
iter_num
()
>
1
);
}
CNodePtr
KernelAdjust
::
CreateSendApplyKernel
(
const
std
::
shared_ptr
<
session
::
KernelGraph
>
&
graph_ptr
,
...
...
mindspore/ccsrc/runtime/device/kernel_runtime.cc
浏览文件 @
8d419314
...
...
@@ -50,7 +50,7 @@ bool KernelRuntime::Run(session::KernelGraph *graph, Debugger *debugger) {
struct
timeval
start_time
,
end_time
;
(
void
)
gettimeofday
(
&
start_time
,
nullptr
);
#endif
bool
is_task_sink
=
context_ptr
->
enable_task_sink
(
);
bool
is_task_sink
=
context_ptr
->
get_param
<
bool
>
(
MS_CTX_ENABLE_TASK_SINK
);
if
(
is_task_sink
)
{
ret
=
RunTask
(
graph
);
}
else
{
...
...
@@ -502,7 +502,7 @@ void KernelRuntime::AssignCommunicationNodeOutputMem(MemType type, const AnfNode
MS_LOG
(
INFO
)
<<
"communication op addr exist"
;
continue
;
}
if
(
context_ptr
->
enable_hccl
(
))
{
if
(
context_ptr
->
get_param
<
bool
>
(
MS_CTX_ENABLE_HCCL
))
{
mem_size
=
mem_manager_
->
GetCommonAlignSize
(
mem_size
);
}
total_size
+=
mem_size
;
...
...
@@ -646,7 +646,8 @@ void KernelRuntime::AssignValueNodeTensor(const ValueNodePtr &value_node, const
DeviceAddressPtr
address
=
nullptr
;
address
=
CreateDeviceAddress
(
nullptr
,
node_size
,
output_format
,
output_type_id
);
MS_EXCEPTION_IF_NULL
(
address
);
if
(
ms_context
->
enable_pynative_infer
()
&&
!
mem_manager_
->
MallocMemFromMemPool
(
address
,
node_size
))
{
if
(
ms_context
->
get_param
<
bool
>
(
MS_CTX_ENABLE_PYNATIVE_INFER
)
&&
!
mem_manager_
->
MallocMemFromMemPool
(
address
,
node_size
))
{
MS_LOG
(
EXCEPTION
)
<<
"Cannot alloc address from memory pool when tensor size is: "
<<
node_size
;
}
else
if
(
mem_manager_
->
MallocMem
(
kStaticMem
,
node_size
,
address
)
==
nullptr
)
{
MS_LOG
(
EXCEPTION
)
<<
"Cannot alloc address when flag is: "
<<
kStaticMem
<<
", tensor size is: "
<<
node_size
;
...
...
@@ -682,7 +683,8 @@ void KernelRuntime::AssignStaticMemoryValueNode(session::KernelGraph *graph) {
DeviceAddressPtr
address
=
nullptr
;
address
=
CreateDeviceAddress
(
nullptr
,
tensor_size
,
kOpFormat_DEFAULT
,
kNumberTypeUInt8
);
MS_EXCEPTION_IF_NULL
(
address
);
if
(
ms_context
->
enable_pynative_infer
()
&&
!
mem_manager_
->
MallocMemFromMemPool
(
address
,
tensor_size
))
{
if
(
ms_context
->
get_param
<
bool
>
(
MS_CTX_ENABLE_PYNATIVE_INFER
)
&&
!
mem_manager_
->
MallocMemFromMemPool
(
address
,
tensor_size
))
{
MS_LOG
(
EXCEPTION
)
<<
"Cannot alloc address from memory pool when tensor size is: "
<<
tensor_size
;
}
else
if
(
mem_manager_
->
MallocMem
(
kStaticMem
,
tensor_size
,
address
)
==
nullptr
)
{
MS_LOG
(
EXCEPTION
)
<<
"Cannot alloc address when flag is: "
<<
kStaticMem
<<
", tensor size is: "
<<
tensor_size
;
...
...
@@ -701,7 +703,7 @@ void KernelRuntime::AssignDynamicMemory(session::KernelGraph *graph) {
MS_EXCEPTION_IF_NULL
(
mem_manager_
);
auto
context_ptr
=
MsContext
::
GetInstance
();
MS_EXCEPTION_IF_NULL
(
context_ptr
);
bool
is_enable_mem_reuse
=
context_ptr
->
enable_mem_reuse
(
);
bool
is_enable_mem_reuse
=
context_ptr
->
get_param
<
bool
>
(
MS_CTX_ENABLE_MEM_REUSE
);
auto
mem_type
=
kDynamicMem
;
if
(
is_enable_mem_reuse
)
{
mem_manager_
->
MallocReusedDynamicMem
(
graph
);
...
...
mindspore/ccsrc/runtime/device/memory_manager.cc
浏览文件 @
8d419314
...
...
@@ -54,7 +54,7 @@ uint8_t *MemoryManager::MallocOutputMem(const AnfNodePtr &node, size_t index, Me
uint8_t
*
ptr
=
nullptr
;
if
(
AnfAlgo
::
IsCommunicationOp
(
node
))
{
bool
communication_mem
=
false
;
if
(
context_ptr
->
enable_hccl
(
))
{
if
(
context_ptr
->
get_param
<
bool
>
(
MS_CTX_ENABLE_HCCL
))
{
communication_mem
=
true
;
}
if
(
type
==
kStaticMem
)
{
...
...
mindspore/ccsrc/transform/graph_ir/convert.cc
浏览文件 @
8d419314
...
...
@@ -1070,7 +1070,7 @@ void DfGraphConvertor::ProcessSubgraph(AnfNodePtr node, const std::vector<AnfNod
convertor
.
inputs_
=
inputs
;
(
void
)
convertor
.
ConvertAllNode
().
BuildGraph
();
std
::
string
name
=
graph_node
->
ToString
()
+
"_ge_graph.dot"
;
if
(
MsContext
::
GetInstance
()
->
save_graphs_flag
(
))
{
if
(
MsContext
::
GetInstance
()
->
get_param
<
bool
>
(
MS_CTX_SAVE_GRAPHS_FLAG
))
{
convertor
.
DrawComputeGraph
(
name
);
}
branches_map_
[
node
.
get
()]
=
*
(
convertor
.
df_graph_
);
...
...
mindspore/ccsrc/utils/context/context_extends.cc
浏览文件 @
8d419314
...
...
@@ -41,13 +41,13 @@ bool OpenTsd(const std::shared_ptr<MsContext> &ms_context_ptr) {
MS_LOG
(
EXCEPTION
)
<<
"nullptr"
;
}
if
(
ms_context_ptr
->
is_pynative_ge_init
(
))
{
if
(
ms_context_ptr
->
get_param
<
bool
>
(
MS_CTX_IS_PYNATIVE_GE_INIT
))
{
return
true
;
}
if
(
ms_context_ptr
->
tsd_ref
(
))
{
if
(
ms_context_ptr
->
get_param
<
uint32_t
>
(
MS_CTX_TSD_REF
))
{
MS_LOG
(
DEBUG
)
<<
"TDT Dataset client is already opened."
;
ms_context_ptr
->
set_tsd_ref
(
"++"
);
ms_context_ptr
->
increase_param
<
uint32_t
>
(
MS_CTX_TSD_REF
);
return
true
;
}
...
...
@@ -59,7 +59,7 @@ bool OpenTsd(const std::shared_ptr<MsContext> &ms_context_ptr) {
unsigned
int
device_id
;
unsigned
int
rank_size
=
1
;
device_id
=
ms_context_ptr
->
device_id
(
);
device_id
=
ms_context_ptr
->
get_param
<
uint32_t
>
(
MS_CTX_DEVICE_ID
);
auto
rank_size_env
=
common
::
GetEnv
(
"RANK_SIZE"
);
if
(
rank_size_env
.
empty
())
{
...
...
@@ -79,7 +79,7 @@ bool OpenTsd(const std::shared_ptr<MsContext> &ms_context_ptr) {
MS_LOG
(
EXCEPTION
)
<<
"Device "
<<
device_id
<<
" is occupied, open tsd failed, status = "
<<
status
<<
"."
;
return
false
;
}
ms_context_ptr
->
set_tsd_ref
(
"++"
);
ms_context_ptr
->
increase_param
<
uint32_t
>
(
MS_CTX_TSD_REF
);
#ifdef ENABLE_TDTQUE
int32_t
initStatus
=
tdt
::
TdtHostInit
(
device_id
);
if
(
initStatus
!=
TDT_OK_CODE
)
{
...
...
@@ -88,7 +88,8 @@ bool OpenTsd(const std::shared_ptr<MsContext> &ms_context_ptr) {
}
ms_context_ptr
->
tdt_print_
=
std
::
thread
(
TensorPrint
());
#endif
MS_LOG
(
INFO
)
<<
"Open and init tsd successful, tsd reference = "
<<
ms_context_ptr
->
tsd_ref
()
<<
"."
;
MS_LOG
(
INFO
)
<<
"Open and init tsd successful, tsd reference = "
<<
ms_context_ptr
->
get_param
<
uint32_t
>
(
MS_CTX_TSD_REF
)
<<
"."
;
return
true
;
}
...
...
@@ -96,12 +97,12 @@ bool CloseTsd(const std::shared_ptr<MsContext> &ms_context_ptr, bool force) {
if
(
ms_context_ptr
==
nullptr
)
{
MS_LOG
(
EXCEPTION
)
<<
"nullptr"
;
}
if
(
ms_context_ptr
->
tsd_ref
(
)
==
0
)
{
if
(
ms_context_ptr
->
get_param
<
uint32_t
>
(
MS_CTX_TSD_REF
)
==
0
)
{
return
true
;
}
ms_context_ptr
->
set_tsd_ref
(
"--"
);
if
(
force
||
ms_context_ptr
->
tsd_ref
(
)
==
0
)
{
ms_context_ptr
->
set_
tsd_ref
(
" "
);
ms_context_ptr
->
decrease_param
<
uint32_t
>
(
MS_CTX_TSD_REF
);
if
(
force
||
ms_context_ptr
->
get_param
<
uint32_t
>
(
MS_CTX_TSD_REF
)
==
0
)
{
ms_context_ptr
->
set_
param
<
uint32_t
>
(
MS_CTX_TSD_REF
,
0
);
#ifdef ENABLE_TDTQUE
int32_t
stopStatus
=
tdt
::
TdtHostStop
(
KNpuLog
);
if
(
stopStatus
!=
TDT_OK_CODE
)
{
...
...
@@ -123,17 +124,17 @@ bool CloseTsd(const std::shared_ptr<MsContext> &ms_context_ptr, bool force) {
MS_LOG
(
ERROR
)
<<
"tdt thread join failed: "
<<
e
.
what
();
}
#endif
auto
device_id
=
ms_context_ptr
->
device_id
(
);
auto
device_id
=
ms_context_ptr
->
get_param
<
uint32_t
>
(
MS_CTX_DEVICE_ID
);
TDT_StatusT
status
=
TsdClose
(
device_id
);
if
(
status
!=
TDT_OK
)
{
MS_LOG
(
EXCEPTION
)
<<
"Close tsd failed, status = "
<<
status
<<
"."
;
return
false
;
}
ms_context_ptr
->
set_p
ynative_ge_init
(
false
);
ms_context_ptr
->
set_p
aram
<
bool
>
(
MS_CTX_IS_PYNATIVE_GE_INIT
,
false
);
MS_LOG
(
INFO
)
<<
"Destroy and close tsd successful, status = "
<<
status
<<
"."
;
}
else
{
MS_LOG
(
DEBUG
)
<<
"TDT Dataset client is used, no need to close, tsd reference = "
<<
ms_context_ptr
->
tsd_ref
()
<<
"."
;
MS_LOG
(
DEBUG
)
<<
"TDT Dataset client is used, no need to close, tsd reference = "
<<
ms_context_ptr
->
get_param
<
uint32_t
>
(
MS_CTX_TSD_REF
)
<<
"."
;
}
return
true
;
...
...
@@ -159,14 +160,14 @@ void GetGeOptions(const std::shared_ptr<MsContext> &ms_context_ptr, std::map<std
}
#ifdef ENABLE_GE
(
*
ge_options
)[
"device_id"
]
=
"0"
;
(
*
ge_options
)[
"ge.exec.enableDump"
]
=
std
::
to_string
(
ms_context_ptr
->
enable_dump
(
));
(
*
ge_options
)[
"ge.exec.dumpPath"
]
=
ms_context_ptr
->
save_dump_path
(
);
(
*
ge_options
)[
"ge.exec.enableDump"
]
=
std
::
to_string
(
ms_context_ptr
->
get_param
<
bool
>
(
MS_CTX_ENABLE_DUMP
));
(
*
ge_options
)[
"ge.exec.dumpPath"
]
=
ms_context_ptr
->
get_param
<
std
::
string
>
(
MS_CTX_SAVE_DUMP_PATH
);
(
*
ge_options
)[
"ge.exec.dumpMode"
]
=
"output"
;
MS_LOG
(
INFO
)
<<
"The enable dump state is "
<<
std
::
to_string
(
ms_context_ptr
->
enable_dump
(
))
<<
" and save dump path is "
<<
ms_context_ptr
->
save_dump_path
(
)
<<
"."
;
(
*
ge_options
)[
"ge.exec.profilingMode"
]
=
std
::
to_string
(
ms_context_ptr
->
enable_profiling
(
));
if
(
ms_context_ptr
->
enable_profiling
(
))
{
(
*
ge_options
)[
"ge.exec.profilingOptions"
]
=
ms_context_ptr
->
profiling_options
(
);
MS_LOG
(
INFO
)
<<
"The enable dump state is "
<<
std
::
to_string
(
ms_context_ptr
->
get_param
<
bool
>
(
MS_CTX_ENABLE_DUMP
))
<<
" and save dump path is "
<<
ms_context_ptr
->
get_param
<
std
::
string
>
(
MS_CTX_SAVE_DUMP_PATH
)
<<
"."
;
(
*
ge_options
)[
"ge.exec.profilingMode"
]
=
std
::
to_string
(
ms_context_ptr
->
get_param
<
bool
>
(
MS_CTX_ENABLE_PROFILING
));
if
(
ms_context_ptr
->
get_param
<
bool
>
(
MS_CTX_ENABLE_PROFILING
))
{
(
*
ge_options
)[
"ge.exec.profilingOptions"
]
=
ms_context_ptr
->
get_param
<
std
::
string
>
(
MS_CTX_PROFILING_OPTIONS
);
}
(
*
ge_options
)[
"rank_table_file"
]
=
""
;
...
...
@@ -178,12 +179,12 @@ void GetGeOptions(const std::shared_ptr<MsContext> &ms_context_ptr, std::map<std
}
(
*
ge_options
)[
"graphType"
]
=
"1"
;
if
(
ms_context_ptr
->
g
raph_memory_max_size
(
)
!=
"0"
)
{
(
*
ge_options
)[
"ge.graphMemoryMaxSize"
]
=
ms_context_ptr
->
g
raph_memory_max_size
(
);
if
(
ms_context_ptr
->
g
et_param
<
std
::
string
>
(
MS_CTX_GRAPH_MEMORY_MAX_SIZE
)
!=
"0"
)
{
(
*
ge_options
)[
"ge.graphMemoryMaxSize"
]
=
ms_context_ptr
->
g
et_param
<
std
::
string
>
(
MS_CTX_GRAPH_MEMORY_MAX_SIZE
);
}
if
(
ms_context_ptr
->
variable_memory_max_size
(
)
!=
"0"
)
{
(
*
ge_options
)[
"ge.variableMemoryMaxSize"
]
=
ms_context_ptr
->
variable_memory_max_size
(
);
if
(
ms_context_ptr
->
get_param
<
std
::
string
>
(
MS_CTX_VARIABLE_MEMORY_MAX_SIZE
)
!=
"0"
)
{
(
*
ge_options
)[
"ge.variableMemoryMaxSize"
]
=
ms_context_ptr
->
get_param
<
std
::
string
>
(
MS_CTX_VARIABLE_MEMORY_MAX_SIZE
);
}
#if ENABLE_TRAIN == 1
...
...
@@ -224,7 +225,7 @@ void GetGeOptions(const std::shared_ptr<MsContext> &ms_context_ptr, std::map<std
}
// Enable auto mixed precision according to the context options
if
(
ms_context_ptr
->
auto_mixed_precision_flag
(
))
{
if
(
ms_context_ptr
->
get_param
<
bool
>
(
MS_CTX_AUTO_MIXED_PRECISION_FLAG
))
{
(
*
ge_options
)[
"ge.exec.precision_mode"
]
=
"allow_mix_precision"
;
}
else
{
(
*
ge_options
)[
"ge.exec.precision_mode"
]
=
"allow_fp32_to_fp16"
;
...
...
@@ -240,7 +241,7 @@ void SetHcclOptions(const std::shared_ptr<MsContext> &ms_context_ptr, std::map<s
}
auto
env_table_file
=
common
::
GetEnv
(
"RANK_TABLE_FILE"
);
auto
env_rank_id
=
common
::
GetEnv
(
"RANK_ID"
);
auto
env_device_id
=
std
::
to_string
(
ms_context_ptr
->
device_id
(
));
auto
env_device_id
=
std
::
to_string
(
ms_context_ptr
->
get_param
<
uint32_t
>
(
MS_CTX_DEVICE_ID
));
if
(
!
(
env_table_file
.
empty
()
||
env_rank_id
.
empty
()))
{
MS_LOG
(
INFO
)
<<
"Initialize Ge for distribute parameter"
;
MS_LOG
(
INFO
)
<<
"Use hccl, make sure hccl lib is set in OPTION_EXEC_EXTERN_PLUGIN_PATH."
;
...
...
@@ -275,12 +276,12 @@ bool InitGe(const std::shared_ptr<MsContext> &ms_context_ptr) {
MS_LOG
(
EXCEPTION
)
<<
"nullptr"
;
}
#ifdef ENABLE_GE
if
(
ms_context_ptr
->
is_pynative_ge_init
(
))
{
if
(
ms_context_ptr
->
get_param
<
bool
>
(
MS_CTX_IS_PYNATIVE_GE_INIT
))
{
return
true
;
}
if
(
ms_context_ptr
->
ge
_ref
(
))
{
ms_context_ptr
->
set_ge_ref
(
"++"
);
if
(
ms_context_ptr
->
ge
t_param
<
uint32_t
>
(
MS_CTX_GE_REF
))
{
ms_context_ptr
->
increase_param
<
uint32_t
>
(
MS_CTX_GE_REF
);
return
true
;
}
...
...
@@ -293,8 +294,8 @@ bool InitGe(const std::shared_ptr<MsContext> &ms_context_ptr) {
MS_LOG
(
EXCEPTION
)
<<
"Initialize GE failed!"
;
}
}
ms_context_ptr
->
set_ge_ref
(
"++"
);
MS_LOG
(
INFO
)
<<
"Init ge successful, ge reference = "
<<
ms_context_ptr
->
ge
_ref
(
)
<<
"."
;
ms_context_ptr
->
increase_param
<
uint32_t
>
(
MS_CTX_GE_REF
);
MS_LOG
(
INFO
)
<<
"Init ge successful, ge reference = "
<<
ms_context_ptr
->
ge
t_param
<
uint32_t
>
(
MS_CTX_GE_REF
)
<<
"."
;
#endif
return
true
;
}
...
...
@@ -303,12 +304,13 @@ bool PynativeInitGe(const std::shared_ptr<MsContext> &ms_context_ptr) {
if
(
ms_context_ptr
==
nullptr
)
{
MS_LOG
(
EXCEPTION
)
<<
"nullptr"
;
}
if
(
ms_context_ptr
->
is_pynative_ge_init
()
||
ms_context_ptr
->
ge_ref
()
||
ms_context_ptr
->
tsd_ref
())
{
if
(
ms_context_ptr
->
get_param
<
bool
>
(
MS_CTX_IS_PYNATIVE_GE_INIT
)
||
ms_context_ptr
->
get_param
<
uint32_t
>
(
MS_CTX_GE_REF
)
||
ms_context_ptr
->
get_param
<
uint32_t
>
(
MS_CTX_TSD_REF
))
{
return
true
;
}
(
void
)
OpenTsd
(
ms_context_ptr
);
(
void
)
InitGe
(
ms_context_ptr
);
ms_context_ptr
->
set_p
ynative_ge_init
(
true
);
ms_context_ptr
->
set_p
aram
(
MS_CTX_IS_PYNATIVE_GE_INIT
,
true
);
return
true
;
}
...
...
@@ -317,12 +319,12 @@ bool FinalizeGe(const std::shared_ptr<MsContext> &ms_context_ptr, bool force) {
MS_LOG
(
EXCEPTION
)
<<
"nullptr"
;
}
#ifdef ENABLE_GE
if
(
ms_context_ptr
->
ge
_ref
(
)
==
0
)
{
if
(
ms_context_ptr
->
ge
t_param
<
uint32_t
>
(
MS_CTX_GE_REF
)
==
0
)
{
return
true
;
}
ms_context_ptr
->
set_ge_ref
(
"--"
);
if
(
force
||
ms_context_ptr
->
ge
_ref
(
)
==
0
)
{
ms_context_ptr
->
set_
ge_ref
(
" "
);
ms_context_ptr
->
decrease_param
<
uint32_t
>
(
MS_CTX_GE_REF
);
if
(
force
||
ms_context_ptr
->
ge
t_param
<
uint32_t
>
(
MS_CTX_GE_REF
)
==
0
)
{
ms_context_ptr
->
set_
param
<
uint32_t
>
(
MS_CTX_GE_REF
,
0
);
try
{
DfGraphManager
::
GetInstance
().
DeleteGraphRunner
();
DfGraphManager
::
GetInstance
().
DeleteGeSession
();
...
...
@@ -337,7 +339,8 @@ bool FinalizeGe(const std::shared_ptr<MsContext> &ms_context_ptr, bool force) {
}
ms_context_ptr
->
set_pynative_ge_init
(
false
);
}
else
{
MS_LOG
(
INFO
)
<<
"Ge is used, no need to finalize, tsd reference = "
<<
ms_context_ptr
->
ge_ref
()
<<
"."
;
MS_LOG
(
INFO
)
<<
"Ge is used, no need to finalize, tsd reference = "
<<
ms_context_ptr
->
get_param
<
uint32_t
>
(
MS_CTX_GE_REF
)
<<
"."
;
}
#endif
return
true
;
...
...
@@ -347,14 +350,14 @@ bool IsTsdOpened(const std::shared_ptr<MsContext> &ms_context_ptr) {
if
(
ms_context_ptr
==
nullptr
)
{
MS_LOG
(
EXCEPTION
)
<<
"nullptr"
;
}
return
ms_context_ptr
->
IsTsdOpened
()
;
return
ms_context_ptr
->
get_param
<
uint32_t
>
(
MS_CTX_TSD_REF
)
>
0
;
}
bool
IsGeInited
(
const
std
::
shared_ptr
<
MsContext
>
&
ms_context_ptr
)
{
if
(
ms_context_ptr
==
nullptr
)
{
MS_LOG
(
EXCEPTION
)
<<
"nullptr"
;
}
return
ms_context_ptr
->
IsGeInited
()
;
return
ms_context_ptr
->
get_param
<
uint32_t
>
(
MS_CTX_GE_REF
)
>
0
;
}
// Register for device type.
...
...
mindspore/ccsrc/utils/convert_utils_py.cc
浏览文件 @
8d419314
...
...
@@ -353,7 +353,7 @@ AbstractBasePtr PyListDtype2AbstractTensor(const py::object &shape_obj, const py
// When sparse enabled, the undetermined might be raised and eliminated in opt passes
auto
context
=
MsContext
::
GetInstance
();
MS_EXCEPTION_IF_NULL
(
context
);
bool
enable_sparse
=
context
->
enable_sparse
(
);
bool
enable_sparse
=
context
->
get_param
<
bool
>
(
MS_CTX_ENABLE_SPARSE
);
if
(
enable_sparse
)
{
return
std
::
make_shared
<
abstract
::
AbstractUndetermined
>
();
}
...
...
mindspore/ccsrc/utils/tensorprint_utils.cc
浏览文件 @
8d419314
...
...
@@ -273,7 +273,7 @@ void TensorPrint::operator()() {
prntpb
::
Print
print
;
auto
ms_context
=
MsContext
::
GetInstance
();
MS_EXCEPTION_IF_NULL
(
ms_context
);
std
::
string
print_file_path
=
ms_context
->
print_file_path
(
);
std
::
string
print_file_path
=
ms_context
->
get_param
<
std
::
string
>
(
MS_CTX_PRINT_FILE_PATH
);
if
(
print_file_path
==
""
)
{
while
(
true
)
{
std
::
vector
<
tdt
::
DataItem
>
bundle
;
...
...
mindspore/ccsrc/vm/backend.cc
浏览文件 @
8d419314
...
...
@@ -59,7 +59,7 @@ LinConvertResult MsBackend::MsConvert(const AnfNodePtrList &lst, const std::stri
graph_id
=
target_sess_
->
CompileGraphAsync
(
lst
,
outputs
);
}
if
(
MsContext
::
GetInstance
()
->
precompile_only
(
))
{
if
(
MsContext
::
GetInstance
()
->
get_param
<
bool
>
(
MS_CTX_PRECOMPILE_ONLY
))
{
MS_LOG
(
INFO
)
<<
"PrecompileOnly, stop run graph"
;
return
result
;
}
...
...
@@ -180,7 +180,7 @@ void MsBackend::CreateOtherSession(const std::string &target) {
}
auto
context_ptr
=
MsContext
::
GetInstance
();
MS_EXCEPTION_IF_NULL
(
context_ptr
);
uint32_t
device_id
=
context_ptr
->
device_id
(
);
uint32_t
device_id
=
context_ptr
->
get_param
<
uint32_t
>
(
MS_CTX_DEVICE_ID
);
other_sess_
->
Init
(
device_id
);
other_sess_
->
RegisterSummaryCallBackFunc
(
callbacks
::
SummarySaveCallback
);
other_device_
=
target
;
...
...
mindspore/ccsrc/vm/transform.cc
浏览文件 @
8d419314
...
...
@@ -56,7 +56,7 @@ namespace {
bool
ContainMultiTarget
(
const
std
::
vector
<
AnfNodePtr
>
&
nodes
)
{
auto
context_ptr
=
MsContext
::
GetInstance
();
MS_EXCEPTION_IF_NULL
(
context_ptr
);
std
::
string
last_target
=
context_ptr
->
device_target
(
);
std
::
string
last_target
=
context_ptr
->
get_param
<
std
::
string
>
(
MS_CTX_DEVICE_TARGET
);
for
(
auto
&
node
:
nodes
)
{
if
(
node
->
isa
<
CNode
>
())
{
std
::
string
cur_target
=
GetCNodeTarget
(
node
);
...
...
@@ -348,7 +348,7 @@ bool CompileGraph::IsCut(const AnfNodePtr &node) {
if
(
prim
->
name
()
==
prim
::
kPrimBpropCut
->
name
())
{
auto
ms_context
=
MsContext
::
GetInstance
();
MS_EXCEPTION_IF_NULL
(
ms_context
);
ms_context
->
set_
enable_pynative_hook
(
true
);
ms_context
->
set_
param
<
bool
>
(
MS_CTX_ENABLE_PYNATIVE_HOOK
,
true
);
}
if
(
backend_
->
name
()
==
kMsConvert
&&
prim
->
name
()
==
prim
::
kPrimMakeTuple
->
name
())
{
...
...
@@ -412,7 +412,7 @@ VectorRef CompileGraph::SplitNodes(const FuncGraphPtr &graph) {
if
(
ContainMultiTarget
(
nodes
))
{
auto
context_ptr
=
MsContext
::
GetInstance
();
MS_EXCEPTION_IF_NULL
(
context_ptr
);
std
::
string
default_target
=
context_ptr
->
device_target
(
);
std
::
string
default_target
=
context_ptr
->
get_param
<
std
::
string
>
(
MS_CTX_DEVICE_TARGET
);
nodes
=
SplitSort
(
graph
,
default_target
);
return
SplitNodesWithTarget
(
nodes
,
graph
);
}
...
...
@@ -920,17 +920,17 @@ BackendPtr CreateBackend() {
}
if
(
name
==
kMsConvert
)
{
std
::
string
target
=
context_ptr
->
device_target
(
);
uint32_t
device_id
=
context_ptr
->
device_id
(
);
std
::
string
target
=
context_ptr
->
get_param
<
std
::
string
>
(
MS_CTX_DEVICE_TARGET
);
uint32_t
device_id
=
context_ptr
->
get_param
<
uint32_t
>
(
MS_CTX_DEVICE_ID
);
auto
backend
=
std
::
make_shared
<
MsBackend
>
(
name
,
target
,
device_id
);
std
::
string
device_target
=
MsContext
::
GetInstance
()
->
device_target
(
);
std
::
string
device_target
=
MsContext
::
GetInstance
()
->
get_param
<
std
::
string
>
(
MS_CTX_DEVICE_TARGET
);
if
(
device_target
==
kAscendDevice
)
{
if
(
MsContext
::
GetInstance
()
->
execution_mode
(
)
==
kPynativeMode
)
{
if
(
MsContext
::
GetInstance
()
->
get_param
<
int
>
(
MS_CTX_EXECUTION_MODE
)
==
kPynativeMode
)
{
backend
->
set_is_multi_graph_sink
(
false
);
context_ptr
->
set_
is_multi_graph_sink
(
false
);
context_ptr
->
set_
param
<
bool
>
(
MS_CTX_IS_MULTI_GRAPH_SINK
,
false
);
}
else
{
backend
->
set_is_multi_graph_sink
(
true
);
context_ptr
->
set_
is_multi_graph_sink
(
true
);
context_ptr
->
set_
param
<
bool
>
(
MS_CTX_IS_MULTI_GRAPH_SINK
,
true
);
}
}
return
backend
;
...
...
mindspore/context.py
浏览文件 @
8d419314
...
...
@@ -22,7 +22,7 @@ import threading
from
collections
import
namedtuple
from
types
import
FunctionType
from
mindspore
import
log
as
logger
from
mindspore._c_expression
import
MSContext
from
mindspore._c_expression
import
MSContext
,
ms_ctx_param
,
ms_ctx_get_param
,
ms_ctx_set_param
from
mindspore._checkparam
import
args_type_check
from
mindspore.parallel._auto_parallel_context
import
_set_auto_parallel_context
,
_get_auto_parallel_context
,
\
_reset_auto_parallel_context
...
...
@@ -157,9 +157,15 @@ class _Context:
raise
ValueError
(
"Context handle is none in context!!!"
)
return
value
def
get_param
(
self
,
param
):
return
ms_ctx_get_param
(
self
.
_context_handle
,
param
)
def
set_param
(
self
,
param
,
value
):
ms_ctx_set_param
(
self
.
_context_handle
,
param
,
value
)
@
property
def
mode
(
self
):
return
self
.
_context_handle
.
get_execution_mode
(
)
return
self
.
get_param
(
ms_ctx_param
.
execution_mode
)
@
mode
.
setter
def
mode
(
self
,
mode
):
...
...
@@ -169,15 +175,17 @@ class _Context:
Args:
mode (int): GRAPH_MODE or PYNATIVE_MODE.
"""
self
.
_context_handle
.
set_execution_mode
(
mode
)
if
mode
==
PYNATIVE_MODE
:
if
self
.
enable_debug_runtime
:
self
.
set_backend_policy
(
"vm"
)
self
.
_context_switches
.
push
(
True
,
None
)
el
se
:
el
if
mode
==
GRAPH_MODE
:
if
self
.
enable_debug_runtime
:
self
.
set_backend_policy
(
"ge"
)
self
.
_context_switches
.
push
(
False
,
None
)
else
:
raise
ValueError
(
f
'The execution mode
{
mode
}
is invalid!'
)
self
.
set_param
(
ms_ctx_param
.
execution_mode
,
mode
)
def
set_backend_policy
(
self
,
policy
):
success
=
self
.
_context_handle
.
set_backend_policy
(
policy
)
...
...
@@ -186,110 +194,106 @@ class _Context:
@
property
def
precompile_only
(
self
):
return
self
.
_context_handle
.
get_precompile_only
(
)
return
self
.
get_param
(
ms_ctx_param
.
precompile_only
)
@
precompile_only
.
setter
def
precompile_only
(
self
,
precompile_only
):
self
.
_context_handle
.
set_precompile_only
(
precompile_only
)
self
.
set_param
(
ms_ctx_param
.
precompile_only
,
precompile_only
)
@
property
def
save_graphs
(
self
):
return
self
.
_context_handle
.
get_save_graphs_flag
(
)
return
self
.
get_param
(
ms_ctx_param
.
save_graphs_flag
)
@
save_graphs
.
setter
def
save_graphs
(
self
,
save_graphs_flag
):
self
.
_context_handle
.
set_save_graphs_flag
(
save_graphs_flag
)
self
.
set_param
(
ms_ctx_param
.
save_graphs_flag
,
save_graphs_flag
)
@
property
def
save_graphs_path
(
self
):
return
self
.
_context_handle
.
get_save_graphs_path
(
)
return
self
.
get_param
(
ms_ctx_param
.
save_graphs_path
)
@
save_graphs_path
.
setter
def
save_graphs_path
(
self
,
save_graphs_path
):
self
.
_context_handle
.
set_save_graphs_path
(
_make_directory
(
save_graphs_path
))
self
.
set_param
(
ms_ctx_param
.
save_graphs_path
,
_make_directory
(
save_graphs_path
))
@
property
def
device_target
(
self
):
return
self
.
_context_handle
.
get_device_target
(
)
return
self
.
get_param
(
ms_ctx_param
.
device_target
)
@
device_target
.
setter
def
device_target
(
self
,
target
):
success
=
self
.
_context_handle
.
set_device_target
(
target
)
if
not
success
:
raise
ValueError
(
"Target device name is invalid!!!"
)
if
self
.
enable_debug_runtime
and
self
.
device_target
==
"CPU"
:
valid_targets
=
[
"CPU"
,
"GPU"
,
"Ascend"
,
"Davinci"
]
if
not
target
in
valid_targets
:
raise
ValueError
(
f
"Target device name
{
target
}
is invalid! It must be one of
{
valid_targets
}
"
)
if
target
==
"Davinci"
:
target
=
"Ascend"
self
.
set_param
(
ms_ctx_param
.
device_target
,
target
)
if
self
.
enable_debug_runtime
and
target
==
"CPU"
:
self
.
set_backend_policy
(
"vm"
)
@
property
def
device_id
(
self
):
return
self
.
_context_handle
.
get_device_id
(
)
return
self
.
get_param
(
ms_ctx_param
.
device_id
)
@
device_id
.
setter
def
device_id
(
self
,
device_id
):
if
device_id
<
0
or
device_id
>
4095
:
raise
ValueError
(
"Device id must be in [0, 4095], but got {}"
.
format
(
device_id
))
success
=
self
.
_context_handle
.
set_device_id
(
device_id
)
if
not
success
:
raise
RuntimeError
(
"Device id set failed!!!"
)
raise
ValueError
(
f
"Device id must be in [0, 4095], but got
{
device_id
}
"
)
self
.
set_param
(
ms_ctx_param
.
device_id
,
device_id
)
@
property
def
max_call_depth
(
self
):
return
self
.
_context_handle
.
get_max_call_depth
(
)
return
self
.
get_param
(
ms_ctx_param
.
max_call_depth
)
@
max_call_depth
.
setter
def
max_call_depth
(
self
,
max_call_depth
):
if
max_call_depth
<=
0
:
raise
ValueError
(
"Max call depth must be greater than 0, but got {}"
.
format
(
max_call_depth
))
self
.
_context_handle
.
set_max_call_depth
(
max_call_depth
)
raise
ValueError
(
f
"Max call depth must be greater than 0, but got
{
max_call_depth
}
"
)
self
.
set_param
(
ms_ctx_param
.
max_call_depth
,
max_call_depth
)
@
property
def
enable_auto_mixed_precision
(
self
):
return
self
.
_context_handle
.
get_auto_mixed_precision_flag
(
)
return
self
.
get_param
(
ms_ctx_param
.
auto_mixed_precision_flag
)
@
enable_auto_mixed_precision
.
setter
def
enable_auto_mixed_precision
(
self
,
enable_auto_mixed_precision
):
self
.
_context_handle
.
set_auto_mixed_precision_flag
(
enable_auto_mixed_precision
)
self
.
set_param
(
ms_ctx_param
.
auto_mixed_precision_flag
,
enable_auto_mixed_precision
)
@
property
def
enable_reduce_precision
(
self
):
return
self
.
_context_handle
.
get_enable_reduce_precision_flag
(
)
return
self
.
get_param
(
ms_ctx_param
.
enable_reduce_precision_flag
)
@
enable_reduce_precision
.
setter
def
enable_reduce_precision
(
self
,
enable_reduce_precision
):
self
.
_context_handle
.
set_enable_reduce_precision_flag
(
enable_reduce_precision
)
self
.
set_param
(
ms_ctx_param
.
enable_reduce_precision_flag
,
enable_reduce_precision
)
@
property
def
enable_dump
(
self
):
return
self
.
_context_handle
.
get_enable_dump
(
)
return
self
.
get_param
(
ms_ctx_param
.
enable_dump
)
@
enable_dump
.
setter
def
enable_dump
(
self
,
enable_dump
):
self
.
_context_handle
.
set_enable_dump
(
enable_dump
)
self
.
set_param
(
ms_ctx_param
.
enable_dump
,
enable_dump
)
@
property
def
save_dump_path
(
self
):
return
self
.
_context_handle
.
get_save_dump_path
(
)
return
self
.
get_param
(
ms_ctx_param
.
save_dump_path
)
@
save_dump_path
.
setter
def
save_dump_path
(
self
,
save_dump_path
):
self
.
_context_handle
.
set_save_dump_path
(
save_dump_path
)
self
.
set_param
(
ms_ctx_param
.
save_dump_path
,
save_dump_path
)
@
property
def
enable_profiling
(
self
):
return
self
.
_context_handle
.
get_enable_profiling
(
)
return
self
.
get_param
(
ms_ctx_param
.
enable_profiling
)
@
enable_profiling
.
setter
def
enable_profiling
(
self
,
flag
):
self
.
_context_handle
.
set_enable_profiling
(
flag
)
self
.
set_param
(
ms_ctx_param
.
enable_profiling
,
flag
)
@
property
def
profiling_options
(
self
):
return
self
.
_context_handle
.
get_profiling_options
(
)
return
self
.
get_param
(
ms_ctx_param
.
profiling_options
)
@
profiling_options
.
setter
def
profiling_options
(
self
,
option
):
...
...
@@ -298,15 +302,15 @@ class _Context:
if
option
not
in
options
:
raise
ValueError
(
"Profiling options must be in 'training_trace' 'task_trace' "
"'task_trace:training_trace' 'training_trace:task_trace' or 'op_trace'."
)
self
.
_context_handle
.
set_profiling_options
(
option
)
self
.
set_param
(
ms_ctx_param
.
profiling_options
,
option
)
@
property
def
enable_graph_kernel
(
self
):
return
self
.
_context_handle
.
get_enable_graph_kernel
(
)
return
self
.
get_param
(
ms_ctx_param
.
enable_graph_kernel
)
@
enable_graph_kernel
.
setter
def
enable_graph_kernel
(
self
,
graph_kernel_switch_
):
self
.
_context_handle
.
set_enable_graph_kernel
(
graph_kernel_switch_
)
self
.
set_param
(
ms_ctx_param
.
enable_graph_kernel
,
graph_kernel_switch_
)
@
property
def
reserve_class_name_in_scope
(
self
):
...
...
@@ -325,20 +329,14 @@ class _Context:
@
variable_memory_max_size
.
setter
def
variable_memory_max_size
(
self
,
variable_memory_max_size
):
if
not
check_input_format
(
variable_memory_max_size
):
raise
ValueError
(
"Context param variable_memory_max_size should be in correct format! Such as
\"
5GB
\"
"
)
raise
ValueError
(
"Context param variable_memory_max_size should be in correct format! Such as
\"
5GB
\"
"
)
if
int
(
variable_memory_max_size
[:
-
2
])
>=
_DEVICE_APP_MEMORY_SIZE
:
raise
ValueError
(
"Context param variable_memory_max_size should be less than 31GB."
)
variable_memory_max_size_
=
variable_memory_max_size
[:
-
2
]
+
" * 1024 * 1024 * 1024"
graph_memory_max_size
=
_DEVICE_APP_MEMORY_SIZE
-
\
int
(
variable_memory_max_size
[:
-
2
])
graph_memory_max_size_
=
str
(
graph_memory_max_size
)
+
" * 1024 * 1024 * 1024"
self
.
_context_handle
.
set_variable_memory_max_size
(
variable_memory_max_size_
)
self
.
_context_handle
.
set_graph_memory_max_size
(
graph_memory_max_size_
)
raise
ValueError
(
"Context param variable_memory_max_size should be less than 31GB."
)
variable_memory_max_size_
=
variable_memory_max_size
[:
-
2
]
+
" * 1024 * 1024 * 1024"
graph_memory_max_size
=
_DEVICE_APP_MEMORY_SIZE
-
int
(
variable_memory_max_size
[:
-
2
])
graph_memory_max_size_
=
str
(
graph_memory_max_size
)
+
" * 1024 * 1024 * 1024"
self
.
set_param
(
ms_ctx_param
.
variable_memory_max_size
,
variable_memory_max_size_
)
self
.
set_param
(
ms_ctx_param
.
graph_memory_max_size
,
graph_memory_max_size_
)
@
property
def
enable_ge
(
self
):
...
...
@@ -355,15 +353,15 @@ class _Context:
@
property
def
check_bprop
(
self
):
return
self
.
_context_handle
.
get_check_bprop_flag
(
)
return
self
.
get_param
(
ms_ctx_param
.
check_bprop_flag
)
@
check_bprop
.
setter
def
check_bprop
(
self
,
check_bprop_flag
):
self
.
_context_handle
.
set_check_bprop_flag
(
check_bprop_flag
)
self
.
set_param
(
ms_ctx_param
.
check_bprop_flag
,
check_bprop_flag
)
@
property
def
max_device_memory
(
self
):
return
self
.
_context_handle
.
get_max_device_memory
(
)
return
self
.
get_param
(
ms_ctx_param
.
max_device_memory
)
@
max_device_memory
.
setter
def
max_device_memory
(
self
,
max_device_memory
):
...
...
@@ -372,7 +370,7 @@ class _Context:
max_device_memory_value
=
float
(
max_device_memory
[:
-
2
])
if
max_device_memory_value
==
0
:
raise
ValueError
(
"Context param max_device_memory should be in correct format! Such as
\"
3.5GB
\"
"
)
self
.
_context_handle
.
set_max_device_memory
(
max_device_memory_value
)
self
.
set_param
(
ms_ctx_param
.
max_device_memory
,
max_device_memory_value
)
@
property
def
print_file_path
(
self
):
...
...
@@ -392,15 +390,15 @@ class _Context:
full_file_name
=
os
.
path
.
join
(
path
,
file_name
)
else
:
full_file_name
=
print_file_path
self
.
_context_handle
.
set_print_file_path
(
full_file_name
)
self
.
set_param
(
ms_ctx_param
.
print_file_path
,
full_file_name
)
@
property
def
enable_sparse
(
self
):
return
self
.
_context_handle
.
get_enable_sparse
(
)
return
self
.
get_param
(
ms_ctx_param
.
enable_sparse
)
@
enable_sparse
.
setter
def
enable_sparse
(
self
,
enable_sparse
):
self
.
_context_handle
.
set_enable_sparse
(
enable_sparse
)
self
.
set_param
(
ms_ctx_param
.
enable_sparse
,
enable_sparse
)
def
check_input_format
(
x
):
import
re
...
...
@@ -486,8 +484,6 @@ def set_auto_parallel_context(**kwargs):
full_batch (bool): Whether to load the whole batch on each device. Default: False.
enable_parallel_optimizer(bool): This is a developing feature, which shards the weight update computation in
data parallel training in the benefit of time and memory saving.
max_call_depth(int): Specify the function call depth limit. Default: 1000.
Raises:
ValueError: If input key is not attribute in auto parallel context.
...
...
@@ -501,7 +497,6 @@ def set_auto_parallel_context(**kwargs):
>>> context.set_auto_parallel_context(parameter_broadcast=False)
>>> context.set_auto_parallel_context(strategy_ckpt_load_file="./strategy_stage1.ckpt")
>>> context.set_auto_parallel_context(strategy_ckpt_save_file="./strategy_stage1.ckpt")
>>> context.set_auto_parallel_context(max_call_depth=80)
"""
_set_auto_parallel_context
(
**
kwargs
)
...
...
@@ -603,6 +598,7 @@ def set_context(**kwargs):
a file by default, and turn off printing to the screen. If the file already exists, add a timestamp
suffix to the file.
enable_sparse (bool): Whether to enable sparsity feature. Default: False.
max_call_depth(int): Specify the function call depth limit. Default: 1000.
Raises:
ValueError: If input key is not an attribute in context.
...
...
@@ -623,6 +619,7 @@ def set_context(**kwargs):
>>> context.set_context(enable_profiling=True, profiling_options="training_trace")
>>> context.set_context(max_device_memory="3.5GB")
>>> context.set_context(print_file_path="print.pb")
>>> context.set_context(max_call_depth=80)
"""
for
key
,
value
in
kwargs
.
items
():
if
not
hasattr
(
_context
(),
key
):
...
...
mindspore/core/abstract/prim_others.cc
浏览文件 @
8d419314
...
...
@@ -51,7 +51,7 @@ AbstractBasePtr InferImplEnvGetItem(const AnalysisEnginePtr &, const PrimitivePt
auto
context
=
MsContext
::
GetInstance
();
MS_EXCEPTION_IF_NULL
(
context
);
bool
enable_sparse
=
context
->
enable_sparse
(
);
bool
enable_sparse
=
context
->
get_param
<
bool
>
(
MS_CTX_ENABLE_SPARSE
);
if
(
enable_sparse
&&
dflt
->
isa
<
AbstractTensor
>
())
{
auto
dflt_tensor
=
dflt
->
cast
<
AbstractTensorPtr
>
();
return
std
::
make_shared
<
AbstractUndetermined
>
(
dflt_tensor
->
element
()
->
Clone
(),
dflt_tensor
->
shape
()
->
Clone
());
...
...
mindspore/core/ir/anf.cc
浏览文件 @
8d419314
...
...
@@ -232,7 +232,7 @@ std::string GetMaketupleNodeTarget(const CNodePtr &cnode) {
auto
context_ptr
=
MsContext
::
GetInstance
();
MS_EXCEPTION_IF_NULL
(
context_ptr
);
std
::
string
default_target
=
context_ptr
->
device_target
(
);
std
::
string
default_target
=
context_ptr
->
get_param
<
std
::
string
>
(
MS_CTX_DEVICE_TARGET
);
return
default_target
;
}
...
...
@@ -248,7 +248,7 @@ std::string GetTupleGetItemTarget(const CNodePtr &cnode, const PrimitivePtr &pri
std
::
string
GetCNodeTarget
(
const
AnfNodePtr
&
node
)
{
auto
context_ptr
=
MsContext
::
GetInstance
();
MS_EXCEPTION_IF_NULL
(
context_ptr
);
std
::
string
default_target
=
context_ptr
->
device_target
(
);
std
::
string
default_target
=
context_ptr
->
get_param
<
std
::
string
>
(
MS_CTX_DEVICE_TARGET
);
if
(
!
node
->
isa
<
CNode
>
())
{
return
default_target
;
}
...
...
mindspore/core/ir/func_graph_cloner.cc
浏览文件 @
8d419314
...
...
@@ -652,7 +652,7 @@ FuncGraphPtr TransformableClone(const FuncGraphPtr &func_graph, const TraceInfoP
new_func_graph
->
set_param_default_value
(
item
.
first
,
cloner
[
item
.
second
]);
}
if
(
MsContext
::
GetInstance
()
->
is_multi_graph_sink
(
))
{
if
(
MsContext
::
GetInstance
()
->
get_param
<
bool
>
(
MS_CTX_IS_MULTI_GRAPH_SINK
))
{
if
(
func_graph
->
has_flag
(
FUNC_GRAPH_FLAG_IGNORE_VALUES
))
{
new_func_graph
->
set_flag
(
FUNC_GRAPH_FLAG_IGNORE_VALUES
,
true
);
}
...
...
mindspore/core/ir/meta_func_graph.cc
浏览文件 @
8d419314
...
...
@@ -30,7 +30,7 @@ abstract::AbstractBasePtr MetaFuncGraph::ToAbstract() {
FuncGraphPtr
MetaFuncGraph
::
GenerateStubFunc
(
const
TypePtrList
&
types
)
{
auto
context
=
MsContext
::
GetInstance
();
MS_EXCEPTION_IF_NULL
(
context
);
bool
enable_sparse
=
context
->
enable_sparse
(
);
bool
enable_sparse
=
context
->
get_param
<
bool
>
(
MS_CTX_ENABLE_SPARSE
);
if
(
!
enable_sparse
)
{
return
nullptr
;
}
...
...
mindspore/core/utils/ms_context.cc
浏览文件 @
8d419314
...
...
@@ -32,49 +32,50 @@ std::map<std::string, MsBackendPolicy> MsContext::policy_map_ = {{"ge", kMsBacke
{
"vm_prior"
,
kMsBackendVmPrior
}};
MsContext
::
MsContext
(
const
std
::
string
&
policy
,
const
std
::
string
&
target
)
{
save_graphs_flag_
=
false
;
save_graphs_path_
=
"."
;
enable_dump_
=
false
;
save_dump_path_
=
"."
;
tsd_ref_
=
0
;
ge_ref_
=
0
;
is_multi_graph_sink_
=
false
;
is_pynative_ge_init_
=
false
;
enable_reduce_precision_
=
true
;
set_param
<
bool
>
(
MS_CTX_SAVE_GRAPHS_FLAG
,
false
);
set_param
<
std
::
string
>
(
MS_CTX_SAVE_GRAPHS_PATH
,
"."
);
set_param
<
std
::
string
>
(
MS_CTX_SAVE_DUMP_PATH
,
"."
);
set_param
<
uint32_t
>
(
MS_CTX_TSD_REF
,
0
);
set_param
<
uint32_t
>
(
MS_CTX_GE_REF
,
0
);
set_param
<
bool
>
(
MS_CTX_IS_MULTI_GRAPH_SINK
,
false
);
set_param
<
bool
>
(
MS_CTX_IS_PYNATIVE_GE_INIT
,
false
);
set_param
<
bool
>
(
MS_CTX_ENABLE_REDUCE_PRECISION
,
true
);
auto
env_device
=
common
::
GetEnv
(
"DEVICE_ID"
);
if
(
!
env_device
.
empty
())
{
device_id_
=
UlongToUint
(
std
::
stoul
(
env_device
.
c_str
()));
uint32_t
device_id
=
UlongToUint
(
std
::
stoul
(
env_device
.
c_str
()));
set_param
<
uint32_t
>
(
MS_CTX_DEVICE_ID
,
device_id
);
}
else
{
device_id_
=
0
;
set_param
<
uint32_t
>
(
MS_CTX_DEVICE_ID
,
0
)
;
}
max_call_depth_
=
MAX_CALL_DEPTH_DEFAULT
;
backend_policy_
=
policy_map_
[
policy
];
device_target_
=
target
;
execution_mode_
=
kPynativeMode
;
enable_task_sink_
=
true
;
ir_fusion_flag_
=
true
;
enable_hccl_
=
false
;
set_param
<
uint32_t
>
(
MS_CTX_MAX_CALL_DEPTH
,
MAX_CALL_DEPTH_DEFAULT
);
set_param
<
std
::
string
>
(
MS_CTX_DEVICE_TARGET
,
target
);
set_param
<
int
>
(
MS_CTX_EXECUTION_MODE
,
kPynativeMode
);
set_param
<
bool
>
(
MS_CTX_ENABLE_TASK_SINK
,
true
);
set_param
<
bool
>
(
MS_CTX_IR_FUSION_FLAG
,
true
);
set_param
<
bool
>
(
MS_CTX_ENABLE_HCCL
,
false
);
#ifdef ENABLE_DEBUGGER
enable_mem_reuse_
=
false
;
set_param
<
bool
>
(
MS_CTX_ENABLE_MEM_REUSE
,
false
)
;
#else
enable_mem_reuse_
=
true
;
set_param
<
bool
>
(
MS_CTX_ENABLE_MEM_REUSE
,
true
)
;
#endif
enable_gpu_summary_
=
true
;
precompile_only_
=
false
;
auto_mixed_precision_flag_
=
false
;
enable_pynative_infer_
=
false
;
enable_pynative_hook_
=
false
;
enable_dynamic_mem_pool_
=
true
;
graph_memory_max_size_
=
"0"
;
variable_memory_max_size_
=
"0"
;
enable_loop_sink_
=
target
==
kAscendDevice
||
target
==
kDavinciDevice
;
profiling_mode_
=
false
;
profiling_options_
=
"training_trace"
;
check_bprop_flag_
=
false
;
max_device_memory_
=
kDefaultMaxDeviceMemory
;
print_file_path_
=
""
;
enable_graph_kernel_
=
false
;
enable_sparse_
=
false
;
set_param
<
bool
>
(
MS_CTX_ENABLE_GPU_SUMMARY
,
true
);
set_param
<
bool
>
(
MS_CTX_PRECOMPILE_ONLY
,
false
);
set_param
<
bool
>
(
MS_CTX_AUTO_MIXED_PRECISION_FLAG
,
false
);
set_param
<
bool
>
(
MS_CTX_ENABLE_PYNATIVE_INFER
,
false
);
set_param
<
bool
>
(
MS_CTX_ENABLE_PYNATIVE_HOOK
,
false
);
set_param
<
bool
>
(
MS_CTX_ENABLE_DYNAMIC_MEM_POOL
,
true
);
set_param
<
std
::
string
>
(
MS_CTX_GRAPH_MEMORY_MAX_SIZE
,
"0"
);
set_param
<
std
::
string
>
(
MS_CTX_VARIABLE_MEMORY_MAX_SIZE
,
"0"
);
set_param
<
bool
>
(
MS_CTX_ENABLE_LOOP_SINK
,
target
==
kAscendDevice
||
target
==
kDavinciDevice
);
set_param
<
bool
>
(
MS_CTX_ENABLE_PROFILING
,
false
);
set_param
<
std
::
string
>
(
MS_CTX_PROFILING_OPTIONS
,
"training_trace"
);
set_param
<
bool
>
(
MS_CTX_CHECK_BPROP_FLAG
,
false
);
set_param
<
float
>
(
MS_CTX_MAX_DEVICE_MEMORY
,
kDefaultMaxDeviceMemory
);
set_param
<
std
::
string
>
(
MS_CTX_PRINT_FILE_PATH
,
""
);
set_param
<
bool
>
(
MS_CTX_ENABLE_GRAPH_KERNEL
,
false
);
set_param
<
bool
>
(
MS_CTX_ENABLE_SPARSE
,
false
);
backend_policy_
=
policy_map_
[
policy
];
}
std
::
shared_ptr
<
MsContext
>
MsContext
::
GetInstance
()
{
...
...
@@ -106,54 +107,4 @@ std::string MsContext::backend_policy() const {
}
return
"unknown"
;
}
void
MsContext
::
set_execution_mode
(
int
execution_mode
)
{
if
(
execution_mode
!=
kGraphMode
&&
execution_mode
!=
kPynativeMode
)
{
MS_LOG
(
EXCEPTION
)
<<
"The execution mode is invalid!"
;
}
execution_mode_
=
execution_mode
;
}
bool
MsContext
::
set_device_target
(
const
std
::
string
&
target
)
{
if
(
kTargetSet
.
find
(
target
)
==
kTargetSet
.
end
())
{
MS_LOG
(
ERROR
)
<<
"invalid device target name: "
<<
target
;
return
false
;
}
if
(
target
==
kDavinciDevice
)
{
device_target_
=
kAscendDevice
;
}
else
{
device_target_
=
target
;
}
if
(
seter_
)
{
seter_
(
device_target_
);
}
MS_LOG
(
INFO
)
<<
"ms set context device target:"
<<
target
;
return
true
;
}
bool
MsContext
::
set_device_id
(
uint32_t
device_id
)
{
device_id_
=
device_id
;
MS_LOG
(
INFO
)
<<
"ms set context device id:"
<<
device_id
;
return
true
;
}
void
MsContext
::
set_tsd_ref
(
const
std
::
string
&
op
)
{
if
(
op
==
"--"
)
{
tsd_ref_
--
;
}
else
if
(
op
==
"++"
)
{
tsd_ref_
++
;
}
else
{
tsd_ref_
=
0
;
}
}
void
MsContext
::
set_ge_ref
(
const
std
::
string
&
op
)
{
if
(
op
==
"--"
)
{
ge_ref_
--
;
}
else
if
(
op
==
"++"
)
{
ge_ref_
++
;
}
else
{
ge_ref_
=
0
;
}
}
}
// namespace mindspore
mindspore/core/utils/ms_context.h
浏览文件 @
8d419314
...
...
@@ -49,6 +49,69 @@ const std::set<std::string> kTargetSet = {kCPUDevice, kGPUDevice, kAscendDevice,
// The default max available device memory is 1024GB.
const
float
kDefaultMaxDeviceMemory
=
1024
;
// enum definition for MindSpore Context Parameter
enum
MsCtxParam
:
unsigned
{
// paramater of type bool
MS_CTX_TYPE_BOOL_BEGIN
,
MS_CTX_AUTO_MIXED_PRECISION_FLAG
=
MS_CTX_TYPE_BOOL_BEGIN
,
MS_CTX_CHECK_BPROP_FLAG
,
MS_CTX_ENABLE_DUMP
,
MS_CTX_ENABLE_DYNAMIC_MEM_POOL
,
MS_CTX_ENABLE_GPU_SUMMARY
,
MS_CTX_ENABLE_GRAPH_KERNEL
,
MS_CTX_ENABLE_HCCL
,
MS_CTX_ENABLE_LOOP_SINK
,
MS_CTX_ENABLE_MEM_REUSE
,
MS_CTX_ENABLE_PYNATIVE_HOOK
,
MS_CTX_ENABLE_PYNATIVE_INFER
,
MS_CTX_ENABLE_REDUCE_PRECISION
,
MS_CTX_ENABLE_SPARSE
,
MS_CTX_ENABLE_TASK_SINK
,
MS_CTX_IR_FUSION_FLAG
,
MS_CTX_IS_MULTI_GRAPH_SINK
,
MS_CTX_IS_PYNATIVE_GE_INIT
,
MS_CTX_PRECOMPILE_ONLY
,
MS_CTX_ENABLE_PROFILING
,
MS_CTX_SAVE_GRAPHS_FLAG
,
MS_CTX_TYPE_BOOL_END
,
// paramater of type int
MS_CTX_TYPE_INT_BEGIN
=
MS_CTX_TYPE_BOOL_END
,
MS_CTX_EXECUTION_MODE
=
MS_CTX_TYPE_INT_BEGIN
,
MS_CTX_TYPE_INT_END
,
// paramater of type uint32
MS_CTX_TYPE_UINT32_BEGIN
=
MS_CTX_TYPE_INT_END
,
MS_CTX_DEVICE_ID
=
MS_CTX_TYPE_UINT32_BEGIN
,
MS_CTX_GE_REF
,
MS_CTX_MAX_CALL_DEPTH
,
MS_CTX_TSD_REF
,
MS_CTX_TYPE_UINT32_END
,
// paramater of type float
MS_CTX_TYPE_FLOAT_BEGIN
=
MS_CTX_TYPE_UINT32_END
,
MS_CTX_MAX_DEVICE_MEMORY
=
MS_CTX_TYPE_FLOAT_BEGIN
,
MS_CTX_TYPE_FLOAT_END
,
// paramater of type string
MS_CTX_TYPE_STRING_BEGIN
=
MS_CTX_TYPE_FLOAT_END
,
MS_CTX_DEVICE_TARGET
=
MS_CTX_TYPE_STRING_BEGIN
,
MS_CTX_GRAPH_MEMORY_MAX_SIZE
,
MS_CTX_PRINT_FILE_PATH
,
MS_CTX_PROFILING_OPTIONS
,
MS_CTX_SAVE_DUMP_PATH
,
MS_CTX_SAVE_GRAPHS_PATH
,
MS_CTX_VARIABLE_MEMORY_MAX_SIZE
,
MS_CTX_TYPE_STRING_END
,
// parameter numbers of each type
NUM_BOOL_PARAMS
=
MS_CTX_TYPE_BOOL_END
-
MS_CTX_TYPE_BOOL_BEGIN
,
NUM_INT_PARAMS
=
MS_CTX_TYPE_INT_END
-
MS_CTX_TYPE_INT_BEGIN
,
NUM_UINT32_PARAMS
=
MS_CTX_TYPE_UINT32_END
-
MS_CTX_TYPE_UINT32_BEGIN
,
NUM_FLOAT_PARAMS
=
MS_CTX_TYPE_FLOAT_END
-
MS_CTX_TYPE_FLOAT_BEGIN
,
NUM_STRING_PARAMS
=
MS_CTX_TYPE_STRING_END
-
MS_CTX_TYPE_STRING_BEGIN
};
class
MsContext
{
public:
MsContext
(
const
std
::
string
&
backend_policy
,
const
std
::
string
&
target
);
...
...
@@ -62,156 +125,113 @@ class MsContext {
std
::
string
backend_policy
()
const
;
bool
set_backend_policy
(
const
std
::
string
&
policy
);
int
execution_mode
()
const
{
return
execution_mode_
;
}
void
set_execution_mode
(
int
execution_mode
);
bool
enable_pynative_infer
()
const
{
return
enable_pynative_infer_
;
}
void
set_enable_pynative_infer
(
bool
enable_pynative_infer
)
{
enable_pynative_infer_
=
enable_pynative_infer
;
}
bool
enable_pynative_hook
()
const
{
return
enable_pynative_hook_
;
}
void
set_enable_pynative_hook
(
bool
enable_pynative_hook
)
{
enable_pynative_hook_
=
enable_pynative_hook
;
}
bool
enable_task_sink
()
const
{
return
enable_task_sink_
;
}
void
set_precompile_only
(
bool
precompile_only
)
{
precompile_only_
=
precompile_only
;
}
bool
precompile_only
()
const
{
return
precompile_only_
;
}
std
::
string
device_target
()
const
{
return
device_target_
;
}
bool
set_device_target
(
const
std
::
string
&
target
);
static
void
device_seter
(
DeviceSeter
device
)
{
seter_
=
device
;
}
static
void
device_type_seter
(
DeviceTypeSeter
device_type
)
{
device_type_seter_
=
device_type
;
}
uint32_t
device_id
()
const
{
return
device_id_
;
}
bool
set_device_id
(
uint32_t
device_id
);
std
::
thread
tdt_print_
;
// uint32_t max_call_depth_
uint32_t
max_call_depth
()
const
{
return
max_call_depth_
;
}
inline
bool
set_max_call_depth
(
uint32_t
max_call_depth
)
{
max_call_depth_
=
max_call_depth
;
return
true
;
template
<
typename
T
>
void
set_param
(
MsCtxParam
param
,
const
T
&
value
)
{
MS_LOG
(
EXCEPTION
)
<<
"Need implemet "
<<
__FUNCTION__
<<
" for type "
<<
typeid
(
T
).
name
()
<<
"."
;
}
bool
save_graphs_flag
()
const
{
return
save_graphs_flag_
;
}
void
set_save_graphs_flag
(
bool
save_graphs_flag
)
{
save_graphs_flag_
=
save_graphs_flag
;
}
std
::
string
save_graphs_path
()
const
{
return
save_graphs_path_
;
}
void
set_save_graphs_path
(
const
std
::
string
&
save_paths
)
{
save_graphs_path_
=
save_paths
;
}
bool
IsGeInited
()
{
return
ge_ref_
>
0
;
}
void
set_enable_hccl
(
bool
enable_hccl
)
{
enable_hccl_
=
enable_hccl
;
}
bool
enable_hccl
()
const
{
return
enable_hccl_
;
}
bool
ir_fusion_flag
()
const
{
return
ir_fusion_flag_
;
}
bool
loop_sink_flag
()
const
{
return
enable_loop_sink_
;
}
void
set_loop_sink_flag
(
bool
enable_loop_sink
)
{
enable_loop_sink_
=
enable_loop_sink
;
}
void
set_enable_mem_reuse
(
bool
enable_mem_reuse
)
{
enable_mem_reuse_
=
enable_mem_reuse
;
}
bool
enable_mem_reuse
()
const
{
return
enable_mem_reuse_
;
}
void
set_enable_gpu_summary
(
bool
enable_gpu_summary
)
{
enable_gpu_summary_
=
enable_gpu_summary
;
}
bool
enable_gpu_summary
()
const
{
return
enable_gpu_summary_
;
}
void
set_auto_mixed_precision_flag
(
bool
auto_mixed_precision_flag
)
{
auto_mixed_precision_flag_
=
auto_mixed_precision_flag
;
template
<
typename
T
>
const
T
&
get_param
(
MsCtxParam
param
)
const
{
MS_LOG
(
EXCEPTION
)
<<
"Need implemet "
<<
__FUNCTION__
<<
" for type "
<<
typeid
(
T
).
name
()
<<
"."
;
}
bool
auto_mixed_precision_flag
()
const
{
return
auto_mixed_precision_flag_
;
}
void
set_enable_reduce_precision
(
bool
flag
)
{
enable_reduce_precision_
=
flag
;
}
bool
enable_reduce_precision
()
const
{
return
enable_reduce_precision_
;
}
void
set_enable_dump
(
bool
flag
)
{
enable_dump_
=
flag
;
}
bool
enable_dump
()
const
{
return
enable_dump_
;
}
void
set_save_dump_path
(
const
std
::
string
&
path
)
{
save_dump_path_
=
path
;
}
std
::
string
save_dump_path
()
const
{
return
save_dump_path_
;
}
bool
IsTsdOpened
()
const
{
return
tsd_ref_
>
0
;
}
void
set_tsd_ref
(
const
std
::
string
&
op
);
uint32_t
tsd_ref
()
const
{
return
tsd_ref_
;
}
void
set_ge_ref
(
const
std
::
string
&
op
);
uint32_t
ge_ref
()
const
{
return
ge_ref_
;
}
bool
is_pynative_ge_init
()
{
return
is_pynative_ge_init_
;
}
void
set_pynative_ge_init
(
bool
flag
)
{
is_pynative_ge_init_
=
flag
;
}
bool
is_multi_graph_sink
()
const
{
return
is_multi_graph_sink_
;
}
void
set_is_multi_graph_sink
(
bool
flag
)
{
is_multi_graph_sink_
=
flag
;
}
void
set_enable_dynamic_mem_pool
(
bool
enable_dynamic_mem_pool
)
{
enable_dynamic_mem_pool_
=
enable_dynamic_mem_pool
;
}
bool
enable_dynamic_mem_pool
()
const
{
return
enable_dynamic_mem_pool_
;
}
void
set_graph_memory_max_size
(
const
std
::
string
&
graph_memory_max_size
)
{
graph_memory_max_size_
=
graph_memory_max_size
;
template
<
typename
T
>
void
increase_param
(
MsCtxParam
param
)
{
MS_LOG
(
EXCEPTION
)
<<
"Need implemet "
<<
__FUNCTION__
<<
" for type "
<<
typeid
(
T
).
name
()
<<
"."
;
}
void
set_variable_memory_max_size
(
const
std
::
string
&
variable_memory_max_size
)
{
variable_memory_max_size_
=
variable_memory_max_size
;
template
<
typename
T
>
void
decrease_param
(
MsCtxParam
param
)
{
MS_LOG
(
EXCEPTION
)
<<
"Need implemet "
<<
__FUNCTION__
<<
" for type "
<<
typeid
(
T
).
name
()
<<
"."
;
}
const
std
::
string
&
variable_memory_max_size
()
const
{
return
variable_memory_max_size_
;
}
const
std
::
string
&
graph_memory_max_size
()
const
{
return
graph_memory_max_size_
;
}
void
set_enable_profiling
(
bool
flag
)
{
profiling_mode_
=
flag
;
}
bool
enable_profiling
()
const
{
return
profiling_mode_
;
}
void
set_profiling_options
(
const
std
::
string
&
options
)
{
profiling_options_
=
options
;
}
std
::
string
profiling_options
()
const
{
return
profiling_options_
;
}
bool
check_bprop_flag
()
const
{
return
check_bprop_flag_
;
}
void
set_check_bprop_flag
(
bool
check_bprop_flag
)
{
check_bprop_flag_
=
check_bprop_flag
;
}
void
set_print_file_path
(
const
std
::
string
&
file
)
{
print_file_path_
=
file
;
}
const
std
::
string
&
print_file_path
()
const
{
return
print_file_path_
;
}
float
max_device_memory
()
const
{
return
max_device_memory_
;
}
void
set_max_device_memory
(
float
max_device_memory
)
{
max_device_memory_
=
max_device_memory
;
}
void
set_enable_graph_kernel
(
bool
enable_graph_kernel
)
{
enable_graph_kernel_
=
enable_graph_kernel
;
}
bool
enable_graph_kernel
()
const
{
return
enable_graph_kernel_
;
}
bool
enable_sparse
()
const
{
return
enable_sparse_
;
}
void
set_enable_sparse
(
bool
enable_sparse
)
{
enable_sparse_
=
enable_sparse
;
}
static
void
device_seter
(
DeviceSeter
device
)
{
seter_
=
device
;
}
static
void
device_type_seter
(
DeviceTypeSeter
device_type
)
{
device_type_seter_
=
device_type
;
}
std
::
thread
tdt_print_
;
private:
inline
static
DeviceSeter
seter_
=
nullptr
;
inline
static
DeviceTypeSeter
device_type_seter_
=
nullptr
;
static
std
::
shared_ptr
<
MsContext
>
inst_context_
;
static
std
::
map
<
std
::
string
,
MsBackendPolicy
>
policy_map_
;
bool
bool_params_
[
MsCtxParam
::
NUM_BOOL_PARAMS
];
int
int_params_
[
MsCtxParam
::
NUM_INT_PARAMS
];
uint32_t
uint32_params_
[
MsCtxParam
::
NUM_UINT32_PARAMS
];
float
float_params_
[
MsCtxParam
::
NUM_FLOAT_PARAMS
];
std
::
string
string_params_
[
MsCtxParam
::
NUM_STRING_PARAMS
];
MsBackendPolicy
backend_policy_
;
std
::
string
device_target_
;
uint32_t
device_id_
;
uint32_t
max_call_depth_
;
int
execution_mode_
;
bool
enable_pynative_infer_
;
bool
enable_pynative_hook_
;
bool
save_graphs_flag_
;
std
::
string
save_graphs_path_
;
uint32_t
tsd_ref_
;
uint32_t
ge_ref_
;
bool
enable_task_sink_
;
bool
enable_hccl_
;
bool
precompile_only_
;
bool
ir_fusion_flag_
;
bool
auto_mixed_precision_flag_
;
bool
enable_reduce_precision_
;
bool
enable_loop_sink_
;
bool
enable_mem_reuse_
;
bool
enable_gpu_summary_
;
bool
enable_dump_
;
std
::
string
save_dump_path_
;
bool
is_multi_graph_sink_
;
bool
is_pynative_ge_init_
;
bool
enable_dynamic_mem_pool_
;
std
::
string
graph_memory_max_size_
;
std
::
string
variable_memory_max_size_
;
bool
profiling_mode_
;
std
::
string
profiling_options_
;
bool
check_bprop_flag_
;
float
max_device_memory_
;
std
::
string
print_file_path_
;
bool
enable_graph_kernel_
;
bool
enable_sparse_
;
};
// set method implementation for type bool/int/uint32_t/float/std::string
template
<
>
inline
void
MsContext
::
set_param
<
bool
>
(
MsCtxParam
param
,
const
bool
&
value
)
{
bool_params_
[
param
-
MS_CTX_TYPE_BOOL_BEGIN
]
=
value
;
}
template
<
>
inline
void
MsContext
::
set_param
<
int
>
(
MsCtxParam
param
,
const
int
&
value
)
{
int_params_
[
param
-
MS_CTX_TYPE_INT_BEGIN
]
=
value
;
}
template
<
>
inline
void
MsContext
::
set_param
<
uint32_t
>
(
MsCtxParam
param
,
const
uint32_t
&
value
)
{
uint32_params_
[
param
-
MS_CTX_TYPE_UINT32_BEGIN
]
=
value
;
}
template
<
>
inline
void
MsContext
::
set_param
<
float
>
(
MsCtxParam
param
,
const
float
&
value
)
{
float_params_
[
param
-
MS_CTX_TYPE_FLOAT_BEGIN
]
=
value
;
}
template
<
>
inline
void
MsContext
::
set_param
<
std
::
string
>
(
MsCtxParam
param
,
const
std
::
string
&
value
)
{
if
(
seter_
!=
nullptr
&&
param
==
MS_CTX_DEVICE_TARGET
)
{
MS_LOG
(
INFO
)
<<
"ms set context device target:"
<<
value
;
seter_
(
value
);
}
string_params_
[
param
-
MS_CTX_TYPE_STRING_BEGIN
]
=
value
;
}
// get method implementation for type bool/int/uint32_t/float/std::string
template
<
>
inline
const
bool
&
MsContext
::
get_param
<
bool
>
(
MsCtxParam
param
)
const
{
return
bool_params_
[
param
-
MS_CTX_TYPE_BOOL_BEGIN
];
}
template
<
>
inline
const
int
&
MsContext
::
get_param
<
int
>
(
MsCtxParam
param
)
const
{
return
int_params_
[
param
-
MS_CTX_TYPE_INT_BEGIN
];
}
template
<
>
inline
const
uint32_t
&
MsContext
::
get_param
<
uint32_t
>
(
MsCtxParam
param
)
const
{
return
uint32_params_
[
param
-
MS_CTX_TYPE_UINT32_BEGIN
];
}
template
<
>
inline
const
float
&
MsContext
::
get_param
<
float
>
(
MsCtxParam
param
)
const
{
return
float_params_
[
param
-
MS_CTX_TYPE_FLOAT_BEGIN
];
}
template
<
>
inline
const
std
::
string
&
MsContext
::
get_param
<
std
::
string
>
(
MsCtxParam
param
)
const
{
return
string_params_
[
param
-
MS_CTX_TYPE_STRING_BEGIN
];
}
// increate method implementation for type uint32_t
template
<
>
inline
void
MsContext
::
increase_param
<
uint32_t
>
(
MsCtxParam
param
)
{
uint32_params_
[
param
-
MS_CTX_TYPE_UINT32_BEGIN
]
++
;
}
// decreate method implementation for type uint32_t
template
<
>
inline
void
MsContext
::
decrease_param
<
uint32_t
>
(
MsCtxParam
param
)
{
uint32_params_
[
param
-
MS_CTX_TYPE_UINT32_BEGIN
]
--
;
}
}
// namespace mindspore
#endif // MINDSPORE_CORE_UTILS_MS_CONTEXT_H_
tests/ut/cpp/optimizer/lib_test.cc
浏览文件 @
8d419314
...
...
@@ -42,7 +42,7 @@ class TestOptLib : public UT::Common {
parse
::
data_converter
::
ClearObjectCache
();
auto
ms_context
=
MsContext
::
GetInstance
();
MS_EXCEPTION_IF_NULL
(
ms_context
);
ms_context
->
set_
execution_mode
(
kGraphMode
);
ms_context
->
set_
param
<
int
>
(
MS_CTX_EXECUTION_MODE
,
kGraphMode
);
}
FuncGraphPtr
RunTransform
(
FuncGraphPtr
gbefore
,
const
SubstitutionList
&
transform
)
{
equiv_node
.
clear
();
...
...
tests/ut/cpp/pre_activate/ascend/format_type/insert_trans_op_test.cc
浏览文件 @
8d419314
...
...
@@ -112,7 +112,7 @@ TEST_F(TestHWInsertTransOp, test_insert_trans_op_for_single_output) {
*/
auto
ms_context
=
MsContext
::
GetInstance
();
MS_EXCEPTION_IF_NULL
(
ms_context
);
ms_context
->
set_
execution_mode
(
kGraphMode
);
ms_context
->
set_
param
<
int
>
(
MS_CTX_EXECUTION_MODE
,
kGraphMode
);
auto
fg
=
GetSingleOutputGraph
(
"test_insert_trans_op_for_single_output"
,
"before"
,
"NC1HWC0"
);
// Do insert_trans_op_ pass of hardware opt
auto
graph_optimizer
=
std
::
make_shared
<
opt
::
GraphOptimizer
>
();
...
...
tests/ut/cpp/pre_activate/ascend/format_type/remove_internal_output_test.cc
浏览文件 @
8d419314
...
...
@@ -112,7 +112,7 @@ class MockRemoveInternalOutputTransOpKernelSelect : public KernelSelect {
TEST_F
(
TestHWRemoveInternalOutput
,
test_remove_internal_output_trans_op_for_single_output
)
{
auto
ms_context
=
MsContext
::
GetInstance
();
MS_EXCEPTION_IF_NULL
(
ms_context
);
ms_context
->
set_
execution_mode
(
kGraphMode
);
ms_context
->
set_
param
<
int
>
(
MS_CTX_EXECUTION_MODE
,
kGraphMode
);
auto
kg
=
GetSingleOutputGraph
(
"test_remove_internal_output_trans_op_for_single_output"
,
"before"
);
// insert trans op for output
auto
graph_optimizer
=
std
::
make_shared
<
opt
::
GraphOptimizer
>
();
...
...
tests/ut/cpp/pre_activate/ascend/ir_fission/transdata_split_test.cc
浏览文件 @
8d419314
...
...
@@ -104,7 +104,7 @@ TEST_F(TestHWTransdataSplit, test_transdata_split_fraz_nchw) {
*/
auto
ms_context
=
MsContext
::
GetInstance
();
MS_EXCEPTION_IF_NULL
(
ms_context
);
ms_context
->
set_
execution_mode
(
kGraphMode
);
ms_context
->
set_
param
<
int
>
(
MS_CTX_EXECUTION_MODE
,
kGraphMode
);
FuncGraphPtr
g
=
get_py_fun_
.
CallAndParseRet
(
"test_transdata_split_fraz_nchw"
,
"before"
);
std
::
vector
<
int
>
shp
{
2
,
4
,
8
,
16
};
auto
x_abstract
=
std
::
make_shared
<
abstract
::
AbstractTensor
>
(
kFloat32
,
shp
);
...
...
tests/ut/cpp/pre_activate/ascend/ir_fusion/transpose_transdata_fusion_test.cc
浏览文件 @
8d419314
...
...
@@ -83,7 +83,7 @@ TEST_F(TestHWTransposeTransdataFusion, test_transpose_transdata_fusion) {
*/
auto
ms_context
=
MsContext
::
GetInstance
();
MS_EXCEPTION_IF_NULL
(
ms_context
);
ms_context
->
set_
execution_mode
(
kGraphMode
);
ms_context
->
set_
param
<
int
>
(
MS_CTX_EXECUTION_MODE
,
kGraphMode
);
FuncGraphPtr
g
=
get_py_fun_
.
CallAndParseRet
(
"test_transpose_transdata_fusion"
,
"before"
);
std
::
vector
<
int
>
shp
{
2
,
4
,
8
,
16
};
auto
x_abstract
=
std
::
make_shared
<
abstract
::
AbstractTensor
>
(
kFloat32
,
shp
);
...
...
tests/ut/cpp/pre_activate/pass/eliminate_redundant_op_test.cc
浏览文件 @
8d419314
...
...
@@ -76,7 +76,7 @@ TEST_F(TestHWEliminateRedundantOp, test_eliminate_5to4_4to5) {
*/
auto
ms_context
=
MsContext
::
GetInstance
();
MS_EXCEPTION_IF_NULL
(
ms_context
);
ms_context
->
set_
execution_mode
(
kGraphMode
);
ms_context
->
set_
param
<
int
>
(
MS_CTX_EXECUTION_MODE
,
kGraphMode
);
FuncGraphPtr
g
=
getPyFun_
.
CallAndParseRet
(
"test_eliminate_5to4_4to5"
,
"before"
);
// Renormalize func_graph to infer and set shape and type information.
std
::
vector
<
int
>
shp
{
2
,
32
,
224
,
224
};
...
...
tests/ut/cpp/pynative/pynative_execute_test.cc
浏览文件 @
8d419314
...
...
@@ -71,15 +71,15 @@ OpExecInfoPtr ConstructOpExecInfo() {
TEST_F
(
TestPynativeExecute
,
TestCreateContext
)
{
auto
ctx3
=
MsContext
::
GetInstance
();
ASSERT_EQ
(
ctx3
->
backend_policy
(),
"vm"
);
ASSERT_EQ
(
ctx3
->
device_target
(
),
"CPU"
);
ASSERT_EQ
(
ctx3
->
get_param
<
std
::
string
>
(
MS_CTX_DEVICE_TARGET
),
"CPU"
);
ctx3
->
set_backend_policy
(
"ge_only"
);
ctx3
->
set_
device_target
(
"GPU"
);
ctx3
->
set_
param
<
std
::
string
>
(
MS_CTX_DEVICE_TARGET
,
"GPU"
);
auto
ctx4
=
MsContext
::
GetInstance
();
ASSERT_EQ
(
ctx3
.
get
(),
ctx4
.
get
());
ASSERT_EQ
(
ctx4
->
backend_policy
(),
"ge_only"
);
ASSERT_EQ
(
ctx4
->
device_target
(
),
"GPU"
);
ASSERT_EQ
(
ctx4
->
get_param
<
std
::
string
>
(
MS_CTX_DEVICE_TARGET
),
"GPU"
);
}
TEST_F
(
TestPynativeExecute
,
TestDefaultContext
)
{
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录