提交 a3dd5a34 编写于 作者: L Li Xinqi 提交者: GitHub

faster improver (#1628)



Former-commit-id: 2550030088fb6b15f1784a2bd1cfb78eeabe3b0d
上级 88a04c86
......@@ -3,17 +3,17 @@
namespace oneflow {
RegstLifetimeGraph::RegstLifetimeGraph(
const std::list<const RegstDescProto*>& regst_descs,
const std::vector<const RegstDescProto*>& regst_descs,
const std::function<void(const RegstDescProto*, HashSet<int64_t>*)>& ComputeLifetimeActorIds) {
std::list<RegstLifetimeNode*> nodes;
std::vector<RegstLifetimeNode*> nodes;
InitNodes(regst_descs, ComputeLifetimeActorIds, &nodes);
InitEdges(nodes);
}
void RegstLifetimeGraph::InitNodes(
const std::list<const RegstDescProto*>& regst_descs,
const std::vector<const RegstDescProto*>& regst_descs,
const std::function<void(const RegstDescProto*, HashSet<int64_t>*)>& ComputeLifetimeActorIds,
std::list<RegstLifetimeNode*>* nodes) {
std::vector<RegstLifetimeNode*>* nodes) {
for (const RegstDescProto* regst_desc : regst_descs) {
auto lifetime_actor_ids = std::make_unique<HashSet<int64_t>>();
ComputeLifetimeActorIds(regst_desc, lifetime_actor_ids.get());
......@@ -23,7 +23,7 @@ void RegstLifetimeGraph::InitNodes(
}
}
void RegstLifetimeGraph::InitEdges(const std::list<RegstLifetimeNode*>& nodes) {
void RegstLifetimeGraph::InitEdges(const std::vector<RegstLifetimeNode*>& nodes) {
HashMap<int64_t, HashSet<RegstLifetimeNode*>> task_id2intersected_nodes;
for (RegstLifetimeNode* node : nodes) {
for (int64_t task_id : node->lifetime_actor_ids()) {
......@@ -46,7 +46,7 @@ void RegstLifetimeGraph::InitEdges(const std::list<RegstLifetimeNode*>& nodes) {
}
void RegstLifetimeGraph::ForEachSameColoredRegstDescs(
const std::function<void(const std::list<const RegstDescProto*>&)>& Handler) const {
const std::function<void(const std::vector<const RegstDescProto*>&)>& Handler) const {
std::vector<const RegstLifetimeNode*> nodes;
ForEachNode([&](const RegstLifetimeNode* node) { nodes.push_back(node); });
std::sort(nodes.begin(), nodes.end(),
......@@ -65,7 +65,7 @@ void RegstLifetimeGraph::ForEachSameColoredRegstDescs(
node2excluded_color_ids[intersected].insert(color_id);
});
}
HashMap<int32_t, std::list<const RegstDescProto*>> color_id2regst_descs;
HashMap<int32_t, std::vector<const RegstDescProto*>> color_id2regst_descs;
for (const auto& pair : node2color_id) {
color_id2regst_descs[pair.second].push_back(&pair.first->regst_desc());
}
......
......@@ -43,19 +43,19 @@ class RegstLifetimeGraph final : public Graph<const RegstLifetimeNode, RegstLife
public:
OF_DISALLOW_COPY_AND_MOVE(RegstLifetimeGraph);
RegstLifetimeGraph(
const std::list<const RegstDescProto*>& regst_descs,
const std::vector<const RegstDescProto*>& regst_descs,
const std::function<void(const RegstDescProto*, HashSet<int64_t>*)>& ComputeLifetimeActorIds);
~RegstLifetimeGraph() = default;
void ForEachSameColoredRegstDescs(
const std::function<void(const std::list<const RegstDescProto*>&)>& Handler) const;
const std::function<void(const std::vector<const RegstDescProto*>&)>& Handler) const;
private:
void InitNodes(
const std::list<const RegstDescProto*>& regst_descs,
const std::vector<const RegstDescProto*>& regst_descs,
const std::function<void(const RegstDescProto*, HashSet<int64_t>*)>& ComputeLifetimeActorIds,
std::list<RegstLifetimeNode*>* nodes);
void InitEdges(const std::list<RegstLifetimeNode*>& nodes);
std::vector<RegstLifetimeNode*>* nodes);
void InitEdges(const std::vector<RegstLifetimeNode*>& nodes);
};
} // namespace oneflow
......
......@@ -69,11 +69,11 @@ SharableMemBlockGraph::SharableMemBlockGraph(
void SharableMemBlockGraph::ForEachSourceNodeGroup(
const std::function<int64_t(const SharableMemBlockNode*)>& GroupBy,
const std::function<void(const std::vector<const SharableMemBlockNode*>&)>& Handler) const {
HashMap<int64_t, std::vector<const SharableMemBlockNode*>> chain_id2source_nodes;
HashMap<int64_t, std::vector<const SharableMemBlockNode*>> group_key2source_nodes;
for (const SharableMemBlockNode* source : source_nodes()) {
chain_id2source_nodes[GroupBy(source)].push_back(source);
group_key2source_nodes[GroupBy(source)].push_back(source);
}
for (const auto& pair : chain_id2source_nodes) { Handler(pair.second); }
for (const auto& pair : group_key2source_nodes) { Handler(pair.second); }
}
} // namespace oneflow
......@@ -29,8 +29,9 @@ bool IsConsumersAndProducerInSameChain(const RegstDescProto& regst_desc,
}
void ForEachSharableStreamRegstDescsWithoutConsumer(
const Plan& plan, const std::function<void(const std::list<const RegstDescProto*>&)>& Handler) {
HashMap<int64_t, std::list<const RegstDescProto*>> global_work_stream_id2regst_descs;
const Plan& plan,
const std::function<void(const std::vector<const RegstDescProto*>&)>& Handler) {
HashMap<int64_t, std::vector<const RegstDescProto*>> global_work_stream_id2regst_descs;
for (const auto& task : plan.task()) {
int64_t global_work_stream_id = Global<IDMgr>::Get()->GlobalWorkStreamId4TaskId(task.task_id());
for (const auto& pair : task.produced_regst_desc()) {
......@@ -45,20 +46,21 @@ void ForEachSharableStreamRegstDescsWithoutConsumer(
}
void ForEachSameColoredStreamRegstDescWithoutConsumer(
const Plan& plan, const std::function<void(const std::list<const RegstDescProto*>&)>& Handler) {
const Plan& plan,
const std::function<void(const std::vector<const RegstDescProto*>&)>& Handler) {
auto GetProducerTaskId = [](const RegstDescProto* regst_desc, HashSet<int64_t>* ret_actor_ids) {
CHECK(regst_desc->enable_mem_sharing());
ret_actor_ids->insert(regst_desc->producer_task_id());
};
ForEachSharableStreamRegstDescsWithoutConsumer(
plan, [&](const std::list<const RegstDescProto*>& regst_descs) {
plan, [&](const std::vector<const RegstDescProto*>& regst_descs) {
RegstLifetimeGraph(regst_descs, GetProducerTaskId).ForEachSameColoredRegstDescs(Handler);
});
}
void ForEachSameColoredChainRegstDescWithConsumer(
const PlanTaskGraph& plan_task_graph,
const std::function<void(const std::list<const RegstDescProto*>&)>& Handler) {
const std::function<void(const std::vector<const RegstDescProto*>&)>& Handler) {
// construct SharableMemBlockGraph
auto ChainId4TaskId = [&](int64_t task_id) {
return plan_task_graph.TaskProto4TaskId(task_id)->task_set_info().chain_id();
......@@ -92,7 +94,7 @@ void ForEachSameColoredChainRegstDescWithConsumer(
header2members.emplace(regst_descs.at(0), regst_descs);
}
auto GetRegstDescs = [&](const std::vector<const SharableMemBlockNode*>& sharable_mem_blocks) {
std::list<const RegstDescProto*> ret;
std::vector<const RegstDescProto*> ret;
for (const SharableMemBlockNode* sharable_mem_block : sharable_mem_blocks) {
for (const RegstDescProto* regst_desc : sharable_mem_block->regst_descs()) {
if (header2members.find(regst_desc) != header2members.end()) {
......@@ -111,8 +113,8 @@ void ForEachSameColoredChainRegstDescWithConsumer(
plan_task_graph.ComputeLifetimeSameChainActorIds(member, ret_actor_ids);
}
};
auto AppendGroupMembers = [&](const std::list<const RegstDescProto*>& regst_descs) {
std::list<const RegstDescProto*> members;
auto AppendGroupMembers = [&](const std::vector<const RegstDescProto*>& regst_descs) {
std::vector<const RegstDescProto*> members;
for (const auto* header : regst_descs) {
for (const auto* member : header2members.at(header)) { members.push_back(member); }
}
......@@ -121,6 +123,11 @@ void ForEachSameColoredChainRegstDescWithConsumer(
sharable_mem_block_gph.ForEachSourceNodeGroup(
&SharableMemBlockNode::chain_id,
[&](const std::vector<const SharableMemBlockNode*>& sharable_mem_blocks) {
if (sharable_mem_blocks.size() == 1) {
const auto& regst_descs = sharable_mem_blocks.at(0)->regst_descs();
if (regst_descs.size() > 1) { Handler(regst_descs); }
return;
}
RegstLifetimeGraph(GetRegstDescs(sharable_mem_blocks), ComputeLifetimeSameChainActorIds)
.ForEachSameColoredRegstDescs(AppendGroupMembers);
});
......@@ -128,9 +135,8 @@ void ForEachSameColoredChainRegstDescWithConsumer(
void ForEachImprovedMemSharedId(const PlanTaskGraph& plan_task_graph,
const std::function<void(int64_t, int64_t)>& Handler) {
using RegstDescs = std::list<const RegstDescProto*>;
const Plan& plan = plan_task_graph.plan();
auto HandleMemSharedId = [&](const RegstDescs& regst_descs) {
auto HandleMemSharedId = [&](const std::vector<const RegstDescProto*>& regst_descs) {
int64_t mem_shared_id = Global<IDMgr>::Get()->NewMemSharedId();
for (const RegstDescProto* regst_desc : regst_descs) {
Handler(regst_desc->regst_desc_id(), mem_shared_id);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册