Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Oneflow-Inc
oneflow
提交
073d368b
O
oneflow
项目概览
Oneflow-Inc
/
oneflow
上一次同步 2 年多
通知
13
Star
2733
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
O
oneflow
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
前往新版Gitcode,体验更适合开发者的 AI 搜索 >>
提交
073d368b
编写于
9月 17, 2020
作者:
S
Shenghang Tsai
浏览文件
操作
浏览文件
下载
差异文件
Merge branch 'master' of
https://github.com/Oneflow-Inc/oneflow
into master
上级
c6e07900
c74a53dd
变更
12
隐藏空白更改
内联
并排
Showing
12 changed file
with
283 addition
and
56 deletion
+283
-56
oneflow/core/control/ctrl_client.cpp
oneflow/core/control/ctrl_client.cpp
+29
-0
oneflow/core/control/ctrl_client.h
oneflow/core/control/ctrl_client.h
+6
-0
oneflow/core/control/ctrl_server.cpp
oneflow/core/control/ctrl_server.cpp
+2
-1
oneflow/core/eager/eager_instruction.proto
oneflow/core/eager/eager_instruction.proto
+10
-0
oneflow/core/eager/eager_oneflow.cpp
oneflow/core/eager/eager_oneflow.cpp
+42
-24
oneflow/core/eager/eager_oneflow.h
oneflow/core/eager/eager_oneflow.h
+42
-0
oneflow/core/job/cluster.cpp
oneflow/core/job/cluster.cpp
+43
-8
oneflow/core/job/cluster_instruction.cpp
oneflow/core/job/cluster_instruction.cpp
+76
-18
oneflow/core/job/cluster_instruction.h
oneflow/core/job/cluster_instruction.h
+3
-1
oneflow/core/job/cluster_instruction.proto
oneflow/core/job/cluster_instruction.proto
+3
-0
oneflow/python/oneflow_internal_helper.h
oneflow/python/oneflow_internal_helper.h
+5
-3
oneflow/python/test/ops/test_ccrelu.py
oneflow/python/test/ops/test_ccrelu.py
+22
-1
未找到文件。
oneflow/core/control/ctrl_client.cpp
浏览文件 @
073d368b
...
...
@@ -100,6 +100,13 @@ void CtrlClient::PushKV(const std::string& k, std::function<void(std::string*)>
call
(
GetResponsibleStub
(
k
));
}
void
CtrlClient
::
PushMasterKV
(
const
std
::
string
&
k
,
std
::
function
<
void
(
std
::
string
*
)
>
VSetter
)
{
ClientCall
<
CtrlMethod
::
kPushKV
>
call
;
call
.
mut_request
()
->
set_key
(
k
);
VSetter
(
call
.
mut_request
()
->
mutable_val
());
call
(
GetMasterStub
());
}
void
CtrlClient
::
PushKV
(
const
std
::
string
&
k
,
const
std
::
string
&
v
)
{
PushKV
(
k
,
[
&
](
std
::
string
*
o
)
{
*
o
=
v
;
});
}
...
...
@@ -108,12 +115,22 @@ void CtrlClient::PushKV(const std::string& k, const PbMessage& msg) {
PushKV
(
k
,
[
&
](
std
::
string
*
o
)
{
msg
.
SerializeToString
(
o
);
});
}
void
CtrlClient
::
PushMasterKV
(
const
std
::
string
&
k
,
const
PbMessage
&
msg
)
{
PushMasterKV
(
k
,
[
&
](
std
::
string
*
o
)
{
msg
.
SerializeToString
(
o
);
});
}
void
CtrlClient
::
ClearKV
(
const
std
::
string
&
k
)
{
ClientCall
<
CtrlMethod
::
kClearKV
>
call
;
call
.
mut_request
()
->
set_key
(
k
);
call
(
GetResponsibleStub
(
k
));
}
void
CtrlClient
::
ClearMasterKV
(
const
std
::
string
&
k
)
{
ClientCall
<
CtrlMethod
::
kClearKV
>
call
;
call
.
mut_request
()
->
set_key
(
k
);
call
(
GetMasterStub
());
}
void
CtrlClient
::
PullKV
(
const
std
::
string
&
k
,
std
::
function
<
void
(
const
std
::
string
&
)
>
VGetter
)
{
ClientCall
<
CtrlMethod
::
kPullKV
>
call
;
call
.
mut_request
()
->
set_key
(
k
);
...
...
@@ -121,6 +138,14 @@ void CtrlClient::PullKV(const std::string& k, std::function<void(const std::stri
VGetter
(
call
.
response
().
val
());
}
void
CtrlClient
::
PullMasterKV
(
const
std
::
string
&
k
,
std
::
function
<
void
(
const
std
::
string
&
)
>
VGetter
)
{
ClientCall
<
CtrlMethod
::
kPullKV
>
call
;
call
.
mut_request
()
->
set_key
(
k
);
call
(
GetMasterStub
());
VGetter
(
call
.
response
().
val
());
}
void
CtrlClient
::
PullKV
(
const
std
::
string
&
k
,
std
::
string
*
v
)
{
PullKV
(
k
,
[
&
](
const
std
::
string
&
i
)
{
*
v
=
i
;
});
}
...
...
@@ -129,6 +154,10 @@ void CtrlClient::PullKV(const std::string& k, PbMessage* msg) {
PullKV
(
k
,
[
&
](
const
std
::
string
&
i
)
{
msg
->
ParseFromString
(
i
);
});
}
void
CtrlClient
::
PullMasterKV
(
const
std
::
string
&
k
,
PbMessage
*
msg
)
{
PullMasterKV
(
k
,
[
&
](
const
std
::
string
&
i
)
{
msg
->
ParseFromString
(
i
);
});
}
void
CtrlClient
::
PushActEvent
(
const
ActEvent
&
act_event
)
{
ClientCall
<
CtrlMethod
::
kPushActEvent
>
call
;
*
(
call
.
mut_request
()
->
mutable_act_event
())
=
act_event
;
...
...
oneflow/core/control/ctrl_client.h
浏览文件 @
073d368b
...
...
@@ -38,15 +38,19 @@ class CtrlClient final {
void
PushKV
(
const
std
::
string
&
k
,
std
::
function
<
void
(
std
::
string
*
)
>
VSetter
);
void
PushKV
(
const
std
::
string
&
k
,
const
std
::
string
&
v
);
void
PushKV
(
const
std
::
string
&
k
,
const
PbMessage
&
msg
);
void
PushMasterKV
(
const
std
::
string
&
k
,
const
PbMessage
&
msg
);
template
<
typename
T
>
typename
std
::
enable_if
<
std
::
is_arithmetic
<
T
>::
value
>::
type
PushKVT
(
const
std
::
string
&
k
,
T
v
)
{
PushKV
(
k
,
std
::
to_string
(
v
));
}
void
ClearKV
(
const
std
::
string
&
k
);
void
ClearMasterKV
(
const
std
::
string
&
k
);
void
PullKV
(
const
std
::
string
&
k
,
std
::
function
<
void
(
const
std
::
string
&
)
>
VGetter
);
void
PullKV
(
const
std
::
string
&
k
,
std
::
string
*
v
);
void
PullKV
(
const
std
::
string
&
k
,
PbMessage
*
msg
);
void
PullMasterKV
(
const
std
::
string
&
k
,
PbMessage
*
msg
);
template
<
typename
T
>
typename
std
::
enable_if
<
std
::
is_arithmetic
<
T
>::
value
>::
type
PullKVT
(
const
std
::
string
&
k
,
T
*
v
)
{
std
::
string
v_str
;
...
...
@@ -65,6 +69,8 @@ class CtrlClient final {
friend
class
Global
<
CtrlClient
>
;
CtrlClient
();
void
LoadServer
(
const
std
::
string
&
server_addr
,
CtrlService
::
Stub
*
stub
);
void
PushMasterKV
(
const
std
::
string
&
k
,
std
::
function
<
void
(
std
::
string
*
)
>
VSetter
);
void
PullMasterKV
(
const
std
::
string
&
k
,
std
::
function
<
void
(
const
std
::
string
&
)
>
VGetter
);
CtrlService
::
Stub
*
GetMasterStub
()
{
return
stubs_
[
0
].
get
();
}
CtrlService
::
Stub
*
GetThisStub
();
CtrlService
::
Stub
*
GetResponsibleStub
(
const
std
::
string
&
key
);
...
...
oneflow/core/control/ctrl_server.cpp
浏览文件 @
073d368b
...
...
@@ -194,7 +194,8 @@ void CtrlServer::Init() {
Add
([
this
](
CtrlCall
<
CtrlMethod
::
kClear
>*
call
)
{
name2lock_status_
.
clear
();
kv_
.
clear
();
CHECK
(
pending_kv_calls_
.
empty
());
CHECK
(
pending_kv_calls_
.
empty
())
<<
"size(): "
<<
pending_kv_calls_
.
size
()
<<
", begin()->key: "
<<
pending_kv_calls_
.
begin
()
->
first
;
call
->
SendResponse
();
EnqueueRequest
<
CtrlMethod
::
kClear
>
();
});
...
...
oneflow/core/eager/eager_instruction.proto
0 → 100644
浏览文件 @
073d368b
syntax
=
"proto2"
;
package
oneflow
.
eager
;
import
"oneflow/core/vm/instruction.proto"
;
import
"oneflow/core/eager/eager_symbol.proto"
;
message
EagerInstruction
{
optional
vm.InstructionListProto
instruction_list
=
1
;
optional
EagerSymbolList
eager_symbol_list
=
2
;
};
oneflow/core/eager/eager_
util
.cpp
→
oneflow/core/eager/eager_
oneflow
.cpp
浏览文件 @
073d368b
...
...
@@ -13,13 +13,15 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/eager/eager_
util
.h"
#include "oneflow/core/eager/eager_
oneflow
.h"
#include "oneflow/core/eager/eager_symbol.pb.h"
#include "oneflow/core/vm/vm_util.h"
#include "oneflow/core/vm/instruction.pb.h"
#include "oneflow/core/eager/eager_symbol_storage.h"
#include "oneflow/core/job/job_desc.h"
#include "oneflow/core/job/scope.h"
#include "oneflow/core/job/machine_context.h"
#include "oneflow/core/job/cluster_instruction.h"
#include "oneflow/core/job/placement.pb.h"
#include "oneflow/core/operator/op_conf.pb.h"
#include "oneflow/core/operator/op_attribute.pb.h"
...
...
@@ -52,41 +54,57 @@ void StorageAdd(const EagerSymbol& symbol) {
}
}
Maybe
<
void
>
RunLogicalInstruction
(
const
vm
::
InstructionListProto
&
instruction_list_proto
,
const
EagerSymbolList
&
eager_symbol_list
)
{
for
(
const
auto
&
eager_symbol
:
eager_symbol_list
.
eager_symbol
())
{
StorageAdd
(
eager_symbol
);
}
return
vm
::
Run
(
instruction_list_proto
);
}
}
// namespace
Maybe
<
void
>
RunPhysicalInstruction
(
const
vm
::
InstructionListProto
&
instruction_list_proto
,
const
EagerSymbolList
&
eager_symbol_list
)
{
Maybe
<
void
>
EagerOneflow
::
RunPhysicalInstruction
(
const
std
::
shared_ptr
<
const
ClusterInstructionProto
>&
cluster_instruction
)
{
const
vm
::
InstructionListProto
&
instruction_list_proto
=
cluster_instruction
->
eager_instruction
().
instruction_list
();
const
EagerSymbolList
&
eager_symbol_list
=
cluster_instruction
->
eager_instruction
().
eager_symbol_list
();
for
(
const
auto
&
eager_symbol
:
eager_symbol_list
.
eager_symbol
())
{
StorageAdd
(
eager_symbol
);
}
return
vm
::
Run
(
instruction_list_proto
);
}
}
// namespace
Maybe
<
void
>
RunPhysicalInstruction
(
const
std
::
string
&
instruction_list_proto_str
,
const
std
::
string
&
eager_symbol_list_str
)
{
vm
::
InstructionListProto
instruction_list_proto
;
CHECK_OR_RETURN
(
TxtString2PbMessage
(
instruction_list_proto_str
,
&
instruction_list_proto
))
Maybe
<
void
>
EagerOneflow
::
RunPhysicalInstruction
(
const
std
::
string
&
instruction_list_proto_str
,
const
std
::
string
&
eager_symbol_list_str
)
{
auto
cluster_instruction
=
std
::
make_shared
<
ClusterInstructionProto
>
();
vm
::
InstructionListProto
*
instruction_list_proto
=
cluster_instruction
->
mutable_eager_instruction
()
->
mutable_instruction_list
()
;
CHECK_OR_RETURN
(
TxtString2PbMessage
(
instruction_list_proto_str
,
instruction_list_proto
))
<<
"InstructionListProto parse failed"
;
EagerSymbolList
eager_symbol_list
;
CHECK_OR_RETURN
(
TxtString2PbMessage
(
eager_symbol_list_str
,
&
eager_symbol_list
))
EagerSymbolList
*
eager_symbol_list
=
cluster_instruction
->
mutable_eager_instruction
()
->
mutable_eager_symbol_list
();
CHECK_OR_RETURN
(
TxtString2PbMessage
(
eager_symbol_list_str
,
eager_symbol_list
))
<<
"EagerSymbolList parse failed"
;
return
RunPhysicalInstruction
(
instruction_list_proto
,
eager_symbol_list
);
return
RunPhysicalInstruction
(
std
::
const_pointer_cast
<
const
ClusterInstructionProto
>
(
cluster_instruction
));
}
Maybe
<
void
>
EagerOneflow
::
RunLogicalInstruction
(
const
std
::
shared_ptr
<
const
ClusterInstructionProto
>&
cluster_instruction
)
{
CHECK
(
cluster_instruction
->
has_eager_instruction
());
CHECK
(
Global
<
MachineCtx
>::
Get
()
->
IsThisMachineMaster
());
ClusterInstruction
::
MasterSendEagerInstruction
(
*
cluster_instruction
);
return
RunPhysicalInstruction
(
cluster_instruction
);
}
Maybe
<
void
>
RunLogicalInstruction
(
const
std
::
string
&
instruction_list_proto_str
,
const
std
::
string
&
eager_symbol_list_str
)
{
vm
::
InstructionListProto
instruction_list_proto
;
CHECK_OR_RETURN
(
TxtString2PbMessage
(
instruction_list_proto_str
,
&
instruction_list_proto
))
Maybe
<
void
>
EagerOneflow
::
RunLogicalInstruction
(
const
std
::
string
&
instruction_list_proto_str
,
const
std
::
string
&
eager_symbol_list_str
)
{
auto
cluster_instruction
=
std
::
make_shared
<
ClusterInstructionProto
>
();
vm
::
InstructionListProto
*
instruction_list_proto
=
cluster_instruction
->
mutable_eager_instruction
()
->
mutable_instruction_list
();
CHECK_OR_RETURN
(
TxtString2PbMessage
(
instruction_list_proto_str
,
instruction_list_proto
))
<<
"InstructionListProto parse failed"
;
EagerSymbolList
eager_symbol_list
;
CHECK_OR_RETURN
(
TxtString2PbMessage
(
eager_symbol_list_str
,
&
eager_symbol_list
))
EagerSymbolList
*
eager_symbol_list
=
cluster_instruction
->
mutable_eager_instruction
()
->
mutable_eager_symbol_list
();
CHECK_OR_RETURN
(
TxtString2PbMessage
(
eager_symbol_list_str
,
eager_symbol_list
))
<<
"EagerSymbolList parse failed"
;
return
RunLogicalInstruction
(
instruction_list_proto
,
eager_symbol_list
);
return
RunLogicalInstruction
(
std
::
const_pointer_cast
<
const
ClusterInstructionProto
>
(
cluster_instruction
));
}
COMMAND
(
Global
<
EagerOneflow
>::
SetAllocated
(
new
EagerOneflow
()));
}
// namespace eager
}
// namespace oneflow
oneflow/core/eager/eager_
util
.h
→
oneflow/core/eager/eager_
oneflow
.h
浏览文件 @
073d368b
...
...
@@ -13,20 +13,30 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_CORE_EAGER_EAGER_
UTIL
_H_
#define ONEFLOW_CORE_EAGER_EAGER_
UTIL
_H_
#ifndef ONEFLOW_CORE_EAGER_EAGER_
ONEFLOW
_H_
#define ONEFLOW_CORE_EAGER_EAGER_
ONEFLOW
_H_
#include "oneflow/core/common/maybe.h"
#include "oneflow/core/job/cluster_instruction.pb.h"
namespace
oneflow
{
namespace
eager
{
Maybe
<
void
>
RunPhysicalInstruction
(
const
std
::
string
&
instruction_list_proto_str
,
const
std
::
string
&
eager_symbol_list_str
);
Maybe
<
void
>
RunLogicalInstruction
(
const
std
::
string
&
instruction_list_proto_str
,
const
std
::
string
&
eager_symbol_list_str
);
class
EagerOneflow
final
{
public:
Maybe
<
void
>
RunLogicalInstruction
(
const
std
::
shared_ptr
<
const
ClusterInstructionProto
>&
cluster_instruction
);
Maybe
<
void
>
RunLogicalInstruction
(
const
std
::
string
&
instruction_list_proto_str
,
const
std
::
string
&
eager_symbol_list_str
);
Maybe
<
void
>
RunPhysicalInstruction
(
const
std
::
string
&
instruction_list_proto_str
,
const
std
::
string
&
eager_symbol_list_str
);
Maybe
<
void
>
RunPhysicalInstruction
(
const
std
::
shared_ptr
<
const
ClusterInstructionProto
>&
cluster_instruction
);
};
}
// namespace eager
}
// namespace oneflow
#endif // ONEFLOW_CORE_EAGER_EAGER_
UTIL
_H_
#endif // ONEFLOW_CORE_EAGER_EAGER_
ONEFLOW
_H_
oneflow/core/job/cluster.cpp
浏览文件 @
073d368b
...
...
@@ -16,33 +16,68 @@ limitations under the License.
#include "oneflow/core/job/cluster.h"
#include "oneflow/core/job/cluster_instruction.pb.h"
#include "oneflow/core/job/cluster_instruction.h"
#include "oneflow/core/eager/eager_oneflow.h"
#include "oneflow/core/control/ctrl_client.h"
#include "oneflow/core/job/oneflow.h"
#include "oneflow/core/job/machine_context.h"
#include "oneflow/core/job/session_global_objects_scope.h"
#include "oneflow/core/job/env_global_objects_scope.h"
#include "oneflow/core/job/job_set.pb.h"
#include "oneflow/core/thread/thread_pool.h"
namespace
oneflow
{
Maybe
<
void
>
Cluster
::
WorkerLoop
()
{
CHECK_OR_RETURN
(
!
Global
<
MachineCtx
>::
Get
()
->
IsThisMachineMaster
());
ClusterInstructionProto
cluster_instruction
;
while
(
ClusterInstruction
::
WorkerReceiveHalt
(
&
cluster_instruction
)
==
false
)
{
namespace
{
void
AsyncRunLazyJobSet
(
ThreadPool
*
lazy_runtime_thread
)
{
lazy_runtime_thread
->
AddWork
([]
{
ConfigProto
config_proto
;
Global
<
CtrlClient
>::
Get
()
->
PullKV
(
"config_proto"
,
&
config_proto
);
int32_t
machine_num
=
config_proto
.
resource
().
machine_num
();
if
(
Global
<
MachineCtx
>::
Get
()
->
this_machine_id
()
>=
machine_num
)
{
continue
;
}
// do nothing if it's not my business
if
(
Global
<
MachineCtx
>::
Get
()
->
this_machine_id
()
>=
machine_num
)
{
return
;
}
Global
<
SessionGlobalObjectsScope
>::
New
();
JUST
(
Global
<
SessionGlobalObjectsScope
>::
Get
()
->
Init
(
config_proto
));
CHECK_JUST
(
Global
<
SessionGlobalObjectsScope
>::
Get
()
->
Init
(
config_proto
));
JobSet
job_set
;
Global
<
CtrlClient
>::
Get
()
->
PullKV
(
"session_job_set"
,
&
job_set
);
{
Oneflow
oneflow
;
JUST
(
oneflow
.
Init
(
job_set
));
CHECK_
JUST
(
oneflow
.
Init
(
job_set
));
}
Global
<
SessionGlobalObjectsScope
>::
Delete
();
});
}
}
// namespace
Maybe
<
void
>
Cluster
::
WorkerLoop
()
{
// The reason why excluding master machine is that
// eager instruction for compile-time symbol constructing must be done synchronously
CHECK_OR_RETURN
(
!
Global
<
MachineCtx
>::
Get
()
->
IsThisMachineMaster
());
{
// Oneflow::~Oneflow blocking in current thread is not acceptable
// Two reasons why `lazy_runtime_thread` is needed:
// 1. making current thread non-block by
// taking over the execution of Oneflow::~Oneflow
// 2. as a Synchronizing guard for all unfinished Oneflow::~Oneflow
//
// thread_num must be 1.
ThreadPool
lazy_runtime_thread
(
1
);
while
(
true
)
{
auto
mut_cluster_instruction
=
std
::
make_shared
<
ClusterInstructionProto
>
();
ClusterInstruction
::
WorkerReceiveInstruction
(
mut_cluster_instruction
.
get
());
if
(
mut_cluster_instruction
->
has_cluster_ctrl_halt
())
{
break
;
}
else
if
(
mut_cluster_instruction
->
has_cluster_ctrl_session_start
())
{
ClusterInstruction
::
NewSessionBarrier
();
AsyncRunLazyJobSet
(
&
lazy_runtime_thread
);
}
else
if
(
mut_cluster_instruction
->
has_eager_instruction
())
{
Global
<
eager
::
EagerOneflow
>::
Get
()
->
RunPhysicalInstruction
(
std
::
const_pointer_cast
<
const
ClusterInstructionProto
>
(
mut_cluster_instruction
));
}
else
{
OF_UNIMPLEMENTED
();
}
}
}
ClusterInstruction
::
HaltBarrier
();
Global
<
EnvGlobalObjectsScope
>::
Delete
();
...
...
oneflow/core/job/cluster_instruction.cpp
浏览文件 @
073d368b
...
...
@@ -13,6 +13,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include <mutex>
#include "oneflow/core/job/cluster_instruction.h"
#include "oneflow/core/job/cluster_instruction.pb.h"
#include "oneflow/core/control/ctrl_server.h"
...
...
@@ -24,46 +25,103 @@ namespace oneflow {
namespace
{
void
BarrierClear
()
{
OF_BARRIER_ALL
();
Global
<
CtrlClient
>::
Get
()
->
Clear
();
OF_BARRIER_ALL
();
}
std
::
string
GetHaltAckCtrlKey
(
int64_t
machine_id
)
{
return
"HaltAckCtrlKey/"
+
std
::
to_string
(
machine_id
);
}
// return unique sequential key
// because ctrl key is not allowed to push/pull twice
std
::
string
Get
HaltOrSessionStartCtrl
Key
()
{
std
::
string
Get
ClusterInstruction
Key
()
{
static
int64_t
seq
=
0
;
return
"HaltOrSessionStart/"
+
std
::
to_string
(
seq
++
);
return
"ClusterInstructionKey/"
+
std
::
to_string
(
seq
++
);
}
class
ObsoleteCtrlKeys
{
public:
ObsoleteCtrlKeys
()
=
default
;
~
ObsoleteCtrlKeys
()
=
default
;
template
<
typename
CallbackT
>
void
ForEach
(
const
CallbackT
&
Callback
)
const
{
std
::
unique_lock
<
std
::
mutex
>
lck
(
mutex_
);
for
(
const
std
::
string
&
k
:
keys_
)
{
Callback
(
k
);
}
}
void
Clear
()
{
std
::
unique_lock
<
std
::
mutex
>
lck
(
mutex_
);
keys_
.
clear
();
}
void
Add
(
const
std
::
string
&
key
)
{
std
::
unique_lock
<
std
::
mutex
>
lck
(
mutex_
);
keys_
.
push_back
(
key
);
}
private:
mutable
std
::
mutex
mutex_
;
std
::
vector
<
std
::
string
>
keys_
;
};
COMMAND
(
Global
<
ObsoleteCtrlKeys
>::
SetAllocated
(
new
ObsoleteCtrlKeys
()));
void
OccasionallyClearCtrlKV
(
const
std
::
string
&
key
)
{
static
std
::
atomic
<
int64_t
>
seq
(
0LL
);
const
static
int64_t
interval
=
65536
;
Global
<
ObsoleteCtrlKeys
>::
Get
()
->
Add
(
key
);
// 1 instead of 0 is better for avoid clearing no ctrl kv
if
((
seq
++
)
%
interval
==
1
)
{
OF_BARRIER_ALL
();
if
(
Global
<
MachineCtx
>::
Get
()
->
IsThisMachineMaster
())
{
Global
<
ObsoleteCtrlKeys
>::
Get
()
->
ForEach
(
[](
const
std
::
string
&
k
)
{
Global
<
CtrlClient
>::
Get
()
->
ClearMasterKV
(
k
);
});
}
Global
<
ObsoleteCtrlKeys
>::
Get
()
->
Clear
();
OF_BARRIER_ALL
();
}
}
void
PushClusterInstruction
(
const
ClusterInstructionProto
&
cluster_instruction
)
{
const
std
::
string
&
key
=
GetClusterInstructionKey
();
Global
<
CtrlClient
>::
Get
()
->
PushMasterKV
(
key
,
cluster_instruction
);
OccasionallyClearCtrlKV
(
key
);
}
void
PullClusterInstruction
(
ClusterInstructionProto
*
cluster_instruction
)
{
const
std
::
string
&
key
=
GetClusterInstructionKey
();
Global
<
CtrlClient
>::
Get
()
->
PullMasterKV
(
key
,
cluster_instruction
);
OccasionallyClearCtrlKV
(
key
);
}
}
// namespace
void
ClusterInstruction
::
NewSessionBarrier
()
{
OF_BARRIER_ALL
();
Global
<
CtrlClient
>::
Get
()
->
Clear
();
Global
<
ObsoleteCtrlKeys
>::
Get
()
->
Clear
();
OF_BARRIER_ALL
();
}
void
ClusterInstruction
::
MasterSendSessionStart
()
{
BarrierClear
();
ClusterInstructionProto
cluster_instruction
;
cluster_instruction
.
mutable_cluster_ctrl_session_start
();
Global
<
CtrlClient
>::
Get
()
->
PushKV
(
GetHaltOrSessionStartCtrlKey
(),
cluster_instruction
);
PushClusterInstruction
(
cluster_instruction
);
NewSessionBarrier
();
}
void
ClusterInstruction
::
MasterSendHalt
()
{
BarrierClear
();
ClusterInstructionProto
cluster_instruction
;
cluster_instruction
.
mutable_cluster_ctrl_halt
();
Global
<
CtrlClient
>::
Get
()
->
PushKV
(
GetHaltOrSessionStartCtrlKey
(),
cluster_instruction
);
PushClusterInstruction
(
cluster_instruction
);
HaltBarrier
();
}
bool
ClusterInstruction
::
WorkerReceiveHalt
(
ClusterInstructionProto
*
cluster_instruction
)
{
BarrierClear
();
Global
<
CtrlClient
>::
Get
()
->
PullKV
(
GetHaltOrSessionStartCtrlKey
(),
cluster_instruction
);
if
(
cluster_instruction
->
has_cluster_ctrl_halt
())
{
return
true
;
}
CHECK
(
cluster_instruction
->
has_cluster_ctrl_session_start
());
return
false
;
void
ClusterInstruction
::
MasterSendEagerInstruction
(
const
ClusterInstructionProto
&
cluster_instruction
)
{
CHECK
(
cluster_instruction
.
has_eager_instruction
());
PushClusterInstruction
(
cluster_instruction
);
}
void
ClusterInstruction
::
WorkerReceiveInstruction
(
ClusterInstructionProto
*
cluster_instruction
)
{
PullClusterInstruction
(
cluster_instruction
);
}
void
ClusterInstruction
::
HaltBarrier
()
{
OF_BARRIER_ALL
();
}
...
...
oneflow/core/job/cluster_instruction.h
浏览文件 @
073d368b
...
...
@@ -22,8 +22,10 @@ namespace oneflow {
struct
ClusterInstruction
final
{
static
void
MasterSendSessionStart
();
static
bool
WorkerReceiveHalt
(
ClusterInstructionProto
*
cluster_instruction
);
static
void
MasterSendHalt
();
static
void
MasterSendEagerInstruction
(
const
ClusterInstructionProto
&
cluster_instruction
);
static
void
WorkerReceiveInstruction
(
ClusterInstructionProto
*
cluster_instruction
);
static
void
NewSessionBarrier
();
static
void
HaltBarrier
();
};
...
...
oneflow/core/job/cluster_instruction.proto
浏览文件 @
073d368b
syntax
=
"proto2"
;
package
oneflow
;
import
"oneflow/core/eager/eager_instruction.proto"
;
message
ClusterCtrlSessionStart
{}
message
ClusterCtrlHalt
{}
...
...
@@ -8,5 +10,6 @@ message ClusterInstructionProto {
oneof
instruction_type
{
ClusterCtrlSessionStart
cluster_ctrl_session_start
=
1
;
ClusterCtrlHalt
cluster_ctrl_halt
=
2
;
eager.EagerInstruction
eager_instruction
=
3
;
}
}
oneflow/python/oneflow_internal_helper.h
浏览文件 @
073d368b
...
...
@@ -44,7 +44,7 @@ limitations under the License.
#include "oneflow/core/vm/instruction.pb.h"
#include "oneflow/core/vm/vm_util.h"
#include "oneflow/core/vm/id_util.h"
#include "oneflow/core/eager/eager_
util
.h"
#include "oneflow/core/eager/eager_
oneflow
.h"
#include "oneflow/core/eager/eager_symbol_storage.h"
#ifdef WITH_TENSORRT
...
...
@@ -271,12 +271,14 @@ Maybe<long> GetOpParallelSymbolId(const std::string& op_conf_str) {
Maybe
<
void
>
RunLogicalInstruction
(
const
std
::
string
&
instruction_list_str
,
const
std
::
string
&
eager_symbol_list_str
)
{
return
eager
::
RunLogicalInstruction
(
instruction_list_str
,
eager_symbol_list_str
);
return
Global
<
eager
::
EagerOneflow
>::
Get
()
->
RunLogicalInstruction
(
instruction_list_str
,
eager_symbol_list_str
);
}
Maybe
<
void
>
RunPhysicalInstruction
(
const
std
::
string
&
instruction_list_str
,
const
std
::
string
&
eager_symbol_list_str
)
{
return
eager
::
RunPhysicalInstruction
(
instruction_list_str
,
eager_symbol_list_str
);
return
Global
<
eager
::
EagerOneflow
>::
Get
()
->
RunPhysicalInstruction
(
instruction_list_str
,
eager_symbol_list_str
);
}
Maybe
<
long
long
>
CurrentMachineId
()
{
...
...
oneflow/python/test/ops/test_ccrelu.py
浏览文件 @
073d368b
...
...
@@ -89,7 +89,28 @@ def test_1n2c_mirror_dynamic_ccrelu(test_case):
@
flow
.
unittest
.
num_nodes_required
(
2
)
def
test_ccrelu_2n1c
(
test_case
):
def
test_ccrelu_2n1c_0
(
test_case
):
func_config
=
flow
.
FunctionConfig
()
func_config
.
default_logical_view
(
flow
.
scope
.
consistent_view
())
fixed_tensor_def_test
(
test_case
,
func_config
)
@
flow
.
unittest
.
num_nodes_required
(
2
)
def
test_ccrelu_2n1c_1
(
test_case
):
func_config
=
flow
.
FunctionConfig
()
func_config
.
default_logical_view
(
flow
.
scope
.
consistent_view
())
fixed_tensor_def_test
(
test_case
,
func_config
)
@
flow
.
unittest
.
num_nodes_required
(
2
)
def
test_ccrelu_2n1c_2
(
test_case
):
func_config
=
flow
.
FunctionConfig
()
func_config
.
default_logical_view
(
flow
.
scope
.
consistent_view
())
fixed_tensor_def_test
(
test_case
,
func_config
)
@
flow
.
unittest
.
num_nodes_required
(
2
)
def
test_ccrelu_2n1c_3
(
test_case
):
func_config
=
flow
.
FunctionConfig
()
func_config
.
default_logical_view
(
flow
.
scope
.
consistent_view
())
fixed_tensor_def_test
(
test_case
,
func_config
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录