提交 073d368b 编写于 作者: S Shenghang Tsai

Merge branch 'master' of https://github.com/Oneflow-Inc/oneflow into master

......@@ -100,6 +100,13 @@ void CtrlClient::PushKV(const std::string& k, std::function<void(std::string*)>
call(GetResponsibleStub(k));
}
void CtrlClient::PushMasterKV(const std::string& k, std::function<void(std::string*)> VSetter) {
ClientCall<CtrlMethod::kPushKV> call;
call.mut_request()->set_key(k);
VSetter(call.mut_request()->mutable_val());
call(GetMasterStub());
}
void CtrlClient::PushKV(const std::string& k, const std::string& v) {
PushKV(k, [&](std::string* o) { *o = v; });
}
......@@ -108,12 +115,22 @@ void CtrlClient::PushKV(const std::string& k, const PbMessage& msg) {
PushKV(k, [&](std::string* o) { msg.SerializeToString(o); });
}
void CtrlClient::PushMasterKV(const std::string& k, const PbMessage& msg) {
PushMasterKV(k, [&](std::string* o) { msg.SerializeToString(o); });
}
void CtrlClient::ClearKV(const std::string& k) {
ClientCall<CtrlMethod::kClearKV> call;
call.mut_request()->set_key(k);
call(GetResponsibleStub(k));
}
void CtrlClient::ClearMasterKV(const std::string& k) {
ClientCall<CtrlMethod::kClearKV> call;
call.mut_request()->set_key(k);
call(GetMasterStub());
}
void CtrlClient::PullKV(const std::string& k, std::function<void(const std::string&)> VGetter) {
ClientCall<CtrlMethod::kPullKV> call;
call.mut_request()->set_key(k);
......@@ -121,6 +138,14 @@ void CtrlClient::PullKV(const std::string& k, std::function<void(const std::stri
VGetter(call.response().val());
}
void CtrlClient::PullMasterKV(const std::string& k,
std::function<void(const std::string&)> VGetter) {
ClientCall<CtrlMethod::kPullKV> call;
call.mut_request()->set_key(k);
call(GetMasterStub());
VGetter(call.response().val());
}
void CtrlClient::PullKV(const std::string& k, std::string* v) {
PullKV(k, [&](const std::string& i) { *v = i; });
}
......@@ -129,6 +154,10 @@ void CtrlClient::PullKV(const std::string& k, PbMessage* msg) {
PullKV(k, [&](const std::string& i) { msg->ParseFromString(i); });
}
void CtrlClient::PullMasterKV(const std::string& k, PbMessage* msg) {
PullMasterKV(k, [&](const std::string& i) { msg->ParseFromString(i); });
}
void CtrlClient::PushActEvent(const ActEvent& act_event) {
ClientCall<CtrlMethod::kPushActEvent> call;
*(call.mut_request()->mutable_act_event()) = act_event;
......
......@@ -38,15 +38,19 @@ class CtrlClient final {
void PushKV(const std::string& k, std::function<void(std::string*)> VSetter);
void PushKV(const std::string& k, const std::string& v);
void PushKV(const std::string& k, const PbMessage& msg);
void PushMasterKV(const std::string& k, const PbMessage& msg);
template<typename T>
typename std::enable_if<std::is_arithmetic<T>::value>::type PushKVT(const std::string& k, T v) {
PushKV(k, std::to_string(v));
}
void ClearKV(const std::string& k);
void ClearMasterKV(const std::string& k);
void PullKV(const std::string& k, std::function<void(const std::string&)> VGetter);
void PullKV(const std::string& k, std::string* v);
void PullKV(const std::string& k, PbMessage* msg);
void PullMasterKV(const std::string& k, PbMessage* msg);
template<typename T>
typename std::enable_if<std::is_arithmetic<T>::value>::type PullKVT(const std::string& k, T* v) {
std::string v_str;
......@@ -65,6 +69,8 @@ class CtrlClient final {
friend class Global<CtrlClient>;
CtrlClient();
void LoadServer(const std::string& server_addr, CtrlService::Stub* stub);
void PushMasterKV(const std::string& k, std::function<void(std::string*)> VSetter);
void PullMasterKV(const std::string& k, std::function<void(const std::string&)> VGetter);
CtrlService::Stub* GetMasterStub() { return stubs_[0].get(); }
CtrlService::Stub* GetThisStub();
CtrlService::Stub* GetResponsibleStub(const std::string& key);
......
......@@ -194,7 +194,8 @@ void CtrlServer::Init() {
Add([this](CtrlCall<CtrlMethod::kClear>* call) {
name2lock_status_.clear();
kv_.clear();
CHECK(pending_kv_calls_.empty());
CHECK(pending_kv_calls_.empty()) << "size(): " << pending_kv_calls_.size()
<< ", begin()->key: " << pending_kv_calls_.begin()->first;
call->SendResponse();
EnqueueRequest<CtrlMethod::kClear>();
});
......
syntax = "proto2";
package oneflow.eager;
import "oneflow/core/vm/instruction.proto";
import "oneflow/core/eager/eager_symbol.proto";
message EagerInstruction {
optional vm.InstructionListProto instruction_list = 1;
optional EagerSymbolList eager_symbol_list = 2;
};
......@@ -13,13 +13,15 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/eager/eager_util.h"
#include "oneflow/core/eager/eager_oneflow.h"
#include "oneflow/core/eager/eager_symbol.pb.h"
#include "oneflow/core/vm/vm_util.h"
#include "oneflow/core/vm/instruction.pb.h"
#include "oneflow/core/eager/eager_symbol_storage.h"
#include "oneflow/core/job/job_desc.h"
#include "oneflow/core/job/scope.h"
#include "oneflow/core/job/machine_context.h"
#include "oneflow/core/job/cluster_instruction.h"
#include "oneflow/core/job/placement.pb.h"
#include "oneflow/core/operator/op_conf.pb.h"
#include "oneflow/core/operator/op_attribute.pb.h"
......@@ -52,41 +54,57 @@ void StorageAdd(const EagerSymbol& symbol) {
}
}
Maybe<void> RunLogicalInstruction(const vm::InstructionListProto& instruction_list_proto,
const EagerSymbolList& eager_symbol_list) {
for (const auto& eager_symbol : eager_symbol_list.eager_symbol()) { StorageAdd(eager_symbol); }
return vm::Run(instruction_list_proto);
}
} // namespace
Maybe<void> RunPhysicalInstruction(const vm::InstructionListProto& instruction_list_proto,
const EagerSymbolList& eager_symbol_list) {
Maybe<void> EagerOneflow::RunPhysicalInstruction(
const std::shared_ptr<const ClusterInstructionProto>& cluster_instruction) {
const vm::InstructionListProto& instruction_list_proto =
cluster_instruction->eager_instruction().instruction_list();
const EagerSymbolList& eager_symbol_list =
cluster_instruction->eager_instruction().eager_symbol_list();
for (const auto& eager_symbol : eager_symbol_list.eager_symbol()) { StorageAdd(eager_symbol); }
return vm::Run(instruction_list_proto);
}
} // namespace
Maybe<void> RunPhysicalInstruction(const std::string& instruction_list_proto_str,
const std::string& eager_symbol_list_str) {
vm::InstructionListProto instruction_list_proto;
CHECK_OR_RETURN(TxtString2PbMessage(instruction_list_proto_str, &instruction_list_proto))
Maybe<void> EagerOneflow::RunPhysicalInstruction(const std::string& instruction_list_proto_str,
const std::string& eager_symbol_list_str) {
auto cluster_instruction = std::make_shared<ClusterInstructionProto>();
vm::InstructionListProto* instruction_list_proto =
cluster_instruction->mutable_eager_instruction()->mutable_instruction_list();
CHECK_OR_RETURN(TxtString2PbMessage(instruction_list_proto_str, instruction_list_proto))
<< "InstructionListProto parse failed";
EagerSymbolList eager_symbol_list;
CHECK_OR_RETURN(TxtString2PbMessage(eager_symbol_list_str, &eager_symbol_list))
EagerSymbolList* eager_symbol_list =
cluster_instruction->mutable_eager_instruction()->mutable_eager_symbol_list();
CHECK_OR_RETURN(TxtString2PbMessage(eager_symbol_list_str, eager_symbol_list))
<< "EagerSymbolList parse failed";
return RunPhysicalInstruction(instruction_list_proto, eager_symbol_list);
return RunPhysicalInstruction(
std::const_pointer_cast<const ClusterInstructionProto>(cluster_instruction));
}
Maybe<void> EagerOneflow::RunLogicalInstruction(
const std::shared_ptr<const ClusterInstructionProto>& cluster_instruction) {
CHECK(cluster_instruction->has_eager_instruction());
CHECK(Global<MachineCtx>::Get()->IsThisMachineMaster());
ClusterInstruction::MasterSendEagerInstruction(*cluster_instruction);
return RunPhysicalInstruction(cluster_instruction);
}
Maybe<void> RunLogicalInstruction(const std::string& instruction_list_proto_str,
const std::string& eager_symbol_list_str) {
vm::InstructionListProto instruction_list_proto;
CHECK_OR_RETURN(TxtString2PbMessage(instruction_list_proto_str, &instruction_list_proto))
Maybe<void> EagerOneflow::RunLogicalInstruction(const std::string& instruction_list_proto_str,
const std::string& eager_symbol_list_str) {
auto cluster_instruction = std::make_shared<ClusterInstructionProto>();
vm::InstructionListProto* instruction_list_proto =
cluster_instruction->mutable_eager_instruction()->mutable_instruction_list();
CHECK_OR_RETURN(TxtString2PbMessage(instruction_list_proto_str, instruction_list_proto))
<< "InstructionListProto parse failed";
EagerSymbolList eager_symbol_list;
CHECK_OR_RETURN(TxtString2PbMessage(eager_symbol_list_str, &eager_symbol_list))
EagerSymbolList* eager_symbol_list =
cluster_instruction->mutable_eager_instruction()->mutable_eager_symbol_list();
CHECK_OR_RETURN(TxtString2PbMessage(eager_symbol_list_str, eager_symbol_list))
<< "EagerSymbolList parse failed";
return RunLogicalInstruction(instruction_list_proto, eager_symbol_list);
return RunLogicalInstruction(
std::const_pointer_cast<const ClusterInstructionProto>(cluster_instruction));
}
COMMAND(Global<EagerOneflow>::SetAllocated(new EagerOneflow()));
} // namespace eager
} // namespace oneflow
......@@ -13,20 +13,30 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_CORE_EAGER_EAGER_UTIL_H_
#define ONEFLOW_CORE_EAGER_EAGER_UTIL_H_
#ifndef ONEFLOW_CORE_EAGER_EAGER_ONEFLOW_H_
#define ONEFLOW_CORE_EAGER_EAGER_ONEFLOW_H_
#include "oneflow/core/common/maybe.h"
#include "oneflow/core/job/cluster_instruction.pb.h"
namespace oneflow {
namespace eager {
Maybe<void> RunPhysicalInstruction(const std::string& instruction_list_proto_str,
const std::string& eager_symbol_list_str);
Maybe<void> RunLogicalInstruction(const std::string& instruction_list_proto_str,
const std::string& eager_symbol_list_str);
class EagerOneflow final {
public:
Maybe<void> RunLogicalInstruction(
const std::shared_ptr<const ClusterInstructionProto>& cluster_instruction);
Maybe<void> RunLogicalInstruction(const std::string& instruction_list_proto_str,
const std::string& eager_symbol_list_str);
Maybe<void> RunPhysicalInstruction(const std::string& instruction_list_proto_str,
const std::string& eager_symbol_list_str);
Maybe<void> RunPhysicalInstruction(
const std::shared_ptr<const ClusterInstructionProto>& cluster_instruction);
};
} // namespace eager
} // namespace oneflow
#endif // ONEFLOW_CORE_EAGER_EAGER_UTIL_H_
#endif // ONEFLOW_CORE_EAGER_EAGER_ONEFLOW_H_
......@@ -16,33 +16,68 @@ limitations under the License.
#include "oneflow/core/job/cluster.h"
#include "oneflow/core/job/cluster_instruction.pb.h"
#include "oneflow/core/job/cluster_instruction.h"
#include "oneflow/core/eager/eager_oneflow.h"
#include "oneflow/core/control/ctrl_client.h"
#include "oneflow/core/job/oneflow.h"
#include "oneflow/core/job/machine_context.h"
#include "oneflow/core/job/session_global_objects_scope.h"
#include "oneflow/core/job/env_global_objects_scope.h"
#include "oneflow/core/job/job_set.pb.h"
#include "oneflow/core/thread/thread_pool.h"
namespace oneflow {
Maybe<void> Cluster::WorkerLoop() {
CHECK_OR_RETURN(!Global<MachineCtx>::Get()->IsThisMachineMaster());
ClusterInstructionProto cluster_instruction;
while (ClusterInstruction::WorkerReceiveHalt(&cluster_instruction) == false) {
namespace {
void AsyncRunLazyJobSet(ThreadPool* lazy_runtime_thread) {
lazy_runtime_thread->AddWork([] {
ConfigProto config_proto;
Global<CtrlClient>::Get()->PullKV("config_proto", &config_proto);
int32_t machine_num = config_proto.resource().machine_num();
if (Global<MachineCtx>::Get()->this_machine_id() >= machine_num) { continue; }
// do nothing if it's not my business
if (Global<MachineCtx>::Get()->this_machine_id() >= machine_num) { return; }
Global<SessionGlobalObjectsScope>::New();
JUST(Global<SessionGlobalObjectsScope>::Get()->Init(config_proto));
CHECK_JUST(Global<SessionGlobalObjectsScope>::Get()->Init(config_proto));
JobSet job_set;
Global<CtrlClient>::Get()->PullKV("session_job_set", &job_set);
{
Oneflow oneflow;
JUST(oneflow.Init(job_set));
CHECK_JUST(oneflow.Init(job_set));
}
Global<SessionGlobalObjectsScope>::Delete();
});
}
} // namespace
Maybe<void> Cluster::WorkerLoop() {
// The reason why excluding master machine is that
// eager instruction for compile-time symbol constructing must be done synchronously
CHECK_OR_RETURN(!Global<MachineCtx>::Get()->IsThisMachineMaster());
{
// Oneflow::~Oneflow blocking in current thread is not acceptable
// Two reasons why `lazy_runtime_thread` is needed:
// 1. making current thread non-block by
// taking over the execution of Oneflow::~Oneflow
// 2. as a Synchronizing guard for all unfinished Oneflow::~Oneflow
//
// thread_num must be 1.
ThreadPool lazy_runtime_thread(1);
while (true) {
auto mut_cluster_instruction = std::make_shared<ClusterInstructionProto>();
ClusterInstruction::WorkerReceiveInstruction(mut_cluster_instruction.get());
if (mut_cluster_instruction->has_cluster_ctrl_halt()) {
break;
} else if (mut_cluster_instruction->has_cluster_ctrl_session_start()) {
ClusterInstruction::NewSessionBarrier();
AsyncRunLazyJobSet(&lazy_runtime_thread);
} else if (mut_cluster_instruction->has_eager_instruction()) {
Global<eager::EagerOneflow>::Get()->RunPhysicalInstruction(
std::const_pointer_cast<const ClusterInstructionProto>(mut_cluster_instruction));
} else {
OF_UNIMPLEMENTED();
}
}
}
ClusterInstruction::HaltBarrier();
Global<EnvGlobalObjectsScope>::Delete();
......
......@@ -13,6 +13,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include <mutex>
#include "oneflow/core/job/cluster_instruction.h"
#include "oneflow/core/job/cluster_instruction.pb.h"
#include "oneflow/core/control/ctrl_server.h"
......@@ -24,46 +25,103 @@ namespace oneflow {
namespace {
void BarrierClear() {
OF_BARRIER_ALL();
Global<CtrlClient>::Get()->Clear();
OF_BARRIER_ALL();
}
std::string GetHaltAckCtrlKey(int64_t machine_id) {
return "HaltAckCtrlKey/" + std::to_string(machine_id);
}
// return unique sequential key
// because ctrl key is not allowed to push/pull twice
std::string GetHaltOrSessionStartCtrlKey() {
std::string GetClusterInstructionKey() {
static int64_t seq = 0;
return "HaltOrSessionStart/" + std::to_string(seq++);
return "ClusterInstructionKey/" + std::to_string(seq++);
}
class ObsoleteCtrlKeys {
public:
ObsoleteCtrlKeys() = default;
~ObsoleteCtrlKeys() = default;
template<typename CallbackT>
void ForEach(const CallbackT& Callback) const {
std::unique_lock<std::mutex> lck(mutex_);
for (const std::string& k : keys_) { Callback(k); }
}
void Clear() {
std::unique_lock<std::mutex> lck(mutex_);
keys_.clear();
}
void Add(const std::string& key) {
std::unique_lock<std::mutex> lck(mutex_);
keys_.push_back(key);
}
private:
mutable std::mutex mutex_;
std::vector<std::string> keys_;
};
COMMAND(Global<ObsoleteCtrlKeys>::SetAllocated(new ObsoleteCtrlKeys()));
void OccasionallyClearCtrlKV(const std::string& key) {
static std::atomic<int64_t> seq(0LL);
const static int64_t interval = 65536;
Global<ObsoleteCtrlKeys>::Get()->Add(key);
// 1 instead of 0 is better for avoid clearing no ctrl kv
if ((seq++) % interval == 1) {
OF_BARRIER_ALL();
if (Global<MachineCtx>::Get()->IsThisMachineMaster()) {
Global<ObsoleteCtrlKeys>::Get()->ForEach(
[](const std::string& k) { Global<CtrlClient>::Get()->ClearMasterKV(k); });
}
Global<ObsoleteCtrlKeys>::Get()->Clear();
OF_BARRIER_ALL();
}
}
void PushClusterInstruction(const ClusterInstructionProto& cluster_instruction) {
const std::string& key = GetClusterInstructionKey();
Global<CtrlClient>::Get()->PushMasterKV(key, cluster_instruction);
OccasionallyClearCtrlKV(key);
}
void PullClusterInstruction(ClusterInstructionProto* cluster_instruction) {
const std::string& key = GetClusterInstructionKey();
Global<CtrlClient>::Get()->PullMasterKV(key, cluster_instruction);
OccasionallyClearCtrlKV(key);
}
} // namespace
void ClusterInstruction::NewSessionBarrier() {
OF_BARRIER_ALL();
Global<CtrlClient>::Get()->Clear();
Global<ObsoleteCtrlKeys>::Get()->Clear();
OF_BARRIER_ALL();
}
void ClusterInstruction::MasterSendSessionStart() {
BarrierClear();
ClusterInstructionProto cluster_instruction;
cluster_instruction.mutable_cluster_ctrl_session_start();
Global<CtrlClient>::Get()->PushKV(GetHaltOrSessionStartCtrlKey(), cluster_instruction);
PushClusterInstruction(cluster_instruction);
NewSessionBarrier();
}
void ClusterInstruction::MasterSendHalt() {
BarrierClear();
ClusterInstructionProto cluster_instruction;
cluster_instruction.mutable_cluster_ctrl_halt();
Global<CtrlClient>::Get()->PushKV(GetHaltOrSessionStartCtrlKey(), cluster_instruction);
PushClusterInstruction(cluster_instruction);
HaltBarrier();
}
bool ClusterInstruction::WorkerReceiveHalt(ClusterInstructionProto* cluster_instruction) {
BarrierClear();
Global<CtrlClient>::Get()->PullKV(GetHaltOrSessionStartCtrlKey(), cluster_instruction);
if (cluster_instruction->has_cluster_ctrl_halt()) { return true; }
CHECK(cluster_instruction->has_cluster_ctrl_session_start());
return false;
void ClusterInstruction::MasterSendEagerInstruction(
const ClusterInstructionProto& cluster_instruction) {
CHECK(cluster_instruction.has_eager_instruction());
PushClusterInstruction(cluster_instruction);
}
void ClusterInstruction::WorkerReceiveInstruction(ClusterInstructionProto* cluster_instruction) {
PullClusterInstruction(cluster_instruction);
}
void ClusterInstruction::HaltBarrier() { OF_BARRIER_ALL(); }
......
......@@ -22,8 +22,10 @@ namespace oneflow {
struct ClusterInstruction final {
static void MasterSendSessionStart();
static bool WorkerReceiveHalt(ClusterInstructionProto* cluster_instruction);
static void MasterSendHalt();
static void MasterSendEagerInstruction(const ClusterInstructionProto& cluster_instruction);
static void WorkerReceiveInstruction(ClusterInstructionProto* cluster_instruction);
static void NewSessionBarrier();
static void HaltBarrier();
};
......
syntax = "proto2";
package oneflow;
import "oneflow/core/eager/eager_instruction.proto";
message ClusterCtrlSessionStart {}
message ClusterCtrlHalt {}
......@@ -8,5 +10,6 @@ message ClusterInstructionProto {
oneof instruction_type {
ClusterCtrlSessionStart cluster_ctrl_session_start = 1;
ClusterCtrlHalt cluster_ctrl_halt = 2;
eager.EagerInstruction eager_instruction = 3;
}
}
......@@ -44,7 +44,7 @@ limitations under the License.
#include "oneflow/core/vm/instruction.pb.h"
#include "oneflow/core/vm/vm_util.h"
#include "oneflow/core/vm/id_util.h"
#include "oneflow/core/eager/eager_util.h"
#include "oneflow/core/eager/eager_oneflow.h"
#include "oneflow/core/eager/eager_symbol_storage.h"
#ifdef WITH_TENSORRT
......@@ -271,12 +271,14 @@ Maybe<long> GetOpParallelSymbolId(const std::string& op_conf_str) {
Maybe<void> RunLogicalInstruction(const std::string& instruction_list_str,
const std::string& eager_symbol_list_str) {
return eager::RunLogicalInstruction(instruction_list_str, eager_symbol_list_str);
return Global<eager::EagerOneflow>::Get()->RunLogicalInstruction(instruction_list_str,
eager_symbol_list_str);
}
Maybe<void> RunPhysicalInstruction(const std::string& instruction_list_str,
const std::string& eager_symbol_list_str) {
return eager::RunPhysicalInstruction(instruction_list_str, eager_symbol_list_str);
return Global<eager::EagerOneflow>::Get()->RunPhysicalInstruction(instruction_list_str,
eager_symbol_list_str);
}
Maybe<long long> CurrentMachineId() {
......
......@@ -89,7 +89,28 @@ def test_1n2c_mirror_dynamic_ccrelu(test_case):
@flow.unittest.num_nodes_required(2)
def test_ccrelu_2n1c(test_case):
def test_ccrelu_2n1c_0(test_case):
func_config = flow.FunctionConfig()
func_config.default_logical_view(flow.scope.consistent_view())
fixed_tensor_def_test(test_case, func_config)
@flow.unittest.num_nodes_required(2)
def test_ccrelu_2n1c_1(test_case):
func_config = flow.FunctionConfig()
func_config.default_logical_view(flow.scope.consistent_view())
fixed_tensor_def_test(test_case, func_config)
@flow.unittest.num_nodes_required(2)
def test_ccrelu_2n1c_2(test_case):
func_config = flow.FunctionConfig()
func_config.default_logical_view(flow.scope.consistent_view())
fixed_tensor_def_test(test_case, func_config)
@flow.unittest.num_nodes_required(2)
def test_ccrelu_2n1c_3(test_case):
func_config = flow.FunctionConfig()
func_config.default_logical_view(flow.scope.consistent_view())
fixed_tensor_def_test(test_case, func_config)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册