Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Oneflow-Inc
oneflow
提交
cc2ee9fc
O
oneflow
项目概览
Oneflow-Inc
/
oneflow
上一次同步 2 年多
通知
13
Star
2733
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
O
oneflow
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
前往新版Gitcode,体验更适合开发者的 AI 搜索 >>
提交
cc2ee9fc
编写于
1月 30, 2018
作者:
W
willzhang4a58
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
naive readable register manager
上级
6f0754f6
变更
7
隐藏空白更改
内联
并排
Showing
7 changed file
with
135 addition
and
68 deletion
+135
-68
oneflow/core/actor/actor.h
oneflow/core/actor/actor.h
+2
-0
oneflow/core/actor/boxing_actor.cpp
oneflow/core/actor/boxing_actor.cpp
+33
-42
oneflow/core/actor/boxing_actor.h
oneflow/core/actor/boxing_actor.h
+2
-3
oneflow/core/actor/model_update_compute_actor.cpp
oneflow/core/actor/model_update_compute_actor.cpp
+9
-21
oneflow/core/actor/model_update_compute_actor.h
oneflow/core/actor/model_update_compute_actor.h
+2
-2
oneflow/core/actor/naive_readable_register_manager.cpp
oneflow/core/actor/naive_readable_register_manager.cpp
+53
-0
oneflow/core/actor/naive_readable_register_manager.h
oneflow/core/actor/naive_readable_register_manager.h
+34
-0
未找到文件。
oneflow/core/actor/actor.h
浏览文件 @
cc2ee9fc
...
...
@@ -36,6 +36,8 @@ class Actor {
int64_t
actor_id
()
const
{
return
actor_id_
;
}
protected:
friend
class
NaiveReadableRegstMgr
;
struct
ExecKernel
{
std
::
unique_ptr
<
const
Kernel
>
kernel
;
HashMap
<
std
::
string
,
int64_t
>
bn_in_op2regst_desc_id
;
...
...
oneflow/core/actor/boxing_actor.cpp
浏览文件 @
cc2ee9fc
...
...
@@ -4,11 +4,8 @@
namespace
oneflow
{
void
BoxingActor
::
VirtualActorInit
(
const
TaskProto
&
task_proto
)
{
for
(
const
auto
&
pair
:
task_proto
.
consumed_regst_desc_id
())
{
readable_regst_
[
pair
.
second
]
=
{};
}
readable_regst_mgr_
.
Init
(
task_proto
);
previous_pid_cid_
=
new
HashMap
<
int64_t
,
std
::
pair
<
int64_t
,
int32_t
>>
;
readable_regst_cnt_
=
0
;
col_id_order_
=
ColIdOrder
::
kUnCertain
;
is_eord_
=
false
;
OF_SET_MSG_HANDLER
(
&
BoxingActor
::
HandlerNormal
);
...
...
@@ -48,9 +45,7 @@ int BoxingActor::HandlerNormal(const ActorMsg& msg) {
&&
col_id_order_
==
ColIdOrder
::
kUnCertain
)
{
TrySetColIdOrder
(
msg
.
regst
());
}
std
::
queue
<
Regst
*>&
rq
=
readable_regst_
.
at
(
msg
.
regst
()
->
regst_desc_id
());
if
(
rq
.
empty
())
{
readable_regst_cnt_
+=
1
;
}
rq
.
push
(
msg
.
regst
());
readable_regst_mgr_
.
Push
(
msg
.
regst
());
}
ActUntilFail
();
}
else
{
...
...
@@ -60,57 +55,53 @@ int BoxingActor::HandlerNormal(const ActorMsg& msg) {
}
void
BoxingActor
::
Act
()
{
int64_t
piece_id
=
readable_regst_
.
begin
()
->
second
.
front
()
->
piece_id
();
AsyncLaunchKernel
(
GenDefaultKernelCtx
(),
[
this
](
int64_t
regst_desc_id
)
->
Regst
*
{
Regst
*
regst
=
GetCurWriteableRegst
(
regst_desc_id
);
if
(
regst
==
nullptr
)
{
return
readable_regst_
.
at
(
regst_desc_id
).
front
(
);
}
else
{
return
regst
;
}
});
int64_t
piece_id
=
readable_regst_
mgr_
.
GetFirstCurReadable
()
->
piece_id
();
AsyncLaunchKernel
(
GenDefaultKernelCtx
(),
[
this
](
int64_t
regst_desc_id
)
->
Regst
*
{
Regst
*
regst
=
GetCurWriteableRegst
(
regst_desc_id
);
if
(
regst
==
nullptr
)
{
return
readable_regst_mgr_
.
GetCurReadable
(
regst_desc_id
);
}
else
{
return
regst
;
}
});
AsyncSendRegstMsgToConsumer
([
&
](
Regst
*
regst
)
{
regst
->
set_piece_id
(
piece_id
);
return
regst
->
col_id
()
<=
regst
->
max_col_id
();
});
int32_t
cur_max_cid
=
0
;
int32_t
cur_max_maxcid
=
0
;
for
(
const
auto
&
pair
:
readable_regst_
)
{
cur_max_cid
=
std
::
max
(
cur_max_cid
,
pair
.
second
.
front
()
->
col_id
());
cur_max_maxcid
=
std
::
max
(
cur_max_maxcid
,
pair
.
second
.
front
()
->
max_col_id
());
}
for
(
auto
&
pair
:
readable_regst_
)
{
if
(
col_id_order_
==
ColIdOrder
::
kAscending
)
{
if
(
pair
.
second
.
front
()
->
IsMaxCol
()
&&
cur_max_cid
<
cur_max_maxcid
)
{
continue
;
}
}
else
if
(
col_id_order_
==
ColIdOrder
::
kDescending
)
{
if
(
pair
.
second
.
front
()
->
col_id
()
<
cur_max_cid
)
{
continue
;
}
}
else
{
// do nothing
}
AsyncSendRegstMsgToProducer
(
pair
.
second
.
front
());
pair
.
second
.
pop
();
if
(
pair
.
second
.
empty
())
{
readable_regst_cnt_
-=
1
;
}
}
readable_regst_mgr_
.
ForEachCurReadableRegst
([
&
](
Regst
*
regst
)
{
cur_max_cid
=
std
::
max
(
cur_max_cid
,
regst
->
col_id
());
cur_max_maxcid
=
std
::
max
(
cur_max_maxcid
,
regst
->
max_col_id
());
});
readable_regst_mgr_
.
ReturnToProducerAndPopCurReadable
(
this
,
[
&
](
Regst
*
regst
)
{
if
(
col_id_order_
==
ColIdOrder
::
kAscending
)
{
if
(
regst
->
IsMaxCol
()
&&
cur_max_cid
<
cur_max_maxcid
)
{
return
false
;
}
}
else
if
(
col_id_order_
==
ColIdOrder
::
kDescending
)
{
if
(
regst
->
col_id
()
<
cur_max_cid
)
{
return
false
;
}
}
else
{
// do nothing
}
return
true
;
});
}
bool
BoxingActor
::
IsReadReady
()
{
return
readable_regst_
.
size
()
==
readable_regst_cnt_
;
}
bool
BoxingActor
::
IsReadReady
()
{
return
readable_regst_mgr_
.
IsReadReady
();
}
bool
BoxingActor
::
IsReadAlwaysUnReadyFromNow
()
{
return
is_eord_
&&
readable_regst_
cnt_
==
0
;
return
is_eord_
&&
readable_regst_
mgr_
.
IsEmpty
()
;
}
void
BoxingActor
::
AsyncReturnAllReadableRegst
()
{
CHECK
_EQ
(
readable_regst_cnt_
,
0
);
CHECK
(
readable_regst_mgr_
.
IsEmpty
()
);
}
void
BoxingActor
::
ForEachCurReadableRegst
(
std
::
function
<
void
(
const
Regst
*
)
>
handler
)
{
for
(
const
auto
&
pair
:
readable_regst_
)
{
handler
(
pair
.
second
.
front
());
}
std
::
function
<
void
(
const
Regst
*
)
>
func
)
{
readable_regst_mgr_
.
ForEachCurReadableRegst
(
func
);
}
REGISTER_ACTOR
(
TaskType
::
kBoxing
,
BoxingActor
);
...
...
oneflow/core/actor/boxing_actor.h
浏览文件 @
cc2ee9fc
...
...
@@ -2,6 +2,7 @@
#define ONEFLOW_CORE_ACTOR_BOXING_ACTOR_H_
#include "oneflow/core/actor/actor.h"
#include "oneflow/core/actor/naive_readable_register_manager.h"
namespace
oneflow
{
...
...
@@ -25,11 +26,9 @@ class BoxingActor final : public Actor {
void
TrySetColIdOrder
(
const
Regst
*
);
// <regst_desc_id, regst*>
HashMap
<
int64_t
,
std
::
queue
<
Regst
*>>
readable_regst_
;
NaiveReadableRegstMgr
readable_regst_mgr_
;
// <regst_desc_id, <pid, cid>>
HashMap
<
int64_t
,
std
::
pair
<
int64_t
,
int32_t
>>*
previous_pid_cid_
;
int8_t
readable_regst_cnt_
;
ColIdOrder
col_id_order_
;
bool
is_eord_
;
};
...
...
oneflow/core/actor/model_update_compute_actor.cpp
浏览文件 @
cc2ee9fc
...
...
@@ -14,11 +14,7 @@ void MdUpdtCompActor::VirtualCompActorInit(const TaskProto& task_proto) {
related_save_model_actor_id_
=
task_proto
.
related_save_model_task_id
();
related_init_model_actor_id_
=
task_proto
.
related_init_model_task_id
();
pre_model_regst_
=
nullptr
;
for
(
const
auto
&
kv
:
task_proto
.
consumed_regst_desc_id
())
{
CHECK
(
model_diff_acc_regsts_
.
emplace
(
kv
.
second
,
std
::
queue
<
Regst
*>
()).
second
);
}
readable_model_diff_acc_cnt_
=
0
;
readable_regst_mgr_
.
Init
(
task_proto
);
OF_SET_MSG_HANDLER
(
&
MdUpdtCompActor
::
HandlerInitModelAndModelTmp
);
}
...
...
@@ -71,9 +67,7 @@ int MdUpdtCompActor::HandlerNormal(const ActorMsg& actor_msg) {
}
else
if
(
actor_msg
.
msg_type
()
==
ActorMsgType
::
kRegstMsg
)
{
Regst
*
regst
=
actor_msg
.
regst
();
if
(
TryUpdtStateAsProducedRegst
(
regst
)
!=
0
)
{
auto
it
=
model_diff_acc_regsts_
.
find
(
regst
->
regst_desc_id
());
if
(
it
->
second
.
empty
())
{
readable_model_diff_acc_cnt_
+=
1
;
}
it
->
second
.
push
(
regst
);
readable_regst_mgr_
.
Push
(
regst
);
}
ActUntilFail
();
}
else
{
...
...
@@ -93,16 +87,12 @@ void MdUpdtCompActor::Act() {
AsyncLaunchKernel
(
kernel_ctx
,
[
&
](
int64_t
regst_desc_id
)
->
Regst
*
{
Regst
*
regst
=
GetCurWriteableRegst
(
regst_desc_id
);
if
(
regst
==
nullptr
)
{
return
model_diff_acc_regsts_
.
at
(
regst_desc_id
).
front
(
);
return
readable_regst_mgr_
.
GetCurReadable
(
regst_desc_id
);
}
else
{
return
regst
;
}
});
for
(
auto
&
kv
:
model_diff_acc_regsts_
)
{
AsyncSendRegstMsgToProducer
(
kv
.
second
.
front
());
kv
.
second
.
pop
();
if
(
kv
.
second
.
empty
())
{
readable_model_diff_acc_cnt_
-=
1
;
}
}
readable_regst_mgr_
.
ReturnToProducerAndPopCurReadable
(
this
);
const
JobDesc
*
job_desc
=
JobDesc
::
Singleton
();
auto
RegstPreProcess
=
[
&
](
Regst
*
regst
)
{
return
regst
==
cur_model_regst
;
};
if
(
next_model_version_id_
==
job_desc
->
TotalBatchNum
())
{
...
...
@@ -122,11 +112,11 @@ void MdUpdtCompActor::Act() {
}
bool
MdUpdtCompActor
::
IsReadReady
()
{
return
readable_
model_diff_acc_cnt_
==
model_diff_acc_regsts_
.
size
();
return
readable_
regst_mgr_
.
IsReadReady
();
}
bool
MdUpdtCompActor
::
IsReadAlwaysUnReadyFromNow
()
{
return
is_model_diff_acc_eord_
&&
readable_
model_diff_acc_cnt_
==
0
;
return
is_model_diff_acc_eord_
&&
readable_
regst_mgr_
.
IsEmpty
()
;
}
bool
MdUpdtCompActor
::
IsWriteReady
()
{
...
...
@@ -134,14 +124,12 @@ bool MdUpdtCompActor::IsWriteReady() {
}
void
MdUpdtCompActor
::
AsyncReturnAllReadableRegst
()
{
CHECK
_EQ
(
0
,
readable_model_diff_acc_cnt_
);
CHECK
(
readable_regst_mgr_
.
IsEmpty
()
);
}
void
MdUpdtCompActor
::
ForEachCurReadableRegst
(
std
::
function
<
void
(
const
Regst
*
)
>
handler
)
{
for
(
const
auto
&
pair
:
model_diff_acc_regsts_
)
{
handler
(
pair
.
second
.
front
());
}
std
::
function
<
void
(
const
Regst
*
)
>
func
)
{
readable_regst_mgr_
.
ForEachCurReadableRegst
(
func
);
}
REGISTER_ACTOR
(
TaskType
::
kMdUpdt
,
MdUpdtCompActor
);
...
...
oneflow/core/actor/model_update_compute_actor.h
浏览文件 @
cc2ee9fc
...
...
@@ -2,6 +2,7 @@
#define ONEFLOW_CORE_ACTOR_MODEL_UPDATE_COMPUTE_ACTOR_H_
#include "oneflow/core/actor/compute_actor.h"
#include "oneflow/core/actor/naive_readable_register_manager.h"
namespace
oneflow
{
...
...
@@ -32,8 +33,7 @@ class MdUpdtCompActor final : public CompActor {
int64_t
model_tmp_regst_desc_id_
;
int8_t
init_remaining_cnt_
;
bool
is_model_diff_acc_eord_
;
int64_t
readable_model_diff_acc_cnt_
;
HashMap
<
int64_t
,
std
::
queue
<
Regst
*>>
model_diff_acc_regsts_
;
NaiveReadableRegstMgr
readable_regst_mgr_
;
int64_t
next_model_version_id_
;
int64_t
related_save_model_actor_id_
;
int64_t
related_init_model_actor_id_
;
...
...
oneflow/core/actor/naive_readable_register_manager.cpp
0 → 100644
浏览文件 @
cc2ee9fc
#include "oneflow/core/actor/naive_readable_register_manager.h"
namespace
oneflow
{
void
NaiveReadableRegstMgr
::
Init
(
const
TaskProto
&
task_proto
)
{
for
(
const
auto
&
pair
:
task_proto
.
consumed_regst_desc_id
())
{
readable_regst_
[
pair
.
second
]
=
{};
}
readable_regst_cnt_
=
0
;
}
void
NaiveReadableRegstMgr
::
Push
(
Regst
*
regst
)
{
std
::
queue
<
Regst
*>&
rq
=
readable_regst_
.
at
(
regst
->
regst_desc_id
());
if
(
rq
.
empty
())
{
readable_regst_cnt_
+=
1
;
}
rq
.
push
(
regst
);
}
void
NaiveReadableRegstMgr
::
ReturnToProducerAndPopCurReadable
(
Actor
*
actor
,
std
::
function
<
bool
(
Regst
*
)
>
IsAllowed
)
{
for
(
auto
&
pair
:
readable_regst_
)
{
CHECK_EQ
(
pair
.
second
.
empty
(),
false
);
if
(
IsAllowed
(
pair
.
second
.
front
())
==
false
)
{
continue
;
}
actor
->
AsyncSendRegstMsgToProducer
(
pair
.
second
.
front
());
pair
.
second
.
pop
();
if
(
pair
.
second
.
empty
())
{
readable_regst_cnt_
-=
1
;
}
}
}
void
NaiveReadableRegstMgr
::
ReturnToProducerAndPopCurReadable
(
Actor
*
actor
)
{
ReturnToProducerAndPopCurReadable
(
actor
,
[](
Regst
*
)
{
return
true
;
});
}
Regst
*
NaiveReadableRegstMgr
::
GetCurReadable
(
int64_t
regst_desc_id
)
{
auto
it
=
readable_regst_
.
find
(
regst_desc_id
);
if
(
it
!=
readable_regst_
.
end
()
&&
it
->
second
.
empty
()
==
false
)
{
return
it
->
second
.
front
();
}
else
{
return
nullptr
;
}
}
void
NaiveReadableRegstMgr
::
ForEachCurReadableRegst
(
std
::
function
<
void
(
Regst
*
)
>
func
)
{
for
(
const
auto
&
pair
:
readable_regst_
)
{
if
(
pair
.
second
.
empty
()
==
false
)
{
func
(
pair
.
second
.
front
());
}
}
}
bool
NaiveReadableRegstMgr
::
IsReadReady
()
{
return
readable_regst_
.
size
()
==
readable_regst_cnt_
;
}
}
// namespace oneflow
oneflow/core/actor/naive_readable_register_manager.h
0 → 100644
浏览文件 @
cc2ee9fc
#ifndef ONEFLOW_CORE_ACTOR_NAIVE_READABLE_REGISTER_MANAGER_H_
#define ONEFLOW_CORE_ACTOR_NAIVE_READABLE_REGISTER_MANAGER_H_
#include "oneflow/core/actor/actor.h"
namespace
oneflow
{
class
NaiveReadableRegstMgr
final
{
public:
OF_DISALLOW_COPY_AND_MOVE
(
NaiveReadableRegstMgr
);
NaiveReadableRegstMgr
()
:
readable_regst_cnt_
(
0
)
{}
~
NaiveReadableRegstMgr
()
=
default
;
void
Init
(
const
TaskProto
&
task_proto
);
void
Push
(
Regst
*
regst
);
void
ReturnToProducerAndPopCurReadable
(
Actor
*
actor
,
std
::
function
<
bool
(
Regst
*
)
>
IsAllowed
);
void
ReturnToProducerAndPopCurReadable
(
Actor
*
actor
);
Regst
*
GetCurReadable
(
int64_t
regst_desc_id
);
Regst
*
GetFirstCurReadable
()
{
return
readable_regst_
.
begin
()
->
second
.
front
();
}
void
ForEachCurReadableRegst
(
std
::
function
<
void
(
Regst
*
)
>
func
);
bool
IsReadReady
();
bool
IsEmpty
()
{
return
readable_regst_cnt_
==
0
;
}
private:
HashMap
<
int64_t
,
std
::
queue
<
Regst
*>>
readable_regst_
;
// regst_desc_id
size_t
readable_regst_cnt_
;
};
}
// namespace oneflow
#endif // ONEFLOW_CORE_ACTOR_NAIVE_READABLE_REGISTER_MANAGER_H_
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录