提交 0d7afa55 编写于 作者: C chengtbf 提交者: Jinhui Yuan

Dev ctrl edge copy reduce (#995)

* add ctrl edge between copyHd and MdUpdte

* fix bug of add ctrl regst

* hack ctrl regst max regst num

* test undo

* after experiment

* use get task type instead of dynamic cast

* fix for review

* remove hack regst

* init ctrl regst min num

* fix ctrl regst num between copy and mdupdt after improver

* mdupdt actor return ctrl regst k(k = num of piece in batch) one act

* add returned regst num of ctrl regst

* CHECK invariant


Former-commit-id: 339d56ee
上级 245b26b5
......@@ -451,9 +451,18 @@ void Actor::AsyncSendCtrlRegstMsg() {
for (auto& pair : consumed_ctrl_regst_) {
CHECK(!pair.second.empty());
Regst* regst = pair.second.front();
AsyncSendMsg(ActorMsg::BuildRegstMsgToProducer(actor_id_, regst->producer_actor_id(), regst));
pair.second.pop_front();
if (pair.second.empty()) { --readable_ctrl_regst_desc_cnt_; }
int32_t returned_regst_num =
regst->regst_desc()->regst_desc_type().ctrl_regst_desc().returned_regst_num();
CHECK_GE(returned_regst_num, 1);
if (!pair.second.empty()) {
CHECK_GE(pair.second.size(), returned_regst_num);
while (returned_regst_num--) {
AsyncSendMsg(
ActorMsg::BuildRegstMsgToProducer(actor_id_, regst->producer_actor_id(), regst));
pair.second.pop_front();
}
if (pair.second.empty()) { --readable_ctrl_regst_desc_cnt_; }
}
}
for (auto& pair : writeable_produced_ctrl_regst_) {
CHECK(!pair.second.empty());
......
#include "oneflow/core/graph/task_graph.h"
#include "oneflow/core/graph/normal_forward_compute_task_node.h"
#include "oneflow/core/graph/normal_model_update_compute_task_node.h"
#include "oneflow/core/graph/chain_graph.h"
#include "oneflow/core/graph/boxing_task_node.h"
#include "oneflow/core/common/balanced_splitter.h"
......@@ -184,7 +186,60 @@ bool TaskGraph::IsEndingTaskType<ReduceGatherCompTaskNode>(TaskType type) {
void TaskGraph::AddMutexCtrlEdgeInSameChain() { UNIMPLEMENTED(); }
void TaskGraph::AddOrderCtrlEdgeBetweenCopyAndMdUpdt() { UNIMPLEMENTED(); }
void TaskGraph::AddOrderCtrlEdgeBetweenCopyAndMdUpdt() {
for (TaskNode* task_node : ordered_task_nodes_) {
auto copy_hd_task_node = dynamic_cast<CopyHdTaskNode*>(task_node);
if (copy_hd_task_node == nullptr) { continue; }
if (copy_hd_task_node->copy_type() != CopyHdOpConf::H2D) { continue; }
if (copy_hd_task_node->area_id() != static_cast<int64_t>(kDataForwardArea)
&& copy_hd_task_node->area_id() != static_cast<int64_t>(kBoundaryArea)) {
continue;
}
std::vector<TaskNode*> candidate_nodes;
auto ForEachNextNode = [&](TaskNode* node,
const std::function<void(TaskNode*)>& TryPushNodeToQueue) {
auto fw_task_node = dynamic_cast<NormalForwardCompTaskNode*>(node);
if (fw_task_node != nullptr && fw_task_node->logical_node()->HasOpWithModelBlob()) { return; }
node->ForEachNodeOnOutEdge([&](TaskNode* node_on_out_edge) {
if (IsForwardTaskType(node_on_out_edge->GetTaskType())) {
TryPushNodeToQueue(node_on_out_edge);
}
});
};
auto HandlerAddCandidate = [&](TaskNode* node) {
auto fw_task_node = dynamic_cast<NormalForwardCompTaskNode*>(node);
if (fw_task_node != nullptr && fw_task_node->logical_node()->HasOpWithModelBlob()
&& fw_task_node->parallel_ctx()->parallel_num() > 1
&& fw_task_node->parallel_ctx()->policy() == kDataParallel) {
candidate_nodes.push_back(node);
}
};
BfsForEachNode({task_node}, ForEachNextNode, HandlerAddCandidate);
std::sort(candidate_nodes.begin(), candidate_nodes.end(),
[](const TaskNode* a, const TaskNode* b) {
return a->order_in_graph() < b->order_in_graph();
});
int64_t last_chain_id = -1;
for (TaskNode* candidate_node : candidate_nodes) {
if (candidate_node->chain_id() != last_chain_id) {
last_chain_id = candidate_node->chain_id();
candidate_node->ForEachNodeOnInEdge([&](TaskNode* node_on_in_edge) {
if (IsMdUpdtTaskType(node_on_in_edge->GetTaskType())) {
RegstDesc* ctrl_regst = task_node->BuildCtrlRegstDesc(node_on_in_edge);
RegstDesc* copy_out_regst = copy_hd_task_node->GetProducedRegst("copy_out").get();
int64_t piece_num_in_batch = Global<JobDesc>::Get()->NumOfPiecesInBatch();
ctrl_regst->UpdtMinRegstNumIfNeed(copy_out_regst->min_register_num()
+ piece_num_in_batch - 1);
CtrlRegstDesc* ctrl_regst_desc =
ctrl_regst->mut_regst_desc_type()->mutable_ctrl_regst_desc();
ctrl_regst_desc->set_reliant_regst_desc_id(copy_out_regst->regst_desc_id());
ctrl_regst_desc->set_returned_regst_num(piece_num_in_batch);
}
});
}
}
}
}
void TaskGraph::CollectAncestorsForEachNode() {
std::vector<TaskNode*> ordered_nodes;
......
......@@ -148,12 +148,17 @@ void TaskNode::BuildCtrlRegstDescIfNeed(TaskNode* dst_node) {
}
const auto& dst_ancestors = dst_node->ancestors();
if (dst_ancestors.find(this) != dst_ancestors.end()) return;
BuildCtrlRegstDesc(dst_node);
}
RegstDesc* TaskNode::BuildCtrlRegstDesc(TaskNode* dst_node) {
RegstDescTypeProto regst_desc_type;
regst_desc_type.mutable_ctrl_regst_desc();
auto regst = NewProducedRegst(false, 1, kMaxRegisterNum, regst_desc_type);
std::string name = "out_ctrl_" + std::to_string(regst->regst_desc_id());
CHECK(produced_regsts_.emplace(name, regst).second);
dst_node->ConsumeRegst("in_ctrl", regst);
return regst.get();
}
void TaskNode::BindEdgeWithProducedRegst(TaskEdge* edge, const std::string& name) {
......
......@@ -75,6 +75,7 @@ class TaskNode : public Node<TaskNode, TaskEdge> {
void BindEdgeWithProducedRegst(TaskEdge*, const std::string& name);
virtual int64_t MemZoneId121() const; // TODO: there is bug for reduce task node
void BuildCtrlRegstDescIfNeed(TaskNode* dst_node);
RegstDesc* BuildCtrlRegstDesc(TaskNode* dst_node);
protected:
std::shared_ptr<RegstDesc> ProduceRegst(const std::string& name, bool enable_mem_sharing);
......
......@@ -63,6 +63,7 @@ Plan Compiler::DoCompile() {
if (Global<JobDesc>::Get()->other_conf().use_ordered_allreduce_in_mdupdt()) {
task_gph->AddCtrlEdgeInReduceStruct();
}
if (job_desc->IsTrain()) { task_gph->AddOrderCtrlEdgeBetweenCopyAndMdUpdt(); }
Plan plan;
task_gph->ForEachNode([&](TaskNode* task_node) {
if (task_node->IsMeaningLess()) { return; }
......
......@@ -182,6 +182,13 @@ std::shared_ptr<HashMap<int64_t, RegstDescProto*>> MakeRegstDescId2RegstDesc(Pla
return regst_desc_id2regst_desc;
}
std::function<uint64_t(int64_t)> MakeGetterGetPlanRegstNum(Plan* plan) {
auto regst_desc_id2regst_desc = MakeRegstDescId2RegstDesc(plan);
return [regst_desc_id2regst_desc](int64_t regst_desc_id) {
return regst_desc_id2regst_desc->at(regst_desc_id)->register_num();
};
}
std::function<void(int64_t, uint64_t)> MakeSetterSetPlanRegstNum(Plan* plan) {
auto regst_desc_id2regst_desc = MakeRegstDescId2RegstDesc(plan);
return [regst_desc_id2regst_desc](int64_t regst_desc_id, uint64_t num) {
......@@ -374,6 +381,24 @@ void ForEachMemSharingCriticalSection(
}
}
void FixReliantCtrlRegstNum(const Plan& plan, const std::function<uint64_t(int64_t)>& GetRegstNum,
const std::function<void(int64_t, uint64_t)>& SetRegstNum) {
for (const auto& task_proto : plan.task()) {
for (const auto& pair : task_proto.produced_regst_desc()) {
const RegstDescProto& regst = pair.second;
const RegstDescTypeProto& regst_type = regst.regst_desc_type();
if (regst_type.has_ctrl_regst_desc()
&& regst_type.ctrl_regst_desc().has_reliant_regst_desc_id()) {
// set ctrl regst num between copyHd and MdUpdt
CHECK(task_proto.task_type() == kCopyHd);
uint64_t regst_num = GetRegstNum(regst_type.ctrl_regst_desc().reliant_regst_desc_id())
+ Global<JobDesc>::Get()->NumOfPiecesInBatch() - 1;
SetRegstNum(regst.regst_desc_id(), regst_num);
}
}
}
}
} // namespace
uint64_t Improver::AvailableMemSize(int64_t machine_id, int64_t memory_zone_id) const {
......@@ -509,6 +534,7 @@ Plan Improver::Improve(const AvailableMemDesc& amd, const Plan& naive_plan,
Plan plan(mem_shared_plan);
ForEachImprovedRegstNum(act_graph, mem_shared_plan, true, PathDurations4RegstDescId,
PathIIScales4RegstDescId, MakeSetterSetPlanRegstNum(&plan));
FixReliantCtrlRegstNum(plan, MakeGetterGetPlanRegstNum(&plan), MakeSetterSetPlanRegstNum(&plan));
return plan;
}
......
......@@ -16,6 +16,8 @@ message DataRegstDesc {
}
message CtrlRegstDesc {
optional int64 reliant_regst_desc_id = 1;
optional int32 returned_regst_num = 2 [default = 1];
}
message RegstDescTypeProto {
......
......@@ -11,6 +11,7 @@ RtRegstDesc::RtRegstDesc(const RegstDescProto& proto) {
consumers_actor_id_ = PbRf2StdVec(proto.consumer_task_id());
register_num_ = proto.register_num();
mem_case_ = proto.mem_case();
regst_desc_type_ = proto.regst_desc_type();
if (proto.regst_desc_type().has_data_regst_desc()) {
const DataRegstDesc& data_regst_desc = proto.regst_desc_type().data_regst_desc();
for (const LbiBlobDescPair& pair : data_regst_desc.lbi2blob_desc()) {
......
......@@ -20,6 +20,7 @@ class RtRegstDesc {
const std::vector<int64_t>& consumers_actor_id() const { return consumers_actor_id_; }
int64_t register_num() const { return register_num_; }
const MemoryCase& mem_case() const { return mem_case_; }
const RegstDescTypeProto& regst_desc_type() const { return regst_desc_type_; }
const BlobDesc* GetBlobDescFromLbi(const LogicalBlobId& lbi) const;
const BlobDesc* packed_blob_desc() const { return &packed_blob_desc_; }
......@@ -32,6 +33,7 @@ class RtRegstDesc {
int64_t producer_actor_id_;
std::vector<int64_t> consumers_actor_id_;
int64_t register_num_;
RegstDescTypeProto regst_desc_type_;
MemoryCase mem_case_;
HashMap<LogicalBlobId, std::unique_ptr<BlobDesc>> lbi2blob_desc_;
BlobDesc packed_blob_desc_;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册