提交 c8b21ac0 编写于 作者: L Li Xinqi 提交者: GitHub

Bugfix split config proto and session job set (#3637)

* rename OF_BARRIAER

* add eager_2node_test.py

* InitLazyGlobalSession if eager execution not enabled

* remove Global<LbiDiffWatcherInfo>

* add TODO() comments for OF_SESSION_BARRIER under directory core/comm_network/

* remove eager_2node_test.py

Former-commit-id: 4d3e2d33
上级 d125c86b
......@@ -78,7 +78,8 @@ EpollCommNet::~EpollCommNet() {
LOG(INFO) << "CommNet Thread " << i << " finish";
pollers_[i]->Stop();
}
OF_BARRIER();
// TODO(chengcheng): change to OF_ENV_BARRIER
OF_SESSION_BARRIER();
for (IOEventPoller* poller : pollers_) { delete poller; }
for (auto& pair : sockfd2helper_) { delete pair.second; }
}
......
......@@ -66,7 +66,8 @@ void IBVerbsCommNet::RegisterMemoryDone() {
.second);
}
}
OF_BARRIER();
// TODO(chengcheng): change to OF_ENV_BARRIER
OF_SESSION_BARRIER();
Global<CtrlClient>::Get()->ClearKV(GenTokensMsgKey(this_machine_id));
}
......@@ -111,14 +112,17 @@ IBVerbsCommNet::IBVerbsCommNet(const Plan& plan)
Global<CtrlClient>::Get()->PullKV(GenConnInfoKey(peer_id, this_machine_id), &conn_info);
qp_vec_.at(peer_id)->Connect(conn_info);
}
OF_BARRIER();
// TODO(chengcheng): change to OF_ENV_BARRIER
OF_SESSION_BARRIER();
for (int64_t peer_id : peer_machine_id()) {
qp_vec_.at(peer_id)->PostAllRecvRequest();
Global<CtrlClient>::Get()->ClearKV(GenConnInfoKey(this_machine_id, peer_id));
}
OF_BARRIER();
// TODO(chengcheng): change to OF_ENV_BARRIER
OF_SESSION_BARRIER();
poll_thread_ = std::thread(&IBVerbsCommNet::PollCQ, this);
OF_BARRIER();
// TODO(chengcheng): change to OF_ENV_BARRIER
OF_SESSION_BARRIER();
}
void IBVerbsCommNet::DoRead(void* read_id, int64_t src_machine_id, void* src_token,
......
......@@ -86,8 +86,8 @@ class CtrlClient final {
#define FILE_LINE_STR __FILE__ ":" OF_PP_STRINGIZE(__LINE__)
#define OF_BARRIER_ALL() Global<CtrlClient>::Get()->Barrier(FILE_LINE_STR)
#define OF_BARRIER() \
#define OF_ENV_BARRIER() Global<CtrlClient>::Get()->Barrier(FILE_LINE_STR)
#define OF_SESSION_BARRIER() \
Global<CtrlClient>::Get()->Barrier(FILE_LINE_STR, \
Global<ResourceDesc, ForSession>::Get()->TotalMachineNum())
......
......@@ -90,7 +90,7 @@ TEST(CtrlServer, new_delete) {
Global<ResourceDesc, ForSession>::New(GetResource());
// do test
// OF_BARRIER_ALL();
// OF_ENV_BARRIER();
Global<ResourceDesc, ForSession>::Delete();
Global<ResourceDesc, ForEnv>::Delete();
......
......@@ -69,13 +69,13 @@ void OccasionallyClearCtrlKV(const std::string& key) {
Global<ObsoleteCtrlKeys>::Get()->Add(key);
// 1 instead of 0 is better for avoid clearing no ctrl kv
if ((seq++) % interval == 1) {
OF_BARRIER_ALL();
OF_ENV_BARRIER();
if (Global<MachineCtx>::Get()->IsThisMachineMaster()) {
Global<ObsoleteCtrlKeys>::Get()->ForEach(
[](const std::string& k) { Global<CtrlClient>::Get()->ClearMasterKV(k); });
}
Global<ObsoleteCtrlKeys>::Get()->Clear();
OF_BARRIER_ALL();
OF_ENV_BARRIER();
}
}
......@@ -94,10 +94,10 @@ void PullClusterInstruction(ClusterInstructionProto* cluster_instruction) {
} // namespace
void ClusterInstruction::NewSessionBarrier() {
OF_BARRIER_ALL();
OF_ENV_BARRIER();
Global<CtrlClient>::Get()->Clear();
Global<ObsoleteCtrlKeys>::Get()->Clear();
OF_BARRIER_ALL();
OF_ENV_BARRIER();
}
void ClusterInstruction::MasterSendSessionStart() {
......@@ -124,6 +124,6 @@ void ClusterInstruction::WorkerReceiveInstruction(ClusterInstructionProto* clust
PullClusterInstruction(cluster_instruction);
}
void ClusterInstruction::HaltBarrier() { OF_BARRIER_ALL(); }
void ClusterInstruction::HaltBarrier() { OF_ENV_BARRIER(); }
} // namespace oneflow
......@@ -11,6 +11,7 @@ import "oneflow/core/register/blob_desc.proto";
import "oneflow/core/operator/op_conf.proto";
import "oneflow/core/common/shape.proto";
import "oneflow/core/job/sbp_parallel.proto";
import "oneflow/core/job/lbi_diff_watcher_info.proto";
message JobParallelViewConf {
map<string, SbpSignature> op_name2sbp_signature_conf = 1;
......@@ -30,6 +31,7 @@ message JobHelperConf {
map<string, int64> lbn2logical_object_id = 5;
map<string, OptInt64> lbn2batch_axis = 6;
optional OpBlobArgPairs identical_sbp_oba_pairs = 7;
optional LbiDiffWatcherInfo lbi_diff_watcher_info = 8;
}
message Job {
......
......@@ -437,6 +437,25 @@ Maybe<void> EagerJobBuildAndInferCtx::CheckAllInputsWithSameParallelNum(
return Maybe<void>::Ok();
}
Maybe<void> JobBuildAndInferCtx::AddLbiAndDiffWatcherUuidPair(
const LbiAndDiffWatcherUuidPair& lbi_uuid_pair) {
const auto& job_name = job_->job_conf().job_name();
auto* job_helper = job_->mutable_helper();
auto* job_name2pairs =
job_helper->mutable_lbi_diff_watcher_info()->mutable_job_name2lbi_and_watcher_uuids();
LbiAndDiffWatcherUuidPairList* pairs = &(*job_name2pairs)[job_name];
auto PairFoundCond = [&](const LbiAndDiffWatcherUuidPair& x) {
return x.lbi() == lbi_uuid_pair.lbi() && x.watcher_uuid() == lbi_uuid_pair.watcher_uuid();
};
auto found_iter = std::find_if(pairs->lbi_and_uuid_pair().begin(),
pairs->lbi_and_uuid_pair().end(), PairFoundCond);
CHECK_OR_RETURN(found_iter == pairs->lbi_and_uuid_pair().end())
<< "diff blob has been watched. (logical_blob_name: "
<< GenLogicalBlobName(lbi_uuid_pair.lbi()) << ", job_name: " << job_name << ")";
*pairs->mutable_lbi_and_uuid_pair()->Add() = lbi_uuid_pair;
return Maybe<void>::Ok();
}
Maybe<OpAttribute> JobBuildAndInferCtx::AddAndInferMirroredOp(const OperatorConf& op_conf) {
CHECK_OR_RETURN(op_conf.has_scope_symbol_id());
const auto& scope = Global<vm::SymbolStorage<Scope>>::Get()->Get(op_conf.scope_symbol_id());
......
......@@ -34,6 +34,7 @@ class JobBuildAndInferCtx {
virtual ~JobBuildAndInferCtx() = default;
Maybe<void> SetJobConf(const JobConfigProto& job_conf);
Maybe<void> AddLbiAndDiffWatcherUuidPair(const LbiAndDiffWatcherUuidPair& lbi_uuid_pair);
Maybe<OpAttribute> AddAndInferConsistentOp(const OperatorConf& op_conf);
Maybe<OpAttribute> AddAndInferMirroredOp(const OperatorConf& op_conf);
Maybe<void> AddLossLogicalBlobName(const std::string& lbn);
......
......@@ -59,24 +59,6 @@ Maybe<std::string> JobBuildAndInferCtxMgr::GetCurrentJobName() const {
return cur_job_name_;
}
Maybe<void> JobBuildAndInferCtxMgr::AddLbiAndDiffWatcherUuidPair(
const LbiAndDiffWatcherUuidPair& lbi_uuid_pair) const {
auto* job_name2pairs =
Global<LbiDiffWatcherInfo>::Get()->mutable_job_name2lbi_and_watcher_uuids();
const auto& job_name = JUST(GetCurrentJobName());
LbiAndDiffWatcherUuidPairList* pairs = &(*job_name2pairs)[*job_name];
auto PairFoundCond = [&](const LbiAndDiffWatcherUuidPair& x) {
return x.lbi() == lbi_uuid_pair.lbi() && x.watcher_uuid() == lbi_uuid_pair.watcher_uuid();
};
auto found_iter = std::find_if(pairs->lbi_and_uuid_pair().begin(),
pairs->lbi_and_uuid_pair().end(), PairFoundCond);
CHECK_OR_RETURN(found_iter == pairs->lbi_and_uuid_pair().end())
<< "diff blob has been watched. (logical_blob_name: "
<< GenLogicalBlobName(lbi_uuid_pair.lbi()) << ", job_name: " << *job_name << ")";
*pairs->mutable_lbi_and_uuid_pair()->Add() = lbi_uuid_pair;
return Maybe<void>::Ok();
}
Maybe<void> JobBuildAndInferCtxMgr::CloseCurrentJobBuildAndInferCtx() {
VirtualCloseJob();
if (!has_cur_job_) { return Maybe<void>::Ok(); }
......
......@@ -34,7 +34,6 @@ class JobBuildAndInferCtxMgr {
Maybe<JobBuildAndInferCtx*> FindJobBuildAndInferCtx(const std::string& job_name);
Maybe<std::string> GetCurrentJobName() const;
Maybe<void> CloseCurrentJobBuildAndInferCtx();
Maybe<void> AddLbiAndDiffWatcherUuidPair(const LbiAndDiffWatcherUuidPair& lbi_uuid_pair) const;
const JobSet& job_set() const { return job_set_; }
std::string structure_graph() const;
......
......@@ -315,7 +315,7 @@ Maybe<void> CompileCurJobOnMaster(Job* job, Plan* improved_plan, bool need_job_c
} else {
PullPlan("complete_plan", &complete_plan);
}
OF_BARRIER();
OF_SESSION_BARRIER();
// Experiment Runtime
{ Runtime experiment_run(complete_plan, job_desc.piece_num_of_experiment_phase(), true); }
// Improve
......@@ -325,7 +325,7 @@ Maybe<void> CompileCurJobOnMaster(Job* job, Plan* improved_plan, bool need_job_c
*improved_plan = *JUST(Improver().Improve(
*Global<AvailableMemDesc>::Get(), naive_plan,
JoinPath(FLAGS_log_dir, ActEventLogger::experiment_act_event_bin_filename())));
OF_BARRIER();
OF_SESSION_BARRIER();
TeePersistentLogStream::Create("improved_plan")->Write(*improved_plan);
}
} else {
......@@ -972,7 +972,7 @@ Maybe<void> CompileAndMergePlanOnMaster(const PbRpf<Job>& conf_jobs, Plan* plan)
TeePersistentLogStream::Create("merged_plan")->Write(*plan);
}
}
OF_BARRIER();
OF_SESSION_BARRIER();
return Maybe<void>::Ok();
}
......
......@@ -80,17 +80,17 @@ Runtime::Runtime(const Plan& plan, size_t total_piece_num, bool is_experiment_ph
HandoutTasks(other_tasks);
runtime_ctx->WaitUntilCntEqualZero("constructing_actor_cnt");
LOG(INFO) << "Actors on this machine constructed";
OF_BARRIER();
OF_SESSION_BARRIER();
LOG(INFO) << "Actors on every machine constructed";
if (Global<CommNet>::Get()) { Global<CommNet>::Get()->RegisterMemoryDone(); }
OF_BARRIER();
OF_SESSION_BARRIER();
runtime_ctx->NewCounter("running_actor_cnt", this_machine_task_num);
SendCmdMsg(source_tasks, ActorCmd::kStart);
}
Runtime::~Runtime() {
Global<RuntimeCtx>::Get()->WaitUntilCntEqualZero("running_actor_cnt");
OF_BARRIER();
OF_SESSION_BARRIER();
DeleteAllGlobal();
}
......
......@@ -30,7 +30,6 @@ limitations under the License.
#include "oneflow/core/job/job_desc.h"
#include "oneflow/core/job/critical_section_desc.h"
#include "oneflow/core/job/job_build_and_infer_ctx_mgr.h"
#include "oneflow/core/job/lbi_diff_watcher_info.pb.h"
#include "oneflow/core/job/job_set_compile_ctx.h"
#include "oneflow/core/job/runtime_buffer_managers_scope.h"
#include "oneflow/core/framework/load_library.h"
......@@ -89,7 +88,6 @@ Maybe<void> SessionGlobalObjectsScope::Init(const ConfigProto& config_proto) {
Global<CriticalSectionDesc>::New();
Global<InterUserJobInfo>::New();
Global<LazyJobBuildAndInferCtxMgr>::New();
Global<LbiDiffWatcherInfo>::New();
Global<JobSetCompileCtx>::New();
Global<RuntimeBufferManagersScope>::New();
}
......@@ -101,7 +99,6 @@ SessionGlobalObjectsScope::~SessionGlobalObjectsScope() {
if (Global<MachineCtx>::Get()->IsThisMachineMaster()) {
Global<RuntimeBufferManagersScope>::Delete();
Global<JobSetCompileCtx>::Delete();
Global<LbiDiffWatcherInfo>::Delete();
Global<LazyJobBuildAndInferCtxMgr>::Delete();
Global<InterUserJobInfo>::Delete();
Global<CriticalSectionDesc>::Delete();
......
......@@ -29,7 +29,7 @@ class AddLbiDiffWatcherOpConfs final : public OpGraphPass {
Maybe<void> AddLbiDiffWatcherOpConfs::Apply(Job* job) const {
JobBuilder job_builder(job);
const auto& map = Global<LbiDiffWatcherInfo>::Get()->job_name2lbi_and_watcher_uuids();
const auto& map = job->helper().lbi_diff_watcher_info().job_name2lbi_and_watcher_uuids();
if (map.find(GlobalJobDesc().job_name()) == map.end()) { return Maybe<void>::Ok(); }
const auto& tag2lbi_relations = job->helper().tag2lbi_relations();
const auto& conf_iter = tag2lbi_relations.find(kProducedLbi2ConsumedDiffLbi);
......
......@@ -307,9 +307,9 @@ Maybe<void> TestTransportOn2Machine(const std::string& first_machine_ip,
Global<EpollCommNet>::New();
Global<Transport>::New();
// OF_BARRIER Must call before test,
// OF_ENV_BARRIER Must call before test,
// to ensure that the Global<Transport> on each machine is created
OF_BARRIER_ALL();
OF_ENV_BARRIER();
// Test for correctness
// Each machine will send and receive 100 messages (50 send and 50 recv) alternately.
......@@ -320,7 +320,7 @@ Maybe<void> TestTransportOn2Machine(const std::string& first_machine_ip,
TestThroughput();
OF_BARRIER_ALL();
OF_ENV_BARRIER();
std::cout << "Deleting all global..." << std::endl;
Global<Transport>::Delete();
Global<EpollCommNet>::Delete();
......
......@@ -114,31 +114,31 @@ def IsSessionInited():
return oneflow_internal.IsSessionInited()
def InitGlobalSession(config_proto):
def InitLazyGlobalSession(config_proto):
assert type(config_proto) is job_set_pb.ConfigProto
config_proto_str = text_format.MessageToString(config_proto)
error_str = oneflow_internal.InitGlobalSession(config_proto_str)
error_str = oneflow_internal.InitLazyGlobalSession(config_proto_str)
error = text_format.Parse(error_str, error_util.ErrorProto())
if error.HasField("error_type"):
raise JobBuildAndInferError(error)
def DestroyGlobalSession():
error_str = oneflow_internal.DestroyGlobalSession()
def DestroyLazyGlobalSession():
error_str = oneflow_internal.DestroyLazyGlobalSession()
error = text_format.Parse(error_str, error_util.ErrorProto())
if error.HasField("error_type"):
raise JobBuildAndInferError(error)
def StartGlobalSession():
error_str = oneflow_internal.StartGlobalSession()
def StartLazyGlobalSession():
error_str = oneflow_internal.StartLazyGlobalSession()
error = text_format.Parse(error_str, error_util.ErrorProto())
if error.HasField("error_type"):
raise JobBuildAndInferError(error)
def StopGlobalSession():
error_str = oneflow_internal.StopGlobalSession()
def StopLazyGlobalSession():
error_str = oneflow_internal.StopLazyGlobalSession()
error = text_format.Parse(error_str, error_util.ErrorProto())
if error.HasField("error_type"):
raise JobBuildAndInferError(error)
......
......@@ -58,9 +58,7 @@ def InterpretScope(session, function_desc, config_proto):
job_conf.job_name = function_desc.job_func.__name__
placement_scope = function_desc.function_attribute.default_placement_scope
if placement_scope is None:
tag_and_dev_ids = placement_util.GetDefaultMachineDeviceIds(
oneflow.env.current_resource()
)
tag_and_dev_ids = placement_util.GetDefaultMachineDeviceIds(session.resource)
placement_scope = placement_util.GetPlacementScope(*tag_and_dev_ids)
distribute_strategy = function_desc.function_attribute.default_distribute_strategy
if distribute_strategy is None:
......
......@@ -60,6 +60,7 @@ class Session(object):
self.inter_user_job_info_ = None
self.uuid2watch_handler_ = {}
self.config_proto_ = None
self.resource_ = None
self.placement_scope_stack_ = []
self.is_mirrored_strategy_enabled_stack_ = []
self.function_flag_name2default_val_ = {}
......@@ -93,6 +94,13 @@ class Session(object):
self.config_proto_ = _GetDefaultConfigProto()
return self.config_proto_
@property
def resource(self):
if self.resource_ is None:
return oneflow.env.current_resource()
else:
return self.resource_
@property
def uuid2watch_handler(self):
return self.uuid2watch_handler_
......@@ -210,14 +218,15 @@ class Session(object):
if not c_api_util.IsEnvInited():
oneflow.env.init()
_TryCompleteConfigProto(self.config_proto)
c_api_util.InitGlobalSession(self.config_proto)
self.resource_ = self.config_proto.resource
if not c_api_util.EagerExecutionEnabled():
c_api_util.InitLazyGlobalSession(self.config_proto)
for job_name, func_desc in self.job_name2function_desc_.items():
compiler.Compile(self, func_desc, self.config_proto)
self.existed_module_names_ = set()
self.job_name2var_name2var_blob_ = dict()
assert len(self.job_name2function_desc_.items()) > 0
c_api_util.StartGlobalSession()
c_api_util.StartLazyGlobalSession()
self.inter_user_job_info_ = c_api_util.GetInterUserJobInfo()
return self
......@@ -232,9 +241,10 @@ class Session(object):
del self.var_name2var_blob_
del self.job_name2module_name2module_
self.ForceReleaseEagerBlobs()
c_api_util.StopGlobalSession()
c_api_util.DestroyGlobalSession()
c_api_util.StopLazyGlobalSession()
c_api_util.DestroyLazyGlobalSession()
self.status_ = SessionStatus.CLOSED
self.resource_ = None
def AddJob(self, function_desc):
assert self.status_ is SessionStatus.OPEN
......
......@@ -101,11 +101,11 @@ Maybe<std::string> CurJobBuildAndInferCtx_AddAndInferConsistentOp(const std::str
Maybe<void> CurJobBuildAndInferCtx_AddLbiAndDiffWatcherUuidPair(
const std::string& lbi_uuid_pair_str) {
auto* mgr = JUST(GlobalJobBuildAndInferCtxMgr());
auto* ctx = JUST(GetCurInferCtx());
LbiAndDiffWatcherUuidPair lbi_uuid_pair;
CHECK_OR_RETURN(TxtString2PbMessage(lbi_uuid_pair_str, &lbi_uuid_pair))
<< "LbiAndDiffWatcherUuidPair parse failed";
return mgr->AddLbiAndDiffWatcherUuidPair(lbi_uuid_pair);
return ctx->AddLbiAndDiffWatcherUuidPair(lbi_uuid_pair);
}
Maybe<std::string> JobBuildAndInferCtx_GetSerializedIdListAsStaticShape(const std::string& job_name,
......
......@@ -67,21 +67,21 @@ bool IsSessionInited() {
return Global<SessionGlobalObjectsScope>::Get() != nullptr;
}
void InitGlobalSession(const std::string& config_proto_str, std::string* error_str) {
void InitLazyGlobalSession(const std::string& config_proto_str, std::string* error_str) {
using namespace oneflow;
return InitGlobalSession(config_proto_str).GetDataAndSerializedErrorProto(error_str);
return InitLazyGlobalSession(config_proto_str).GetDataAndSerializedErrorProto(error_str);
}
void DestroyGlobalSession(std::string* error_str) {
return oneflow::DestroyGlobalSession().GetDataAndSerializedErrorProto(error_str);
void DestroyLazyGlobalSession(std::string* error_str) {
return oneflow::DestroyLazyGlobalSession().GetDataAndSerializedErrorProto(error_str);
}
void StartGlobalSession(std::string* error_str) {
return oneflow::StartGlobalSession().GetDataAndSerializedErrorProto(error_str);
void StartLazyGlobalSession(std::string* error_str) {
return oneflow::StartLazyGlobalSession().GetDataAndSerializedErrorProto(error_str);
}
void StopGlobalSession(std::string* error_str) {
return oneflow::StopGlobalSession().GetDataAndSerializedErrorProto(error_str);
void StopLazyGlobalSession(std::string* error_str) {
return oneflow::StopLazyGlobalSession().GetDataAndSerializedErrorProto(error_str);
}
std::string GetSerializedInterUserJobInfo(std::string* error_str) {
......
......@@ -115,7 +115,7 @@ void FixCpuDeviceNum(ConfigProto* config_proto) {
config_proto->mutable_resource()->set_cpu_device_num(std::thread::hardware_concurrency());
}
Maybe<void> InitGlobalSession(const std::string& config_proto_str) {
Maybe<void> InitLazyGlobalSession(const std::string& config_proto_str) {
CHECK_NOTNULL_OR_RETURN(Global<EnvDesc>::Get()) << "env not found";
CHECK_OR_RETURN(Global<MachineCtx>::Get()->IsThisMachineMaster());
......@@ -134,14 +134,14 @@ Maybe<void> InitGlobalSession(const std::string& config_proto_str) {
return Maybe<void>::Ok();
}
Maybe<void> DestroyGlobalSession() {
Maybe<void> DestroyLazyGlobalSession() {
if (Global<SessionGlobalObjectsScope>::Get() == nullptr) { return Maybe<void>::Ok(); }
CHECK_OR_RETURN(Global<MachineCtx>::Get()->IsThisMachineMaster());
Global<SessionGlobalObjectsScope>::Delete();
return Maybe<void>::Ok();
}
Maybe<void> StartGlobalSession() {
Maybe<void> StartLazyGlobalSession() {
CHECK_NOTNULL_OR_RETURN(Global<SessionGlobalObjectsScope>::Get()) << "session not found";
CHECK_OR_RETURN(Global<MachineCtx>::Get()->IsThisMachineMaster());
const JobSet& job_set = Global<LazyJobBuildAndInferCtxMgr>::Get()->job_set();
......@@ -163,7 +163,7 @@ Maybe<std::string> GetSerializedStructureGraph() {
return job_ctx_mgr->structure_graph();
}
Maybe<void> StopGlobalSession() {
Maybe<void> StopLazyGlobalSession() {
if (Global<Oneflow>::Get() == nullptr) { return Maybe<void>::Ok(); }
CHECK_OR_RETURN(Global<MachineCtx>::Get()->IsThisMachineMaster());
CHECK_NOTNULL_OR_RETURN(Global<Oneflow>::Get());
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册