Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
578cdb42
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
578cdb42
编写于
5月 26, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
5月 26, 2020
浏览文件
操作
浏览文件
下载
差异文件
!1454 Allow TaskGroup::join_all() to block for GNN
Merge pull request !1454 from JesseKLee/zirui
上级
0a8ef2fe
a8fa8475
变更
7
隐藏空白更改
内联
并排
Showing
7 changed file
with
36 addition
and
25 deletion
+36
-25
mindspore/ccsrc/dataset/engine/gnn/graph_loader.cc
mindspore/ccsrc/dataset/engine/gnn/graph_loader.cc
+8
-5
mindspore/ccsrc/dataset/util/task.cc
mindspore/ccsrc/dataset/util/task.cc
+17
-10
mindspore/ccsrc/dataset/util/task.h
mindspore/ccsrc/dataset/util/task.h
+3
-2
mindspore/ccsrc/dataset/util/task_manager.cc
mindspore/ccsrc/dataset/util/task_manager.cc
+3
-3
mindspore/ccsrc/dataset/util/task_manager.h
mindspore/ccsrc/dataset/util/task_manager.h
+1
-1
tests/ut/cpp/dataset/interrupt_test.cc
tests/ut/cpp/dataset/interrupt_test.cc
+3
-3
tests/ut/cpp/dataset/task_manager_test.cc
tests/ut/cpp/dataset/task_manager_test.cc
+1
-1
未找到文件。
mindspore/ccsrc/dataset/engine/gnn/graph_loader.cc
浏览文件 @
578cdb42
...
...
@@ -22,6 +22,7 @@
#include "mindspore/ccsrc/mindrecord/include/shard_error.h"
#include "dataset/engine/gnn/local_edge.h"
#include "dataset/engine/gnn/local_node.h"
#include "dataset/util/task_manager.h"
using
ShardTuple
=
std
::
vector
<
std
::
tuple
<
std
::
vector
<
uint8_t
>
,
mindspore
::
mindrecord
::
json
>>
;
...
...
@@ -80,7 +81,7 @@ Status GraphLoader::InitAndLoad() {
n_feature_maps_
.
resize
(
num_workers_
);
e_feature_maps_
.
resize
(
num_workers_
);
default_feature_maps_
.
resize
(
num_workers_
);
std
::
vector
<
std
::
future
<
Status
>>
r_codes
(
num_workers_
)
;
TaskGroup
vg
;
shard_reader_
=
std
::
make_unique
<
ShardReader
>
();
CHECK_FAIL_RETURN_UNEXPECTED
(
shard_reader_
->
Open
({
mr_path_
},
true
,
num_workers_
)
==
MSRStatus
::
SUCCESS
,
...
...
@@ -97,12 +98,11 @@ Status GraphLoader::InitAndLoad() {
// launching worker threads
for
(
int
wkr_id
=
0
;
wkr_id
<
num_workers_
;
++
wkr_id
)
{
r_codes
[
wkr_id
]
=
std
::
async
(
std
::
launch
::
async
,
&
GraphLoader
::
WorkerEntry
,
this
,
wkr_id
);
RETURN_IF_NOT_OK
(
vg
.
CreateAsyncTask
(
"GraphLoader"
,
std
::
bind
(
&
GraphLoader
::
WorkerEntry
,
this
,
wkr_id
))
);
}
// wait for threads to finish and check its return code
for
(
int
wkr_id
=
0
;
wkr_id
<
num_workers_
;
++
wkr_id
)
{
RETURN_IF_NOT_OK
(
r_codes
[
wkr_id
].
get
());
}
vg
.
join_all
(
Task
::
WaitFlag
::
kBlocking
);
RETURN_IF_NOT_OK
(
vg
.
GetTaskErrorIfAny
());
return
Status
::
OK
();
}
...
...
@@ -201,8 +201,11 @@ Status GraphLoader::LoadFeatureIndex(const std::string &key, const std::vector<u
}
Status
GraphLoader
::
WorkerEntry
(
int32_t
worker_id
)
{
// Handshake
TaskManager
::
FindMe
()
->
Post
();
ShardTuple
rows
=
shard_reader_
->
GetNextById
(
row_id_
++
,
worker_id
);
while
(
rows
.
empty
()
==
false
)
{
RETURN_IF_INTERRUPTED
();
for
(
const
auto
&
tupled_row
:
rows
)
{
std
::
vector
<
uint8_t
>
col_blob
=
std
::
get
<
0
>
(
tupled_row
);
mindrecord
::
json
col_jsn
=
std
::
get
<
1
>
(
tupled_row
);
...
...
mindspore/ccsrc/dataset/util/task.cc
浏览文件 @
578cdb42
...
...
@@ -108,20 +108,27 @@ Status Task::Run() {
return
rc
;
}
Status
Task
::
Join
()
{
Status
Task
::
Join
(
WaitFlag
blocking
)
{
if
(
running_
)
{
RETURN_UNEXPECTED_IF_NULL
(
MyTaskGroup
());
auto
interrupt_svc
=
MyTaskGroup
()
->
GetIntrpService
();
try
{
// There is a race condition in the global resource tracking such that a thread can miss the
// interrupt and becomes blocked on a conditional variable forever. As a result, calling
// join() will not come back. We need some timeout version of join such that if the thread
// doesn't come back in a reasonable of time, we will send the interrupt again.
while
(
thrd_
.
wait_for
(
std
::
chrono
::
seconds
(
1
))
!=
std
::
future_status
::
ready
)
{
// We can't tell which conditional_variable this thread is waiting on. So we may need
// to interrupt everything one more time.
MS_LOG
(
INFO
)
<<
"Some threads not responding. Interrupt again"
;
interrupt_svc
->
InterruptAll
();
if
(
blocking
==
WaitFlag
::
kBlocking
)
{
// If we are asked to wait, then wait
thrd_
.
get
();
}
else
if
(
blocking
==
WaitFlag
::
kNonBlocking
)
{
// There is a race condition in the global resource tracking such that a thread can miss the
// interrupt and becomes blocked on a conditional variable forever. As a result, calling
// join() will not come back. We need some timeout version of join such that if the thread
// doesn't come back in a reasonable of time, we will send the interrupt again.
while
(
thrd_
.
wait_for
(
std
::
chrono
::
seconds
(
1
))
!=
std
::
future_status
::
ready
)
{
// We can't tell which conditional_variable this thread is waiting on. So we may need
// to interrupt everything one more time.
MS_LOG
(
INFO
)
<<
"Some threads not responding. Interrupt again"
;
interrupt_svc
->
InterruptAll
();
}
}
else
{
RETURN_STATUS_UNEXPECTED
(
"Unknown WaitFlag"
);
}
std
::
stringstream
ss
;
ss
<<
get_id
();
...
...
mindspore/ccsrc/dataset/util/task.h
浏览文件 @
578cdb42
...
...
@@ -42,9 +42,10 @@ class TaskManager;
class
Task
:
public
IntrpResource
{
public:
friend
class
TaskManager
;
friend
class
TaskGroup
;
enum
class
WaitFlag
:
int
{
kBlocking
,
kNonBlocking
};
Task
(
const
std
::
string
&
myName
,
const
std
::
function
<
Status
()
>
&
f
);
// Future objects are not copyable.
...
...
@@ -74,7 +75,7 @@ class Task : public IntrpResource {
// Run the task
Status
Run
();
Status
Join
();
Status
Join
(
WaitFlag
wf
=
WaitFlag
::
kBlocking
);
bool
Running
()
const
{
return
running_
;
}
...
...
mindspore/ccsrc/dataset/util/task_manager.cc
浏览文件 @
578cdb42
...
...
@@ -278,12 +278,12 @@ Status TaskGroup::CreateAsyncTask(const std::string &my_name, const std::functio
void
TaskGroup
::
interrupt_all
()
noexcept
{
intrp_svc_
->
InterruptAll
();
}
Status
TaskGroup
::
join_all
()
{
Status
TaskGroup
::
join_all
(
Task
::
WaitFlag
wf
)
{
Status
rc
;
Status
rc2
;
SharedLock
lck
(
&
rw_lock_
);
for
(
Task
&
tk
:
grp_list_
)
{
rc
=
tk
.
Join
();
rc
=
tk
.
Join
(
wf
);
if
(
rc
.
IsError
())
{
rc2
=
rc
;
}
...
...
@@ -294,7 +294,7 @@ Status TaskGroup::join_all() {
Status
TaskGroup
::
DoServiceStop
()
{
intrp_svc_
->
ServiceStop
();
interrupt_all
();
return
(
join_all
());
return
(
join_all
(
Task
::
WaitFlag
::
kNonBlocking
));
}
TaskGroup
::
TaskGroup
()
:
grp_list_
(
&
Task
::
group
),
intrp_svc_
(
nullptr
)
{
...
...
mindspore/ccsrc/dataset/util/task_manager.h
浏览文件 @
578cdb42
...
...
@@ -122,7 +122,7 @@ class TaskGroup : public Service {
void
interrupt_all
()
noexcept
;
Status
join_all
();
Status
join_all
(
Task
::
WaitFlag
wf
=
Task
::
WaitFlag
::
kBlocking
);
int
size
()
const
noexcept
{
return
grp_list_
.
count
;
}
...
...
tests/ut/cpp/dataset/interrupt_test.cc
浏览文件 @
578cdb42
...
...
@@ -48,7 +48,7 @@ TEST_F(MindDataTestIntrpService, Test1) {
return
rc
;
});
vg_
.
GetIntrpService
()
->
InterruptAll
();
vg_
.
join_all
();
vg_
.
join_all
(
Task
::
WaitFlag
::
kNonBlocking
);
}
TEST_F
(
MindDataTestIntrpService
,
Test2
)
{
...
...
@@ -64,5 +64,5 @@ TEST_F(MindDataTestIntrpService, Test2) {
return
rc
;
});
vg_
.
GetIntrpService
()
->
InterruptAll
();
vg_
.
join_all
();
}
vg_
.
join_all
(
Task
::
WaitFlag
::
kNonBlocking
);
}
\ No newline at end of file
tests/ut/cpp/dataset/task_manager_test.cc
浏览文件 @
578cdb42
...
...
@@ -80,5 +80,5 @@ TEST_F(MindDataTestTaskManager, Test2) {
vg
.
interrupt_all
();
EXPECT_TRUE
(
rc
.
IsOk
());
// Now we test the async Join
ASSERT_TRUE
(
vg
.
join_all
().
IsOk
());
ASSERT_TRUE
(
vg
.
join_all
(
Task
::
WaitFlag
::
kNonBlocking
).
IsOk
());
}
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录