diff --git a/oneflow/core/control/ctrl_client.cpp b/oneflow/core/control/ctrl_client.cpp index b182c2258d9e9b75f4de9ed298f841807a4f8e75..95088682f9231fe6bd58ba2598cf66948117a55d 100644 --- a/oneflow/core/control/ctrl_client.cpp +++ b/oneflow/core/control/ctrl_client.cpp @@ -100,6 +100,13 @@ void CtrlClient::PushKV(const std::string& k, std::function call(GetResponsibleStub(k)); } +void CtrlClient::PushMasterKV(const std::string& k, std::function VSetter) { + ClientCall call; + call.mut_request()->set_key(k); + VSetter(call.mut_request()->mutable_val()); + call(GetMasterStub()); +} + void CtrlClient::PushKV(const std::string& k, const std::string& v) { PushKV(k, [&](std::string* o) { *o = v; }); } @@ -108,12 +115,22 @@ void CtrlClient::PushKV(const std::string& k, const PbMessage& msg) { PushKV(k, [&](std::string* o) { msg.SerializeToString(o); }); } +void CtrlClient::PushMasterKV(const std::string& k, const PbMessage& msg) { + PushMasterKV(k, [&](std::string* o) { msg.SerializeToString(o); }); +} + void CtrlClient::ClearKV(const std::string& k) { ClientCall call; call.mut_request()->set_key(k); call(GetResponsibleStub(k)); } +void CtrlClient::ClearMasterKV(const std::string& k) { + ClientCall call; + call.mut_request()->set_key(k); + call(GetMasterStub()); +} + void CtrlClient::PullKV(const std::string& k, std::function VGetter) { ClientCall call; call.mut_request()->set_key(k); @@ -121,6 +138,14 @@ void CtrlClient::PullKV(const std::string& k, std::function VGetter) { + ClientCall call; + call.mut_request()->set_key(k); + call(GetMasterStub()); + VGetter(call.response().val()); +} + void CtrlClient::PullKV(const std::string& k, std::string* v) { PullKV(k, [&](const std::string& i) { *v = i; }); } @@ -129,6 +154,10 @@ void CtrlClient::PullKV(const std::string& k, PbMessage* msg) { PullKV(k, [&](const std::string& i) { msg->ParseFromString(i); }); } +void CtrlClient::PullMasterKV(const std::string& k, PbMessage* msg) { + PullMasterKV(k, [&](const std::string& i) { msg->ParseFromString(i); }); +} + void CtrlClient::PushActEvent(const ActEvent& act_event) { ClientCall call; *(call.mut_request()->mutable_act_event()) = act_event; diff --git a/oneflow/core/control/ctrl_client.h b/oneflow/core/control/ctrl_client.h index 3513f7d3923f829242a65cd046c2048804c1dc7a..4840fff56a77d84ceeae02b4e01251ca611d719d 100644 --- a/oneflow/core/control/ctrl_client.h +++ b/oneflow/core/control/ctrl_client.h @@ -38,15 +38,19 @@ class CtrlClient final { void PushKV(const std::string& k, std::function VSetter); void PushKV(const std::string& k, const std::string& v); void PushKV(const std::string& k, const PbMessage& msg); + void PushMasterKV(const std::string& k, const PbMessage& msg); template typename std::enable_if::value>::type PushKVT(const std::string& k, T v) { PushKV(k, std::to_string(v)); } void ClearKV(const std::string& k); + void ClearMasterKV(const std::string& k); + void PullKV(const std::string& k, std::function VGetter); void PullKV(const std::string& k, std::string* v); void PullKV(const std::string& k, PbMessage* msg); + void PullMasterKV(const std::string& k, PbMessage* msg); template typename std::enable_if::value>::type PullKVT(const std::string& k, T* v) { std::string v_str; @@ -65,6 +69,8 @@ class CtrlClient final { friend class Global; CtrlClient(); void LoadServer(const std::string& server_addr, CtrlService::Stub* stub); + void PushMasterKV(const std::string& k, std::function VSetter); + void PullMasterKV(const std::string& k, std::function VGetter); CtrlService::Stub* GetMasterStub() { return stubs_[0].get(); } CtrlService::Stub* GetThisStub(); CtrlService::Stub* GetResponsibleStub(const std::string& key); diff --git a/oneflow/core/control/ctrl_server.cpp b/oneflow/core/control/ctrl_server.cpp index 7ff644619b827e093fe9fdc0c98f574c22f81875..ffc1cd2f3efb09b0e5bef1fb7584e447ae3fbbae 100644 --- a/oneflow/core/control/ctrl_server.cpp +++ b/oneflow/core/control/ctrl_server.cpp @@ -194,7 +194,8 @@ void CtrlServer::Init() { Add([this](CtrlCall* call) { name2lock_status_.clear(); kv_.clear(); - CHECK(pending_kv_calls_.empty()); + CHECK(pending_kv_calls_.empty()) << "size(): " << pending_kv_calls_.size() + << ", begin()->key: " << pending_kv_calls_.begin()->first; call->SendResponse(); EnqueueRequest(); }); diff --git a/oneflow/core/eager/eager_instruction.proto b/oneflow/core/eager/eager_instruction.proto new file mode 100644 index 0000000000000000000000000000000000000000..f5eb7260fca05ee0a08f4f91dd8e9e47bbd1e3c4 --- /dev/null +++ b/oneflow/core/eager/eager_instruction.proto @@ -0,0 +1,10 @@ +syntax = "proto2"; +package oneflow.eager; + +import "oneflow/core/vm/instruction.proto"; +import "oneflow/core/eager/eager_symbol.proto"; + +message EagerInstruction { + optional vm.InstructionListProto instruction_list = 1; + optional EagerSymbolList eager_symbol_list = 2; +}; diff --git a/oneflow/core/eager/eager_util.cpp b/oneflow/core/eager/eager_oneflow.cpp similarity index 52% rename from oneflow/core/eager/eager_util.cpp rename to oneflow/core/eager/eager_oneflow.cpp index c4f80fbc18bfe931df181d1f0f41540e234c9c50..c8eef9d4f19f66b9da5df5607ae16439461b0f45 100644 --- a/oneflow/core/eager/eager_util.cpp +++ b/oneflow/core/eager/eager_oneflow.cpp @@ -13,13 +13,15 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "oneflow/core/eager/eager_util.h" +#include "oneflow/core/eager/eager_oneflow.h" #include "oneflow/core/eager/eager_symbol.pb.h" #include "oneflow/core/vm/vm_util.h" #include "oneflow/core/vm/instruction.pb.h" #include "oneflow/core/eager/eager_symbol_storage.h" #include "oneflow/core/job/job_desc.h" #include "oneflow/core/job/scope.h" +#include "oneflow/core/job/machine_context.h" +#include "oneflow/core/job/cluster_instruction.h" #include "oneflow/core/job/placement.pb.h" #include "oneflow/core/operator/op_conf.pb.h" #include "oneflow/core/operator/op_attribute.pb.h" @@ -52,41 +54,57 @@ void StorageAdd(const EagerSymbol& symbol) { } } -Maybe RunLogicalInstruction(const vm::InstructionListProto& instruction_list_proto, - const EagerSymbolList& eager_symbol_list) { - for (const auto& eager_symbol : eager_symbol_list.eager_symbol()) { StorageAdd(eager_symbol); } - return vm::Run(instruction_list_proto); -} +} // namespace -Maybe RunPhysicalInstruction(const vm::InstructionListProto& instruction_list_proto, - const EagerSymbolList& eager_symbol_list) { +Maybe EagerOneflow::RunPhysicalInstruction( + const std::shared_ptr& cluster_instruction) { + const vm::InstructionListProto& instruction_list_proto = + cluster_instruction->eager_instruction().instruction_list(); + const EagerSymbolList& eager_symbol_list = + cluster_instruction->eager_instruction().eager_symbol_list(); for (const auto& eager_symbol : eager_symbol_list.eager_symbol()) { StorageAdd(eager_symbol); } return vm::Run(instruction_list_proto); } -} // namespace - -Maybe RunPhysicalInstruction(const std::string& instruction_list_proto_str, - const std::string& eager_symbol_list_str) { - vm::InstructionListProto instruction_list_proto; - CHECK_OR_RETURN(TxtString2PbMessage(instruction_list_proto_str, &instruction_list_proto)) +Maybe EagerOneflow::RunPhysicalInstruction(const std::string& instruction_list_proto_str, + const std::string& eager_symbol_list_str) { + auto cluster_instruction = std::make_shared(); + vm::InstructionListProto* instruction_list_proto = + cluster_instruction->mutable_eager_instruction()->mutable_instruction_list(); + CHECK_OR_RETURN(TxtString2PbMessage(instruction_list_proto_str, instruction_list_proto)) << "InstructionListProto parse failed"; - EagerSymbolList eager_symbol_list; - CHECK_OR_RETURN(TxtString2PbMessage(eager_symbol_list_str, &eager_symbol_list)) + EagerSymbolList* eager_symbol_list = + cluster_instruction->mutable_eager_instruction()->mutable_eager_symbol_list(); + CHECK_OR_RETURN(TxtString2PbMessage(eager_symbol_list_str, eager_symbol_list)) << "EagerSymbolList parse failed"; - return RunPhysicalInstruction(instruction_list_proto, eager_symbol_list); + return RunPhysicalInstruction( + std::const_pointer_cast(cluster_instruction)); +} + +Maybe EagerOneflow::RunLogicalInstruction( + const std::shared_ptr& cluster_instruction) { + CHECK(cluster_instruction->has_eager_instruction()); + CHECK(Global::Get()->IsThisMachineMaster()); + ClusterInstruction::MasterSendEagerInstruction(*cluster_instruction); + return RunPhysicalInstruction(cluster_instruction); } -Maybe RunLogicalInstruction(const std::string& instruction_list_proto_str, - const std::string& eager_symbol_list_str) { - vm::InstructionListProto instruction_list_proto; - CHECK_OR_RETURN(TxtString2PbMessage(instruction_list_proto_str, &instruction_list_proto)) +Maybe EagerOneflow::RunLogicalInstruction(const std::string& instruction_list_proto_str, + const std::string& eager_symbol_list_str) { + auto cluster_instruction = std::make_shared(); + vm::InstructionListProto* instruction_list_proto = + cluster_instruction->mutable_eager_instruction()->mutable_instruction_list(); + CHECK_OR_RETURN(TxtString2PbMessage(instruction_list_proto_str, instruction_list_proto)) << "InstructionListProto parse failed"; - EagerSymbolList eager_symbol_list; - CHECK_OR_RETURN(TxtString2PbMessage(eager_symbol_list_str, &eager_symbol_list)) + EagerSymbolList* eager_symbol_list = + cluster_instruction->mutable_eager_instruction()->mutable_eager_symbol_list(); + CHECK_OR_RETURN(TxtString2PbMessage(eager_symbol_list_str, eager_symbol_list)) << "EagerSymbolList parse failed"; - return RunLogicalInstruction(instruction_list_proto, eager_symbol_list); + return RunLogicalInstruction( + std::const_pointer_cast(cluster_instruction)); } +COMMAND(Global::SetAllocated(new EagerOneflow())); + } // namespace eager } // namespace oneflow diff --git a/oneflow/core/eager/eager_oneflow.h b/oneflow/core/eager/eager_oneflow.h new file mode 100644 index 0000000000000000000000000000000000000000..1ec038912fc1dcc2187d00a825a2b4ba3c307fa9 --- /dev/null +++ b/oneflow/core/eager/eager_oneflow.h @@ -0,0 +1,42 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#ifndef ONEFLOW_CORE_EAGER_EAGER_ONEFLOW_H_ +#define ONEFLOW_CORE_EAGER_EAGER_ONEFLOW_H_ + +#include "oneflow/core/common/maybe.h" +#include "oneflow/core/job/cluster_instruction.pb.h" + +namespace oneflow { +namespace eager { + +class EagerOneflow final { + public: + Maybe RunLogicalInstruction( + const std::shared_ptr& cluster_instruction); + + Maybe RunLogicalInstruction(const std::string& instruction_list_proto_str, + const std::string& eager_symbol_list_str); + + Maybe RunPhysicalInstruction(const std::string& instruction_list_proto_str, + const std::string& eager_symbol_list_str); + Maybe RunPhysicalInstruction( + const std::shared_ptr& cluster_instruction); +}; + +} // namespace eager +} // namespace oneflow + +#endif // ONEFLOW_CORE_EAGER_EAGER_ONEFLOW_H_ diff --git a/oneflow/core/eager/eager_util.h b/oneflow/core/eager/eager_util.h deleted file mode 100644 index 5478411b80fa61a070fd496f547d6660f426dfd7..0000000000000000000000000000000000000000 --- a/oneflow/core/eager/eager_util.h +++ /dev/null @@ -1,32 +0,0 @@ -/* -Copyright 2020 The OneFlow Authors. All rights reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ -#ifndef ONEFLOW_CORE_EAGER_EAGER_UTIL_H_ -#define ONEFLOW_CORE_EAGER_EAGER_UTIL_H_ - -#include "oneflow/core/common/maybe.h" - -namespace oneflow { -namespace eager { - -Maybe RunPhysicalInstruction(const std::string& instruction_list_proto_str, - const std::string& eager_symbol_list_str); -Maybe RunLogicalInstruction(const std::string& instruction_list_proto_str, - const std::string& eager_symbol_list_str); - -} // namespace eager -} // namespace oneflow - -#endif // ONEFLOW_CORE_EAGER_EAGER_UTIL_H_ diff --git a/oneflow/core/job/cluster.cpp b/oneflow/core/job/cluster.cpp index e611d00fa3ad8b9724c37ae38a88c72f27b4f991..b3aa866ef0118dc6d0d9a6c29d00192c945a50a5 100644 --- a/oneflow/core/job/cluster.cpp +++ b/oneflow/core/job/cluster.cpp @@ -16,33 +16,68 @@ limitations under the License. #include "oneflow/core/job/cluster.h" #include "oneflow/core/job/cluster_instruction.pb.h" #include "oneflow/core/job/cluster_instruction.h" +#include "oneflow/core/eager/eager_oneflow.h" #include "oneflow/core/control/ctrl_client.h" #include "oneflow/core/job/oneflow.h" #include "oneflow/core/job/machine_context.h" #include "oneflow/core/job/session_global_objects_scope.h" #include "oneflow/core/job/env_global_objects_scope.h" #include "oneflow/core/job/job_set.pb.h" +#include "oneflow/core/thread/thread_pool.h" namespace oneflow { -Maybe Cluster::WorkerLoop() { - CHECK_OR_RETURN(!Global::Get()->IsThisMachineMaster()); - ClusterInstructionProto cluster_instruction; - while (ClusterInstruction::WorkerReceiveHalt(&cluster_instruction) == false) { +namespace { + +void AsyncRunLazyJobSet(ThreadPool* lazy_runtime_thread) { + lazy_runtime_thread->AddWork([] { ConfigProto config_proto; Global::Get()->PullKV("config_proto", &config_proto); int32_t machine_num = config_proto.resource().machine_num(); - if (Global::Get()->this_machine_id() >= machine_num) { continue; } + // do nothing if it's not my business + if (Global::Get()->this_machine_id() >= machine_num) { return; } Global::New(); - JUST(Global::Get()->Init(config_proto)); - + CHECK_JUST(Global::Get()->Init(config_proto)); JobSet job_set; Global::Get()->PullKV("session_job_set", &job_set); { Oneflow oneflow; - JUST(oneflow.Init(job_set)); + CHECK_JUST(oneflow.Init(job_set)); } Global::Delete(); + }); +} + +} // namespace + +Maybe Cluster::WorkerLoop() { + // The reason why excluding master machine is that + // eager instruction for compile-time symbol constructing must be done synchronously + CHECK_OR_RETURN(!Global::Get()->IsThisMachineMaster()); + { + // Oneflow::~Oneflow blocking in current thread is not acceptable + // Two reasons why `lazy_runtime_thread` is needed: + // 1. making current thread non-block by + // taking over the execution of Oneflow::~Oneflow + // 2. as a Synchronizing guard for all unfinished Oneflow::~Oneflow + // + // thread_num must be 1. + ThreadPool lazy_runtime_thread(1); + while (true) { + auto mut_cluster_instruction = std::make_shared(); + ClusterInstruction::WorkerReceiveInstruction(mut_cluster_instruction.get()); + if (mut_cluster_instruction->has_cluster_ctrl_halt()) { + break; + } else if (mut_cluster_instruction->has_cluster_ctrl_session_start()) { + ClusterInstruction::NewSessionBarrier(); + AsyncRunLazyJobSet(&lazy_runtime_thread); + } else if (mut_cluster_instruction->has_eager_instruction()) { + Global::Get()->RunPhysicalInstruction( + std::const_pointer_cast(mut_cluster_instruction)); + } else { + OF_UNIMPLEMENTED(); + } + } } ClusterInstruction::HaltBarrier(); Global::Delete(); diff --git a/oneflow/core/job/cluster_instruction.cpp b/oneflow/core/job/cluster_instruction.cpp index 60f75b8f04435bcd700650bd9e76c0061fc8e514..dd037c07ffb17d250186551632098ed2c1ec4337 100644 --- a/oneflow/core/job/cluster_instruction.cpp +++ b/oneflow/core/job/cluster_instruction.cpp @@ -13,6 +13,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#include #include "oneflow/core/job/cluster_instruction.h" #include "oneflow/core/job/cluster_instruction.pb.h" #include "oneflow/core/control/ctrl_server.h" @@ -24,46 +25,103 @@ namespace oneflow { namespace { -void BarrierClear() { - OF_BARRIER_ALL(); - Global::Get()->Clear(); - OF_BARRIER_ALL(); -} - std::string GetHaltAckCtrlKey(int64_t machine_id) { return "HaltAckCtrlKey/" + std::to_string(machine_id); } // return unique sequential key // because ctrl key is not allowed to push/pull twice -std::string GetHaltOrSessionStartCtrlKey() { +std::string GetClusterInstructionKey() { static int64_t seq = 0; - return "HaltOrSessionStart/" + std::to_string(seq++); + return "ClusterInstructionKey/" + std::to_string(seq++); +} + +class ObsoleteCtrlKeys { + public: + ObsoleteCtrlKeys() = default; + ~ObsoleteCtrlKeys() = default; + + template + void ForEach(const CallbackT& Callback) const { + std::unique_lock lck(mutex_); + for (const std::string& k : keys_) { Callback(k); } + } + + void Clear() { + std::unique_lock lck(mutex_); + keys_.clear(); + } + void Add(const std::string& key) { + std::unique_lock lck(mutex_); + keys_.push_back(key); + } + + private: + mutable std::mutex mutex_; + std::vector keys_; +}; + +COMMAND(Global::SetAllocated(new ObsoleteCtrlKeys())); + +void OccasionallyClearCtrlKV(const std::string& key) { + static std::atomic seq(0LL); + const static int64_t interval = 65536; + Global::Get()->Add(key); + // 1 instead of 0 is better for avoid clearing no ctrl kv + if ((seq++) % interval == 1) { + OF_BARRIER_ALL(); + if (Global::Get()->IsThisMachineMaster()) { + Global::Get()->ForEach( + [](const std::string& k) { Global::Get()->ClearMasterKV(k); }); + } + Global::Get()->Clear(); + OF_BARRIER_ALL(); + } +} + +void PushClusterInstruction(const ClusterInstructionProto& cluster_instruction) { + const std::string& key = GetClusterInstructionKey(); + Global::Get()->PushMasterKV(key, cluster_instruction); + OccasionallyClearCtrlKV(key); +} + +void PullClusterInstruction(ClusterInstructionProto* cluster_instruction) { + const std::string& key = GetClusterInstructionKey(); + Global::Get()->PullMasterKV(key, cluster_instruction); + OccasionallyClearCtrlKV(key); } } // namespace +void ClusterInstruction::NewSessionBarrier() { + OF_BARRIER_ALL(); + Global::Get()->Clear(); + Global::Get()->Clear(); + OF_BARRIER_ALL(); +} + void ClusterInstruction::MasterSendSessionStart() { - BarrierClear(); ClusterInstructionProto cluster_instruction; cluster_instruction.mutable_cluster_ctrl_session_start(); - Global::Get()->PushKV(GetHaltOrSessionStartCtrlKey(), cluster_instruction); + PushClusterInstruction(cluster_instruction); + NewSessionBarrier(); } void ClusterInstruction::MasterSendHalt() { - BarrierClear(); ClusterInstructionProto cluster_instruction; cluster_instruction.mutable_cluster_ctrl_halt(); - Global::Get()->PushKV(GetHaltOrSessionStartCtrlKey(), cluster_instruction); + PushClusterInstruction(cluster_instruction); HaltBarrier(); } -bool ClusterInstruction::WorkerReceiveHalt(ClusterInstructionProto* cluster_instruction) { - BarrierClear(); - Global::Get()->PullKV(GetHaltOrSessionStartCtrlKey(), cluster_instruction); - if (cluster_instruction->has_cluster_ctrl_halt()) { return true; } - CHECK(cluster_instruction->has_cluster_ctrl_session_start()); - return false; +void ClusterInstruction::MasterSendEagerInstruction( + const ClusterInstructionProto& cluster_instruction) { + CHECK(cluster_instruction.has_eager_instruction()); + PushClusterInstruction(cluster_instruction); +} + +void ClusterInstruction::WorkerReceiveInstruction(ClusterInstructionProto* cluster_instruction) { + PullClusterInstruction(cluster_instruction); } void ClusterInstruction::HaltBarrier() { OF_BARRIER_ALL(); } diff --git a/oneflow/core/job/cluster_instruction.h b/oneflow/core/job/cluster_instruction.h index c14d99852f5c355d1261a565d326fb4cce0b7272..cadc399819466aaa9a712252d6328cfa97dcdf9d 100644 --- a/oneflow/core/job/cluster_instruction.h +++ b/oneflow/core/job/cluster_instruction.h @@ -22,8 +22,10 @@ namespace oneflow { struct ClusterInstruction final { static void MasterSendSessionStart(); - static bool WorkerReceiveHalt(ClusterInstructionProto* cluster_instruction); static void MasterSendHalt(); + static void MasterSendEagerInstruction(const ClusterInstructionProto& cluster_instruction); + static void WorkerReceiveInstruction(ClusterInstructionProto* cluster_instruction); + static void NewSessionBarrier(); static void HaltBarrier(); }; diff --git a/oneflow/core/job/cluster_instruction.proto b/oneflow/core/job/cluster_instruction.proto index 9b564acb790d760bd8d37ffd1023f0d7c9f86fdd..c313225e4f32d846ce909d45aa796ff2c1cfc13c 100644 --- a/oneflow/core/job/cluster_instruction.proto +++ b/oneflow/core/job/cluster_instruction.proto @@ -1,6 +1,8 @@ syntax = "proto2"; package oneflow; +import "oneflow/core/eager/eager_instruction.proto"; + message ClusterCtrlSessionStart {} message ClusterCtrlHalt {} @@ -8,5 +10,6 @@ message ClusterInstructionProto { oneof instruction_type { ClusterCtrlSessionStart cluster_ctrl_session_start = 1; ClusterCtrlHalt cluster_ctrl_halt = 2; + eager.EagerInstruction eager_instruction = 3; } } diff --git a/oneflow/python/oneflow_internal_helper.h b/oneflow/python/oneflow_internal_helper.h index ed7e43ca1a0013a860728f5c75915afcbcec781a..e0d461de192d02fb4dbbd6a36d0fa01305a64ae1 100644 --- a/oneflow/python/oneflow_internal_helper.h +++ b/oneflow/python/oneflow_internal_helper.h @@ -44,7 +44,7 @@ limitations under the License. #include "oneflow/core/vm/instruction.pb.h" #include "oneflow/core/vm/vm_util.h" #include "oneflow/core/vm/id_util.h" -#include "oneflow/core/eager/eager_util.h" +#include "oneflow/core/eager/eager_oneflow.h" #include "oneflow/core/eager/eager_symbol_storage.h" #ifdef WITH_TENSORRT @@ -271,12 +271,14 @@ Maybe GetOpParallelSymbolId(const std::string& op_conf_str) { Maybe RunLogicalInstruction(const std::string& instruction_list_str, const std::string& eager_symbol_list_str) { - return eager::RunLogicalInstruction(instruction_list_str, eager_symbol_list_str); + return Global::Get()->RunLogicalInstruction(instruction_list_str, + eager_symbol_list_str); } Maybe RunPhysicalInstruction(const std::string& instruction_list_str, const std::string& eager_symbol_list_str) { - return eager::RunPhysicalInstruction(instruction_list_str, eager_symbol_list_str); + return Global::Get()->RunPhysicalInstruction(instruction_list_str, + eager_symbol_list_str); } Maybe CurrentMachineId() { diff --git a/oneflow/python/test/ops/test_ccrelu.py b/oneflow/python/test/ops/test_ccrelu.py index 00620beba74cde61d17cc5c3484d2912dee51466..0d6e164680796772dd88f0d631c7d6e74db3a544 100644 --- a/oneflow/python/test/ops/test_ccrelu.py +++ b/oneflow/python/test/ops/test_ccrelu.py @@ -89,7 +89,28 @@ def test_1n2c_mirror_dynamic_ccrelu(test_case): @flow.unittest.num_nodes_required(2) -def test_ccrelu_2n1c(test_case): +def test_ccrelu_2n1c_0(test_case): + func_config = flow.FunctionConfig() + func_config.default_logical_view(flow.scope.consistent_view()) + fixed_tensor_def_test(test_case, func_config) + + +@flow.unittest.num_nodes_required(2) +def test_ccrelu_2n1c_1(test_case): + func_config = flow.FunctionConfig() + func_config.default_logical_view(flow.scope.consistent_view()) + fixed_tensor_def_test(test_case, func_config) + + +@flow.unittest.num_nodes_required(2) +def test_ccrelu_2n1c_2(test_case): + func_config = flow.FunctionConfig() + func_config.default_logical_view(flow.scope.consistent_view()) + fixed_tensor_def_test(test_case, func_config) + + +@flow.unittest.num_nodes_required(2) +def test_ccrelu_2n1c_3(test_case): func_config = flow.FunctionConfig() func_config.default_logical_view(flow.scope.consistent_view()) fixed_tensor_def_test(test_case, func_config)