Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Oneflow-Inc
oneflow
提交
cf5ed195
O
oneflow
项目概览
Oneflow-Inc
/
oneflow
上一次同步 2 年多
通知
13
Star
2733
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
O
oneflow
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
cf5ed195
编写于
7月 05, 2017
作者:
W
willzhang4a58
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Define Common Handle: WaitUntilNoReadableRegst, Common
上级
0c550414
变更
17
隐藏空白更改
内联
并排
Showing
17 changed file
with
53 addition
and
51 deletion
+53
-51
oneflow/core/actor/actor.h
oneflow/core/actor/actor.h
+2
-0
oneflow/core/actor/boxing_actor.cpp
oneflow/core/actor/boxing_actor.cpp
+4
-4
oneflow/core/actor/boxing_actor.h
oneflow/core/actor/boxing_actor.h
+2
-2
oneflow/core/actor/bp_data_comp_actor.cpp
oneflow/core/actor/bp_data_comp_actor.cpp
+5
-5
oneflow/core/actor/bp_data_comp_actor.h
oneflow/core/actor/bp_data_comp_actor.h
+2
-2
oneflow/core/actor/copy_comm_net_actor.cpp
oneflow/core/actor/copy_comm_net_actor.cpp
+4
-6
oneflow/core/actor/copy_comm_net_actor.h
oneflow/core/actor/copy_comm_net_actor.h
+2
-2
oneflow/core/actor/copy_hd_actor.cpp
oneflow/core/actor/copy_hd_actor.cpp
+4
-4
oneflow/core/actor/copy_hd_actor.h
oneflow/core/actor/copy_hd_actor.h
+2
-2
oneflow/core/actor/fw_data_comp_actor.cpp
oneflow/core/actor/fw_data_comp_actor.cpp
+5
-5
oneflow/core/actor/fw_data_comp_actor.h
oneflow/core/actor/fw_data_comp_actor.h
+2
-2
oneflow/core/actor/model_diff_accumulate_actor.cpp
oneflow/core/actor/model_diff_accumulate_actor.cpp
+4
-4
oneflow/core/actor/model_diff_accumulate_actor.h
oneflow/core/actor/model_diff_accumulate_actor.h
+2
-2
oneflow/core/actor/model_save_comp_actor.cpp
oneflow/core/actor/model_save_comp_actor.cpp
+3
-3
oneflow/core/actor/model_save_comp_actor.h
oneflow/core/actor/model_save_comp_actor.h
+4
-1
oneflow/core/actor/model_update_comp_actor.cpp
oneflow/core/actor/model_update_comp_actor.cpp
+4
-5
oneflow/core/actor/model_update_comp_actor.h
oneflow/core/actor/model_update_comp_actor.h
+2
-2
未找到文件。
oneflow/core/actor/actor.h
浏览文件 @
cf5ed195
...
...
@@ -50,6 +50,8 @@ class Actor {
} while (0)
// Common Handles
virtual
int
HandleNormal
(
const
ActorMsg
&
msg
)
=
0
;
virtual
int
HandleWaitUntilNoReadableRegst
(
const
ActorMsg
&
msg
)
=
0
;
int
HandleWaitUntilReadingCntEqualZero
(
const
ActorMsg
&
msg
);
// Status of Produced Registers
...
...
oneflow/core/actor/boxing_actor.cpp
浏览文件 @
cf5ed195
...
...
@@ -12,15 +12,15 @@ void BoxingActor::Init(const TaskProto& task_proto,
num_of_eord_
=
0
;
CHECK
(
thread_ctx
.
cpu_stream
);
mut_device_ctx
().
reset
(
new
CpuDeviceCtx
(
thread_ctx
.
cpu_stream
));
OF_SET_MSG_HANDLE
(
&
BoxingActor
::
Handle
Boxing
);
OF_SET_MSG_HANDLE
(
&
BoxingActor
::
Handle
Normal
);
}
int
BoxingActor
::
Handle
Boxing
(
const
ActorMsg
&
msg
)
{
int
BoxingActor
::
Handle
Normal
(
const
ActorMsg
&
msg
)
{
if
(
msg
.
msg_type
()
==
ActorMsgType
::
kCmdMsg
)
{
CHECK_EQ
(
msg
.
actor_cmd
(),
ActorCmd
::
kEORD
);
num_of_eord_
+=
1
;
if
(
num_of_eord_
==
num_of_subscribed_regsts_
)
{
OF_SET_MSG_HANDLE
(
&
BoxingActor
::
Handle
BoxingWhenNoReadableRegstMsg
);
OF_SET_MSG_HANDLE
(
&
BoxingActor
::
Handle
WaitUntilNoReadableRegst
);
}
}
else
if
(
msg
.
msg_type
()
==
ActorMsgType
::
kRegstMsg
)
{
if
(
TryUpdtStateAsProducedRegst
(
msg
.
regst_warpper
()
->
regst_raw_ptr
())
...
...
@@ -36,7 +36,7 @@ int BoxingActor::HandleBoxing(const ActorMsg& msg) {
return
0
;
}
int
BoxingActor
::
Handle
BoxingWhenNoReadableRegstMsg
(
const
ActorMsg
&
msg
)
{
int
BoxingActor
::
Handle
WaitUntilNoReadableRegst
(
const
ActorMsg
&
msg
)
{
CHECK_EQ
(
TryUpdtStateAsProducedRegst
(
msg
.
regst_warpper
()
->
regst_raw_ptr
()),
0
);
TryLaunchKernelAndSendMsg
();
...
...
oneflow/core/actor/boxing_actor.h
浏览文件 @
cf5ed195
...
...
@@ -14,8 +14,8 @@ class BoxingActor final : public Actor {
void
Init
(
const
TaskProto
&
,
const
ThreadCtx
&
)
override
;
private:
int
Handle
Boxing
(
const
ActorMsg
&
)
;
int
Handle
BoxingWhenNoReadableRegstMsg
(
const
ActorMsg
&
)
;
int
Handle
Normal
(
const
ActorMsg
&
)
override
;
int
Handle
WaitUntilNoReadableRegst
(
const
ActorMsg
&
)
override
;
void
TryLaunchKernelAndSendMsg
();
...
...
oneflow/core/actor/bp_data_comp_actor.cpp
浏览文件 @
cf5ed195
...
...
@@ -21,7 +21,7 @@ void BpDataCompActor::Init(const TaskProto& task_proto,
cuda_handle_
.
cublas_handle
(),
cuda_handle_
.
cudnn_handle
()));
}
OF_SET_MSG_HANDLE
(
&
BpDataCompActor
::
Handle
BpComp
);
OF_SET_MSG_HANDLE
(
&
BpDataCompActor
::
Handle
Normal
);
}
bool
BpDataCompActor
::
IsReadReady
()
{
...
...
@@ -37,12 +37,12 @@ bool BpDataCompActor::IsReadReady() {
return
!
num_of_read_empty_
;
}
int
BpDataCompActor
::
Handle
BpComp
(
const
ActorMsg
&
msg
)
{
int
BpDataCompActor
::
Handle
Normal
(
const
ActorMsg
&
msg
)
{
if
(
msg
.
msg_type
()
==
ActorMsgType
::
kCmdMsg
)
{
CHECK_EQ
(
msg
.
actor_cmd
(),
ActorCmd
::
kEORD
);
num_of_eord_
+=
1
;
if
(
num_of_eord_
==
6
)
{
OF_SET_MSG_HANDLE
(
&
BpDataCompActor
::
Handle
BpCompWhenNoReadableRegstMsg
);
OF_SET_MSG_HANDLE
(
&
BpDataCompActor
::
Handle
WaitUntilNoReadableRegst
);
}
}
else
if
(
msg
.
msg_type
()
==
ActorMsgType
::
kRegstMsg
)
{
if
(
TryUpdtStateAsProducedRegst
(
msg
.
regst_warpper
()
->
regst_raw_ptr
())
...
...
@@ -63,7 +63,7 @@ int BpDataCompActor::HandleBpComp(const ActorMsg& msg) {
return
0
;
}
int
BpDataCompActor
::
Handle
BpCompWhenNoReadableRegstMsg
(
const
ActorMsg
&
msg
)
{
int
BpDataCompActor
::
Handle
WaitUntilNoReadableRegst
(
const
ActorMsg
&
msg
)
{
CHECK_EQ
(
TryUpdtStateAsProducedRegst
(
msg
.
regst_warpper
()
->
regst_raw_ptr
()),
0
);
TryLaunchKernelAndSendMsg
();
...
...
@@ -81,7 +81,7 @@ int BpDataCompActor::HandleBpCompWhenNoReadableRegstMsg(const ActorMsg& msg) {
OF_SET_MSG_HANDLE
(
nullptr
);
return
1
;
}
else
{
OF_SET_MSG_HANDLE
(
nullptr
);
OF_SET_MSG_HANDLE
(
&
BpDataCompActor
::
HandleWaitUntilReadingCntEqualZero
);
return
0
;
}
}
...
...
oneflow/core/actor/bp_data_comp_actor.h
浏览文件 @
cf5ed195
...
...
@@ -14,8 +14,8 @@ class BpDataCompActor final : public Actor {
void
Init
(
const
TaskProto
&
,
const
ThreadCtx
&
)
override
;
private:
int
Handle
BpComp
(
const
ActorMsg
&
)
;
int
Handle
BpCompWhenNoReadableRegstMsg
(
const
ActorMsg
&
)
;
int
Handle
Normal
(
const
ActorMsg
&
)
override
;
int
Handle
WaitUntilNoReadableRegst
(
const
ActorMsg
&
)
override
;
bool
IsReadReady
();
void
TryLaunchKernelAndSendMsg
();
...
...
oneflow/core/actor/copy_comm_net_actor.cpp
浏览文件 @
cf5ed195
...
...
@@ -9,14 +9,13 @@ void CopyCommNetActor::Init(const TaskProto& task_proto,
Actor
::
Init
(
task_proto
,
thread_ctx
);
CHECK
(
thread_ctx
.
cpu_stream
);
mut_device_ctx
().
reset
(
new
CpuDeviceCtx
(
thread_ctx
.
cpu_stream
));
OF_SET_MSG_HANDLE
(
&
CopyCommNetActor
::
Handle
CopyCommNet
);
OF_SET_MSG_HANDLE
(
&
CopyCommNetActor
::
Handle
Normal
);
}
int
CopyCommNetActor
::
Handle
CopyCommNet
(
const
ActorMsg
&
msg
)
{
int
CopyCommNetActor
::
Handle
Normal
(
const
ActorMsg
&
msg
)
{
if
(
msg
.
msg_type
()
==
ActorMsgType
::
kCmdMsg
)
{
CHECK_EQ
(
msg
.
actor_cmd
(),
ActorCmd
::
kEORD
);
OF_SET_MSG_HANDLE
(
&
CopyCommNetActor
::
HandleCopyCommNetWhenNoReadableRegstMsg
);
OF_SET_MSG_HANDLE
(
&
CopyCommNetActor
::
HandleWaitUntilNoReadableRegst
);
}
else
if
(
msg
.
msg_type
()
==
ActorMsgType
::
kRegstMsg
)
{
auto
regst_wp
=
msg
.
regst_warpper
();
if
(
TryUpdtStateAsProducedRegst
(
regst_wp
->
regst_raw_ptr
())
!=
0
)
{
...
...
@@ -28,8 +27,7 @@ int CopyCommNetActor::HandleCopyCommNet(const ActorMsg& msg) {
return
0
;
}
int
CopyCommNetActor
::
HandleCopyCommNetWhenNoReadableRegstMsg
(
const
ActorMsg
&
msg
)
{
int
CopyCommNetActor
::
HandleWaitUntilNoReadableRegst
(
const
ActorMsg
&
msg
)
{
CHECK_EQ
(
TryUpdtStateAsProducedRegst
(
msg
.
regst_warpper
()
->
regst_raw_ptr
()),
0
);
TryLaunchKernelAndSendMsg
();
...
...
oneflow/core/actor/copy_comm_net_actor.h
浏览文件 @
cf5ed195
...
...
@@ -14,8 +14,8 @@ class CopyCommNetActor final : public Actor {
void
Init
(
const
TaskProto
&
,
const
ThreadCtx
&
)
override
;
private:
int
Handle
CopyCommNet
(
const
ActorMsg
&
)
;
int
Handle
CopyCommNetWhenNoReadableRegstMsg
(
const
ActorMsg
&
)
;
int
Handle
Normal
(
const
ActorMsg
&
)
override
;
int
Handle
WaitUntilNoReadableRegst
(
const
ActorMsg
&
)
override
;
void
TryLaunchKernelAndSendMsg
();
HashMap
<
int64_t
,
std
::
shared_ptr
<
RegstWarpper
>>
piece_id2waiting_in_regst_
;
...
...
oneflow/core/actor/copy_hd_actor.cpp
浏览文件 @
cf5ed195
...
...
@@ -10,13 +10,13 @@ void CopyHdActor::Init(const TaskProto& task_proto,
CHECK
(
thread_ctx
.
copy_hd_cuda_stream
);
mut_device_ctx
().
reset
(
new
CudaDeviceCtx
(
thread_ctx
.
copy_hd_cuda_stream
,
nullptr
,
nullptr
));
OF_SET_MSG_HANDLE
(
&
CopyHdActor
::
Handle
CopyHd
);
OF_SET_MSG_HANDLE
(
&
CopyHdActor
::
Handle
Normal
);
}
int
CopyHdActor
::
Handle
CopyHd
(
const
ActorMsg
&
msg
)
{
int
CopyHdActor
::
Handle
Normal
(
const
ActorMsg
&
msg
)
{
if
(
msg
.
msg_type
()
==
ActorMsgType
::
kCmdMsg
)
{
CHECK_EQ
(
msg
.
actor_cmd
(),
ActorCmd
::
kEORD
);
OF_SET_MSG_HANDLE
(
&
CopyHdActor
::
Handle
CopyHdWhenNoReadableRegstMsg
);
OF_SET_MSG_HANDLE
(
&
CopyHdActor
::
Handle
WaitUntilNoReadableRegst
);
}
else
if
(
msg
.
msg_type
()
==
ActorMsgType
::
kRegstMsg
)
{
if
(
TryUpdtStateAsProducedRegst
(
msg
.
regst_warpper
()
->
regst_raw_ptr
())
!=
0
)
{
...
...
@@ -27,7 +27,7 @@ int CopyHdActor::HandleCopyHd(const ActorMsg& msg) {
return
0
;
}
int
CopyHdActor
::
Handle
CopyHdWhenNoReadableRegstMsg
(
const
ActorMsg
&
msg
)
{
int
CopyHdActor
::
Handle
WaitUntilNoReadableRegst
(
const
ActorMsg
&
msg
)
{
CHECK_EQ
(
TryUpdtStateAsProducedRegst
(
msg
.
regst_warpper
()
->
regst_raw_ptr
()),
0
);
TryLaunchKernelAndSendMsg
();
...
...
oneflow/core/actor/copy_hd_actor.h
浏览文件 @
cf5ed195
...
...
@@ -14,8 +14,8 @@ class CopyHdActor final : public Actor {
void
Init
(
const
TaskProto
&
,
const
ThreadCtx
&
)
override
;
private:
int
Handle
CopyHd
(
const
ActorMsg
&
)
;
int
Handle
CopyHdWhenNoReadableRegstMsg
(
const
ActorMsg
&
)
;
int
Handle
Normal
(
const
ActorMsg
&
)
override
;
int
Handle
WaitUntilNoReadableRegst
(
const
ActorMsg
&
)
override
;
void
TryLaunchKernelAndSendMsg
();
std
::
queue
<
std
::
shared_ptr
<
RegstWarpper
>>
waiting_in_regst_
;
...
...
oneflow/core/actor/fw_data_comp_actor.cpp
浏览文件 @
cf5ed195
...
...
@@ -27,7 +27,7 @@ void FwDataCompActor::Init(const TaskProto& task_proto,
}
else
{
num_of_not_eord_
=
1
+
(
model_regst_desc_id_
!=
-
1
)
+
(
model_tmp_regst_desc_id_
!=
-
1
);
OF_SET_MSG_HANDLE
(
&
FwDataCompActor
::
Handle
FwComp
);
OF_SET_MSG_HANDLE
(
&
FwDataCompActor
::
Handle
Normal
);
}
}
...
...
@@ -53,16 +53,16 @@ bool FwDataCompActor::IsReadReady() {
int
FwDataCompActor
::
WaitToStart
(
const
ActorMsg
&
msg
)
{
CHECK_EQ
(
msg
.
actor_cmd
(),
ActorCmd
::
kStart
);
TryLaunchKernelAndSendMsg
();
OF_SET_MSG_HANDLE
(
&
FwDataCompActor
::
Handle
FwCompWhenNoReadableRegstMsg
);
OF_SET_MSG_HANDLE
(
&
FwDataCompActor
::
Handle
WaitUntilNoReadableRegst
);
return
0
;
}
int
FwDataCompActor
::
Handle
FwComp
(
const
ActorMsg
&
msg
)
{
int
FwDataCompActor
::
Handle
Normal
(
const
ActorMsg
&
msg
)
{
if
(
msg
.
msg_type
()
==
ActorMsgType
::
kCmdMsg
)
{
CHECK_EQ
(
msg
.
actor_cmd
(),
ActorCmd
::
kEORD
);
num_of_not_eord_
-=
1
;
if
(
!
num_of_not_eord_
)
{
OF_SET_MSG_HANDLE
(
&
FwDataCompActor
::
Handle
FwCompWhenNoReadableRegstMsg
);
OF_SET_MSG_HANDLE
(
&
FwDataCompActor
::
Handle
WaitUntilNoReadableRegst
);
}
}
else
if
(
msg
.
msg_type
()
==
ActorMsgType
::
kRegstMsg
)
{
if
(
TryUpdtStateAsProducedRegst
(
msg
.
regst_warpper
()
->
regst_raw_ptr
())
...
...
@@ -87,7 +87,7 @@ int FwDataCompActor::HandleFwComp(const ActorMsg& msg) {
return
0
;
}
int
FwDataCompActor
::
Handle
FwCompWhenNoReadableRegstMsg
(
const
ActorMsg
&
msg
)
{
int
FwDataCompActor
::
Handle
WaitUntilNoReadableRegst
(
const
ActorMsg
&
msg
)
{
CHECK_EQ
(
TryUpdtStateAsProducedRegst
(
msg
.
regst_warpper
()
->
regst_raw_ptr
()),
0
);
TryLaunchKernelAndSendMsg
();
...
...
oneflow/core/actor/fw_data_comp_actor.h
浏览文件 @
cf5ed195
...
...
@@ -15,8 +15,8 @@ class FwDataCompActor final : public CompActor {
private:
int
WaitToStart
(
const
ActorMsg
&
);
int
Handle
FwComp
(
const
ActorMsg
&
)
;
int
Handle
FwCompWhenNoReadableRegstMsg
(
const
ActorMsg
&
)
;
int
Handle
Normal
(
const
ActorMsg
&
)
override
;
int
Handle
WaitUntilNoReadableRegst
(
const
ActorMsg
&
)
override
;
bool
IsReadReady
();
void
TryLaunchKernelAndSendMsg
();
...
...
oneflow/core/actor/model_diff_accumulate_actor.cpp
浏览文件 @
cf5ed195
...
...
@@ -14,15 +14,15 @@ void MdDiffAccActor::Init(const TaskProto& task_proto,
cuda_handle_
.
cublas_handle
(),
cuda_handle_
.
cudnn_handle
()));
}
OF_SET_MSG_HANDLE
(
&
MdDiffAccActor
::
Handle
MdDiffAcc
);
OF_SET_MSG_HANDLE
(
&
MdDiffAccActor
::
Handle
Normal
);
ForEachCurWriteableRegst
(
[
this
](
Regst
*
regst
)
{
model_diff_acc_cnt_
[
regst
]
=
0
;
});
}
int
MdDiffAccActor
::
Handle
MdDiffAcc
(
const
ActorMsg
&
msg
)
{
int
MdDiffAccActor
::
Handle
Normal
(
const
ActorMsg
&
msg
)
{
if
(
msg
.
msg_type
()
==
ActorMsgType
::
kCmdMsg
)
{
CHECK_EQ
(
msg
.
actor_cmd
(),
ActorCmd
::
kEORD
);
OF_SET_MSG_HANDLE
(
&
MdDiffAccActor
::
Handle
MdDiffAccWhenNoReadableRegstMsg
);
OF_SET_MSG_HANDLE
(
&
MdDiffAccActor
::
Handle
WaitUntilNoReadableRegst
);
}
else
if
(
msg
.
msg_type
()
==
ActorMsgType
::
kRegstMsg
)
{
if
(
TryUpdtStateAsProducedRegst
(
msg
.
regst_warpper
()
->
regst_raw_ptr
())
!=
0
)
{
...
...
@@ -33,7 +33,7 @@ int MdDiffAccActor::HandleMdDiffAcc(const ActorMsg& msg) {
return
0
;
}
int
MdDiffAccActor
::
Handle
MdDiffAccWhenNoReadableRegstMsg
(
const
ActorMsg
&
msg
)
{
int
MdDiffAccActor
::
Handle
WaitUntilNoReadableRegst
(
const
ActorMsg
&
msg
)
{
CHECK_EQ
(
TryUpdtStateAsProducedRegst
(
msg
.
regst_warpper
()
->
regst_raw_ptr
()),
0
);
TryLaunchKernelAndSendMsg
();
...
...
oneflow/core/actor/model_diff_accumulate_actor.h
浏览文件 @
cf5ed195
...
...
@@ -14,8 +14,8 @@ class MdDiffAccActor final : public CompActor {
void
Init
(
const
TaskProto
&
,
const
ThreadCtx
&
)
override
;
private:
int
Handle
MdDiffAcc
(
const
ActorMsg
&
)
;
int
Handle
MdDiffAccWhenNoReadableRegstMsg
(
const
ActorMsg
&
)
;
int
Handle
Normal
(
const
ActorMsg
&
)
override
;
int
Handle
WaitUntilNoReadableRegst
(
const
ActorMsg
&
)
override
;
void
TryLaunchKernelAndSendMsg
();
...
...
oneflow/core/actor/model_save_comp_actor.cpp
浏览文件 @
cf5ed195
...
...
@@ -9,12 +9,12 @@ void MdSaveCompActor::Init(const TaskProto& task_proto,
model_regst_desc_id_
=
RegstDescId4Name
(
"model"
);
CHECK
(
thread_ctx
.
cpu_stream
);
mut_device_ctx
().
reset
(
new
CpuDeviceCtx
(
thread_ctx
.
cpu_stream
));
OF_SET_MSG_HANDLE
(
&
MdSaveCompActor
::
Handle
SaveMode
l
);
OF_SET_MSG_HANDLE
(
&
MdSaveCompActor
::
Handle
Norma
l
);
}
int
MdSaveCompActor
::
Handle
SaveMode
l
(
const
ActorMsg
&
actor_msg
)
{
int
MdSaveCompActor
::
Handle
Norma
l
(
const
ActorMsg
&
actor_msg
)
{
if
(
actor_msg
.
msg_type
()
==
ActorMsgType
::
kCmdMsg
)
{
CHECK
(
actor_msg
.
actor_cmd
()
==
ActorCmd
::
kEORD
);
CHECK
_EQ
(
actor_msg
.
actor_cmd
(),
ActorCmd
::
kEORD
);
return
1
;
}
else
if
(
actor_msg
.
msg_type
()
==
ActorMsgType
::
kRegstMsg
)
{
std
::
shared_ptr
<
RegstWarpper
>
regst_warpper
=
actor_msg
.
regst_warpper
();
...
...
oneflow/core/actor/model_save_comp_actor.h
浏览文件 @
cf5ed195
...
...
@@ -14,7 +14,10 @@ class MdSaveCompActor final : public CompActor {
void
Init
(
const
TaskProto
&
,
const
ThreadCtx
&
)
override
;
private:
int
HandleSaveModel
(
const
ActorMsg
&
);
int
HandleNormal
(
const
ActorMsg
&
)
override
;
int
HandleWaitUntilNoReadableRegst
(
const
ActorMsg
&
msg
)
override
{
UNEXPECTED_RUN
();
}
int64_t
model_regst_desc_id_
;
};
...
...
oneflow/core/actor/model_update_comp_actor.cpp
浏览文件 @
cf5ed195
...
...
@@ -56,7 +56,7 @@ int MdUpdtCompActor::HandleBeforeSendInitialModel(const ActorMsg& actor_msg) {
SetReadOnlyForRegstDescId
(
model_tmp_regst_desc_id_
);
AsyncSendEORDMsgToSubscribers
(
model_tmp_regst_desc_id_
);
if
(
JobDesc
::
Singleton
()
->
is_train
())
{
OF_SET_MSG_HANDLE
(
&
MdUpdtCompActor
::
Handle
UpdateMode
l
);
OF_SET_MSG_HANDLE
(
&
MdUpdtCompActor
::
Handle
Norma
l
);
}
else
{
AsyncSendEORDMsgToSubscribers
(
model_regst_desc_id_
);
OF_SET_MSG_HANDLE
(
&
MdUpdtCompActor
::
HandleWaitUntilReadingCntEqualZero
);
...
...
@@ -64,10 +64,10 @@ int MdUpdtCompActor::HandleBeforeSendInitialModel(const ActorMsg& actor_msg) {
return
0
;
}
int
MdUpdtCompActor
::
Handle
UpdateMode
l
(
const
ActorMsg
&
actor_msg
)
{
int
MdUpdtCompActor
::
Handle
Norma
l
(
const
ActorMsg
&
actor_msg
)
{
if
(
actor_msg
.
msg_type
()
==
ActorMsgType
::
kCmdMsg
)
{
CHECK_EQ
(
actor_msg
.
actor_cmd
(),
ActorCmd
::
kEORD
);
OF_SET_MSG_HANDLE
(
&
MdUpdtCompActor
::
Handle
UpdtModelWhenNoReadableRegstMsg
);
OF_SET_MSG_HANDLE
(
&
MdUpdtCompActor
::
Handle
WaitUntilNoReadableRegst
);
}
else
if
(
actor_msg
.
msg_type
()
==
ActorMsgType
::
kRegstMsg
)
{
auto
regst_warpper
=
actor_msg
.
regst_warpper
();
if
(
TryUpdtStateAsProducedRegst
(
regst_warpper
->
regst_raw_ptr
())
!=
0
)
{
...
...
@@ -80,8 +80,7 @@ int MdUpdtCompActor::HandleUpdateModel(const ActorMsg& actor_msg) {
return
0
;
}
int
MdUpdtCompActor
::
HandleUpdtModelWhenNoReadableRegstMsg
(
const
ActorMsg
&
actor_msg
)
{
int
MdUpdtCompActor
::
HandleWaitUntilNoReadableRegst
(
const
ActorMsg
&
actor_msg
)
{
CHECK_EQ
(
TryUpdtStateAsProducedRegst
(
actor_msg
.
regst_warpper
()
->
regst_raw_ptr
()),
0
);
...
...
oneflow/core/actor/model_update_comp_actor.h
浏览文件 @
cf5ed195
...
...
@@ -17,8 +17,8 @@ class MdUpdtCompActor final : public CompActor {
int
HandleBeforeInitDeviceCtx
(
const
ActorMsg
&
);
int
HandleBeforeInitializeModel
(
const
ActorMsg
&
);
int
HandleBeforeSendInitialModel
(
const
ActorMsg
&
);
int
Handle
UpdateModel
(
const
ActorMsg
&
)
;
int
Handle
UpdtModelWhenNoReadableRegstMsg
(
const
ActorMsg
&
)
;
int
Handle
Normal
(
const
ActorMsg
&
)
override
;
int
Handle
WaitUntilNoReadableRegst
(
const
ActorMsg
&
)
override
;
void
TryLaunchKernelAndSendMsg
();
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录