提交 a5a1cea9 编写于 作者: W willzhang4a58

Singleton Ptr

上级 d39a2f44
......@@ -16,13 +16,13 @@ void Actor::Init(const TaskProto& task_proto, const ThreadCtx& thread_ctx) {
exec_kernel_vec_.reserve(task_proto.exec_sequence().exec_node_size());
for (const ExecNodeProto& node : task_proto.exec_sequence().exec_node()) {
ExecKernel ek;
ek.kernel = KernelMgr::Singleton().GetKernelFromOpName(node.op_name());
ek.kernel = KernelMgr::Singleton()->GetKernelFromOpName(node.op_name());
ek.bn_in_op2regst_desc_id = PbMap2HashMap(node.bn_in_op2regst_desc_id());
exec_kernel_vec_.push_back(std::move(ek));
}
// produced_regsts_
for (const auto& pair : task_proto.produced_regst_desc()) {
RegstMgr::Singleton().NewRegsts(pair.second, [this](Regst* regst) {
RegstMgr::Singleton()->NewRegsts(pair.second, [this](Regst* regst) {
produced_regsts_[regst->regst_desc_id()].emplace_back(regst);
});
}
......@@ -88,7 +88,7 @@ void Actor::AsyncSendReadableRegstMsg() {
device_ctx_->AddCallBack([regst]() {
for (int64_t subscriber : regst->subscribers_actor_id()) {
ActorMsg msg = ActorMsg::BuildReadableRegstMsg(subscriber, regst);
ActorMsgBus::Singleton().SendMsg(std::move(msg));
ActorMsgBus::Singleton()->SendMsg(std::move(msg));
}
});
produced_regst2reading_cnt_.at(regst) =
......@@ -106,7 +106,7 @@ void Actor::AsyncSendEORDMsgToSubscribers(int64_t regst_desc_id) {
ActorMsg msg;
msg.set_dst_actor_id(subscriber);
msg.set_actor_cmd(ActorCmd::kEORD);
ActorMsgBus::Singleton().SendMsg(std::move(msg));
ActorMsgBus::Singleton()->SendMsg(std::move(msg));
}
});
}
......@@ -125,7 +125,7 @@ void Actor::AsyncSendRegstMsgToProducer(
const std::shared_ptr<RegstWarpper>& wp) {
ActorMsg msg = ActorMsg::BuildRegstMsgToProducer(wp->producer_actor_id(),
wp->regst_raw_ptr());
AsyncDo([msg]() { ActorMsgBus::Singleton().SendMsg(msg); });
AsyncDo([msg]() { ActorMsgBus::Singleton()->SendMsg(msg); });
}
int Actor::TryUpdtStateAsProducedRegst(Regst* regst) {
......
......@@ -16,8 +16,8 @@ ActorMsg ActorMsg::BuildReadableRegstMsg(int64_t reader_actor_id,
ActorMsg msg;
msg.dst_actor_id_ = reader_actor_id;
msg.msg_type_ = ActorMsgType::kRegstMsg;
if (IDMgr::Singleton().MachineId4ActorId(reader_actor_id)
== RuntimeCtx::Singleton().this_machine_id()) {
if (IDMgr::Singleton()->MachineId4ActorId(reader_actor_id)
== RuntimeCtx::Singleton()->this_machine_id()) {
msg.regst_warpper_.reset(new LocalRegstWarpper(regst_raw_ptr));
} else {
msg.regst_warpper_.reset(new RemoteRegstWarpper(regst_raw_ptr));
......
......@@ -7,11 +7,11 @@ namespace oneflow {
void ActorMsgBus::SendMsg(const ActorMsg& msg) {
int64_t dst_machine_id =
IDMgr::Singleton().MachineId4ActorId(msg.dst_actor_id());
if (dst_machine_id == RuntimeCtx::Singleton().this_machine_id()) {
IDMgr::Singleton()->MachineId4ActorId(msg.dst_actor_id());
if (dst_machine_id == RuntimeCtx::Singleton()->this_machine_id()) {
int64_t thrd_loc_id =
IDMgr::Singleton().ThrdLocId4ActorId(msg.dst_actor_id());
ThreadMgr::Singleton().GetThrd(thrd_loc_id)->GetMsgChannelPtr()->Send(msg);
IDMgr::Singleton()->ThrdLocId4ActorId(msg.dst_actor_id());
ThreadMgr::Singleton()->GetThrd(thrd_loc_id)->GetMsgChannelPtr()->Send(msg);
} else {
TODO();
}
......
......@@ -40,9 +40,9 @@ bool FwDataCompActor::IsReadReady() {
if (model_regst_desc_id_ != -1) {
// Ho Q, Cipar J, Cui H, et al. More effective distributed ml via a stale
// synchronous parallel parameter server
int32_t staleness = JobDesc::Singleton().staleness();
int32_t staleness = JobDesc::Singleton()->staleness();
int32_t num_of_piece_in_batch =
JobDesc::Singleton().num_of_piece_in_batch();
JobDesc::Singleton()->num_of_piece_in_batch();
int64_t cur_iteration = in_.front()->piece_id() / num_of_piece_in_batch;
int64_t stale_version = cur_iteration - staleness;
return model_regst_->model_version_id() >= stale_version;
......@@ -91,7 +91,7 @@ int FwDataCompActor::HandleFwCompWhenNoReadableRegstMsg(const ActorMsg& msg) {
CHECK_EQ(TryUpdtStateAsProducedRegst(msg.regst_warpper()->regst_raw_ptr()),
0);
TryWardKernelAndSendMsg();
int total_piece_num = JobDesc::Singleton().total_piece_num();
int total_piece_num = JobDesc::Singleton()->total_piece_num();
if ((in_desc_id_ != -1 && in_.empty())
|| expected_piece_id() == total_piece_num) {
if (model_regst_desc_id_ != -1) {
......
......@@ -6,10 +6,10 @@ void MdDiffAccActor::Init(const TaskProto& task_proto,
const ThreadCtx& thread_ctx) {
CompActor::Init(task_proto, thread_ctx);
if (thread_ctx.cpu_stream) {
clear_kernel_ = KernelMgr::Singleton().GetKernelFromOpName("cpu_clear");
clear_kernel_ = KernelMgr::Singleton()->GetKernelFromOpName("cpu_clear");
mut_device_ctx().reset(new CpuDeviceCtx(thread_ctx.cpu_stream));
} else {
clear_kernel_ = KernelMgr::Singleton().GetKernelFromOpName("gpu_clear");
clear_kernel_ = KernelMgr::Singleton()->GetKernelFromOpName("gpu_clear");
mut_device_ctx().reset(new CudaDeviceCtx(cuda_handle_.cuda_stream(),
cuda_handle_.cublas_handle(),
cuda_handle_.cudnn_handle()));
......@@ -57,7 +57,7 @@ void MdDiffAccActor::TryWardKernelAndSendMsg() {
KernelCtx ctx = GenDefaultKernelCtx();
ForEachCurWriteableRegst([&](Regst* regst) {
auto diff_cnt = model_diff_acc_cnt_.find(regst);
if (diff_cnt->second != JobDesc::Singleton().num_of_piece_in_batch()) {
if (diff_cnt->second != JobDesc::Singleton()->num_of_piece_in_batch()) {
return;
}
clear_kernel_->Forward(ctx, [&](const std::string& bn_in_op) {
......
......@@ -20,12 +20,12 @@ int MdSaveCompActor::HandleSaveModel(const ActorMsg& actor_msg) {
std::shared_ptr<RegstWarpper> regst_warpper = actor_msg.regst_warpper();
int64_t model_version_id = regst_warpper->model_version_id();
int32_t num_of_batches_in_snapshot =
JobDesc::Singleton().num_of_batches_in_snapshot();
JobDesc::Singleton()->num_of_batches_in_snapshot();
CHECK_GT(num_of_batches_in_snapshot, 0);
if (model_version_id % num_of_batches_in_snapshot == 0) {
int64_t snapshot_id = model_version_id / num_of_batches_in_snapshot;
Snapshot* snapshot =
SnapshotMgr::Singleton().GetWriteableSnapshot(snapshot_id);
SnapshotMgr::Singleton()->GetWriteableSnapshot(snapshot_id);
KernelCtx kernel_ctx = GenDefaultKernelCtx();
std::tuple<Snapshot*, int64_t> save_ctx =
std::make_tuple(snapshot, parallel_id());
......@@ -39,7 +39,7 @@ int MdSaveCompActor::HandleSaveModel(const ActorMsg& actor_msg) {
}
ActorMsg msg = ActorMsg::BuildRegstMsgToProducer(
regst_warpper->producer_actor_id(), regst_warpper->regst_raw_ptr());
AsyncDo([msg]() { ActorMsgBus::Singleton().SendMsg(msg); });
AsyncDo([msg]() { ActorMsgBus::Singleton()->SendMsg(msg); });
} else {
UNEXPECTED_RUN();
}
......
......@@ -28,7 +28,7 @@ int MdUpdtCompActor::HandleBeforeInitializeModel(const ActorMsg& actor_msg) {
HashSet<const Kernel*> kernels;
auto CollectKernelsFromLbn = [&kernels](const std::string& lbn) {
std::string op_name = GetOpNameFromLbn(lbn);
kernels.insert(KernelMgr::Singleton().GetKernelFromOpName(op_name));
kernels.insert(KernelMgr::Singleton()->GetKernelFromOpName(op_name));
};
model_regst->ForEachLbn(CollectKernelsFromLbn);
model_tmp_regst->ForEachLbn(CollectKernelsFromLbn);
......@@ -36,7 +36,7 @@ int MdUpdtCompActor::HandleBeforeInitializeModel(const ActorMsg& actor_msg) {
for (const Kernel* kernel : kernels) {
kernel->InitModelAndModelTmpBlobs(
GenDefaultKernelCtx(), parallel_policy(), parallel_id(), parallel_num(),
SnapshotMgr::Singleton().GetReadableSnapshot(),
SnapshotMgr::Singleton()->GetReadableSnapshot(),
[&](const std::string& bn_in_op) {
const std::string& lbn = kernel->Lbn4BnInOp(bn_in_op);
Blob* ret = model_regst->GetBlobPtrFromLbn(lbn);
......@@ -45,7 +45,7 @@ int MdUpdtCompActor::HandleBeforeInitializeModel(const ActorMsg& actor_msg) {
return ret;
});
}
AsyncDo([]() { RuntimeCtx::Singleton().OneModelInitDone(); });
AsyncDo([]() { RuntimeCtx::Singleton()->OneModelInitDone(); });
OF_SET_MSG_HANDLE(&MdUpdtCompActor::HandleBeforeSendInitialModel);
return 0;
}
......@@ -55,7 +55,7 @@ int MdUpdtCompActor::HandleBeforeSendInitialModel(const ActorMsg& actor_msg) {
AsyncSendReadableRegstMsg();
SetReadOnlyForRegstDescId(model_tmp_regst_desc_id_);
AsyncSendEORDMsgToSubscribers(model_tmp_regst_desc_id_);
if (JobDesc::Singleton().is_train()) {
if (JobDesc::Singleton()->is_train()) {
OF_SET_MSG_HANDLE(&MdUpdtCompActor::HandleUpdateModel);
} else {
AsyncSendEORDMsgToSubscribers(model_regst_desc_id_);
......
......@@ -40,10 +40,10 @@ namespace oneflow {
#define TODO() LOG(FATAL) << "TODO";
#define OF_SINGLETON(ClassName) \
static ClassName& Singleton() { \
static ClassName obj; \
return obj; \
#define OF_SINGLETON(ClassName) \
static ClassName* Singleton() { \
static ClassName* ptr = new ClassName; \
return ptr; \
}
template<typename T>
......
......@@ -133,7 +133,7 @@ void BoxingTaskNode::FwBuildChainSortedEdgesPair(
box_conf->set_in_num(sorted_in_edges.size());
box_conf->set_out_num(sorted_out_edges.size());
CompleteBoxOp(box_conf);
return OpMgr::Singleton().ConstructOp(op_conf);
return OpMgr::Singleton()->ConstructOp(op_conf);
};
// lbns
std::vector<std::string> lbns = FindLbnsBetween(in_chain, out_chain);
......
......@@ -16,9 +16,10 @@ std::string CompTaskNode::VisualStr() const {
}
std::string CompTaskNode::device_name() const {
return IDMgr::Singleton().MachineName4MachineId(stage_node()->machine_id())
return IDMgr::Singleton()->MachineName4MachineId(stage_node()->machine_id())
+ ":"
+ std::to_string(IDMgr::Singleton().DevPhyId4ThrdLocId(thrd_loc_id()));
+ std::to_string(
IDMgr::Singleton()->DevPhyId4ThrdLocId(thrd_loc_id()));
}
void SortByParallelId(std::vector<CompTaskNode*>* comp_node_vec) {
......
......@@ -46,14 +46,14 @@ std::shared_ptr<const Operator> CopyHDTaskNode::ConstructOp() const {
op_conf.set_name("copy_hd_" + NewUniqueId());
CopyHdOpConf* copy_hd_conf = op_conf.mutable_copy_hd_conf();
copy_hd_conf->set_type(IsH2D() ? CopyHdOpConf::H2D : CopyHdOpConf::D2H);
return OpMgr::Singleton().ConstructOp(op_conf);
return OpMgr::Singleton()->ConstructOp(op_conf);
}
std::shared_ptr<const Operator> CopyCommNetTaskNode::ConstructOp() const {
OperatorConf op_conf;
op_conf.set_name("comm_net_" + NewUniqueId());
op_conf.mutable_copy_comm_net_conf();
return OpMgr::Singleton().ConstructOp(op_conf);
return OpMgr::Singleton()->ConstructOp(op_conf);
}
} // namespace oneflow
......@@ -24,7 +24,7 @@ void LogicalGraph::NaiveBuildGraphStruct(
const OperatorConf& cur_op_conf = dl_net_conf.op(op_i);
// Construct cur node
LogicalNode* cur_node = NewNode();
cur_node->mut_op() = OpMgr::Singleton().ConstructOp(cur_op_conf);
cur_node->mut_op() = OpMgr::Singleton()->ConstructOp(cur_op_conf);
// Connect input node
for (const std::string& ibn : cur_node->op()->input_bns()) {
const std::string& lbn = cur_node->op()->Lbn4BnInOp(ibn);
......@@ -90,7 +90,7 @@ void LogicalGraph::CollectCloneInfos(
pb_op_conf.set_name("clone_" + lbn);
pb_op_conf.mutable_clone_conf()->set_out_num(edges.size());
pb_op_conf.mutable_clone_conf()->set_lbn(lbn);
auto clone_op = OpMgr::Singleton().ConstructOp(pb_op_conf);
auto clone_op = OpMgr::Singleton()->ConstructOp(pb_op_conf);
// Set clone_info
CloneInfo clone_info;
clone_info.clone_op = clone_op;
......
......@@ -19,7 +19,7 @@ void MdDiffAccTaskGraph::BuildTaskGraph(const ChainNode* data_chain) {
OperatorConf op_conf;
op_conf.set_name("model_diff_acc_" + NewUniqueId());
op_conf.mutable_model_diff_acc_conf();
auto model_diff_acc_op = OpMgr::Singleton().ConstructOp(op_conf);
auto model_diff_acc_op = OpMgr::Singleton()->ConstructOp(op_conf);
// ModelDiffAccChain
auto chain_gph = of_make_unique<ChainGraph>();
ChainNode* diff_acc_chain = chain_gph->NewNode();
......
......@@ -22,7 +22,7 @@ void MdSaveCompTaskNode::BuildExecAndEnrollLbn2Regsts(TaskGraph* gph) {
});
ExecNode* exec_node = mut_exec_gph().NewNode();
exec_node->mut_op() = OpMgr::Singleton().ConstructOp(op_conf);
exec_node->mut_op() = OpMgr::Singleton()->ConstructOp(op_conf);
for (const std::string& ibn : exec_node->op()->input_bns()) {
exec_node->BindBnInOpAndRegst(ibn, GetRelatedRegst(SoleInEdge()));
}
......
......@@ -17,7 +17,7 @@ void MdUpdtTaskGraph::BuildTaskGraph() {
OperatorConf op_conf;
op_conf.set_name("model_update_" + NewUniqueId());
op_conf.mutable_model_update_conf();
auto model_updt_op = OpMgr::Singleton().ConstructOp(op_conf);
auto model_updt_op = OpMgr::Singleton()->ConstructOp(op_conf);
ChainNode* updt_chain = chain_gph->NewNode();
ParallelConf updt_pr_conf;
......
......@@ -94,7 +94,7 @@ void TaskGraph::Stage2DeviceCompTaskNodes(
int64_t parallel_idx = stage->parallel_range().begin();
for (auto device_phy_id : stage->SortedDevicePhyIds()) {
int64_t thread_local_id =
IDMgr::Singleton().ThrdLocId4DevPhyId(device_phy_id);
IDMgr::Singleton()->ThrdLocId4DevPhyId(device_phy_id);
// comp_task_node
CompTaskNodeType* comp_task_node = NewTaskNode<CompTaskNodeType>();
comp_task_node->SetFwNode();
......@@ -151,12 +151,12 @@ void TaskGraph::Stage2HostCompTaskNodes(const StageNode* stage,
// Set comp_task_node::thread_local_id
if (stage->SortedDevicePhyIds().empty()) {
comp_task_node->mut_thrd_loc_id() =
IDMgr::Singleton().PersistenceThrdLocId();
IDMgr::Singleton()->PersistenceThrdLocId();
} else {
auto device_id =
stage->SortedDevicePhyIds().at(parallel_idx - parallel_begin);
comp_task_node->mut_thrd_loc_id() =
IDMgr::Singleton().ThrdLocId4DevPhyId(device_id);
IDMgr::Singleton()->ThrdLocId4DevPhyId(device_id);
}
//
task_nodes_in_stage->comp_in_task_nodes.push_back(comp_task_node);
......@@ -186,7 +186,7 @@ void TaskGraph::InitInboxingTaskNode(const StageNode* stage,
InBoxingTaskNode* boxing_task = NewTaskNode<InBoxingTaskNode>();
boxing_task->SetFwNode();
boxing_task->set_stage_node(stage);
boxing_task->mut_thrd_loc_id() = IDMgr::Singleton().BoxingThrdLocId();
boxing_task->mut_thrd_loc_id() = IDMgr::Singleton()->BoxingThrdLocId();
boxing_task->set_task_id();
for (TaskNode* comp_in_task : task_nodes_in_stage->comp_in_task_nodes) {
TaskConnect(boxing_task, NewEdge(), comp_in_task);
......@@ -205,7 +205,7 @@ void TaskGraph::InitOutBoxingTaskNode(const StageNode* stage,
OutBoxingTaskNode* boxing_task = NewTaskNode<OutBoxingTaskNode>();
boxing_task->SetFwNode();
boxing_task->set_stage_node(stage);
boxing_task->mut_thrd_loc_id() = IDMgr::Singleton().BoxingThrdLocId();
boxing_task->mut_thrd_loc_id() = IDMgr::Singleton()->BoxingThrdLocId();
boxing_task->set_task_id();
for (TaskNode* comp_out_task : task_nodes_in_stage->comp_out_task_nodes) {
TaskConnect(comp_out_task, NewEdge(), boxing_task);
......@@ -238,7 +238,7 @@ void TaskGraph::ConnectBoxingTaskNodes(
CopyCommNetTaskNode* comm_net_node = NewTaskNode<CopyCommNetTaskNode>();
comm_net_node->SetFwNode();
comm_net_node->set_stage_node(succ_stage);
comm_net_node->mut_thrd_loc_id() = IDMgr::Singleton().CommNetThrdLocId();
comm_net_node->mut_thrd_loc_id() = IDMgr::Singleton()->CommNetThrdLocId();
comm_net_node->set_task_id();
TaskConnect(out_node, NewEdge(), comm_net_node);
......
......@@ -30,7 +30,7 @@ int64_t& TaskNode::mut_thrd_loc_id() {
void TaskNode::set_task_id() {
int64_t machine_id = stage_node_->machine_id();
task_id_ = IDMgr::Singleton().NewTaskId(machine_id, thrd_loc_id_);
task_id_ = IDMgr::Singleton()->NewTaskId(machine_id, thrd_loc_id_);
}
std::unique_ptr<TaskNode> TaskNode::BuildAndConnectBpNode() {
......@@ -72,7 +72,7 @@ void TaskNode::TakeOverRegstDesc(TaskNode* rhs,
CHECK_EQ(produced_regst2out_edge_.count(rhs_regst_it->second), 0);
this_regst.swap(rhs_regst_it->second);
this_regst->SetProducer(this);
this_regst->set_regst_desc_id(IDMgr::Singleton().NewRegstDescId());
this_regst->set_regst_desc_id(IDMgr::Singleton()->NewRegstDescId());
rhs->produced_regst_descs_.erase(rhs_regst_it);
CHECK(produced_regst_descs_.emplace(regst_desc_name, this_regst).second);
}
......@@ -119,7 +119,7 @@ std::shared_ptr<RegstDesc> TaskNode::NewProducedRegstDesc(
const std::string& regst_desc_name) {
auto regst_desc = std::make_shared<RegstDesc>();
regst_desc->SetProducer(this);
regst_desc->set_regst_desc_id(IDMgr::Singleton().NewRegstDescId());
regst_desc->set_regst_desc_id(IDMgr::Singleton()->NewRegstDescId());
CHECK(produced_regst_descs_.emplace(regst_desc_name, regst_desc).second);
return regst_desc;
}
......
......@@ -17,10 +17,7 @@ class Compiler final {
OF_DISALLOW_COPY_AND_MOVE(Compiler);
~Compiler() = default;
static Compiler& Singleton() {
static Compiler obj;
return obj;
}
OF_SINGLETON(Compiler);
void Compile(const JobConf& job_conf, const std::string& plan_filepath);
......@@ -65,8 +62,8 @@ void Compiler::ForEachTaskNode(std::function<void(TaskNode*)> func) {
// TODO: inference "register_num for each register_desc"
void Compiler::Compile(const JobConf& job_conf,
const std::string& plan_filepath) {
JobDesc::Singleton().InitFromJobConf(job_conf);
IDMgr::Singleton().InitFromResource(JobDesc::Singleton().resource());
JobDesc::Singleton()->InitFromJobConf(job_conf);
IDMgr::Singleton()->InitFromResource(JobDesc::Singleton()->resource());
BuildGraphs();
InferShape4Regsts();
EraseMeaningLessRegsts();
......@@ -78,8 +75,8 @@ void Compiler::BuildGraphs() {
// data graph
LOG(INFO) << "Build DataTaskGraph...";
auto data_task_gph = new DataTaskGraph(
"data", JobDesc::Singleton().train_dlnet_conf(),
JobDesc::Singleton().strategy(), JobDesc::Singleton().is_train());
"data", JobDesc::Singleton()->train_dlnet_conf(),
JobDesc::Singleton()->strategy(), JobDesc::Singleton()->is_train());
ordered_task_gphs_.emplace_back(data_task_gph);
// construct data_chain2sorted_fw_comp_tasks
HashMap<const ChainNode*, std::vector<CompTaskNode*>>
......@@ -110,7 +107,7 @@ void Compiler::BuildModelGraphs(
str_replace(&chain_tag, '/', '_');
ParallelPolicy policy = pair.first->parallel_desc()->policy();
bool is_train = JobDesc::Singleton().is_train();
bool is_train = JobDesc::Singleton()->is_train();
std::vector<CompTaskNode*> sorted_diff_acc_tasks;
if (is_train) {
LOG(INFO) << "Build MdDiffAccTaskGraph... for " << chain_tag;
......@@ -172,13 +169,13 @@ void Compiler::GenPlanFile(const std::string& plan_filepath) {
OperatorConf gpu_clear_op_conf;
gpu_clear_op_conf.set_name("gpu_clear");
gpu_clear_op_conf.mutable_clear_conf();
auto gpu_clear_op = OpMgr::Singleton().ConstructOp(gpu_clear_op_conf);
auto gpu_clear_op = OpMgr::Singleton()->ConstructOp(gpu_clear_op_conf);
OperatorConf cpu_clear_op_conf;
cpu_clear_op_conf.set_name("cpu_clear");
cpu_clear_op_conf.mutable_clear_conf();
auto cpu_clear_op = OpMgr::Singleton().ConstructOp(cpu_clear_op_conf);
OpMgr::Singleton().AllOpToProto(plan.mutable_op());
JobDesc::Singleton().ToProto(plan.mutable_job_desc());
auto cpu_clear_op = OpMgr::Singleton()->ConstructOp(cpu_clear_op_conf);
OpMgr::Singleton()->AllOpToProto(plan.mutable_op());
JobDesc::Singleton()->ToProto(plan.mutable_job_desc());
ConstForEachChainNode([&plan](const ChainNode* node) {
for (std::shared_ptr<const Operator> op : node->op_vec()) {
CHECK(plan.mutable_op_name2device_type()
......@@ -216,7 +213,7 @@ int main(int argc, char** argv) {
LOG(INFO) << "Compiler Starting Up...";
oneflow::JobConf job_conf;
oneflow::ParseProtoFromTextFile(FLAGS_job_conf_filepath, &job_conf);
oneflow::Compiler::Singleton().Compile(job_conf, FLAGS_plan_filepath);
oneflow::Compiler::Singleton()->Compile(job_conf, FLAGS_plan_filepath);
LOG(INFO) << "Compiler Shutting Down...";
return 0;
}
......@@ -20,64 +20,67 @@ Resource GetResource() {
} // namespace
TEST(IDMgr, compile_machine_id_and_name) {
IDMgr::Singleton().InitFromResource(GetResource());
ASSERT_EQ(IDMgr::Singleton().MachineID4MachineName("machine_0"), 0);
ASSERT_EQ(IDMgr::Singleton().MachineID4MachineName("machine_1"), 1);
ASSERT_EQ(IDMgr::Singleton().MachineID4MachineName("machine_5"), 5);
ASSERT_EQ(IDMgr::Singleton().MachineName4MachineId(2), "machine_2");
ASSERT_EQ(IDMgr::Singleton().MachineName4MachineId(3), "machine_3");
ASSERT_EQ(IDMgr::Singleton().MachineName4MachineId(7), "machine_7");
IDMgr::Singleton()->InitFromResource(GetResource());
ASSERT_EQ(IDMgr::Singleton()->MachineID4MachineName("machine_0"), 0);
ASSERT_EQ(IDMgr::Singleton()->MachineID4MachineName("machine_1"), 1);
ASSERT_EQ(IDMgr::Singleton()->MachineID4MachineName("machine_5"), 5);
ASSERT_EQ(IDMgr::Singleton()->MachineName4MachineId(2), "machine_2");
ASSERT_EQ(IDMgr::Singleton()->MachineName4MachineId(3), "machine_3");
ASSERT_EQ(IDMgr::Singleton()->MachineName4MachineId(7), "machine_7");
}
TEST(IDMgr, compile_special_thrd_loc_id) {
IDMgr::Singleton().InitFromResource(GetResource());
ASSERT_EQ(IDMgr::Singleton().PersistenceThrdLocId(), 8);
ASSERT_EQ(IDMgr::Singleton().BoxingThrdLocId(), 9);
ASSERT_EQ(IDMgr::Singleton().CommNetThrdLocId(), 10);
IDMgr::Singleton()->InitFromResource(GetResource());
ASSERT_EQ(IDMgr::Singleton()->PersistenceThrdLocId(), 8);
ASSERT_EQ(IDMgr::Singleton()->BoxingThrdLocId(), 9);
ASSERT_EQ(IDMgr::Singleton()->CommNetThrdLocId(), 10);
}
TEST(IDMgr, compile_task_id) {
IDMgr::Singleton().InitFromResource(GetResource());
IDMgr::Singleton()->InitFromResource(GetResource());
int64_t machine1device2 =
(static_cast<int64_t>(1) << (8 + 39)) + (static_cast<int64_t>(2) << 39);
ASSERT_EQ(IDMgr::Singleton().NewTaskId(1, 2), machine1device2);
ASSERT_EQ(IDMgr::Singleton().NewTaskId(1, 2), machine1device2 + 1);
ASSERT_EQ(IDMgr::Singleton().NewTaskId(1, 2), machine1device2 + 2);
ASSERT_EQ(IDMgr::Singleton()->NewTaskId(1, 2), machine1device2);
ASSERT_EQ(IDMgr::Singleton()->NewTaskId(1, 2), machine1device2 + 1);
ASSERT_EQ(IDMgr::Singleton()->NewTaskId(1, 2), machine1device2 + 2);
int64_t machine3device4 =
(static_cast<int64_t>(3) << (8 + 39)) + (static_cast<int64_t>(4) << 39);
ASSERT_EQ(IDMgr::Singleton().NewTaskId(3, 4), machine3device4);
ASSERT_EQ(IDMgr::Singleton().NewTaskId(3, 4), machine3device4 + 1);
ASSERT_EQ(IDMgr::Singleton().NewTaskId(3, 4), machine3device4 + 2);
ASSERT_EQ(IDMgr::Singleton()->NewTaskId(3, 4), machine3device4);
ASSERT_EQ(IDMgr::Singleton()->NewTaskId(3, 4), machine3device4 + 1);
ASSERT_EQ(IDMgr::Singleton()->NewTaskId(3, 4), machine3device4 + 2);
}
TEST(IDMgr, compile_regst_desc_id) {
IDMgr::Singleton().InitFromResource(GetResource());
ASSERT_EQ(IDMgr::Singleton().NewRegstDescId(), 0);
ASSERT_EQ(IDMgr::Singleton().NewRegstDescId(), 1);
ASSERT_EQ(IDMgr::Singleton().NewRegstDescId(), 2);
IDMgr::Singleton()->InitFromResource(GetResource());
ASSERT_EQ(IDMgr::Singleton()->NewRegstDescId(), 0);
ASSERT_EQ(IDMgr::Singleton()->NewRegstDescId(), 1);
ASSERT_EQ(IDMgr::Singleton()->NewRegstDescId(), 2);
}
TEST(IDMgr, runtime_machine_id) {
IDMgr::Singleton().InitFromResource(GetResource());
IDMgr::Singleton()->InitFromResource(GetResource());
int64_t actor_id5_machine1device3 =
(static_cast<int64_t>(1) << (8 + 39)) // machine_id_1
+ (static_cast<int64_t>(3) << 39) // device_id_3
+ 5; // actor_id_5
ASSERT_EQ(IDMgr::Singleton().MachineId4ActorId(actor_id5_machine1device3), 1);
ASSERT_EQ(IDMgr::Singleton()->MachineId4ActorId(actor_id5_machine1device3),
1);
}
TEST(IDMgr, runtime_thrd_loc_id) {
IDMgr::Singleton().InitFromResource(GetResource());
IDMgr::Singleton()->InitFromResource(GetResource());
int64_t actor_id5_machine1device3 =
(static_cast<int64_t>(1) << (8 + 39)) // machine_id_1
+ (static_cast<int64_t>(3) << 39) // device_id_3
+ 5; // actor_id_5
ASSERT_EQ(IDMgr::Singleton().ThrdLocId4ActorId(actor_id5_machine1device3), 3);
ASSERT_EQ(IDMgr::Singleton()->ThrdLocId4ActorId(actor_id5_machine1device3),
3);
int64_t actor_id6_machine2device4 =
(static_cast<int64_t>(2) << (8 + 39)) // machine_id_2
+ (static_cast<int64_t>(4) << 39) // device_id_4
+ 6; // actor_id_6
ASSERT_EQ(IDMgr::Singleton().ThrdLocId4ActorId(actor_id6_machine2device4), 4);
ASSERT_EQ(IDMgr::Singleton()->ThrdLocId4ActorId(actor_id6_machine2device4),
4);
}
} // namespace oneflow
......@@ -12,14 +12,15 @@ std::pair<std::string, std::string> ParseDeviceNameConf(
ParallelDesc::ParallelDesc(const ParallelConf& user_conf) {
policy_ = user_conf.policy();
device_type_ = JobDesc::Singleton().resource().device_type();
device_type_ = JobDesc::Singleton()->resource().device_type();
for (int64_t i = 0; i < user_conf.device_set().device_name_size(); ++i) {
const std::string& device_name = user_conf.device_set().device_name(i);
std::pair<std::string, std::string> machine_name_device_id =
ParseDeviceNameConf(device_name);
std::string machine_name = machine_name_device_id.first;
std::string device_id_str = machine_name_device_id.second;
int64_t machine_id = IDMgr::Singleton().MachineID4MachineName(machine_name);
int64_t machine_id =
IDMgr::Singleton()->MachineID4MachineName(machine_name);
sorted_machine_ids_.push_back(machine_id);
// if the device_name format is "machine_xxx:0-3", add device_id {0,1,2,3}
int64_t to_symbol_pos = device_id_str.rfind("-");
......
......@@ -22,7 +22,7 @@ class Runtime final {
std::vector<const TaskProto*> source_tasks;
std::vector<const TaskProto*> other_tasks;
for (const TaskProto& task : plan.task()) {
if (task.machine_id() != RuntimeCtx::Singleton().this_machine_id()) {
if (task.machine_id() != RuntimeCtx::Singleton()->this_machine_id()) {
continue;
}
if (task.type() == kMdUpdtCompTask) {
......@@ -35,40 +35,40 @@ class Runtime final {
}
LOG(INFO) << "InitModel";
HandoutTasks(mdupdt_tasks);
RuntimeCtx::Singleton().SetModelInitCnt(mdupdt_tasks.size());
RuntimeCtx::Singleton()->SetModelInitCnt(mdupdt_tasks.size());
SendCmdMsg(mdupdt_tasks, ActorCmd::kInitializeModel);
HandoutTasks(source_tasks);
HandoutTasks(other_tasks);
RuntimeCtx::Singleton().WaitUnitlAllModelInitDone();
RuntimeCtx::Singleton()->WaitUnitlAllModelInitDone();
LOG(INFO) << "InitModel on this machine done";
// TODO: Barrier
LOG(INFO) << "InitModel on all machine done";
SendCmdMsg(mdupdt_tasks, ActorCmd::kSendInitialModel);
SendCmdMsg(source_tasks, ActorCmd::kStart);
ThreadMgr::Singleton().ForEachThread(
ThreadMgr::Singleton()->ForEachThread(
[](Thread* thrd) { thrd->JoinAllActor(); });
ThreadMgr::Singleton().ClearAllThread();
ThreadMgr::Singleton()->ClearAllThread();
}
private:
Runtime() = default;
void InitSingleton(const Plan& plan, const std::string& this_machine_name) {
JobDesc::Singleton().InitFromProto(plan.job_desc());
IDMgr::Singleton().InitFromResource(JobDesc::Singleton().resource());
RuntimeCtx::Singleton().set_this_machine_name(this_machine_name);
KernelMgr::Singleton().InitFromPlan(plan);
JobDesc::Singleton()->InitFromProto(plan.job_desc());
IDMgr::Singleton()->InitFromResource(JobDesc::Singleton()->resource());
RuntimeCtx::Singleton()->set_this_machine_name(this_machine_name);
KernelMgr::Singleton()->InitFromPlan(plan);
}
void HandoutTasks(const std::vector<const TaskProto*>& tasks) {
for (const TaskProto* task : tasks) {
ThreadMgr::Singleton().GetThrd(task->thrd_local_id())->AddTask(*task);
ThreadMgr::Singleton()->GetThrd(task->thrd_local_id())->AddTask(*task);
}
}
void SendCmdMsg(const std::vector<const TaskProto*>& tasks, ActorCmd cmd) {
for (const TaskProto* task : tasks) {
ActorMsg msg;
msg.set_dst_actor_id(IDMgr::Singleton().ActorId4TaskId(task->id()));
msg.set_dst_actor_id(IDMgr::Singleton()->ActorId4TaskId(task->id()));
msg.set_actor_cmd(cmd);
ActorMsgBus::Singleton().SendMsg(msg);
ActorMsgBus::Singleton()->SendMsg(msg);
}
}
};
......@@ -84,7 +84,7 @@ int main(int argc, char** argv) {
LOG(INFO) << "Runtime Starting Up...";
oneflow::Plan plan;
oneflow::ParseProtoFromTextFile(FLAGS_plan_filepath, &plan);
oneflow::Runtime::Singleton().Run(plan, FLAGS_this_machine_name);
oneflow::Runtime::Singleton()->Run(plan, FLAGS_this_machine_name);
LOG(INFO) << "Runtime Shutting Down...";
return 0;
}
......@@ -4,7 +4,7 @@ namespace oneflow {
void RuntimeCtx::set_this_machine_name(const std::string& name) {
this_machine_name_ = name;
this_machine_id_ = IDMgr::Singleton().MachineID4MachineName(name);
this_machine_id_ = IDMgr::Singleton()->MachineID4MachineName(name);
LOG(INFO) << "this machine name: " << this_machine_name_;
LOG(INFO) << "this machine id: " << this_machine_id_;
}
......
......@@ -72,7 +72,7 @@ Kernel* ConstructCloneKernel(const int out_num, const std::string& lbn) {
CloneOpConf* clone_conf = op_conf.mutable_clone_conf();
clone_conf->set_out_num(out_num);
clone_conf->set_lbn(lbn);
auto clone_op = OpMgr::Singleton().ConstructOp(op_conf);
auto clone_op = OpMgr::Singleton()->ConstructOp(op_conf);
OperatorProto op_proto;
clone_op->ToProto(&op_proto);
......
......@@ -31,7 +31,7 @@ void BuildCopyHdKernel(CopyHdKernel<DeviceType::kGPU, float>* copy_hd_kernel,
op_conf.set_name("copy_hd_test");
CopyHdOpConf* copy_hd_conf = op_conf.mutable_copy_hd_conf();
copy_hd_conf->set_type(hd_type);
auto copy_hd_op = OpMgr::Singleton().ConstructOp(op_conf);
auto copy_hd_op = OpMgr::Singleton()->ConstructOp(op_conf);
OperatorProto op_proto;
copy_hd_op->ToProto(&op_proto);
......
......@@ -24,13 +24,13 @@ void DataLoaderKernel<DeviceType::kCPU, FloatingPointType>::Forward(
const KernelCtx& kernel_ctx,
std::function<Blob*(const std::string&)> BnInOp2BlobPtr) const {
PersistentCircularLineReader* reader =
RuntimeCtx::Singleton().GetDataReader();
RuntimeCtx::Singleton()->GetDataReader();
if (reader == nullptr) {
std::string data_dir = op()->GetStringFromSpecialConf("data_dir");
int64_t parallel_id = reinterpret_cast<int64_t>(kernel_ctx.other);
std::string file_path = data_dir + "part-" + std::to_string(parallel_id);
RuntimeCtx::Singleton().InitDataReader(file_path);
reader = RuntimeCtx::Singleton().GetDataReader();
RuntimeCtx::Singleton()->InitDataReader(file_path);
reader = RuntimeCtx::Singleton()->GetDataReader();
}
TODO();
}
......
......@@ -132,7 +132,7 @@ Kernel* BuildInnerProductKernel(bool has_bias_term) {
inner_product_conf->set_out("ip_out");
inner_product_conf->set_out_num(40);
inner_product_conf->set_has_bias_term(has_bias_term);
auto inner_product_op = OpMgr::Singleton().ConstructOp(op_conf);
auto inner_product_op = OpMgr::Singleton()->ConstructOp(op_conf);
OperatorProto op_proto;
inner_product_op->ToProto(&op_proto);
......
......@@ -72,7 +72,7 @@ void AddGpuDoubleKernelCreator(OperatorConf::OpTypeCase op_type_case,
}
void KernelMgr::InitFromPlan(const Plan& plan) {
int64_t this_machine_id = RuntimeCtx::Singleton().this_machine_id();
int64_t this_machine_id = RuntimeCtx::Singleton()->this_machine_id();
const PbRpf<std::string>& op_names_rpf =
plan.machine_id2op_name_set().at(this_machine_id).op_name();
std::unordered_set<std::string> op_name_set(op_names_rpf.begin(),
......@@ -84,7 +84,7 @@ void KernelMgr::InitFromPlan(const Plan& plan) {
LOG(INFO) << "construct kernel: " << op_name;
std::unique_ptr<Kernel> kernel_ptr(
CreateKernel(op_proto.op_conf().op_type_case(), device_type,
JobDesc::Singleton().floating_point_type()));
JobDesc::Singleton()->floating_point_type()));
kernel_ptr->InitFromOpProto(op_proto);
CHECK(op_name2kernel_ptr_.emplace(op_name, std::move(kernel_ptr)).second);
}
......
......@@ -11,8 +11,7 @@ enum class Location { kHost, kDevice };
template<typename FloatingPointType>
Blob* CreateBlob(const std::vector<int64_t>& dim_vec,
FloatingPointType* data_vec,
Location mem_location) {
FloatingPointType* data_vec, Location mem_location) {
void* dptr;
Shape* shape = new Shape(dim_vec);
......@@ -47,9 +46,12 @@ std::function<Blob*(const std::string&)> BuildBnInOp2BlobPtr() {
auto bn2blob_ptr = new HashMap<std::string, Blob*>;
(*bn2blob_ptr)["model_diff"] = CreateBlob<FloatingPointType>(dim_vec, diff_data, loc);
(*bn2blob_ptr)["model_diff_acc"] = CreateBlob<FloatingPointType>(dim_vec, diff_acc_data, loc);
(*bn2blob_ptr)["expected_acc"] = CreateBlob<FloatingPointType>(dim_vec, expected_data, loc);
(*bn2blob_ptr)["model_diff"] =
CreateBlob<FloatingPointType>(dim_vec, diff_data, loc);
(*bn2blob_ptr)["model_diff_acc"] =
CreateBlob<FloatingPointType>(dim_vec, diff_acc_data, loc);
(*bn2blob_ptr)["expected_acc"] =
CreateBlob<FloatingPointType>(dim_vec, expected_data, loc);
return [bn2blob_ptr](const std::string& bn) { return bn2blob_ptr->at(bn); };
}
......@@ -78,12 +80,13 @@ Kernel* BuildMdDiffAccKernel() {
OperatorConf op_conf;
op_conf.set_name("model_diff_acc");
op_conf.mutable_model_diff_acc_conf();
auto model_diff_acc_op = OpMgr::Singleton().ConstructOp(op_conf);
auto model_diff_acc_op = OpMgr::Singleton()->ConstructOp(op_conf);
OperatorProto op_proto;
model_diff_acc_op->ToProto(&op_proto);
auto model_diff_acc_kernel = new MdDiffAccKernel<device_type, FloatingPointType>();
auto model_diff_acc_kernel =
new MdDiffAccKernel<device_type, FloatingPointType>();
model_diff_acc_kernel->InitFromOpProto(op_proto);
return model_diff_acc_kernel;
......@@ -154,15 +157,18 @@ void TestMdDiffAccKernel() {
auto BnInOp2BlobPtr = BuildBnInOp2BlobPtr<device_type, FloatingPointType>();
auto model_diff_acc_kernel = BuildMdDiffAccKernel<device_type, FloatingPointType>();
auto model_diff_acc_kernel =
BuildMdDiffAccKernel<device_type, FloatingPointType>();
model_diff_acc_kernel->Forward(ctx, BnInOp2BlobPtr);
SyncStream<device_type>(&ctx);
if (device_type == DeviceType::kCPU) {
CheckResult<FloatingPointType>(BnInOp2BlobPtr, BlobCmpCpu<FloatingPointType>);
CheckResult<FloatingPointType>(BnInOp2BlobPtr,
BlobCmpCpu<FloatingPointType>);
} else {
CheckResult<FloatingPointType>(BnInOp2BlobPtr, BlobCmpGpu<FloatingPointType>);
CheckResult<FloatingPointType>(BnInOp2BlobPtr,
BlobCmpGpu<FloatingPointType>);
}
}
} // namespace
......
......@@ -20,7 +20,7 @@ TEST(BoxingOp, box_4_10x5x6x6) {
// test concat_box shape function
boxing_conf->mutable_concat_box()->set_axis(1);
boxing_conf->mutable_data_split_box();
auto boxing_op = OpMgr::Singleton().ConstructOp(op_conf);
auto boxing_op = OpMgr::Singleton()->ConstructOp(op_conf);
HashMap<std::string, Shape*> bn2shape_ptr{
{boxing_op->input_bns()[0], new Shape(input_shape_vec2)},
{boxing_op->input_bns()[1], new Shape(input_shape_vec2)},
......@@ -59,7 +59,7 @@ TEST(BoxingOp, box_4_10x5x6x6) {
boxing_conf->set_out_num(1);
boxing_conf->mutable_add_box();
boxing_conf->mutable_clone_box();
boxing_op = OpMgr::Singleton().ConstructOp(op_conf);
boxing_op = OpMgr::Singleton()->ConstructOp(op_conf);
// do infer shape
boxing_op->InferShape4FwBlobs(fp, kModelParallel, 0, 1);
......
......@@ -7,7 +7,7 @@ TEST(CloneOp, clone_4x3_3_times) {
op_conf.set_name("clone_test");
op_conf.mutable_clone_conf()->set_out_num(3);
op_conf.mutable_clone_conf()->set_lbn("clone_test_lbn");
auto clone_op = OpMgr::Singleton().ConstructOp(op_conf);
auto clone_op = OpMgr::Singleton()->ConstructOp(op_conf);
HashMap<std::string, Shape*> bn2shape_ptr{
{clone_op->SoleIbn(), new Shape({4, 3})}};
......
......@@ -10,7 +10,7 @@ TEST(ConcatOp, concat_two_3x3) {
op_conf.mutable_concat_conf()->add_in("concat/in1");
op_conf.mutable_concat_conf()->set_axis(1);
op_conf.mutable_concat_conf()->set_out("concat_test_lbn");
auto concat_op = OpMgr::Singleton().ConstructOp(op_conf);
auto concat_op = OpMgr::Singleton()->ConstructOp(op_conf);
std::vector<int64_t> shape_vec = {3, 3};
HashMap<std::string, Shape*> bn2shape_ptr{
......
......@@ -15,7 +15,7 @@ TEST(ConvolutionOp, TestForInferShape4FwBlobs) {
op_conf.mutable_convolution_conf()->add_kernel_size(20);
op_conf.mutable_convolution_conf()->add_stride(3);
op_conf.mutable_convolution_conf()->add_stride(3);
auto convolution_op = OpMgr::Singleton().ConstructOp(op_conf);
auto convolution_op = OpMgr::Singleton()->ConstructOp(op_conf);
std::vector<int64_t> input_vec = {100, 64, 256, 256};
HashMap<std::string, Shape*> bn2shape_ptr{
{convolution_op->SoleIbn(), new Shape(input_vec)},
......
......@@ -19,7 +19,7 @@ void DataLoaderOp::InferShape4FwBlobs(
std::function<Shape*(const std::string&)> GetShapePtr4BnInOp,
ParallelPolicy policy, int64_t parallel_id, int64_t parallel_num) const {
// useful vars
int32_t piece_size = JobDesc::Singleton().piece_size();
int32_t piece_size = JobDesc::Singleton()->piece_size();
auto op_conf = static_cast<const DataLoaderOpConf*>(&GetSpecialConf());
// feature shape
Shape feature_shape_of_one_ins(op_conf->shape_of_one_feature_ins());
......
......@@ -14,7 +14,7 @@ void TestModelParallelInnerProductOp(bool has_bias_term) {
op_conf.mutable_innerproduct_conf()->set_out("ip_out");
op_conf.mutable_innerproduct_conf()->set_has_bias_term(has_bias_term);
op_conf.mutable_innerproduct_conf()->set_out_num(40);
auto ip_op = OpMgr::Singleton().ConstructOp(op_conf);
auto ip_op = OpMgr::Singleton()->ConstructOp(op_conf);
std::vector<int64_t> shape_vec = {1000, 3, 256, 256};
HashMap<std::string, Shape*> bn2shape_ptr = {
{ip_op->SoleIbn(), new Shape(shape_vec)},
......@@ -54,7 +54,7 @@ void TestDataParallelInnerProductOp(bool has_bias_term) {
op_conf.mutable_innerproduct_conf()->set_out("ip_out");
op_conf.mutable_innerproduct_conf()->set_has_bias_term(has_bias_term);
op_conf.mutable_innerproduct_conf()->set_out_num(40);
auto ip_op = OpMgr::Singleton().ConstructOp(op_conf);
auto ip_op = OpMgr::Singleton()->ConstructOp(op_conf);
std::vector<int64_t> shape_vec = {1000, 3, 256, 256};
HashMap<std::string, Shape*> bn2shape_ptr = {
......
......@@ -9,7 +9,7 @@ TEST(MultinomialLogisticLossOp, test_loss_op) {
"prediction");
op_conf.mutable_multinomial_logistic_loss_conf()->set_label("label");
op_conf.mutable_multinomial_logistic_loss_conf()->set_loss("loss");
auto loss_op = OpMgr::Singleton().ConstructOp(op_conf);
auto loss_op = OpMgr::Singleton()->ConstructOp(op_conf);
HashMap<std::string, Shape*> bn2shape_ptr{
{loss_op->input_bns().at(0), new Shape({500, 3 * 256 * 256 * 256, 1, 1})},
......
......@@ -20,7 +20,7 @@ TEST(PoolingOp, pool_100x64x11x11) {
pooling_conf->add_kernel_size(2);
pooling_conf->add_stride(2);
pooling_conf->add_stride(2);
auto pooling_op = OpMgr::Singleton().ConstructOp(op_conf);
auto pooling_op = OpMgr::Singleton()->ConstructOp(op_conf);
std::vector<int64_t> input_shape_vec = {100, 64, 11, 11};
HashMap<std::string, Shape*> bn2shape_ptr{
{pooling_op->SoleIbn(), new Shape(input_shape_vec)},
......
......@@ -8,7 +8,7 @@ TEST(ReluOp, relu_3x5x4) {
op_conf.set_name("relu_test");
op_conf.mutable_relu_conf()->set_in("relu_in");
op_conf.mutable_relu_conf()->set_out("relu_out");
auto relu_op = OpMgr::Singleton().ConstructOp(op_conf);
auto relu_op = OpMgr::Singleton()->ConstructOp(op_conf);
std::vector<int64_t> input_shape_vec = {3, 5, 4};
HashMap<std::string, Shape*> bn2shape_ptr{
{relu_op->SoleIbn(), new Shape(input_shape_vec)},
......
......@@ -9,7 +9,7 @@ TEST(SoftmaxOp, softmax_3x4x5) {
op_conf.mutable_softmax_conf()->set_axis(1);
op_conf.mutable_softmax_conf()->set_in("softmax/in");
op_conf.mutable_softmax_conf()->set_out("softmax/out");
auto softmax_op = OpMgr::Singleton().ConstructOp(op_conf);
auto softmax_op = OpMgr::Singleton()->ConstructOp(op_conf);
HashMap<std::string, Shape*> bn2shape_ptr{
{softmax_op->SoleIbn(), new Shape({3, 4, 5})},
{softmax_op->SoleObn(), new Shape}};
......
......@@ -5,7 +5,7 @@
namespace oneflow {
void SnapshotMgr::Init() {
model_save_snapshots_path_ = JobDesc::Singleton().md_save_snapshots_path();
model_save_snapshots_path_ = JobDesc::Singleton()->md_save_snapshots_path();
tensorflow::Env* env = tensorflow::Env::Default();
if (env->IsDirectory(model_save_snapshots_path_).code()
!= tensorflow::error::OK) {
......@@ -14,7 +14,7 @@ void SnapshotMgr::Init() {
std::vector<std::string> result;
TF_CHECK_OK(env->GetChildren(model_save_snapshots_path_, &result));
CHECK_EQ(result.size(), 0);
const std::string& load_path = JobDesc::Singleton().md_load_snapshot_path();
const std::string& load_path = JobDesc::Singleton()->md_load_snapshot_path();
if (load_path != "") {
readable_snapshot_ptr_.reset(new Snapshot(load_path));
}
......
......@@ -10,7 +10,7 @@ namespace {
void SetDeviceCudaMemoryAccordingToThrdLocId(MemoryCase& mem_case,
int64_t thrd_loc_id) {
int64_t device_id = IDMgr::Singleton().DevPhyId4ThrdLocId(thrd_loc_id);
int64_t device_id = IDMgr::Singleton()->DevPhyId4ThrdLocId(thrd_loc_id);
mem_case.mutable_device_cuda_mem()->set_device_id(device_id);
}
......
......@@ -14,7 +14,7 @@ void RegstMgr::NewRegsts(const RegstDescProto& regst_desc_proto,
regst->regst_desc_ = runtime_regst_desc;
size_t elem_size = sizeof(float);
if (JobDesc::Singleton().floating_point_type() == kDouble) {
if (JobDesc::Singleton()->floating_point_type() == kDouble) {
elem_size = sizeof(double);
}
int64_t elem_cnt = 0;
......@@ -28,8 +28,8 @@ void RegstMgr::NewRegsts(const RegstDescProto& regst_desc_proto,
}
std::sort(lbns.begin(), lbns.end());
std::pair<char*, std::function<void()>> allocation =
MemoryAllocator::Singleton().Allocate(regst_desc_proto.mem_case(),
elem_cnt * elem_size);
MemoryAllocator::Singleton()->Allocate(regst_desc_proto.mem_case(),
elem_cnt * elem_size);
int64_t blob_idx = 0;
for (const std::string& lbn : lbns) {
......
......@@ -6,13 +6,14 @@ namespace oneflow {
RtRegstDesc::RtRegstDesc(const RegstDescProto& regst_desc_proto) {
regst_desc_id_ = regst_desc_proto.regst_desc_id();
producer_actor_id_ =
IDMgr::Singleton().ActorId4TaskId(regst_desc_proto.producer_task_id());
IDMgr::Singleton()->ActorId4TaskId(regst_desc_proto.producer_task_id());
register_num_ = regst_desc_proto.register_num();
const auto& subscriber = regst_desc_proto.subscriber_task_id();
subscribers_actor_id_.reserve(subscriber.size());
for (int64_t task_id : subscriber) {
subscribers_actor_id_.push_back(IDMgr::Singleton().ActorId4TaskId(task_id));
subscribers_actor_id_.push_back(
IDMgr::Singleton()->ActorId4TaskId(task_id));
}
for (const auto& pair : regst_desc_proto.lbn2shape()) {
......
......@@ -16,7 +16,7 @@ void Thread::PollMsgChannel(const ThreadCtx& thread_ctx) {
auto actor_it = id2actor_ptr_.find(actor_id);
if (actor_it == id2actor_ptr_.end()) {
std::unique_lock<std::mutex> lck(id2task_mtx_);
int64_t task_id = IDMgr::Singleton().TaskId4ActorId(actor_id);
int64_t task_id = IDMgr::Singleton()->TaskId4ActorId(actor_id);
auto task_it = id2task_.find(task_id);
auto emplace_ret = id2actor_ptr_.emplace(
actor_id, ConstructActor(task_it->second, thread_ctx));
......
......@@ -16,8 +16,8 @@ void ThreadMgr::ForEachThread(std::function<void(Thread*)> func) {
ThreadMgr::ThreadMgr() {
// device thread - device_num_per_machine
int64_t dev_num_per_machine =
JobDesc::Singleton().resource().device_num_per_machine();
int64_t device_type = JobDesc::Singleton().resource().device_type();
JobDesc::Singleton()->resource().device_num_per_machine();
int64_t device_type = JobDesc::Singleton()->resource().device_type();
threads_.reserve(dev_num_per_machine + 3);
for (int64_t dev_phy_id = 0; dev_phy_id < dev_num_per_machine; ++dev_phy_id) {
if (device_type == kGPU) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册