Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
c1f881e6
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
c1f881e6
编写于
7月 10, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
7月 10, 2020
浏览文件
操作
浏览文件
下载
差异文件
!2592 Keep parameters of previous step in TensorLoader
Merge pull request !2592 from ShidaHe/debugger_dev
上级
cf5a27e9
cb4c74c7
变更
8
隐藏空白更改
内联
并排
Showing
8 changed file
with
158 addition
and
123 deletion
+158
-123
mindspore/ccsrc/debug/debugger/debug_graph.proto
mindspore/ccsrc/debug/debugger/debug_graph.proto
+6
-0
mindspore/ccsrc/debug/debugger/debugger.cc
mindspore/ccsrc/debug/debugger/debugger.cc
+107
-98
mindspore/ccsrc/debug/debugger/debugger.h
mindspore/ccsrc/debug/debugger/debugger.h
+18
-15
mindspore/ccsrc/debug/tensor_load.h
mindspore/ccsrc/debug/tensor_load.h
+15
-2
mindspore/ccsrc/device/ascend/ascend_device_address.cc
mindspore/ccsrc/device/ascend/ascend_device_address.cc
+4
-4
mindspore/ccsrc/device/ascend/ascend_device_address.h
mindspore/ccsrc/device/ascend/ascend_device_address.h
+2
-1
mindspore/ccsrc/device/ascend/ascend_kernel_runtime.cc
mindspore/ccsrc/device/ascend/ascend_kernel_runtime.cc
+4
-2
mindspore/ccsrc/session/ascend_session.cc
mindspore/ccsrc/session/ascend_session.cc
+2
-1
未找到文件。
mindspore/ccsrc/debug/debugger/debug_graph.proto
浏览文件 @
c1f881e6
...
...
@@ -313,4 +313,10 @@ message TensorProto {
// If the tensor content transferring is finished.
optional
bool
finished
=
6
;
// The iteration of the tensor. Supported: "prev" or leave empty.
optional
string
iter
=
7
;
// If the tensor name should be truncated.
optional
bool
truncate
=
8
;
}
\ No newline at end of file
mindspore/ccsrc/debug/debugger/debugger.cc
浏览文件 @
c1f881e6
...
...
@@ -178,7 +178,7 @@ void Debugger::CheckDatasetGraph() {
is_dataset_graph_
=
false
;
}
GraphProto
Debugger
::
GetGraphProto
()
{
GraphProto
Debugger
::
GetGraphProto
()
const
{
// convert kernel graph to debugger modelproto
ModelProto
model
=
GetDebuggerFuncGraphProto
(
graph_ptr_
);
return
model
.
graph
();
...
...
@@ -261,12 +261,9 @@ void Debugger::CommandLoop() {
MS_LOG
(
INFO
)
<<
"node name: "
<<
node
.
node_name
();
MS_LOG
(
INFO
)
<<
"node type: "
<<
node
.
node_type
();
}
WatchCondition
recieved_condition
=
GetWatchcondition
(
reply
);
MS_LOG
(
INFO
)
<<
"condition: "
<<
recieved_condition
.
condition
();
int32_t
id
=
GetWatchpointID
(
reply
);
MS_LOG
(
INFO
)
<<
"id: "
<<
id
;
bool
delete_
=
GetWatchpointDelete
(
reply
);
MS_LOG
(
INFO
)
<<
"delete: "
<<
delete_
;
MS_LOG
(
INFO
)
<<
"condition: "
<<
GetWatchcondition
(
reply
).
condition
();
MS_LOG
(
INFO
)
<<
"id: "
<<
GetWatchpointID
(
reply
);
MS_LOG
(
INFO
)
<<
"delete: "
<<
GetWatchpointDelete
(
reply
);
}
MS_LOG
(
INFO
)
<<
"Setting watchpoint"
;
if
(
GetWatchpointDelete
(
reply
))
{
...
...
@@ -284,15 +281,20 @@ void Debugger::CommandLoop() {
MS_LOG
(
INFO
)
<<
"tensor node name: "
<<
tensor
.
node_name
();
MS_LOG
(
INFO
)
<<
"tensor slot: "
<<
tensor
.
slot
();
MS_LOG
(
INFO
)
<<
"tensor finished: "
<<
std
::
boolalpha
<<
tensor
.
finished
()
<<
std
::
noboolalpha
;
MS_LOG
(
INFO
)
<<
"tensor iter: "
<<
tensor
.
iter
();
MS_LOG
(
INFO
)
<<
"tensor truncate: "
<<
std
::
boolalpha
<<
tensor
.
truncate
()
<<
std
::
noboolalpha
;
}
}
MS_LOG
(
INFO
)
<<
"Sending tensors"
;
std
::
list
<
TensorProto
>
tensors
=
LoadTensors
(
GetTensors
(
reply
));
{
// print view cmd reply
for
(
auto
tensor
:
tensors
)
{
MS_LOG
(
INFO
)
<<
"tensor node name: "
<<
tensor
.
node_name
();
MS_LOG
(
INFO
)
<<
"tensor slot: "
<<
tensor
.
slot
();
MS_LOG
(
INFO
)
<<
"tensor finished: "
<<
std
::
boolalpha
<<
tensor
.
finished
()
<<
std
::
noboolalpha
;
MS_LOG
(
INFO
)
<<
"tensor iter: "
<<
tensor
.
iter
();
MS_LOG
(
INFO
)
<<
"tensor truncate: "
<<
std
::
boolalpha
<<
tensor
.
truncate
()
<<
std
::
noboolalpha
;
MS_LOG
(
INFO
)
<<
"tensor dims: "
;
for
(
auto
dim
:
tensor
.
dims
())
{
MS_LOG
(
INFO
)
<<
dim
<<
","
;
...
...
@@ -309,68 +311,6 @@ void Debugger::CommandLoop() {
}
}
DebuggerCommand
Debugger
::
GetCommand
(
const
EventReply
&
reply
)
{
DebuggerCommand
cmd
=
DebuggerCommand
::
kUnknownCMD
;
switch
(
reply
.
cmd_case
())
{
case
debugger
::
EventReply
::
CmdCase
::
kExit
:
cmd
=
DebuggerCommand
::
kExitCMD
;
break
;
case
debugger
::
EventReply
::
CmdCase
::
kRunCmd
:
cmd
=
DebuggerCommand
::
kRunCMD
;
break
;
case
debugger
::
EventReply
::
CmdCase
::
kSetCmd
:
cmd
=
DebuggerCommand
::
kSetCMD
;
break
;
case
debugger
::
EventReply
::
CmdCase
::
kViewCmd
:
cmd
=
DebuggerCommand
::
kViewCMD
;
break
;
default:
MS_LOG
(
ERROR
)
<<
"Error: UnknownCMD"
;
break
;
}
return
cmd
;
}
ProtoVector
<
WatchNode
>
Debugger
::
GetWatchnodes
(
const
EventReply
&
reply
)
{
if
(
!
reply
.
has_set_cmd
())
{
MS_LOG
(
ERROR
)
<<
"Error: Not SetCMD, can not get WatchNodes. Returning default value: ProtoVector<WatchNode>()."
;
return
ProtoVector
<
WatchNode
>
();
}
return
reply
.
set_cmd
().
watch_nodes
();
}
WatchCondition
Debugger
::
GetWatchcondition
(
const
EventReply
&
reply
)
{
if
(
!
reply
.
has_set_cmd
()
||
!
reply
.
set_cmd
().
has_watch_condition
())
{
MS_LOG
(
ERROR
)
<<
"Error: Can not get WatchCondition from command. Returning default value: WatchCondition()."
;
return
WatchCondition
();
}
return
reply
.
set_cmd
().
watch_condition
();
}
int32_t
Debugger
::
GetWatchpointID
(
const
EventReply
&
reply
)
{
if
(
!
reply
.
has_set_cmd
())
{
MS_LOG
(
ERROR
)
<<
"Error: Not SetCMD, can not get Watchpoint ID. Returning default value: 0."
;
return
0
;
}
return
reply
.
set_cmd
().
id
();
}
bool
Debugger
::
GetWatchpointDelete
(
const
EventReply
&
reply
)
{
if
(
!
reply
.
has_set_cmd
())
{
MS_LOG
(
ERROR
)
<<
"Error: Not SetCMD, can not get Watchpoint delete flag. Returning default value: false."
;
return
false
;
}
return
reply
.
set_cmd
().
delete_
();
}
ProtoVector
<
TensorProto
>
Debugger
::
GetTensors
(
const
EventReply
&
reply
)
{
if
(
!
reply
.
has_view_cmd
())
{
MS_LOG
(
ERROR
)
<<
"Error: Not ViewCMD, can not get Tensors. Returning default value: ProtoVector<TensorProto>()."
;
return
ProtoVector
<
TensorProto
>
();
}
return
reply
.
view_cmd
().
tensors
();
}
void
Debugger
::
SetWatchpoint
(
const
ProtoVector
<
WatchNode
>
&
nodes
,
const
WatchCondition
&
condition
,
const
int32_t
id
)
{
std
::
vector
<
std
::
tuple
<
std
::
string
,
bool
>>
check_node_list
;
std
::
transform
(
nodes
.
begin
(),
nodes
.
end
(),
std
::
back_inserter
(
check_node_list
),
...
...
@@ -383,7 +323,7 @@ void Debugger::SetWatchpoint(const ProtoVector<WatchNode> &nodes, const WatchCon
void
Debugger
::
RemoveWatchpoint
(
const
int32_t
id
)
{
debug_services_
->
remove_watchpoint
(
id
);
}
std
::
list
<
TensorProto
>
Debugger
::
LoadTensors
(
const
ProtoVector
<
TensorProto
>
&
tensors
)
{
std
::
list
<
TensorProto
>
Debugger
::
LoadTensors
(
const
ProtoVector
<
TensorProto
>
&
tensors
)
const
{
std
::
vector
<
std
::
string
>
name
;
std
::
vector
<
std
::
string
>
ret_name
;
std
::
vector
<
char
*>
data_ptr
;
...
...
@@ -391,38 +331,42 @@ std::list<TensorProto> Debugger::LoadTensors(const ProtoVector<TensorProto> &ten
std
::
vector
<
TypePtr
>
dtype
;
std
::
vector
<
std
::
vector
<
int
>>
shape
;
std
::
transform
(
tensors
.
begin
(),
tensors
.
end
(),
std
::
back_inserter
(
name
),
[](
TensorProto
tensor
)
->
std
::
string
{
return
tensor
.
node_name
()
+
":"
+
tensor
.
slot
();
});
std
::
transform
(
tensors
.
begin
(),
tensors
.
end
(),
std
::
back_inserter
(
name
),
GetTensorFullName
);
// ret_name will contain tensor names that are found in TensorLoader
// items in ret_name will be in the same order with tensors if found
debug_services_
->
read_nodes_tensors
(
name
,
&
ret_name
,
&
data_ptr
,
&
data_size
,
&
dtype
,
&
shape
);
std
::
list
<
TensorProto
>
tensor_list
;
unsigned
int
result_index
=
0
;
TensorProto
tensor_item
;
for
(
auto
tensor
:
tensors
)
{
TensorProto
tensor_item
;
tensor_item
.
set_node_name
(
tensor
.
node_name
());
tensor_item
.
set_slot
(
tensor
.
slot
());
tensor_item
.
set_iter
(
tensor
.
iter
());
tensor_item
.
set_truncate
(
tensor
.
truncate
());
tensor_item
.
clear_tensor_content
();
tensor_item
.
clear_data_type
();
tensor_item
.
clear_dims
();
// always set finished to true before big tensor splitting is supported
tensor_item
.
set_finished
(
true
);
// return empty tensor if didn't find the requested tensor
if
(
result_index
>=
ret_name
.
size
()
||
ret_name
[
result_index
]
!=
tensor
.
node_name
()
+
":"
+
tensor
.
slot
(
))
{
if
(
result_index
>=
ret_name
.
size
()
||
ret_name
[
result_index
]
!=
GetTensorFullName
(
tensor
))
{
tensor_list
.
push_back
(
tensor_item
);
continue
;
}
tensor_item
.
set_tensor_content
(
data_ptr
[
result_index
],
data_size
[
result_index
]);
tensor_item
.
set_data_type
(
GetDebuggerNumberDataType
(
dtype
[
result_index
]));
tensor_item
.
clear_dims
();
for
(
auto
&
elem
:
shape
[
result_index
])
{
tensor_item
.
add_dims
(
elem
);
}
// add tensor to result list and increment result_index to check next item in ret_name
tensor_list
.
push_back
(
tensor_item
);
result_index
++
;
}
return
tensor_list
;
}
...
...
@@ -432,7 +376,7 @@ void Debugger::Exit() {
std
::
exit
(
EXIT_FAILURE
);
}
std
::
list
<
WatchpointHit
>
Debugger
::
CheckWatchpoints
()
{
std
::
list
<
WatchpointHit
>
Debugger
::
CheckWatchpoints
()
const
{
std
::
vector
<
std
::
string
>
name
;
std
::
vector
<
std
::
string
>
slot
;
std
::
vector
<
char
*>
data_ptr
;
...
...
@@ -442,31 +386,23 @@ std::list<WatchpointHit> Debugger::CheckWatchpoints() {
debug_services_
->
check_watchpoints
(
&
name
,
&
slot
,
&
data_ptr
,
&
data_size
,
&
condition
,
&
watchpoint_id
);
std
::
list
<
WatchpointHit
>
points
;
std
::
list
<
WatchpointHit
>
hits
;
for
(
unsigned
int
i
=
0
;
i
<
name
.
size
();
i
++
)
{
TensorProto
*
tensor_item
;
tensor_item
=
new
TensorProto
();
WatchpointHit
hit
;
hit
.
set_id
(
watchpoint_id
[
i
]);
// here TensorProto act as a tensor indicator, not sending tensor content
TensorProto
*
tensor_item
=
hit
.
mutable_tensor
();
tensor_item
->
set_node_name
(
name
[
i
]);
tensor_item
->
set_slot
(
slot
[
i
]);
tensor_item
->
set_tensor_content
(
data_ptr
[
i
],
data_size
[
i
]);
// finished in TensorProto will always be true before we implement big tensor splitting
tensor_item
->
set_finished
(
true
);
WatchCondition
*
condition_item
;
condition_item
=
new
WatchCondition
();
WatchCondition
*
condition_item
=
hit
.
mutable_watch_condition
();
condition_item
->
set_condition
(
debugger
::
WatchCondition_Condition
(
condition
[
i
]));
WatchpointHit
point
;
point
.
set_allocated_tensor
(
tensor_item
);
point
.
set_allocated_watch_condition
(
condition_item
);
point
.
set_id
(
watchpoint_id
[
i
]);
points
.
push_back
(
point
);
hits
.
push_back
(
hit
);
}
return
points
;
return
hits
;
}
void
Debugger
::
SendWatchpointsAndSuspend
(
const
std
::
list
<
WatchpointHit
>
&
points
)
{
...
...
@@ -481,8 +417,81 @@ void Debugger::SendWatchpointsAndSuspend(const std::list<WatchpointHit> &points)
CommandLoop
();
}
DebugServices
*
Debugger
::
get_debug_services
()
{
return
debug_services_
.
get
();
}
DebugServices
*
Debugger
::
debug_services
()
const
{
return
debug_services_
.
get
();
}
bool
Debugger
::
debugger_enabled
()
const
{
return
debugger_enabled_
;
}
DebuggerCommand
GetCommand
(
const
EventReply
&
reply
)
{
DebuggerCommand
cmd
=
DebuggerCommand
::
kUnknownCMD
;
switch
(
reply
.
cmd_case
())
{
case
debugger
::
EventReply
::
CmdCase
::
kExit
:
cmd
=
DebuggerCommand
::
kExitCMD
;
break
;
case
debugger
::
EventReply
::
CmdCase
::
kRunCmd
:
cmd
=
DebuggerCommand
::
kRunCMD
;
break
;
case
debugger
::
EventReply
::
CmdCase
::
kSetCmd
:
cmd
=
DebuggerCommand
::
kSetCMD
;
break
;
case
debugger
::
EventReply
::
CmdCase
::
kViewCmd
:
cmd
=
DebuggerCommand
::
kViewCMD
;
break
;
default:
MS_LOG
(
ERROR
)
<<
"Error: UnknownCMD"
;
break
;
}
return
cmd
;
}
ProtoVector
<
WatchNode
>
GetWatchnodes
(
const
EventReply
&
reply
)
{
if
(
!
reply
.
has_set_cmd
())
{
MS_LOG
(
ERROR
)
<<
"Error: Not SetCMD, can not get WatchNodes. Returning default value: ProtoVector<WatchNode>()."
;
return
ProtoVector
<
WatchNode
>
();
}
return
reply
.
set_cmd
().
watch_nodes
();
}
WatchCondition
GetWatchcondition
(
const
EventReply
&
reply
)
{
if
(
!
reply
.
has_set_cmd
()
||
!
reply
.
set_cmd
().
has_watch_condition
())
{
MS_LOG
(
ERROR
)
<<
"Error: Can not get WatchCondition from command. Returning default value: WatchCondition()."
;
return
WatchCondition
();
}
return
reply
.
set_cmd
().
watch_condition
();
}
int32_t
GetWatchpointID
(
const
EventReply
&
reply
)
{
if
(
!
reply
.
has_set_cmd
())
{
MS_LOG
(
ERROR
)
<<
"Error: Not SetCMD, can not get Watchpoint ID. Returning default value: 0."
;
return
0
;
}
return
reply
.
set_cmd
().
id
();
}
bool
Debugger
::
debugger_enabled
()
{
return
debugger_enabled_
;
}
bool
GetWatchpointDelete
(
const
EventReply
&
reply
)
{
if
(
!
reply
.
has_set_cmd
())
{
MS_LOG
(
ERROR
)
<<
"Error: Not SetCMD, can not get Watchpoint delete flag. Returning default value: false."
;
return
false
;
}
return
reply
.
set_cmd
().
delete_
();
}
ProtoVector
<
TensorProto
>
GetTensors
(
const
EventReply
&
reply
)
{
if
(
!
reply
.
has_view_cmd
())
{
MS_LOG
(
ERROR
)
<<
"Error: Not ViewCMD, can not get Tensors. Returning default value: ProtoVector<TensorProto>()."
;
return
ProtoVector
<
TensorProto
>
();
}
return
reply
.
view_cmd
().
tensors
();
}
std
::
string
GetTensorFullName
(
const
TensorProto
&
tensor
)
{
string
node_name
=
tensor
.
node_name
();
if
(
tensor
.
truncate
())
{
// scopes in node name are seperated by '/'
// use the name without scope if truncate is true
std
::
size_t
found
=
node_name
.
find_last_of
(
"/"
);
node_name
=
node_name
.
substr
(
found
+
1
);
}
return
node_name
+
":"
+
tensor
.
slot
()
+
(
tensor
.
iter
()
==
""
?
""
:
":"
+
tensor
.
iter
());
}
}
// namespace mindspore
mindspore/ccsrc/debug/debugger/debugger.h
浏览文件 @
c1f881e6
...
...
@@ -72,9 +72,9 @@ class Debugger : public std::enable_shared_from_this<Debugger> {
// suspend the execution after a debug_op
void
PostDebugOp
();
DebugServices
*
get_debug_services
()
;
DebugServices
*
debug_services
()
const
;
bool
debugger_enabled
();
bool
debugger_enabled
()
const
;
private:
// private constructor for singleton
...
...
@@ -92,7 +92,7 @@ class Debugger : public std::enable_shared_from_this<Debugger> {
void
CheckDatasetGraph
();
// serialize graph and get proto
GraphProto
GetGraphProto
();
GraphProto
GetGraphProto
()
const
;
// send graph and enter command wait loop
void
SendGraphAndSuspend
(
const
GraphProto
&
graph_proto
);
...
...
@@ -102,16 +102,6 @@ class Debugger : public std::enable_shared_from_this<Debugger> {
// break if RunCMD
void
CommandLoop
();
// process reply and command type
DebuggerCommand
GetCommand
(
const
EventReply
&
reply
);
// parse other data out of EventReply
ProtoVector
<
WatchNode
>
GetWatchnodes
(
const
EventReply
&
reply
);
WatchCondition
GetWatchcondition
(
const
EventReply
&
reply
);
int32_t
GetWatchpointID
(
const
EventReply
&
reply
);
bool
GetWatchpointDelete
(
const
EventReply
&
reply
);
ProtoVector
<
TensorProto
>
GetTensors
(
const
EventReply
&
reply
);
// set what nodes and conditions to watch
void
SetWatchpoint
(
const
ProtoVector
<
WatchNode
>
&
nodes
,
const
WatchCondition
&
condition
,
const
int32_t
id
);
...
...
@@ -119,14 +109,14 @@ class Debugger : public std::enable_shared_from_this<Debugger> {
void
RemoveWatchpoint
(
const
int32_t
id
);
// load tensor for view command
std
::
list
<
TensorProto
>
LoadTensors
(
const
ProtoVector
<
TensorProto
>
&
tensors
);
std
::
list
<
TensorProto
>
LoadTensors
(
const
ProtoVector
<
TensorProto
>
&
tensors
)
const
;
// terminate training process
void
Exit
();
// analyze tensors and check watchpoint conditions
// return names of tensors and what condition they hit
std
::
list
<
WatchpointHit
>
CheckWatchpoints
();
std
::
list
<
WatchpointHit
>
CheckWatchpoints
()
const
;
// send watchpoints that hit and enter command wait loop
void
SendWatchpointsAndSuspend
(
const
std
::
list
<
WatchpointHit
>
&
points
);
...
...
@@ -155,5 +145,18 @@ ModelProto GetDebuggerFuncGraphProto(const FuncGraphPtr &func_graph);
// for getting proto DataType from Type of Tensor
DataType
GetDebuggerNumberDataType
(
const
TypePtr
&
type
);
// process reply and command type
DebuggerCommand
GetCommand
(
const
EventReply
&
reply
);
// parse other data out of EventReply
ProtoVector
<
WatchNode
>
GetWatchnodes
(
const
EventReply
&
reply
);
WatchCondition
GetWatchcondition
(
const
EventReply
&
reply
);
int32_t
GetWatchpointID
(
const
EventReply
&
reply
);
bool
GetWatchpointDelete
(
const
EventReply
&
reply
);
ProtoVector
<
TensorProto
>
GetTensors
(
const
EventReply
&
reply
);
// get the full name of a tensor, which is the name used in TensorLoader
std
::
string
GetTensorFullName
(
const
TensorProto
&
tensor
);
}
// namespace mindspore
#endif // MINDSPORE_CCSRC_DEBUG_DEBUGGER_DEBUGGER_H_
mindspore/ccsrc/debug/tensor_load.h
浏览文件 @
c1f881e6
...
...
@@ -21,6 +21,7 @@
#include <map>
#include <tuple>
#include <string>
#include <utility>
#include "debug/tensor_data.h"
namespace
mindspore
{
class
TensorLoader
{
...
...
@@ -29,7 +30,15 @@ class TensorLoader {
~
TensorLoader
()
{}
bool
LoadNewTensor
(
std
::
shared_ptr
<
TensorData
>
tensor
)
{
bool
LoadNewTensor
(
std
::
shared_ptr
<
TensorData
>
tensor
,
bool
keep_prev
)
{
if
(
keep_prev
)
{
// add prev step tensor into current step map with ":prev" suffix
auto
handle
=
prev_tensor_list_map
.
extract
(
tensor
->
GetName
());
if
(
!
handle
.
empty
())
{
handle
.
key
()
=
tensor
->
GetName
()
+
":prev"
;
tensor_list_map
.
insert
(
std
::
move
(
handle
));
}
}
tensor_list
.
push_back
(
tensor
);
tensor_list_map
.
insert
({
tensor
->
GetName
(),
tensor
});
return
true
;
...
...
@@ -53,16 +62,20 @@ class TensorLoader {
}
bool
EmptyTensor
()
{
tensor_list_map
.
clear
();
prev_tensor_list_map
.
clear
();
tensor_list_map
.
swap
(
prev_tensor_list_map
);
tensor_list
.
clear
();
return
true
;
}
void
EmptyPrevTensor
()
{
prev_tensor_list_map
.
clear
();
}
void
set_iter_num
(
uint32_t
iter_num
)
{
this
->
iter_num
=
iter_num
;
}
private:
std
::
vector
<
std
::
shared_ptr
<
TensorData
>>
tensor_list
;
std
::
map
<
std
::
string
,
std
::
shared_ptr
<
TensorData
>>
tensor_list_map
;
std
::
map
<
std
::
string
,
std
::
shared_ptr
<
TensorData
>>
prev_tensor_list_map
;
uint32_t
iter_num
;
};
}
// namespace mindspore
...
...
mindspore/ccsrc/device/ascend/ascend_device_address.cc
浏览文件 @
c1f881e6
...
...
@@ -370,10 +370,10 @@ bool AscendDeviceAddress::DumpMemToFile(bool trans_flag, const std::string &file
#ifdef ENABLE_DEBUGGER
bool
AscendDeviceAddress
::
LoadMemToHost
(
bool
trans_flag
,
const
std
::
string
&
tensor_name
,
int
execution_order
,
const
std
::
string
&
host_fmt
,
const
std
::
vector
<
int
>
&
host_shape
,
TypeId
host_type
,
size_t
slot
,
Debugger
*
debugger
)
const
{
TypeId
host_type
,
size_t
slot
,
Debugger
*
debugger
,
bool
keep_prev
)
const
{
bool
ret
=
false
;
DebugServices
*
debug_services
=
debugger
->
get_
debug_services
();
DebugServices
*
debug_services
=
debugger
->
debug_services
();
TensorLoader
*
tensor_loader
=
debug_services
->
get_tensor_loader
();
if
(
trans_flag
)
{
...
...
@@ -390,7 +390,7 @@ bool AscendDeviceAddress::LoadMemToHost(bool trans_flag, const std::string &tens
tensor_data
->
SetExecutionOrder
(
execution_order
);
tensor_data
->
SetTensor
(
out_tensor
);
tensor_data
->
SetSlot
(
slot
);
ret
=
tensor_loader
->
LoadNewTensor
(
tensor_data
);
ret
=
tensor_loader
->
LoadNewTensor
(
tensor_data
,
keep_prev
);
}
else
{
mindspore
::
tensor
::
TensorPtr
out_tensor
=
std
::
make_shared
<
tensor
::
Tensor
>
(
type_id_
,
host_shape
);
size_t
host_size
=
out_tensor
->
data
().
nbytes
();
...
...
@@ -401,7 +401,7 @@ bool AscendDeviceAddress::LoadMemToHost(bool trans_flag, const std::string &tens
tensor_data
->
SetExecutionOrder
(
execution_order
);
tensor_data
->
SetTensor
(
out_tensor
);
tensor_data
->
SetSlot
(
slot
);
ret
=
tensor_loader
->
LoadNewTensor
(
tensor_data
);
ret
=
tensor_loader
->
LoadNewTensor
(
tensor_data
,
keep_prev
);
if
(
ret_rt_memcpy
!=
RT_ERROR_NONE
)
{
MS_LOG
(
ERROR
)
<<
"SyncDeviceToHost: rtMemcpy mem size["
<<
size_
<<
"] fail, ret["
<<
ret_rt_memcpy
<<
"]"
;
}
...
...
mindspore/ccsrc/device/ascend/ascend_device_address.h
浏览文件 @
c1f881e6
...
...
@@ -46,7 +46,8 @@ class AscendDeviceAddress : public DeviceAddress {
#endif
#ifdef ENABLE_DEBUGGER
bool
LoadMemToHost
(
bool
dump_mode
,
const
std
::
string
&
tensor_name
,
int
execution_order
,
const
std
::
string
&
host_fmt
,
const
std
::
vector
<
int
>
&
host_shape
,
TypeId
host_type
,
size_t
slot
,
Debugger
*
debugger
)
const
;
const
std
::
vector
<
int
>
&
host_shape
,
TypeId
host_type
,
size_t
slot
,
Debugger
*
debugger
,
bool
keep_prev
)
const
;
#endif
private:
...
...
mindspore/ccsrc/device/ascend/ascend_kernel_runtime.cc
浏览文件 @
c1f881e6
...
...
@@ -322,7 +322,8 @@ void LoadOutput(mindspore::session::KernelGraph *graph, Debugger *debugger) {
(
void
)
std
::
transform
(
shape
.
begin
(),
shape
.
end
(),
std
::
back_inserter
(
int_shapes
),
[](
size_t
inner_item
)
{
return
SizeToInt
(
inner_item
);
});
}
auto
ret
=
ascend_addr
->
LoadMemToHost
(
trans_flag
,
tensor_name
,
exec_order
,
format
,
int_shapes
,
type
,
j
,
debugger
);
auto
ret
=
ascend_addr
->
LoadMemToHost
(
trans_flag
,
tensor_name
,
exec_order
,
format
,
int_shapes
,
type
,
j
,
debugger
,
false
);
if
(
!
ret
)
{
MS_LOG
(
ERROR
)
<<
"LoadMemToHost: flag:"
<<
trans_flag
<<
", tensor_name:"
<<
tensor_name
<<
", host_format:"
<<
format
<<
".!"
;
...
...
@@ -356,7 +357,8 @@ void LoadParameters(mindspore::session::KernelGraph *graph, Debugger *debugger)
(
void
)
std
::
transform
(
shape
.
begin
(),
shape
.
end
(),
std
::
back_inserter
(
int_shapes
),
[](
size_t
inner_item
)
{
return
SizeToInt
(
inner_item
);
});
}
auto
ret
=
ascend_addr
->
LoadMemToHost
(
trans_flag
,
tensor_name
,
exec_order
,
format
,
int_shapes
,
type
,
0
,
debugger
);
auto
ret
=
ascend_addr
->
LoadMemToHost
(
trans_flag
,
tensor_name
,
exec_order
,
format
,
int_shapes
,
type
,
0
,
debugger
,
true
);
if
(
!
ret
)
{
MS_LOG
(
ERROR
)
<<
"LoadMemToHost Failed: flag:"
<<
trans_flag
<<
", path:"
<<
tensor_name
<<
", host_format:"
<<
format
<<
".!"
;
...
...
mindspore/ccsrc/session/ascend_session.cc
浏览文件 @
c1f881e6
...
...
@@ -799,12 +799,13 @@ void AscendSession::LoadTensor(const std::shared_ptr<KernelGraph> &kernel_graph)
#ifdef ENABLE_DEBUGGER
auto
runtime_instance
=
device
::
KernelRuntimeManager
::
Instance
().
GetKernelRuntime
(
kAscendDevice
,
device_id_
);
MS_EXCEPTION_IF_NULL
(
runtime_instance
);
DebugServices
*
debug_services
=
debugger_
->
get_
debug_services
();
DebugServices
*
debug_services
=
debugger_
->
debug_services
();
TensorLoader
*
tensor_loader
=
debug_services
->
get_tensor_loader
();
tensor_loader
->
EmptyTensor
();
uint32_t
iter_num
=
tensor_loader
->
GetIterNum
();
tensor_loader
->
set_iter_num
(
++
iter_num
);
(
void
)
runtime_instance
->
LoadData
(
kernel_graph
.
get
(),
debugger_
.
get
());
tensor_loader
->
EmptyPrevTensor
();
#endif
MS_LOG
(
INFO
)
<<
"Finish!"
;
}
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录