Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
7951b318
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
7951b318
编写于
6月 30, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
6月 30, 2020
浏览文件
操作
浏览文件
下载
差异文件
!2697 format device ascend code
Merge pull request !2697 from kisnwang/format-device-ascend-code
上级
2291213b
89c9f46a
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
122 addition
and
102 deletion
+122
-102
mindspore/ccsrc/device/ascend/ascend_kernel_runtime.cc
mindspore/ccsrc/device/ascend/ascend_kernel_runtime.cc
+34
-34
mindspore/ccsrc/device/ascend/ascend_stream_assign.cc
mindspore/ccsrc/device/ascend/ascend_stream_assign.cc
+43
-43
mindspore/ccsrc/device/cpu/cpu_kernel_runtime.cc
mindspore/ccsrc/device/cpu/cpu_kernel_runtime.cc
+33
-23
mindspore/ccsrc/device/cpu/cpu_kernel_runtime.h
mindspore/ccsrc/device/cpu/cpu_kernel_runtime.h
+4
-0
mindspore/ccsrc/device/cpu/cpu_simple_mem_plan.cc
mindspore/ccsrc/device/cpu/cpu_simple_mem_plan.cc
+7
-1
mindspore/ccsrc/device/cpu/cpu_simple_mem_plan.h
mindspore/ccsrc/device/cpu/cpu_simple_mem_plan.h
+1
-1
未找到文件。
mindspore/ccsrc/device/ascend/ascend_kernel_runtime.cc
浏览文件 @
7951b318
...
...
@@ -68,9 +68,9 @@ std::string GetRankId() {
int
rank_offset
=
std
::
stoi
(
offset
);
rank_id
+=
rank_offset
;
}
catch
(
std
::
invalid_argument
)
{
MS_LOG
(
EXCEPTION
)
<<
"stoi invalid argument:"
<<
offset
;
MS_LOG
(
EXCEPTION
)
<<
"
Call
stoi invalid argument:"
<<
offset
;
}
catch
(
std
::
out_of_range
)
{
MS_LOG
(
EXCEPTION
)
<<
"stoi out_of_range:"
<<
offset
;
MS_LOG
(
EXCEPTION
)
<<
"
Call
stoi out_of_range:"
<<
offset
;
}
}
rank_id_str
=
std
::
to_string
(
rank_id
);
...
...
@@ -81,7 +81,7 @@ std::string GetRankId() {
rank_id_str
=
std
::
getenv
(
"RANK_ID"
);
#endif
if
(
rank_id_str
.
empty
())
{
MS_LOG
(
ERROR
)
<<
"
g
et hccl rankid failed, please set env RANK_ID"
;
MS_LOG
(
ERROR
)
<<
"
G
et hccl rankid failed, please set env RANK_ID"
;
}
return
rank_id_str
;
}
...
...
@@ -100,7 +100,7 @@ void AscendKernelRuntime::ClearGraphModelMap() {
}
void
AscendKernelRuntime
::
ClearGraphRuntimeResource
(
uint32_t
graph_id
)
{
MS_LOG
(
DEBUG
)
<<
"
c
lear graph:"
<<
graph_id
<<
" runtime resource"
;
MS_LOG
(
DEBUG
)
<<
"
C
lear graph:"
<<
graph_id
<<
" runtime resource"
;
auto
iter
=
graph_model_map_
.
find
(
graph_id
);
if
(
iter
==
graph_model_map_
.
end
())
{
MS_LOG
(
DEBUG
)
<<
"GraphId:"
<<
graph_id
<<
" not found"
;
...
...
@@ -118,7 +118,7 @@ bool AscendKernelRuntime::NeedDestroyHccl() {
auto
context_ptr
=
MsContext
::
GetInstance
();
MS_EXCEPTION_IF_NULL
(
context_ptr
);
if
(
!
context_ptr
->
enable_hccl
())
{
MS_LOG
(
INFO
)
<<
"
h
ccl is not enabled"
;
MS_LOG
(
INFO
)
<<
"
H
ccl is not enabled"
;
return
false
;
}
// Note: make sure hcom_connectivity_detection api never be used.
...
...
@@ -126,7 +126,7 @@ bool AscendKernelRuntime::NeedDestroyHccl() {
}
void
AscendKernelRuntime
::
ReleaseDeviceRes
()
{
MS_LOG
(
INFO
)
<<
"
a
scend finalize start"
;
MS_LOG
(
INFO
)
<<
"
A
scend finalize start"
;
// release ge runtime
ClearGraphModelMap
();
...
...
@@ -134,7 +134,7 @@ void AscendKernelRuntime::ReleaseDeviceRes() {
MS_EXCEPTION_IF_NULL
(
context_ptr
);
auto
ret
=
rtSetDevice
(
context_ptr
->
device_id
());
if
(
ret
!=
RT_ERROR_NONE
)
{
MS_EXCEPTION
(
DeviceProcessError
)
<<
"rtSetDevice, ret["
<<
static_cast
<
int
>
(
ret
)
<<
"]"
;
MS_EXCEPTION
(
DeviceProcessError
)
<<
"
Call
rtSetDevice, ret["
<<
static_cast
<
int
>
(
ret
)
<<
"]"
;
}
if
(
mem_manager_
!=
nullptr
)
{
...
...
@@ -144,7 +144,7 @@ void AscendKernelRuntime::ReleaseDeviceRes() {
(
void
)
DestroyHccl
();
(
void
)
ResetDevice
();
(
void
)
ProfilingManager
::
GetInstance
().
StopProfiling
();
MS_LOG
(
INFO
)
<<
"
a
scend finalize end"
;
MS_LOG
(
INFO
)
<<
"
A
scend finalize end"
;
}
bool
AscendKernelRuntime
::
Init
()
{
...
...
@@ -155,7 +155,7 @@ bool AscendKernelRuntime::Init() {
#ifdef ENABLE_DUMP_E2E
ret
=
SetDumpConf
();
if
(
!
ret
)
{
MS_LOG
(
INFO
)
<<
"
n
o dump conf to set!"
;
MS_LOG
(
INFO
)
<<
"
N
o dump conf to set!"
;
}
#endif
...
...
@@ -263,13 +263,13 @@ void DumpParameters(mindspore::session::KernelGraph *graph, const string &dump_p
bool
AscendKernelRuntime
::
DumpData
(
mindspore
::
session
::
KernelGraph
*
graph
)
{
MS_EXCEPTION_IF_NULL
(
graph
);
#ifdef ENABLE_DUMP_E2E
MS_LOG
(
INFO
)
<<
"
s
tart dump step"
;
MS_LOG
(
INFO
)
<<
"
S
tart dump step"
;
DumpConfPtr
dump_conf
=
GetDumpConf
();
MS_EXCEPTION_IF_NULL
(
dump_conf
);
dump_conf
->
UpdataCurIter
();
bool
dump_flag
=
dump_conf
->
dump_enable
();
if
(
!
dump_flag
)
{
MS_LOG
(
INFO
)
<<
"
d
ump flag is disable, pass dump step"
;
MS_LOG
(
INFO
)
<<
"
D
ump flag is disable, pass dump step"
;
return
true
;
}
uint32_t
cur_iter
=
dump_conf
->
cur_iter
();
...
...
@@ -278,7 +278,7 @@ bool AscendKernelRuntime::DumpData(mindspore::session::KernelGraph *graph) {
return
true
;
}
}
MS_LOG
(
INFO
)
<<
"
c
ur iter is "
<<
cur_iter
;
MS_LOG
(
INFO
)
<<
"
C
ur iter is "
<<
cur_iter
;
std
::
string
net_name
=
dump_conf
->
dump_net_name
();
std
::
string
iterator
=
to_string
(
cur_iter
);
std
::
string
dump_path
=
dump_conf
->
dump_path
();
...
...
@@ -369,9 +369,9 @@ void LoadParameters(mindspore::session::KernelGraph *graph, Debugger *debugger)
bool
AscendKernelRuntime
::
LoadData
(
mindspore
::
session
::
KernelGraph
*
graph
,
Debugger
*
debugger
)
{
MS_EXCEPTION_IF_NULL
(
graph
);
#ifdef ENABLE_DEBUGGER
MS_LOG
(
INFO
)
<<
"
s
tart load step"
;
MS_LOG
(
INFO
)
<<
"
S
tart load step"
;
uint32_t
cur_iter
=
0
;
MS_LOG
(
INFO
)
<<
"
c
ur iter is "
<<
cur_iter
;
MS_LOG
(
INFO
)
<<
"
C
ur iter is "
<<
cur_iter
;
// load output
LoadOutput
(
graph
,
debugger
);
// load parameters
...
...
@@ -421,7 +421,7 @@ bool AscendKernelRuntime::GenTask(const session::KernelGraph *graph) {
}
// Graph may have no compute node, such TensorAddGrad.
if
(
task_info_list
.
empty
())
{
MS_LOG
(
WARNING
)
<<
"
g
raph "
<<
graph
->
graph_id
()
<<
" have no compute node"
;
MS_LOG
(
WARNING
)
<<
"
G
raph "
<<
graph
->
graph_id
()
<<
" have no compute node"
;
return
true
;
}
AscendStreamAssign
&
assign_instance
=
AscendStreamAssign
::
GetInstance
();
...
...
@@ -432,7 +432,7 @@ bool AscendKernelRuntime::GenTask(const session::KernelGraph *graph) {
assign_instance
.
GetWaitStreams
(
&
wait_active_stream_list
);
std
::
vector
<
uint32_t
>
force_copy_stream_list
;
assign_instance
.
GetHcomStreams
(
&
force_copy_stream_list
);
MS_LOG
(
INFO
)
<<
"
c
all DavinciModel total stream num:"
<<
resource_manager
.
get_cur_stream_num
()
MS_LOG
(
INFO
)
<<
"
C
all DavinciModel total stream num:"
<<
resource_manager
.
get_cur_stream_num
()
<<
", total event num:"
<<
resource_manager
.
get_cur_event_num
()
<<
", total label num:"
<<
label_assign_instance
.
GetLabelNum
(
NOT_NULL
(
graph
))
<<
", wait_active_stream_list size:"
<<
wait_active_stream_list
.
size
()
...
...
@@ -524,7 +524,7 @@ bool AscendKernelRuntime::RunTask(const session::KernelGraph *graph) {
bool
status
=
ge
::
model_runner
::
ModelRunner
::
Instance
().
RunModel
(
graph
->
graph_id
(),
input_tensors
,
output_tensors
);
if
(
!
status
)
{
MS_LOG
(
ERROR
)
<<
"
r
un task failed"
;
MS_LOG
(
ERROR
)
<<
"
R
un task failed"
;
DebugTaskIdName
(
graph
->
graph_id
());
return
false
;
}
...
...
@@ -543,18 +543,18 @@ bool AscendKernelRuntime::InitDevice() {
int
device_count
=
0
;
auto
ret
=
rtGetDeviceCount
(
&
device_count
);
if
(
ret
!=
RT_ERROR_NONE
)
{
MS_EXCEPTION
(
DeviceProcessError
)
<<
"rtGetDeviceCount, ret["
<<
static_cast
<
int
>
(
ret
)
<<
"]"
;
MS_EXCEPTION
(
DeviceProcessError
)
<<
"
Call
rtGetDeviceCount, ret["
<<
static_cast
<
int
>
(
ret
)
<<
"]"
;
}
ret
=
rtSetDevice
(
device_id_
);
if
(
ret
!=
RT_ERROR_NONE
)
{
MS_EXCEPTION
(
DeviceProcessError
)
<<
"rtSetDevice, ret["
<<
static_cast
<
int
>
(
ret
)
<<
"]"
;
MS_EXCEPTION
(
DeviceProcessError
)
<<
"
Call
rtSetDevice, ret["
<<
static_cast
<
int
>
(
ret
)
<<
"]"
;
}
auto
context_ptr
=
MsContext
::
GetInstance
();
MS_EXCEPTION_IF_NULL
(
context_ptr
);
if
(
context_ptr
==
nullptr
)
{
MS_LOG
(
ERROR
)
<<
"
g
et MsContext instance failed"
;
MS_LOG
(
ERROR
)
<<
"
G
et MsContext instance failed"
;
return
false
;
}
if
(
context_ptr
->
enable_hccl
())
{
...
...
@@ -566,17 +566,17 @@ bool AscendKernelRuntime::InitDevice() {
ret
=
rtCtxCreate
(
&
rt_context_
,
0
,
device_id_
);
if
(
ret
!=
RT_ERROR_NONE
)
{
MS_EXCEPTION
(
DeviceProcessError
)
<<
"rtCtxCreate, ret["
<<
static_cast
<
int
>
(
ret
)
<<
"]"
;
MS_EXCEPTION
(
DeviceProcessError
)
<<
"
Call
rtCtxCreate, ret["
<<
static_cast
<
int
>
(
ret
)
<<
"]"
;
}
ret
=
rtCtxSetCurrent
(
rt_context_
);
if
(
ret
!=
RT_ERROR_NONE
)
{
MS_EXCEPTION
(
DeviceProcessError
)
<<
"rtCtxSetCurrent, ret["
<<
ret
<<
"]"
;
MS_EXCEPTION
(
DeviceProcessError
)
<<
"
Call
rtCtxSetCurrent, ret["
<<
ret
<<
"]"
;
}
ret
=
rtStreamCreate
(
&
stream_
,
0
);
if
(
ret
!=
RT_ERROR_NONE
)
{
MS_LOG
(
EXCEPTION
)
<<
"rtStreamCreate, ret["
<<
ret
<<
"]"
;
MS_LOG
(
EXCEPTION
)
<<
"
Call
rtStreamCreate, ret["
<<
ret
<<
"]"
;
}
return
true
;
...
...
@@ -585,14 +585,14 @@ bool AscendKernelRuntime::InitDevice() {
bool
AscendKernelRuntime
::
ResetDevice
()
{
auto
ret
=
rtCtxSetCurrent
(
rt_context_
);
if
(
ret
!=
RT_ERROR_NONE
)
{
MS_LOG
(
ERROR
)
<<
"
c
all rtCtxSetCurrent failed"
;
MS_LOG
(
ERROR
)
<<
"
C
all rtCtxSetCurrent failed"
;
return
false
;
}
if
(
stream_
!=
nullptr
)
{
ret
=
rtStreamDestroy
(
stream_
);
if
(
ret
!=
RT_ERROR_NONE
)
{
MS_LOG
(
EXCEPTION
)
<<
"rtStreamDestroy, ret["
<<
ret
<<
"]"
;
MS_LOG
(
EXCEPTION
)
<<
"
Call
rtStreamDestroy, ret["
<<
ret
<<
"]"
;
}
stream_
=
nullptr
;
}
...
...
@@ -600,7 +600,7 @@ bool AscendKernelRuntime::ResetDevice() {
if
(
rt_context_
!=
nullptr
)
{
ret
=
rtCtxDestroy
(
rt_context_
);
if
(
ret
!=
RT_ERROR_NONE
)
{
MS_EXCEPTION
(
DeviceProcessError
)
<<
"rtCtxDestroy, ret["
<<
ret
<<
"]"
;
MS_EXCEPTION
(
DeviceProcessError
)
<<
"
Call
rtCtxDestroy, ret["
<<
ret
<<
"]"
;
}
rt_context_
=
nullptr
;
}
...
...
@@ -613,30 +613,30 @@ bool AscendKernelRuntime::HcclInit() {
if
(
!
context_ptr
->
IsTsdOpened
())
{
MS_LOG
(
EXCEPTION
)
<<
"Hccl dependent tsd is not open"
;
}
MS_LOG
(
INFO
)
<<
"
d
o hcom init"
;
MS_LOG
(
INFO
)
<<
"
D
o hcom init"
;
auto
config_path_str
=
std
::
getenv
(
"MINDSPORE_HCCL_CONFIG_PATH"
);
if
(
config_path_str
==
nullptr
)
{
config_path_str
=
std
::
getenv
(
"RANK_TABLE_FILE"
);
if
(
config_path_str
==
nullptr
)
{
MS_LOG
(
ERROR
)
<<
"
g
et hccl json config failed, please set env MINDSPORE_HCCL_CONFIG_PATH or RANK_TABLE_FILE"
;
MS_LOG
(
ERROR
)
<<
"
G
et hccl json config failed, please set env MINDSPORE_HCCL_CONFIG_PATH or RANK_TABLE_FILE"
;
return
false
;
}
}
if
(
strlen
(
config_path_str
)
>
PATH_MAX
)
{
MS_LOG
(
ERROR
)
<<
"
f
ile path oversize"
;
MS_LOG
(
ERROR
)
<<
"
F
ile path oversize"
;
return
false
;
}
std
::
string
rank_id_str
=
GetRankId
();
auto
full_path
=
realpath
(
config_path_str
,
nullptr
);
if
(
full_path
==
nullptr
)
{
MS_LOG
(
ERROR
)
<<
"
f
ile path "
<<
config_path_str
<<
" does not exist"
;
MS_LOG
(
ERROR
)
<<
"
F
ile path "
<<
config_path_str
<<
" does not exist"
;
return
false
;
}
MS_LOG
(
INFO
)
<<
"MINDSPORE_HCCL_CONFIG_PATH : "
<<
full_path
<<
", RANK_ID: "
<<
rank_id_str
;
hcclResult_t
res
=
hcom_init
(
full_path
,
rank_id_str
.
c_str
());
free
(
full_path
);
if
(
res
!=
HCCL_SUCCESS
)
{
MS_LOG
(
ERROR
)
<<
"
h
com init failed, res is "
<<
static_cast
<
int
>
(
res
);
MS_LOG
(
ERROR
)
<<
"
H
com init failed, res is "
<<
static_cast
<
int
>
(
res
);
return
false
;
}
return
true
;
...
...
@@ -646,15 +646,15 @@ bool AscendKernelRuntime::DestroyHccl() {
auto
context_ptr
=
MsContext
::
GetInstance
();
MS_EXCEPTION_IF_NULL
(
context_ptr
);
if
(
!
NeedDestroyHccl
())
{
MS_LOG
(
INFO
)
<<
"
h
ccl is not enable, no need to close."
;
MS_LOG
(
INFO
)
<<
"
H
ccl is not enable, no need to close."
;
return
true
;
}
hcclResult_t
res
=
hcom_destroy
();
if
(
res
!=
HCCL_SUCCESS
)
{
MS_LOG
(
ERROR
)
<<
"
h
ccl destroy failed"
;
MS_LOG
(
ERROR
)
<<
"
H
ccl destroy failed"
;
return
false
;
}
MS_LOG
(
INFO
)
<<
"
h
ccl destroy successful, status = "
<<
res
<<
"."
;
MS_LOG
(
INFO
)
<<
"
H
ccl destroy successful, status = "
<<
res
<<
"."
;
context_ptr
->
set_enable_hccl
(
false
);
return
true
;
}
...
...
mindspore/ccsrc/device/ascend/ascend_stream_assign.cc
浏览文件 @
7951b318
...
...
@@ -46,7 +46,7 @@ void AscendStreamAssign::AssignStream(const NotNull<KernelGraphPtr> &graph_ptr)
GetNeedActiveStreams
(
graph_ptr
);
graph_ptr
->
PrintGraphExecuteOrder
();
CheckResourceAssign
(
graph_ptr
);
MS_LOG
(
INFO
)
<<
"
a
fter finish stream assign"
;
MS_LOG
(
INFO
)
<<
"
A
fter finish stream assign"
;
// Get info for D Model
AscendResourceMng
&
resource_manager
=
AscendResourceMng
::
GetInstance
();
...
...
@@ -64,7 +64,7 @@ void AscendStreamAssign::ReorderIndependentOrders(const NotNull<KernelGraphPtr>
std
::
vector
<
CNodePtr
>
others
;
auto
cnode_ptr_list
=
graph_ptr
->
execution_order
();
MS_LOG
(
INFO
)
<<
"
b
efore reorder, graph orders size:"
<<
cnode_ptr_list
.
size
();
MS_LOG
(
INFO
)
<<
"
B
efore reorder, graph orders size:"
<<
cnode_ptr_list
.
size
();
for
(
size_t
i
=
0
;
i
<
cnode_ptr_list
.
size
();
++
i
)
{
auto
cur_cnode_ptr
=
cnode_ptr_list
[
i
];
MS_EXCEPTION_IF_NULL
(
cur_cnode_ptr
);
...
...
@@ -76,7 +76,7 @@ void AscendStreamAssign::ReorderIndependentOrders(const NotNull<KernelGraphPtr>
}
if
(
others
.
empty
()
||
independents
.
empty
())
{
MS_LOG
(
INFO
)
<<
"
i
ndependent or others is empty, no need reorder"
;
MS_LOG
(
INFO
)
<<
"
I
ndependent or others is empty, no need reorder"
;
return
;
}
...
...
@@ -107,9 +107,9 @@ void AscendStreamAssign::ReorderIndependentOrders(const NotNull<KernelGraphPtr>
}
}
MS_LOG
(
INFO
)
<<
"
a
fter reorder, graph orders size:"
<<
exe_orders
.
size
();
MS_LOG
(
INFO
)
<<
"
A
fter reorder, graph orders size:"
<<
exe_orders
.
size
();
if
(
processed
.
size
()
!=
independents
.
size
())
{
MS_LOG
(
WARNING
)
<<
"
p
rocessed independent nodes size is not equal to exiting independent nodes size"
;
MS_LOG
(
WARNING
)
<<
"
P
rocessed independent nodes size is not equal to exiting independent nodes size"
;
return
;
}
...
...
@@ -142,7 +142,7 @@ void AscendStreamAssign::AssignAllNodesStream(const NotNull<KernelGraphPtr> &gra
AssignCommonStreamId
(
cur_cnode_ptr
);
}
MS_LOG
(
INFO
)
<<
"
c
ommon start from 0, common stream nums:"
<<
resource_manager
.
get_cur_stream_num
();
MS_LOG
(
INFO
)
<<
"
C
ommon start from 0, common stream nums:"
<<
resource_manager
.
get_cur_stream_num
();
if
(
exit_hcom
)
{
uint32_t
first_hcom_stream_id
=
resource_manager
.
ApplyNewStream
();
...
...
@@ -157,7 +157,7 @@ void AscendStreamAssign::AssignAllNodesStream(const NotNull<KernelGraphPtr> &gra
AssignHcomStreamId
(
cur_cnode_ptr
);
}
}
MS_LOG
(
INFO
)
<<
"
h
com start from :"
<<
first_hcom_stream_id
<<
", hcom stream nums:"
<<
hcom_stream_map_
.
size
();
MS_LOG
(
INFO
)
<<
"
H
com start from :"
<<
first_hcom_stream_id
<<
", hcom stream nums:"
<<
hcom_stream_map_
.
size
();
}
if
(
exit_independent
)
{
...
...
@@ -171,10 +171,10 @@ void AscendStreamAssign::AssignAllNodesStream(const NotNull<KernelGraphPtr> &gra
AssignIndependentStreamId
(
cur_cnode_ptr
);
}
}
MS_LOG
(
INFO
)
<<
"
i
ndepend start from:"
<<
first_independ
<<
", stream nums:"
<<
independent_stream_map_
.
size
();
MS_LOG
(
INFO
)
<<
"
I
ndepend start from:"
<<
first_independ
<<
", stream nums:"
<<
independent_stream_map_
.
size
();
}
MS_LOG
(
INFO
)
<<
"
a
fter stream assign, total stream nums:"
<<
resource_manager
.
get_cur_stream_num
();
MS_LOG
(
INFO
)
<<
"
A
fter stream assign, total stream nums:"
<<
resource_manager
.
get_cur_stream_num
();
}
void
AscendStreamAssign
::
AssignCommonStreamId
(
const
CNodePtr
&
cur_cnode_ptr
)
{
...
...
@@ -257,7 +257,7 @@ bool AscendStreamAssign::IsIndependentNode(const CNodePtr &node_ptr) {
uint32_t
input_nums
=
AnfAlgo
::
GetInputTensorNum
(
node_ptr
);
if
(
input_nums
==
0
)
{
MS_LOG
(
INFO
)
<<
"
n
ode "
<<
node_ptr
->
fullname_with_scope
()
<<
" is independent, as inputs nums is zero"
;
MS_LOG
(
INFO
)
<<
"
N
ode "
<<
node_ptr
->
fullname_with_scope
()
<<
" is independent, as inputs nums is zero"
;
return
true
;
}
...
...
@@ -267,13 +267,13 @@ bool AscendStreamAssign::IsIndependentNode(const CNodePtr &node_ptr) {
return
false
;
}
}
MS_LOG
(
INFO
)
<<
"
n
ode "
<<
node_ptr
->
fullname_with_scope
()
<<
" is independent, as inputs is all value node"
;
MS_LOG
(
INFO
)
<<
"
N
ode "
<<
node_ptr
->
fullname_with_scope
()
<<
" is independent, as inputs is all value node"
;
return
true
;
}
// section 3:
void
AscendStreamAssign
::
UpdateAtomicAddrCleanStreamId
(
const
NotNull
<
KernelGraphPtr
>
&
graph_ptr
)
{
MS_LOG
(
INFO
)
<<
"
s
tart"
;
MS_LOG
(
INFO
)
<<
"
S
tart"
;
auto
cnode_ptr_list
=
graph_ptr
->
execution_order
();
for
(
size_t
i
=
0
;
i
<
cnode_ptr_list
.
size
();
++
i
)
{
CNodePtr
cur_cnode_ptr
=
cnode_ptr_list
[
i
];
...
...
@@ -283,12 +283,12 @@ void AscendStreamAssign::UpdateAtomicAddrCleanStreamId(const NotNull<KernelGraph
AnfAlgo
::
SetStreamId
(
AnfAlgo
::
GetStreamId
(
cur_cnode_ptr
),
cnode_ptr_list
[
i
-
1
].
get
());
}
}
MS_LOG
(
INFO
)
<<
"
e
nd"
;
MS_LOG
(
INFO
)
<<
"
E
nd"
;
}
// section 4
void
AscendStreamAssign
::
InsertStreamActive
(
const
NotNull
<
KernelGraphPtr
>
&
graph_ptr
)
{
MS_LOG
(
INFO
)
<<
"
s
tart"
;
MS_LOG
(
INFO
)
<<
"
S
tart"
;
GetProcessedStream
(
graph_ptr
);
std
::
vector
<
CNodePtr
>
update_cnode_list
;
CNodePtr
cur_cnode_ptr
=
nullptr
;
...
...
@@ -314,7 +314,7 @@ void AscendStreamAssign::InsertStreamActive(const NotNull<KernelGraphPtr> &graph
bool
processed
=
IsProcessedStream
(
cur_stream_id
);
// 1)inner stream assign, need insert active op
if
(
!
processed
)
{
MS_LOG
(
INFO
)
<<
"
c
ommon stream active info:"
<<
pre_stream_id
<<
"->active"
<<
cur_stream_id
;
MS_LOG
(
INFO
)
<<
"
C
ommon stream active info:"
<<
pre_stream_id
<<
"->active"
<<
cur_stream_id
;
CNodePtr
active_ptr
=
KernelAdjust
::
GetInstance
().
CreateStreamActiveOp
(
graph_ptr
);
// 1.set stream id
AnfAlgo
::
SetStreamId
(
pre_stream_id
,
active_ptr
.
get
());
...
...
@@ -336,7 +336,7 @@ void AscendStreamAssign::InsertStreamActive(const NotNull<KernelGraphPtr> &graph
pre_cnode_ptr
=
cur_cnode_ptr
;
}
graph_ptr
->
set_execution_order
(
update_cnode_list
);
MS_LOG
(
INFO
)
<<
"
e
nd"
;
MS_LOG
(
INFO
)
<<
"
E
nd"
;
}
void
AscendStreamAssign
::
GetProcessedStream
(
const
NotNull
<
KernelGraphPtr
>
&
graph_ptr
)
{
...
...
@@ -364,7 +364,7 @@ void AscendStreamAssign::GetProcessedStream(const NotNull<KernelGraphPtr> &graph
}
}
for
(
const
auto
&
item
:
processed_streams_
)
{
MS_LOG
(
INFO
)
<<
"
b
efore active:"
<<
item
<<
" is been processed"
;
MS_LOG
(
INFO
)
<<
"
B
efore active:"
<<
item
<<
" is been processed"
;
}
}
...
...
@@ -385,7 +385,7 @@ void AscendStreamAssign::UpdateStreamSwitch(const NotNull<KernelGraphPtr> &graph
MS_EXCEPTION_IF_NULL
(
switch_ptr
);
auto
true_stream_id
=
GetValue
<
uint32_t
>
(
primitive
->
GetAttr
(
kAttrTrueBranchStream
));
MS_LOG
(
INFO
)
<<
"
s
treamswtich stream id:"
<<
AnfAlgo
::
GetStreamId
(
switch_ptr
)
MS_LOG
(
INFO
)
<<
"
S
treamswtich stream id:"
<<
AnfAlgo
::
GetStreamId
(
switch_ptr
)
<<
"; active stream id:"
<<
true_stream_id
;
CNodePtr
active_ptr
=
KernelAdjust
::
GetInstance
().
CreateStreamActiveOp
(
graph_ptr
);
...
...
@@ -425,11 +425,11 @@ bool AscendStreamAssign::IsProcessedStream(uint32_t stream_id) {
// section5
void
AscendStreamAssign
::
InsertEventForHcomParallel
(
const
NotNull
<
KernelGraphPtr
>
&
graph_ptr
)
{
MS_LOG
(
INFO
)
<<
"
s
tart"
;
MS_LOG
(
INFO
)
<<
"
S
tart"
;
InsertEventCommonDependHcom
(
graph_ptr
);
InsertEventHcomDependCommon
(
graph_ptr
);
InsertEventHcomDependHcom
(
graph_ptr
);
MS_LOG
(
INFO
)
<<
"
e
nd"
;
MS_LOG
(
INFO
)
<<
"
E
nd"
;
}
void
AscendStreamAssign
::
InsertEventCommonDependHcom
(
const
NotNull
<
KernelGraphPtr
>
&
graph_ptr
)
{
...
...
@@ -447,7 +447,7 @@ void AscendStreamAssign::InsertEventCommonDependHcom(const NotNull<KernelGraphPt
auto
target
=
FindTargetOp
(
it
,
cnodes
.
end
(),
*
(
it
-
1
));
if
(
target
==
cnodes
.
end
())
{
MS_LOG
(
WARNING
)
<<
"
h
com node:"
<<
(
*
(
it
-
1
))
->
fullname_with_scope
()
MS_LOG
(
WARNING
)
<<
"
H
com node:"
<<
(
*
(
it
-
1
))
->
fullname_with_scope
()
<<
", can't find target for insert recv op, no insert send/recv"
;
it
=
cnodes
.
erase
(
it
);
continue
;
...
...
@@ -469,7 +469,7 @@ void AscendStreamAssign::InsertEventCommonDependHcom(const NotNull<KernelGraphPt
// one event allocated additional, should delete
resource_manager
.
DeleteEvent
();
graph_ptr
->
set_execution_order
(
cnodes
);
MS_LOG
(
INFO
)
<<
"
a
fter common depend hcom, total event nums:"
<<
resource_manager
.
get_cur_event_num
();
MS_LOG
(
INFO
)
<<
"
A
fter common depend hcom, total event nums:"
<<
resource_manager
.
get_cur_event_num
();
}
void
AscendStreamAssign
::
InsertEventHcomDependCommon
(
const
NotNull
<
KernelGraphPtr
>
&
graph_ptr
)
{
...
...
@@ -512,7 +512,7 @@ void AscendStreamAssign::InsertEventHcomDependCommon(const NotNull<KernelGraphPt
}
graph_ptr
->
set_execution_order
(
cnodes
);
MS_LOG
(
INFO
)
<<
"
a
fter hcom depend common, total event nums:"
<<
resource_manager
.
get_cur_event_num
();
MS_LOG
(
INFO
)
<<
"
A
fter hcom depend common, total event nums:"
<<
resource_manager
.
get_cur_event_num
();
}
void
AscendStreamAssign
::
InsertEventHcomDependHcom
(
const
NotNull
<
KernelGraphPtr
>
&
graph_ptr
)
{
...
...
@@ -547,11 +547,11 @@ void AscendStreamAssign::InsertEventHcomDependHcom(const NotNull<KernelGraphPtr>
}
if
(
hcom_index
.
size
()
<
2
)
{
MS_LOG
(
INFO
)
<<
"
d
ifferent stream hcom size is less than 2, no need insert event between them"
;
MS_LOG
(
INFO
)
<<
"
D
ifferent stream hcom size is less than 2, no need insert event between them"
;
return
;
}
InsertEventBetweenHcom
(
graph_ptr
,
hcom_index
,
first_hcom_stream
,
last_hcom_stream
);
MS_LOG
(
INFO
)
<<
"
a
fter hcom depend hcom, total event nums:"
<<
resource_manager
.
get_cur_event_num
();
MS_LOG
(
INFO
)
<<
"
A
fter hcom depend hcom, total event nums:"
<<
resource_manager
.
get_cur_event_num
();
}
void
AscendStreamAssign
::
InsertEventBetweenHcom
(
const
NotNull
<
KernelGraphPtr
>
&
graph_ptr
,
...
...
@@ -630,7 +630,7 @@ bool AscendStreamAssign::IsSatisfiedHcom(const std::map<uint32_t, vector<size_t>
// section6
void
AscendStreamAssign
::
InsertEventForIndependentParallel
(
const
NotNull
<
KernelGraphPtr
>
&
graph_ptr
)
{
MS_LOG
(
INFO
)
<<
"
s
tart"
;
MS_LOG
(
INFO
)
<<
"
S
tart"
;
AscendResourceMng
&
resource_manager
=
AscendResourceMng
::
GetInstance
();
auto
cnode_ptr_list
=
graph_ptr
->
execution_order
();
vector
<
CNodePtr
>
cnodes
=
cnode_ptr_list
;
...
...
@@ -639,13 +639,13 @@ void AscendStreamAssign::InsertEventForIndependentParallel(const NotNull<KernelG
while
(
it
!=
cnodes
.
end
())
{
MS_EXCEPTION_IF_NULL
(
*
it
);
if
(
IsIndependentNode
(
*
it
))
{
MS_LOG
(
INFO
)
<<
"
d
eal independent op["
<<
(
*
it
)
->
DebugString
()
<<
"]"
;
MS_LOG
(
INFO
)
<<
"
D
eal independent op["
<<
(
*
it
)
->
DebugString
()
<<
"]"
;
CNodePtr
send_cnode_ptr
=
CreateSendApplyKernel
(
graph_ptr
,
cur_event_id
,
AnfAlgo
::
GetStreamId
(
*
it
));
it
=
cnodes
.
insert
(
it
+
1
,
send_cnode_ptr
);
auto
target
=
FindTargetOp
(
it
,
cnodes
.
end
(),
*
(
it
-
1
));
if
(
target
==
cnodes
.
end
())
{
MS_LOG
(
DEBUG
)
<<
"
i
ndepend node["
<<
(
*
(
it
-
1
))
->
fullname_with_scope
()
MS_LOG
(
DEBUG
)
<<
"
I
ndepend node["
<<
(
*
(
it
-
1
))
->
fullname_with_scope
()
<<
"] can't find target for insert recv op, no insert send/recv"
;
it
=
cnodes
.
erase
(
it
);
continue
;
...
...
@@ -662,8 +662,8 @@ void AscendStreamAssign::InsertEventForIndependentParallel(const NotNull<KernelG
// one event allocated additional, should delete
resource_manager
.
DeleteEvent
();
graph_ptr
->
set_execution_order
(
cnodes
);
MS_LOG
(
INFO
)
<<
"
a
fter independent parallel, total event nums:"
<<
resource_manager
.
get_cur_event_num
();
MS_LOG
(
INFO
)
<<
"
e
nd"
;
MS_LOG
(
INFO
)
<<
"
A
fter independent parallel, total event nums:"
<<
resource_manager
.
get_cur_event_num
();
MS_LOG
(
INFO
)
<<
"
E
nd"
;
}
// section7
...
...
@@ -687,7 +687,7 @@ void AscendStreamAssign::GetNeedActiveStreams(const NotNull<KernelGraphPtr> &gra
auto
need_active
=
GetValue
<
bool
>
(
value_ptr
);
if
(
need_active
)
{
auto
stream_id
=
AnfAlgo
::
GetStreamId
(
cur_cnode_ptr
);
MS_LOG
(
INFO
)
<<
"
s
tream id:"
<<
stream_id
<<
" is need actived at first"
;
MS_LOG
(
INFO
)
<<
"
S
tream id:"
<<
stream_id
<<
" is need actived at first"
;
need_first_active_streams_
.
push_back
(
stream_id
);
}
}
...
...
@@ -724,7 +724,7 @@ void AscendStreamAssign::CheckStreamAssign(const NotNull<KernelGraphPtr> &graph_
MS_EXCEPTION_IF_NULL
(
cur_cnode_ptr
);
uint32_t
stream_id
=
AnfAlgo
::
GetStreamId
(
cur_cnode_ptr
);
if
(
stream_id
==
kInvalidStreamId
)
{
MS_LOG
(
EXCEPTION
)
<<
"
n
ode:"
<<
AnfAlgo
::
GetCNodeName
(
cur_cnode_ptr
)
<<
"had not been assigned stream"
;
MS_LOG
(
EXCEPTION
)
<<
"
N
ode:"
<<
AnfAlgo
::
GetCNodeName
(
cur_cnode_ptr
)
<<
"had not been assigned stream"
;
}
(
void
)
streams
.
emplace
(
stream_id
);
...
...
@@ -739,11 +739,11 @@ void AscendStreamAssign::CheckStreamAssign(const NotNull<KernelGraphPtr> &graph_
// check stream assign
if
(
!
streams
.
empty
())
{
if
(
min_stream
!=
0
)
{
MS_LOG
(
EXCEPTION
)
<<
"
s
tream should start from 0, now is from "
<<
min_stream
;
MS_LOG
(
EXCEPTION
)
<<
"
S
tream should start from 0, now is from "
<<
min_stream
;
}
uint32_t
assigned_stream_num
=
resource_manager
.
get_cur_stream_num
();
if
((
max_stream
!=
assigned_stream_num
-
1
)
||
(
streams
.
size
()
!=
assigned_stream_num
))
{
MS_LOG
(
EXCEPTION
)
<<
"
s
tream should be consecutive, max stream id:"
<<
max_stream
MS_LOG
(
EXCEPTION
)
<<
"
S
tream should be consecutive, max stream id:"
<<
max_stream
<<
"; alloc stream nums:"
<<
assigned_stream_num
<<
"; streams size:"
<<
streams
.
size
();
}
}
...
...
@@ -779,20 +779,20 @@ void AscendStreamAssign::CheckEventAssign(const NotNull<KernelGraphPtr> &graph_p
// check event assign
if
(
!
event_map
.
empty
())
{
if
(
min_event_id
!=
0
)
{
MS_LOG
(
EXCEPTION
)
<<
"
e
vent should start from 0, now is from "
<<
min_event_id
;
MS_LOG
(
EXCEPTION
)
<<
"
E
vent should start from 0, now is from "
<<
min_event_id
;
}
uint32_t
assigned_event_num
=
resource_manager
.
get_cur_event_num
();
if
((
max_event_id
!=
assigned_event_num
-
1
)
||
(
event_map
.
size
()
!=
assigned_event_num
))
{
MS_LOG
(
EXCEPTION
)
<<
"
e
vent should be consecutive"
;
MS_LOG
(
EXCEPTION
)
<<
"
E
vent should be consecutive"
;
}
for
(
const
auto
&
item
:
event_map
)
{
if
(
item
.
second
.
size
()
!=
2
)
{
MS_LOG
(
EXCEPTION
)
<<
"
s
end/recv should be in pair and share one event id"
;
MS_LOG
(
EXCEPTION
)
<<
"
S
end/recv should be in pair and share one event id"
;
}
auto
first_name
=
AnfAlgo
::
GetCNodeName
(
item
.
second
[
0
]);
auto
second_name
=
AnfAlgo
::
GetCNodeName
(
item
.
second
[
1
]);
if
(
!
(
first_name
==
kSendOpName
&&
second_name
==
kRecvOpName
))
{
MS_LOG
(
EXCEPTION
)
<<
"
s
end should be before recv"
;
MS_LOG
(
EXCEPTION
)
<<
"
S
end should be before recv"
;
}
}
}
...
...
@@ -858,7 +858,7 @@ vector<CNodePtr>::iterator AscendStreamAssign::FindTargetOp(vector<CNodePtr>::it
}
else
{
auto
real_input
=
AnfAlgo
::
VisitKernel
(
input
,
0
);
if
(
node
==
real_input
.
first
)
{
MS_LOG
(
INFO
)
<<
"
f
ind target op["
<<
(
*
begin
)
->
DebugString
()
<<
"]"
;
MS_LOG
(
INFO
)
<<
"
F
ind target op["
<<
(
*
begin
)
->
DebugString
()
<<
"]"
;
return
begin
;
}
}
...
...
@@ -872,10 +872,10 @@ bool AscendStreamAssign::IsTaskSink() {
auto
ms_context
=
MsContext
::
GetInstance
();
MS_EXCEPTION_IF_NULL
(
ms_context
);
if
(
!
ms_context
->
enable_task_sink
())
{
MS_LOG
(
INFO
)
<<
"
t
ask sink mode is not enable"
;
MS_LOG
(
INFO
)
<<
"
T
ask sink mode is not enable"
;
return
false
;
}
else
{
MS_LOG
(
INFO
)
<<
"
t
ask sink mode is enable"
;
MS_LOG
(
INFO
)
<<
"
T
ask sink mode is enable"
;
return
true
;
}
}
...
...
@@ -885,7 +885,7 @@ void AscendStreamAssign::GetWaitStreams(vector<uint32_t> *wait_active_stream_lis
AscendResourceMng
&
resource_manager
=
AscendResourceMng
::
GetInstance
();
uint32_t
total_stream_num
=
resource_manager
.
get_cur_stream_num
();
if
(
total_stream_num
==
0
)
{
MS_LOG
(
INFO
)
<<
"total_common_stream_num is zero"
;
MS_LOG
(
INFO
)
<<
"
The
total_common_stream_num is zero"
;
return
;
}
...
...
@@ -893,7 +893,7 @@ void AscendStreamAssign::GetWaitStreams(vector<uint32_t> *wait_active_stream_lis
for
(
uint32_t
i
=
0
;
i
<
total_stream_num
;
i
++
)
{
auto
it
=
std
::
find
(
need_first_active_streams_
.
begin
(),
need_first_active_streams_
.
end
(),
i
);
if
(
it
==
need_first_active_streams_
.
end
())
{
MS_LOG
(
INFO
)
<<
"
w
ait common stream id = "
<<
i
;
MS_LOG
(
INFO
)
<<
"
W
ait common stream id = "
<<
i
;
wait_active_stream_list
->
push_back
(
i
);
}
}
...
...
mindspore/ccsrc/device/cpu/cpu_kernel_runtime.cc
浏览文件 @
7951b318
...
...
@@ -142,6 +142,37 @@ DeviceAddressPtr CPUKernelRuntime::CreateDeviceAddress(void *device_ptr, size_t
return
std
::
make_shared
<
CPUDeviceAddress
>
(
device_ptr
,
device_size
,
format
,
type_id
);
}
tensor
::
TensorPtr
CPUKernelRuntime
::
CreatTensorForOutput
(
const
CNodePtr
&
node
,
size_t
index
,
std
::
set
<
DeviceAddressPtr
>
*
bound_addresses
,
std
::
vector
<
tensor
::
TensorPtr
>
*
need_sync_outputs
)
{
MS_EXCEPTION_IF_NULL
(
node
);
MS_EXCEPTION_IF_NULL
(
bound_addresses
);
MS_EXCEPTION_IF_NULL
(
need_sync_outputs
);
size_t
output_size
=
AnfAlgo
::
GetOutputTensorNum
(
node
);
if
(
index
>=
output_size
)
{
MS_LOG
(
EXCEPTION
)
<<
"Invalid input index "
<<
index
;
}
auto
address
=
AnfAlgo
::
GetMutableOutputAddr
(
node
,
index
);
MS_EXCEPTION_IF_NULL
(
address
);
auto
shape
=
AnfAlgo
::
GetOutputInferShape
(
node
,
index
);
std
::
vector
<
int
>
temp_shape
;
(
void
)
temp_shape
.
insert
(
temp_shape
.
end
(),
shape
.
begin
(),
shape
.
end
());
TypeId
type_id
=
AnfAlgo
::
GetOutputInferDataType
(
node
,
index
);
type_id
=
GetCPUSupportOutputTypeId
(
type_id
);
tensor
::
TensorPtr
tensor
=
std
::
make_shared
<
tensor
::
Tensor
>
(
type_id
,
temp_shape
);
MS_EXCEPTION_IF_NULL
(
tensor
);
if
(
bound_addresses
->
find
(
address
)
!=
bound_addresses
->
end
())
{
tensor
->
set_device_address
(
address
);
need_sync_outputs
->
emplace_back
(
tensor
);
}
else
{
address
->
ptr_
=
tensor
->
data_c
();
address
->
ref_count_
=
INIT_NODE_REF
;
(
void
)
bound_addresses
->
insert
(
address
);
}
tensor
->
set_dirty
(
false
);
return
tensor
;
}
BaseRef
CPUKernelRuntime
::
CreatTensorForOutput
(
const
session
::
KernelWithIndex
&
kernel_with_index
,
const
std
::
unordered_map
<
AnfNode
*
,
tensor
::
TensorPtr
>
&
input_map
,
std
::
set
<
DeviceAddressPtr
>
*
bound_addresses
,
...
...
@@ -161,29 +192,7 @@ BaseRef CPUKernelRuntime::CreatTensorForOutput(const session::KernelWithIndex &k
}
return
ret
;
}
size_t
output_size
=
AnfAlgo
::
GetOutputTensorNum
(
node
);
if
(
index
>=
output_size
)
{
MS_LOG
(
EXCEPTION
)
<<
"Invalid input index "
<<
index
;
}
auto
address
=
AnfAlgo
::
GetMutableOutputAddr
(
node
,
index
);
MS_EXCEPTION_IF_NULL
(
address
);
auto
shape
=
AnfAlgo
::
GetOutputInferShape
(
node
,
index
);
std
::
vector
<
int
>
temp_shape
;
(
void
)
temp_shape
.
insert
(
temp_shape
.
end
(),
shape
.
begin
(),
shape
.
end
());
TypeId
type_id
=
AnfAlgo
::
GetOutputInferDataType
(
node
,
index
);
type_id
=
GetCPUSupportOutputTypeId
(
type_id
);
tensor
::
TensorPtr
tensor
=
std
::
make_shared
<
tensor
::
Tensor
>
(
type_id
,
temp_shape
);
MS_EXCEPTION_IF_NULL
(
tensor
);
if
(
bound_addresses
->
find
(
address
)
!=
bound_addresses
->
end
())
{
tensor
->
set_device_address
(
address
);
need_sync_outputs
->
emplace_back
(
tensor
);
}
else
{
address
->
ptr_
=
tensor
->
data_c
();
address
->
ref_count_
=
INIT_NODE_REF
;
(
void
)
bound_addresses
->
insert
(
address
);
}
tensor
->
set_dirty
(
false
);
return
tensor
;
return
CreatTensorForOutput
(
node
,
index
,
bound_addresses
,
need_sync_outputs
);
}
else
if
(
input_node
->
isa
<
Parameter
>
()
||
input_node
->
isa
<
ValueNode
>
())
{
auto
iter
=
input_map
.
find
(
input_node
.
get
());
if
(
iter
!=
input_map
.
end
())
{
...
...
@@ -247,6 +256,7 @@ void CPUKernelRuntime::BindInputOutput(const session::KernelGraph *kernel_graph,
void
CPUKernelRuntime
::
AddRuntimeAddress
(
DeviceAddress
*
address
,
std
::
vector
<
kernel
::
AddressPtr
>
*
input_list
)
{
MS_EXCEPTION_IF_NULL
(
address
);
MS_EXCEPTION_IF_NULL
(
input_list
);
kernel
::
AddressPtr
input
=
std
::
make_shared
<
kernel
::
Address
>
();
MS_EXCEPTION_IF_NULL
(
input
);
if
(
address
->
ptr_
==
nullptr
)
{
...
...
mindspore/ccsrc/device/cpu/cpu_kernel_runtime.h
浏览文件 @
7951b318
...
...
@@ -49,6 +49,10 @@ class CPUKernelRuntime : public KernelRuntime {
TypeId
type_id
)
override
;
private:
tensor
::
TensorPtr
CreatTensorForOutput
(
const
CNodePtr
&
node
,
size_t
index
,
std
::
set
<
DeviceAddressPtr
>
*
bound_addresses
,
std
::
vector
<
tensor
::
TensorPtr
>
*
need_sync_outputs
);
BaseRef
CreatTensorForOutput
(
const
session
::
KernelWithIndex
&
kernel_with_index
,
const
std
::
unordered_map
<
AnfNode
*
,
tensor
::
TensorPtr
>
&
input_map
,
std
::
set
<
DeviceAddressPtr
>
*
bound_addresses
,
...
...
mindspore/ccsrc/device/cpu/cpu_simple_mem_plan.cc
浏览文件 @
7951b318
...
...
@@ -56,7 +56,13 @@ void CPUSimpleMemPlan::MemPlan(const session::KernelGraph *graph) {
graph_mem_size_
[
graph
]
=
total_mem_size
;
}
size_t
CPUSimpleMemPlan
::
GetGraphMemSize
(
const
session
::
KernelGraph
*
graph
)
{
return
graph_mem_size_
[
graph
];
}
size_t
CPUSimpleMemPlan
::
GetGraphMemSize
(
const
session
::
KernelGraph
*
graph
)
const
{
auto
iter
=
graph_mem_size_
.
find
(
graph
);
if
(
iter
!=
graph_mem_size_
.
end
())
{
return
iter
->
second
;
}
return
0
;
}
void
CPUSimpleMemPlan
::
MemAssign
(
const
session
::
KernelGraph
*
graph
,
uint8_t
*
base_ptr
)
{
MS_EXCEPTION_IF_NULL
(
graph
);
...
...
mindspore/ccsrc/device/cpu/cpu_simple_mem_plan.h
浏览文件 @
7951b318
...
...
@@ -31,7 +31,7 @@ class CPUSimpleMemPlan {
void
MemPlan
(
const
session
::
KernelGraph
*
graph
);
void
MemAssign
(
const
session
::
KernelGraph
*
graph
,
uint8_t
*
base_ptr
);
size_t
GetGraphMemSize
(
const
session
::
KernelGraph
*
graph
);
size_t
GetGraphMemSize
(
const
session
::
KernelGraph
*
graph
)
const
;
private:
std
::
unordered_map
<
const
session
::
KernelGraph
*
,
size_t
>
graph_mem_size_
;
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录