Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Oneflow-Inc
oneflow
提交
ad0783cb
O
oneflow
项目概览
Oneflow-Inc
/
oneflow
上一次同步 接近 3 年
通知
13
Star
2733
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
O
oneflow
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
ad0783cb
编写于
6月 23, 2017
作者:
W
willzhang4a58
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
make msg_handle in base actor
上级
8fcade0c
变更
12
隐藏空白更改
内联
并排
Showing
12 changed file
with
44 addition
and
94 deletion
+44
-94
oneflow/core/actor/actor.cpp
oneflow/core/actor/actor.cpp
+9
-0
oneflow/core/actor/actor.h
oneflow/core/actor/actor.h
+14
-1
oneflow/core/actor/boxing_actor.cpp
oneflow/core/actor/boxing_actor.cpp
+4
-17
oneflow/core/actor/boxing_actor.h
oneflow/core/actor/boxing_actor.h
+0
-3
oneflow/core/actor/bp_data_comp_actor.cpp
oneflow/core/actor/bp_data_comp_actor.cpp
+4
-17
oneflow/core/actor/bp_data_comp_actor.h
oneflow/core/actor/bp_data_comp_actor.h
+0
-3
oneflow/core/actor/fw_data_comp_actor.cpp
oneflow/core/actor/fw_data_comp_actor.cpp
+4
-17
oneflow/core/actor/fw_data_comp_actor.h
oneflow/core/actor/fw_data_comp_actor.h
+0
-3
oneflow/core/actor/model_save_comp_actor.cpp
oneflow/core/actor/model_save_comp_actor.cpp
+1
-5
oneflow/core/actor/model_save_comp_actor.h
oneflow/core/actor/model_save_comp_actor.h
+1
-3
oneflow/core/actor/model_update_comp_actor.cpp
oneflow/core/actor/model_update_comp_actor.cpp
+7
-22
oneflow/core/actor/model_update_comp_actor.h
oneflow/core/actor/model_update_comp_actor.h
+0
-3
未找到文件。
oneflow/core/actor/actor.cpp
浏览文件 @
ad0783cb
...
...
@@ -60,6 +60,15 @@ KernelCtx Actor::GenDefaultKernelCtx() const {
return
ctx
;
}
int
Actor
::
HandleWaitUntilReadingCntEqualZero
(
const
ActorMsg
&
msg
)
{
CHECK_EQ
(
TryUpdtStateAsProducedRegst
(
msg
.
regst_warpper
()
->
regst_raw_ptr
()),
0
);
if
(
total_reading_cnt_
==
0
)
{
msg_handle_
=
nullptr
;
return
1
;
}
return
0
;
}
void
Actor
::
AsyncWardKernel
(
const
KernelCtx
&
kernel_ctx
,
std
::
function
<
std
::
shared_ptr
<
RegstWarpper
>
(
int64_t
)
>
Regst4RegstDescId
)
{
...
...
oneflow/core/actor/actor.h
浏览文件 @
ad0783cb
...
...
@@ -24,7 +24,9 @@ class Actor {
virtual
void
Init
(
const
TaskProto
&
,
const
ThreadCtx
&
)
=
0
;
// 1: success, and actor finish
// 0: success, and actor not finish
virtual
int
ProcessMsg
(
const
ActorMsg
&
)
=
0
;
int
ProcessMsg
(
const
ActorMsg
&
msg
)
{
return
(
this
->*
msg_handle_
)(
msg
);
}
int64_t
actor_id
()
const
{
return
actor_id_
;
}
...
...
@@ -40,6 +42,15 @@ class Actor {
std
::
unique_ptr
<
DeviceCtx
>&
mut_device_ctx
()
{
return
device_ctx_
;
}
KernelCtx
GenDefaultKernelCtx
()
const
;
// Msg Handle
using
MsgHandle
=
int
(
Actor
::*
)(
const
ActorMsg
&
);
void
set_msg_handle
(
MsgHandle
val
)
{
msg_handle_
=
val
;
}
#define OF_SET_MSG_HANDLE(val) \
do { \
set_msg_handle(static_cast<MsgHandle>(val)); \
} while(0)
int
HandleWaitUntilReadingCntEqualZero
(
const
ActorMsg
&
msg
);
// Status of Produced Registers
int64_t
expected_piece_id
()
const
{
return
expected_piece_id_
;
}
void
AsyncWardKernel
(
...
...
@@ -70,6 +81,8 @@ class Actor {
HashMap
<
std
::
string
,
int64_t
>
name2regst_desc_id_
;
std
::
unique_ptr
<
DeviceCtx
>
device_ctx_
;
MsgHandle
msg_handle_
;
// Status of Produced Registers
int64_t
expected_piece_id_
;
...
...
oneflow/core/actor/boxing_actor.cpp
浏览文件 @
ad0783cb
...
...
@@ -11,11 +11,7 @@ void BoxingActor::Init(const TaskProto& task_proto, const ThreadCtx& thread_ctx)
num_of_eord_
=
0
;
CHECK
(
thread_ctx
.
cpu_stream
);
mut_device_ctx
().
reset
(
new
CpuDeviceCtx
(
thread_ctx
.
cpu_stream
));
cur_msg_handle_
=
&
BoxingActor
::
HandleBoxing
;
}
int
BoxingActor
::
ProcessMsg
(
const
ActorMsg
&
msg
)
{
return
(
this
->*
cur_msg_handle_
)(
msg
);
OF_SET_MSG_HANDLE
(
&
BoxingActor
::
HandleBoxing
);
}
int
BoxingActor
::
HandleBoxing
(
const
ActorMsg
&
msg
)
{
...
...
@@ -23,7 +19,7 @@ int BoxingActor::HandleBoxing(const ActorMsg& msg) {
CHECK_EQ
(
msg
.
actor_cmd
(),
ActorCmd
::
kEORD
);
num_of_eord_
+=
1
;
if
(
num_of_eord_
==
num_of_subscribed_regsts_
)
{
cur_msg_handle_
=
&
BoxingActor
::
HandleBoxingWhenNoReadableRegstMsg
;
OF_SET_MSG_HANDLE
(
&
BoxingActor
::
HandleBoxingWhenNoReadableRegstMsg
)
;
}
}
else
if
(
msg
.
msg_type
()
==
ActorMsgType
::
kRegstMsg
)
{
if
(
TryUpdtStateAsProducedRegst
(
msg
.
regst_warpper
()
->
regst_raw_ptr
())
!=
0
)
{
...
...
@@ -44,25 +40,16 @@ int BoxingActor::HandleBoxingWhenNoReadableRegstMsg(const ActorMsg& msg) {
if
(
num_of_read_empty_
==
num_of_subscribed_regsts_
)
{
AsyncSendEORDMsgForAllProducedRegstDesc
();
if
(
total_reading_cnt
()
==
0
)
{
cur_msg_handle_
=
nullptr
;
OF_SET_MSG_HANDLE
(
nullptr
)
;
return
1
;
}
else
{
cur_msg_handle_
=
&
BoxingActor
::
HandleWaitUntilReadingCntEqualZero
;
OF_SET_MSG_HANDLE
(
&
BoxingActor
::
HandleWaitUntilReadingCntEqualZero
)
;
return
0
;
}
}
return
0
;
}
int
BoxingActor
::
HandleWaitUntilReadingCntEqualZero
(
const
ActorMsg
&
msg
)
{
CHECK_EQ
(
TryUpdtStateAsProducedRegst
(
msg
.
regst_warpper
()
->
regst_raw_ptr
()),
0
);
if
(
total_reading_cnt
()
==
0
)
{
cur_msg_handle_
=
nullptr
;
return
1
;
}
return
0
;
}
void
BoxingActor
::
TryWardKernelAndSendMsg
()
{
if
(
!
num_of_read_empty_
&&
IsWriteReady
())
{
int64_t
piece_id
=
expected_piece_id
();
...
...
oneflow/core/actor/boxing_actor.h
浏览文件 @
ad0783cb
...
...
@@ -12,17 +12,14 @@ class BoxingActor final : public Actor {
~
BoxingActor
()
=
default
;
void
Init
(
const
TaskProto
&
,
const
ThreadCtx
&
)
override
;
int
ProcessMsg
(
const
ActorMsg
&
)
override
;
private:
int
HandleInitDeviceCtx
(
const
ActorMsg
&
);
int
HandleBoxing
(
const
ActorMsg
&
);
int
HandleBoxingWhenNoReadableRegstMsg
(
const
ActorMsg
&
);
int
HandleWaitUntilReadingCntEqualZero
(
const
ActorMsg
&
);
void
TryWardKernelAndSendMsg
();
int
(
BoxingActor
::*
cur_msg_handle_
)(
const
ActorMsg
&
);
int
num_of_subscribed_regsts_
;
int
num_of_read_empty_
;
int
num_of_eord_
;
...
...
oneflow/core/actor/bp_data_comp_actor.cpp
浏览文件 @
ad0783cb
...
...
@@ -20,7 +20,7 @@ void BpDataCompActor::Init(const TaskProto& task_proto, const ThreadCtx& thread_
cuda_handle_
.
cublas_handle
(),
cuda_handle_
.
cudnn_handle
()));
}
cur_msg_handle_
=
&
BpDataCompActor
::
HandleBpComp
;
OF_SET_MSG_HANDLE
(
&
BpDataCompActor
::
HandleBpComp
)
;
}
bool
BpDataCompActor
::
IsReadReady
()
{
...
...
@@ -36,16 +36,12 @@ bool BpDataCompActor::IsReadReady() {
return
!
num_of_read_empty_
;
}
int
BpDataCompActor
::
ProcessMsg
(
const
ActorMsg
&
msg
)
{
return
(
this
->*
cur_msg_handle_
)(
msg
);
}
int
BpDataCompActor
::
HandleBpComp
(
const
ActorMsg
&
msg
)
{
if
(
msg
.
msg_type
()
==
ActorMsgType
::
kCmdMsg
)
{
CHECK_EQ
(
msg
.
actor_cmd
(),
ActorCmd
::
kEORD
);
num_of_eord_
+=
1
;
if
(
num_of_eord_
==
6
)
{
cur_msg_handle_
=
&
BpDataCompActor
::
HandleBpCompWhenNoReadableRegstMsg
;
OF_SET_MSG_HANDLE
(
&
BpDataCompActor
::
HandleBpCompWhenNoReadableRegstMsg
)
;
}
}
else
if
(
msg
.
msg_type
()
==
ActorMsgType
::
kRegstMsg
)
{
if
(
TryUpdtStateAsProducedRegst
(
msg
.
regst_warpper
()
->
regst_raw_ptr
())
!=
0
)
{
...
...
@@ -78,25 +74,16 @@ int BpDataCompActor::HandleBpCompWhenNoReadableRegstMsg(const ActorMsg& msg) {
AsyncSendEORDMsgForAllProducedRegstDesc
();
num_of_read_empty_
=
6
;
if
(
total_reading_cnt
()
==
0
)
{
cur_msg_handle_
=
nullptr
;
OF_SET_MSG_HANDLE
(
nullptr
)
;
return
1
;
}
else
{
cur_msg_handle_
=
&
BpDataCompActor
::
HandleWaitUntilReadingCntEqualZero
;
OF_SET_MSG_HANDLE
(
nullptr
)
;
return
0
;
}
}
return
0
;
}
int
BpDataCompActor
::
HandleWaitUntilReadingCntEqualZero
(
const
ActorMsg
&
msg
)
{
CHECK_EQ
(
TryUpdtStateAsProducedRegst
(
msg
.
regst_warpper
()
->
regst_raw_ptr
()),
0
);
if
(
total_reading_cnt
()
==
0
)
{
cur_msg_handle_
=
nullptr
;
return
1
;
}
return
0
;
}
void
BpDataCompActor
::
TryWardKernelAndSendMsg
()
{
while
(
IsReadReady
()
&&
IsWriteReady
())
{
int64_t
cur_model
=
read_regst_
.
at
(
model_regst_desc_id_
).
front
()
->
model_version_id
();
...
...
oneflow/core/actor/bp_data_comp_actor.h
浏览文件 @
ad0783cb
...
...
@@ -12,19 +12,16 @@ public:
~
BpDataCompActor
()
=
default
;
void
Init
(
const
TaskProto
&
,
const
ThreadCtx
&
)
override
;
int
ProcessMsg
(
const
ActorMsg
&
)
override
;
private:
int
HandleInitDeviceCtx
(
const
ActorMsg
&
);
int
HandleBpComp
(
const
ActorMsg
&
);
int
HandleBpCompWhenNoReadableRegstMsg
(
const
ActorMsg
&
);
int
HandleWaitUntilReadingCntEqualZero
(
const
ActorMsg
&
);
bool
IsReadReady
();
void
TryWardKernelAndSendMsg
();
CudaStreamHandle
cuda_handle_
;
int
(
BpDataCompActor
::*
cur_msg_handle_
)(
const
ActorMsg
&
);
int
num_of_read_empty_
;
int
num_of_eord_
;
int64_t
expected_model_version_id_
;
...
...
oneflow/core/actor/fw_data_comp_actor.cpp
浏览文件 @
ad0783cb
...
...
@@ -18,7 +18,7 @@ void FwDataCompActor::Init(const TaskProto& task_proto,
cuda_handle_
.
cublas_handle
(),
cuda_handle_
.
cudnn_handle
()));
}
cur_msg_handle_
=
&
FwDataCompActor
::
HandleFwComp
;
OF_SET_MSG_HANDLE
(
&
FwDataCompActor
::
HandleFwComp
)
;
}
bool
FwDataCompActor
::
IsReadReady
()
{
...
...
@@ -33,16 +33,12 @@ bool FwDataCompActor::IsReadReady() {
return
false
;
}
int
FwDataCompActor
::
ProcessMsg
(
const
ActorMsg
&
msg
)
{
return
(
this
->*
cur_msg_handle_
)(
msg
);
}
int
FwDataCompActor
::
HandleFwComp
(
const
ActorMsg
&
msg
)
{
if
(
msg
.
msg_type
()
==
ActorMsgType
::
kCmdMsg
)
{
CHECK_EQ
(
msg
.
actor_cmd
(),
ActorCmd
::
kEORD
);
num_of_eord_
+=
1
;
if
(
num_of_eord_
==
3
)
{
cur_msg_handle_
=
&
FwDataCompActor
::
HandleFwCompWhenNoReadableRegstMsg
;
OF_SET_MSG_HANDLE
(
&
FwDataCompActor
::
HandleFwCompWhenNoReadableRegstMsg
)
;
}
}
else
if
(
msg
.
msg_type
()
==
ActorMsgType
::
kRegstMsg
)
{
if
(
TryUpdtStateAsProducedRegst
(
msg
.
regst_warpper
()
->
regst_raw_ptr
())
!=
0
)
{
...
...
@@ -78,25 +74,16 @@ int FwDataCompActor::HandleFwCompWhenNoReadableRegstMsg(const ActorMsg& msg) {
model_tmp_regst_
=
nullptr
;
AsyncSendEORDMsgForAllProducedRegstDesc
();
if
(
total_reading_cnt
()
==
0
)
{
cur_msg_handle_
=
nullptr
;
OF_SET_MSG_HANDLE
(
nullptr
)
;
return
1
;
}
else
{
cur_msg_handle_
=
&
FwDataCompActor
::
HandleWaitUntilReadingCntEqualZero
;
OF_SET_MSG_HANDLE
(
&
FwDataCompActor
::
HandleWaitUntilReadingCntEqualZero
)
;
return
0
;
}
}
return
0
;
}
int
FwDataCompActor
::
HandleWaitUntilReadingCntEqualZero
(
const
ActorMsg
&
msg
)
{
CHECK_EQ
(
TryUpdtStateAsProducedRegst
(
msg
.
regst_warpper
()
->
regst_raw_ptr
()),
0
);
if
(
total_reading_cnt
()
==
0
)
{
cur_msg_handle_
=
nullptr
;
return
1
;
}
return
0
;
}
void
FwDataCompActor
::
TryWardKernelAndSendMsg
()
{
while
(
IsReadReady
()
&&
IsWriteReady
())
{
CHECK_EQ
(
in_
.
front
()
->
piece_id
(),
expected_piece_id
());
...
...
oneflow/core/actor/fw_data_comp_actor.h
浏览文件 @
ad0783cb
...
...
@@ -12,18 +12,15 @@ public:
~
FwDataCompActor
()
=
default
;
void
Init
(
const
TaskProto
&
,
const
ThreadCtx
&
)
override
;
int
ProcessMsg
(
const
ActorMsg
&
)
override
;
private:
int
HandleFwComp
(
const
ActorMsg
&
);
int
HandleFwCompWhenNoReadableRegstMsg
(
const
ActorMsg
&
);
int
HandleWaitUntilReadingCntEqualZero
(
const
ActorMsg
&
);
bool
IsReadReady
();
void
TryWardKernelAndSendMsg
();
CudaStreamHandle
cuda_handle_
;
int
(
FwDataCompActor
::*
cur_msg_handle_
)(
const
ActorMsg
&
);
int
num_of_eord_
;
int64_t
expected_model_version_id_
;
int64_t
model_regst_desc_id_
;
...
...
oneflow/core/actor/model_save_comp_actor.cpp
浏览文件 @
ad0783cb
...
...
@@ -8,11 +8,7 @@ void MdSaveCompActor::Init(const TaskProto& task_proto, const ThreadCtx& thread_
model_regst_desc_id_
=
RegstDescId4Name
(
"model"
);
CHECK
(
thread_ctx
.
cpu_stream
);
mut_device_ctx
().
reset
(
new
CpuDeviceCtx
(
thread_ctx
.
cpu_stream
));
cur_msg_handle_
=
&
MdSaveCompActor
::
HandleSaveModel
;
}
int
MdSaveCompActor
::
ProcessMsg
(
const
ActorMsg
&
actor_msg
)
{
return
(
this
->*
cur_msg_handle_
)(
actor_msg
);
OF_SET_MSG_HANDLE
(
&
MdSaveCompActor
::
HandleSaveModel
);
}
int
MdSaveCompActor
::
HandleSaveModel
(
const
ActorMsg
&
actor_msg
)
{
...
...
oneflow/core/actor/model_save_comp_actor.h
浏览文件 @
ad0783cb
...
...
@@ -12,11 +12,9 @@ class MdSaveCompActor final : public CompActor {
~
MdSaveCompActor
()
=
default
;
void
Init
(
const
TaskProto
&
,
const
ThreadCtx
&
)
override
;
int
ProcessMsg
(
const
ActorMsg
&
)
override
;
private:
int
HandleSaveModel
(
const
ActorMsg
&
);
int
(
MdSaveCompActor
::*
cur_msg_handle_
)(
const
ActorMsg
&
);
int
HandleSaveModel
(
const
ActorMsg
&
);
int64_t
model_regst_desc_id_
;
};
...
...
oneflow/core/actor/model_update_comp_actor.cpp
浏览文件 @
ad0783cb
...
...
@@ -15,11 +15,7 @@ void MdUpdtCompActor::Init(const TaskProto& task_proto, const ThreadCtx& thread_
cuda_handle_
.
cublas_handle
(),
cuda_handle_
.
cudnn_handle
()));
}
cur_msg_handle_
=
&
MdUpdtCompActor
::
HandleBeforeInitializeModel
;
}
int
MdUpdtCompActor
::
ProcessMsg
(
const
ActorMsg
&
actor_msg
)
{
return
(
this
->*
cur_msg_handle_
)(
actor_msg
);
OF_SET_MSG_HANDLE
(
&
MdUpdtCompActor
::
HandleBeforeInitializeModel
);
}
int
MdUpdtCompActor
::
HandleBeforeInitializeModel
(
const
ActorMsg
&
actor_msg
)
{
...
...
@@ -50,7 +46,7 @@ int MdUpdtCompActor::HandleBeforeInitializeModel(const ActorMsg& actor_msg) {
return
ret
;
});
}
cur_msg_handle_
=
&
MdUpdtCompActor
::
HandleBeforeSendInitialModel
;
OF_SET_MSG_HANDLE
(
&
MdUpdtCompActor
::
HandleBeforeSendInitialModel
)
;
return
0
;
}
...
...
@@ -60,10 +56,10 @@ int MdUpdtCompActor::HandleBeforeSendInitialModel(const ActorMsg& actor_msg) {
SetReadOnlyForRegstDescId
(
model_tmp_regst_desc_id_
);
AsyncSendEORDMsgToSubscribers
(
model_tmp_regst_desc_id_
);
if
(
JobDesc
::
Singleton
().
is_train
())
{
cur_msg_handle_
=
&
MdUpdtCompActor
::
HandleUpdateModel
;
OF_SET_MSG_HANDLE
(
&
MdUpdtCompActor
::
HandleUpdateModel
)
;
}
else
{
AsyncSendEORDMsgToSubscribers
(
model_regst_desc_id_
);
cur_msg_handle_
=
&
MdUpdtCompActor
::
HandleWaitUntilReadingCntEqualZero
;
OF_SET_MSG_HANDLE
(
&
MdUpdtCompActor
::
HandleWaitUntilReadingCntEqualZero
)
;
}
return
0
;
}
...
...
@@ -71,7 +67,7 @@ int MdUpdtCompActor::HandleBeforeSendInitialModel(const ActorMsg& actor_msg) {
int
MdUpdtCompActor
::
HandleUpdateModel
(
const
ActorMsg
&
actor_msg
)
{
if
(
actor_msg
.
msg_type
()
==
ActorMsgType
::
kCmdMsg
)
{
CHECK_EQ
(
actor_msg
.
actor_cmd
(),
ActorCmd
::
kEORD
);
cur_msg_handle_
=
&
MdUpdtCompActor
::
HandleUpdtModelWhenNoReadableRegstMsg
;
OF_SET_MSG_HANDLE
(
&
MdUpdtCompActor
::
HandleUpdtModelWhenNoReadableRegstMsg
)
;
}
else
if
(
actor_msg
.
msg_type
()
==
ActorMsgType
::
kRegstMsg
)
{
auto
regst_warpper
=
actor_msg
.
regst_warpper
();
if
(
TryUpdtStateAsProducedRegst
(
regst_warpper
->
regst_raw_ptr
())
!=
0
)
{
...
...
@@ -92,27 +88,16 @@ int MdUpdtCompActor::HandleUpdtModelWhenNoReadableRegstMsg(
if
(
waiting_model_diff_acc_queue_
.
empty
())
{
AsyncSendEORDMsgToSubscribers
(
model_regst_desc_id_
);
if
(
total_reading_cnt
()
==
0
)
{
cur_msg_handle_
=
nullptr
;
OF_SET_MSG_HANDLE
(
nullptr
)
;
return
1
;
}
else
{
cur_msg_handle_
=
&
MdUpdtCompActor
::
HandleWaitUntilReadingCntEqualZero
;
OF_SET_MSG_HANDLE
(
&
MdUpdtCompActor
::
HandleWaitUntilReadingCntEqualZero
)
;
return
0
;
}
}
return
0
;
}
int
MdUpdtCompActor
::
HandleWaitUntilReadingCntEqualZero
(
const
ActorMsg
&
actor_msg
)
{
CHECK_EQ
(
TryUpdtStateAsProducedRegst
(
actor_msg
.
regst_warpper
()
->
regst_raw_ptr
()),
0
);
if
(
total_reading_cnt
()
==
0
)
{
cur_msg_handle_
=
nullptr
;
return
1
;
}
return
0
;
}
void
MdUpdtCompActor
::
TryWardKernelAndSendMsg
()
{
if
(
!
waiting_model_diff_acc_queue_
.
empty
()
&&
IsWriteReady
())
{
auto
model_diff_acc_wpr
=
waiting_model_diff_acc_queue_
.
front
();
...
...
oneflow/core/actor/model_update_comp_actor.h
浏览文件 @
ad0783cb
...
...
@@ -12,7 +12,6 @@ class MdUpdtCompActor final : public CompActor {
~
MdUpdtCompActor
()
=
default
;
void
Init
(
const
TaskProto
&
,
const
ThreadCtx
&
)
override
;
int
ProcessMsg
(
const
ActorMsg
&
)
override
;
private:
int
HandleBeforeInitDeviceCtx
(
const
ActorMsg
&
);
...
...
@@ -20,12 +19,10 @@ class MdUpdtCompActor final : public CompActor {
int
HandleBeforeSendInitialModel
(
const
ActorMsg
&
);
int
HandleUpdateModel
(
const
ActorMsg
&
);
int
HandleUpdtModelWhenNoReadableRegstMsg
(
const
ActorMsg
&
);
int
HandleWaitUntilReadingCntEqualZero
(
const
ActorMsg
&
);
void
TryWardKernelAndSendMsg
();
CudaStreamHandle
cuda_handle_
;
int
(
MdUpdtCompActor
::*
cur_msg_handle_
)(
const
ActorMsg
&
);
int64_t
model_regst_desc_id_
;
int64_t
model_tmp_regst_desc_id_
;
std
::
queue
<
std
::
shared_ptr
<
RegstWarpper
>>
waiting_model_diff_acc_queue_
;
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录