Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
e8980ed2
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
e8980ed2
编写于
5月 25, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
5月 25, 2020
浏览文件
操作
浏览文件
下载
差异文件
!1406 Simplify CondVar class
Merge pull request !1406 from JesseKLee/CondVar
上级
59c67946
641112a4
变更
12
隐藏空白更改
内联
并排
Showing
12 changed file
with
88 addition
and
82 deletion
+88
-82
mindspore/ccsrc/dataset/util/cond_var.cc
mindspore/ccsrc/dataset/util/cond_var.cc
+20
-22
mindspore/ccsrc/dataset/util/cond_var.h
mindspore/ccsrc/dataset/util/cond_var.h
+1
-1
mindspore/ccsrc/dataset/util/intrp_resource.h
mindspore/ccsrc/dataset/util/intrp_resource.h
+8
-4
mindspore/ccsrc/dataset/util/intrp_service.cc
mindspore/ccsrc/dataset/util/intrp_service.cc
+6
-13
mindspore/ccsrc/dataset/util/intrp_service.h
mindspore/ccsrc/dataset/util/intrp_service.h
+1
-1
mindspore/ccsrc/dataset/util/queue.h
mindspore/ccsrc/dataset/util/queue.h
+4
-4
mindspore/ccsrc/dataset/util/services.h
mindspore/ccsrc/dataset/util/services.h
+6
-0
mindspore/ccsrc/dataset/util/task.cc
mindspore/ccsrc/dataset/util/task.cc
+9
-1
mindspore/ccsrc/dataset/util/task.h
mindspore/ccsrc/dataset/util/task.h
+4
-3
mindspore/ccsrc/dataset/util/task_manager.cc
mindspore/ccsrc/dataset/util/task_manager.cc
+5
-7
mindspore/ccsrc/dataset/util/task_manager.h
mindspore/ccsrc/dataset/util/task_manager.h
+16
-26
tests/ut/cpp/dataset/connector_test.cc
tests/ut/cpp/dataset/connector_test.cc
+8
-0
未找到文件。
mindspore/ccsrc/dataset/util/cond_var.cc
浏览文件 @
e8980ed2
...
...
@@ -14,35 +14,34 @@
* limitations under the License.
*/
#include "dataset/util/cond_var.h"
#include <exception>
#include <utility>
#include "dataset/util/services.h"
#include "dataset/util/task_manager.h"
namespace
mindspore
{
namespace
dataset
{
CondVar
::
CondVar
()
:
svc_
(
nullptr
),
my_name_
(
std
::
move
(
Services
::
GetUniqueID
()
))
{}
CondVar
::
CondVar
()
:
svc_
(
nullptr
),
my_name_
(
Services
::
GetUniqueID
(
))
{}
Status
CondVar
::
Wait
(
std
::
unique_lock
<
std
::
mutex
>
*
lck
,
const
std
::
function
<
bool
()
>
&
pred
)
{
// Append an additional condition on top of the given predicate.
// We will also bail out if this cv got interrupted.
auto
f
=
[
this
,
&
pred
]()
->
bool
{
return
(
pred
()
||
(
CurState
()
==
State
::
kInterrupted
));
};
// If we have interrupt service, just wait on the cv unconditionally.
// Otherwise fall back to the old way of checking interrupt.
if
(
svc_
)
{
cv_
.
wait
(
*
lck
,
f
);
if
(
CurState
()
==
State
::
kInterrupted
)
{
Task
*
my_task
=
TaskManager
::
FindMe
();
if
(
my_task
->
IsMasterThread
()
&&
my_task
->
CaughtSevereException
())
{
return
TaskManager
::
GetMasterThreadRc
()
;
}
else
{
return
Status
(
StatusCode
::
kInterrupted
);
try
{
if
(
svc_
!=
nullptr
)
{
// If this cv registers with a global resource tracking, then wait unconditionally.
auto
f
=
[
this
,
&
pred
]()
->
bool
{
return
(
pred
()
||
this
->
Interrupted
());
};
cv_
.
wait
(
*
lck
,
f
);
// If we are interrupted, override the return value if this is the master thread.
// Master thread is being interrupted mostly because of some thread is reporting error.
RETURN_IF_NOT_OK
(
Task
::
OverrideInterruptRc
(
this
->
GetInterruptStatus
()));
}
else
{
// Otherwise we wake up once a while to check for interrupt (for this thread).
auto
f
=
[
&
pred
]()
->
bool
{
return
(
pred
()
||
this_thread
::
is_interrupted
());
}
;
while
(
!
f
())
{
(
void
)
cv_
.
wait_for
(
*
lck
,
std
::
chrono
::
milliseconds
(
1
)
);
}
RETURN_IF_INTERRUPTED
();
}
}
else
{
RETURN_IF_NOT_OK
(
interruptible_wait
(
&
cv_
,
lck
,
pred
));
if
(
CurState
()
==
State
::
kInterrupted
)
{
return
Status
(
StatusCode
::
kInterrupted
);
}
}
catch
(
const
std
::
exception
&
e
)
{
RETURN_STATUS_UNEXPECTED
(
e
.
what
());
}
return
Status
::
OK
();
}
...
...
@@ -66,10 +65,9 @@ Status CondVar::Register(std::shared_ptr<IntrpService> svc) {
return
rc
;
}
Status
CondVar
::
Interrupt
()
{
RETURN_IF_NOT_OK
(
IntrpResource
::
Interrupt
()
);
void
CondVar
::
Interrupt
()
{
IntrpResource
::
Interrupt
(
);
cv_
.
notify_all
();
return
Status
::
OK
();
}
std
::
string
CondVar
::
my_name
()
const
{
return
my_name_
;
}
...
...
mindspore/ccsrc/dataset/util/cond_var.h
浏览文件 @
e8980ed2
...
...
@@ -35,7 +35,7 @@ class CondVar : public IntrpResource {
Status
Wait
(
std
::
unique_lock
<
std
::
mutex
>
*
lck
,
const
std
::
function
<
bool
()
>
&
pred
);
Status
Interrupt
()
override
;
void
Interrupt
()
override
;
void
NotifyOne
()
noexcept
;
...
...
mindspore/ccsrc/dataset/util/intrp_resource.h
浏览文件 @
e8980ed2
...
...
@@ -29,10 +29,7 @@ class IntrpResource {
virtual
~
IntrpResource
()
=
default
;
virtual
Status
Interrupt
()
{
st_
=
State
::
kInterrupted
;
return
Status
::
OK
();
}
virtual
void
Interrupt
()
{
st_
=
State
::
kInterrupted
;
}
virtual
void
ResetIntrpState
()
{
st_
=
State
::
kRunning
;
}
...
...
@@ -40,6 +37,13 @@ class IntrpResource {
bool
Interrupted
()
const
{
return
CurState
()
==
State
::
kInterrupted
;
}
virtual
Status
GetInterruptStatus
()
const
{
if
(
Interrupted
())
{
return
Status
(
StatusCode
::
kInterrupted
);
}
return
Status
::
OK
();
}
protected:
std
::
atomic
<
State
>
st_
;
};
...
...
mindspore/ccsrc/dataset/util/intrp_service.cc
浏览文件 @
e8980ed2
...
...
@@ -27,7 +27,7 @@ IntrpService::~IntrpService() noexcept {
MS_LOG
(
INFO
)
<<
"Number of registered resources is "
<<
high_water_mark_
<<
"."
;
if
(
!
all_intrp_resources_
.
empty
())
{
try
{
(
void
)
InterruptAll
();
InterruptAll
();
}
catch
(
const
std
::
exception
&
e
)
{
// Ignore all error as we can't throw in the destructor.
}
...
...
@@ -64,11 +64,9 @@ Status IntrpService::Deregister(const std::string &name) noexcept {
std
::
ostringstream
ss
;
ss
<<
this_thread
::
get_id
();
MS_LOG
(
DEBUG
)
<<
"De-register resource with name "
<<
name
<<
". Thread ID is "
<<
ss
.
str
()
<<
"."
;
auto
it
=
all_intrp_resources_
.
find
(
name
);
if
(
it
!=
all_intrp_resources_
.
end
())
{
(
void
)
all_intrp_resources_
.
erase
(
it
);
}
else
{
MS_LOG
(
DEBUG
)
<<
"Key "
<<
name
<<
" not found."
;
auto
n
=
all_intrp_resources_
.
erase
(
name
);
if
(
n
==
0
)
{
MS_LOG
(
INFO
)
<<
"Key "
<<
name
<<
" not found."
;
}
}
catch
(
std
::
exception
&
e
)
{
RETURN_STATUS_UNEXPECTED
(
e
.
what
());
...
...
@@ -76,21 +74,16 @@ Status IntrpService::Deregister(const std::string &name) noexcept {
return
Status
::
OK
();
}
Status
IntrpService
::
InterruptAll
()
noexcept
{
void
IntrpService
::
InterruptAll
()
noexcept
{
std
::
lock_guard
<
std
::
mutex
>
lck
(
mutex_
);
Status
rc
;
for
(
auto
const
&
it
:
all_intrp_resources_
)
{
std
::
string
kName
=
it
.
first
;
try
{
Status
rc2
=
it
.
second
->
Interrupt
();
if
(
rc2
.
IsError
())
{
rc
=
rc2
;
}
it
.
second
->
Interrupt
();
}
catch
(
const
std
::
exception
&
e
)
{
// continue the clean up.
}
}
return
rc
;
}
}
// namespace dataset
}
// namespace mindspore
mindspore/ccsrc/dataset/util/intrp_service.h
浏览文件 @
e8980ed2
...
...
@@ -47,7 +47,7 @@ class IntrpService : public Service {
Status
Deregister
(
const
std
::
string
&
name
)
noexcept
;
Status
InterruptAll
()
noexcept
;
void
InterruptAll
()
noexcept
;
Status
DoServiceStart
()
override
{
return
Status
::
OK
();
}
...
...
mindspore/ccsrc/dataset/util/queue.h
浏览文件 @
e8980ed2
...
...
@@ -110,7 +110,7 @@ class Queue {
empty_cv_
.
NotifyAll
();
_lock
.
unlock
();
}
else
{
(
void
)
empty_cv_
.
Interrupt
();
empty_cv_
.
Interrupt
();
}
return
rc
;
}
...
...
@@ -125,7 +125,7 @@ class Queue {
empty_cv_
.
NotifyAll
();
_lock
.
unlock
();
}
else
{
(
void
)
empty_cv_
.
Interrupt
();
empty_cv_
.
Interrupt
();
}
return
rc
;
}
...
...
@@ -141,7 +141,7 @@ class Queue {
empty_cv_
.
NotifyAll
();
_lock
.
unlock
();
}
else
{
(
void
)
empty_cv_
.
Interrupt
();
empty_cv_
.
Interrupt
();
}
return
rc
;
}
...
...
@@ -160,7 +160,7 @@ class Queue {
full_cv_
.
NotifyAll
();
_lock
.
unlock
();
}
else
{
(
void
)
full_cv_
.
Interrupt
();
full_cv_
.
Interrupt
();
}
return
rc
;
}
...
...
mindspore/ccsrc/dataset/util/services.h
浏览文件 @
e8980ed2
...
...
@@ -20,6 +20,7 @@
#include <mutex>
#include <string>
#include "dataset/util/memory_pool.h"
#include "dataset/util/allocator.h"
#include "dataset/util/service.h"
#define UNIQUEID_LEN 36
...
...
@@ -72,6 +73,11 @@ class Services {
static
std
::
string
GetUniqueID
();
template
<
typename
T
>
static
Allocator
<
T
>
GetAllocator
()
{
return
Allocator
<
T
>
(
Services
::
GetInstance
().
GetServiceMemPool
());
}
private:
static
std
::
once_flag
init_instance_flag_
;
static
std
::
unique_ptr
<
Services
>
instance_
;
...
...
mindspore/ccsrc/dataset/util/task.cc
浏览文件 @
e8980ed2
...
...
@@ -72,7 +72,7 @@ void Task::ShutdownGroup() { // Wake up watch dog and shutdown the engine.
}
}
Status
Task
::
GetTaskErrorIfAny
()
{
Status
Task
::
GetTaskErrorIfAny
()
const
{
std
::
lock_guard
<
std
::
mutex
>
lk
(
mux_
);
if
(
caught_severe_exception_
)
{
return
rc_
;
...
...
@@ -141,5 +141,13 @@ TaskGroup *Task::MyTaskGroup() { return task_group_; }
void
Task
::
set_task_group
(
TaskGroup
*
vg
)
{
task_group_
=
vg
;
}
Task
::~
Task
()
{
task_group_
=
nullptr
;
}
Status
Task
::
OverrideInterruptRc
(
const
Status
&
rc
)
{
if
(
rc
.
IsInterrupted
()
&&
this_thread
::
is_master_thread
())
{
// If we are interrupted, override the return value if this is the master thread.
// Master thread is being interrupted mostly because of some thread is reporting error.
return
TaskManager
::
GetMasterThreadRc
();
}
return
rc
;
}
}
// namespace dataset
}
// namespace mindspore
mindspore/ccsrc/dataset/util/task.h
浏览文件 @
e8980ed2
...
...
@@ -60,7 +60,7 @@ class Task : public IntrpResource {
Task
&
operator
=
(
Task
&&
)
=
delete
;
Status
GetTaskErrorIfAny
();
Status
GetTaskErrorIfAny
()
const
;
void
ChangeName
(
const
std
::
string
&
newName
)
{
my_name_
=
newName
;
}
...
...
@@ -95,10 +95,10 @@ class Task : public IntrpResource {
Status
Wait
()
{
return
(
wp_
.
Wait
());
}
void
set_task_group
(
TaskGroup
*
vg
);
static
Status
OverrideInterruptRc
(
const
Status
&
rc
);
private:
std
::
mutex
mux_
;
mutable
std
::
mutex
mux_
;
std
::
string
my_name_
;
Status
rc_
;
WaitPost
wp_
;
...
...
@@ -115,6 +115,7 @@ class Task : public IntrpResource {
void
ShutdownGroup
();
TaskGroup
*
MyTaskGroup
();
void
set_task_group
(
TaskGroup
*
vg
);
};
extern
thread_local
Task
*
gMyTask
;
...
...
mindspore/ccsrc/dataset/util/task_manager.cc
浏览文件 @
e8980ed2
...
...
@@ -84,7 +84,7 @@ void TaskManager::interrupt_all() noexcept {
svc
->
InterruptAll
();
}
}
(
void
)
master_
->
Interrupt
();
master_
->
Interrupt
();
}
Task
*
TaskManager
::
FindMe
()
{
return
gMyTask
;
}
...
...
@@ -94,8 +94,7 @@ TaskManager::TaskManager() try : global_interrupt_(0),
free_lst_
(
&
Task
::
free
),
watchdog_grp_
(
nullptr
),
watchdog_
(
nullptr
)
{
std
::
shared_ptr
<
MemoryPool
>
mp
=
Services
::
GetInstance
().
GetServiceMemPool
();
Allocator
<
Task
>
alloc
(
mp
);
auto
alloc
=
Services
::
GetAllocator
<
Task
>
();
// Create a dummy Task for the master thread (this thread)
master_
=
std
::
allocate_shared
<
Task
>
(
alloc
,
"master"
,
[]()
->
Status
{
return
Status
::
OK
();
});
master_
->
id_
=
this_thread
::
get_id
();
...
...
@@ -185,7 +184,7 @@ void TaskManager::InterruptMaster(const Status &rc) {
TaskManager
&
tm
=
TaskManager
::
GetInstance
();
std
::
shared_ptr
<
Task
>
master
=
tm
.
master_
;
std
::
lock_guard
<
std
::
mutex
>
lck
(
master
->
mux_
);
(
void
)
master
->
Interrupt
();
master
->
Interrupt
();
if
(
rc
.
IsError
()
&&
master
->
rc_
.
IsOk
())
{
master
->
rc_
=
rc
;
master
->
caught_severe_exception_
=
true
;
...
...
@@ -277,7 +276,7 @@ Status TaskGroup::CreateAsyncTask(const std::string &my_name, const std::functio
return
Status
::
OK
();
}
void
TaskGroup
::
interrupt_all
()
noexcept
{
(
void
)
intrp_svc_
->
InterruptAll
();
}
void
TaskGroup
::
interrupt_all
()
noexcept
{
intrp_svc_
->
InterruptAll
();
}
Status
TaskGroup
::
join_all
()
{
Status
rc
;
...
...
@@ -299,8 +298,7 @@ Status TaskGroup::DoServiceStop() {
}
TaskGroup
::
TaskGroup
()
:
grp_list_
(
&
Task
::
group
),
intrp_svc_
(
nullptr
)
{
std
::
shared_ptr
<
MemoryPool
>
mp
=
Services
::
GetInstance
().
GetServiceMemPool
();
Allocator
<
IntrpService
>
alloc
(
mp
);
auto
alloc
=
Services
::
GetAllocator
<
IntrpService
>
();
intrp_svc_
=
std
::
allocate_shared
<
IntrpService
>
(
alloc
);
(
void
)
Service
::
ServiceStart
();
}
...
...
mindspore/ccsrc/dataset/util/task_manager.h
浏览文件 @
e8980ed2
...
...
@@ -154,37 +154,27 @@ inline bool is_interrupted() {
return
true
;
}
Task
*
my_task
=
TaskManager
::
FindMe
();
return
(
my_task
!=
nullptr
)
?
my_task
->
Interrupted
()
:
false
;
return
my_task
->
Interrupted
();
}
inline
bool
is_master_thread
()
{
Task
*
my_task
=
TaskManager
::
FindMe
();
return
my_task
->
IsMasterThread
();
}
inline
Status
GetInterruptStatus
()
{
Task
*
my_task
=
TaskManager
::
FindMe
();
return
my_task
->
GetInterruptStatus
();
}
}
// namespace this_thread
#define RETURN_IF_INTERRUPTED() \
do { \
if (mindspore::dataset::this_thread::is_interrupted()) { \
Task *myTask = TaskManager::FindMe(); \
if (myTask->IsMasterThread() && myTask->CaughtSevereException()) { \
return TaskManager::GetMasterThreadRc(); \
} else { \
return Status(StatusCode::kInterrupted); \
} \
} \
#define RETURN_IF_INTERRUPTED() \
do { \
if (mindspore::dataset::this_thread::is_interrupted()) { \
return Task::OverrideInterruptRc(this_thread::GetInterruptStatus()); \
} \
} while (false)
inline
Status
interruptible_wait
(
std
::
condition_variable
*
cv
,
std
::
unique_lock
<
std
::
mutex
>
*
lk
,
const
std
::
function
<
bool
()
>
&
pred
)
noexcept
{
if
(
!
pred
())
{
do
{
RETURN_IF_INTERRUPTED
();
try
{
(
void
)
cv
->
wait_for
(
*
lk
,
std
::
chrono
::
milliseconds
(
1
));
}
catch
(
std
::
exception
&
e
)
{
// Anything thrown by wait_for is considered system error.
RETURN_STATUS_UNEXPECTED
(
e
.
what
());
}
}
while
(
!
pred
());
}
return
Status
::
OK
();
}
}
// namespace dataset
}
// namespace mindspore
...
...
tests/ut/cpp/dataset/connector_test.cc
浏览文件 @
e8980ed2
...
...
@@ -139,6 +139,9 @@ Status MindDataTestConnector::Run_test_0() {
10
);
// capacity of each queue
DS_ASSERT
(
my_conn
!=
nullptr
);
rc
=
my_conn
->
Register
(
tg_
.
get
());
RETURN_IF_NOT_OK
(
rc
);
// Spawn a thread to read input_ vector and put it in my_conn
rc
=
tg_
->
CreateAsyncTask
(
"Worker Push"
,
std
::
bind
(
&
MindDataTestConnector
::
FirstWorkerPush
,
...
...
@@ -184,6 +187,11 @@ Status MindDataTestConnector::Run_test_1() {
l3_threads
,
conn2_qcap
);
rc
=
conn1
->
Register
(
tg_
.
get
());
RETURN_IF_NOT_OK
(
rc
);
rc
=
conn2
->
Register
(
tg_
.
get
());
RETURN_IF_NOT_OK
(
rc
);
// Instantiating the threads in the first layer
for
(
int
i
=
0
;
i
<
l1_threads
;
i
++
)
{
rc
=
tg_
->
CreateAsyncTask
(
"First Worker Push"
,
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录