提交 55b46427 编写于 作者: S strickland12 提交者: Jinhui Yuan

multi thread build chain_act_sub_graph (#1155)

上级 400e277e
#include "oneflow/core/graph/chain_act_graph.h"
#include "oneflow/core/common/protobuf.h"
#include "oneflow/core/thread/thread_pool.h"
#include "oneflow/core/common/blocking_counter.h"
namespace oneflow {
......@@ -61,72 +63,7 @@ void ChainActNode::AddProducedRegstAct(std::unique_ptr<RegstAct>&& regst_act) {
produced_regst_acts_.push_back(std::move(regst_act));
}
void ChainActGraph::ForEachRegstDescConsumerPathMeanDuration(
const std::function<void(int64_t, int64_t, double)>& Handler) const {
std::map<std::pair<int64_t, int64_t>, double> regst_desc_id_consumed2duration;
std::map<std::pair<int64_t, int64_t>, int> regst_desc_id_consumed2cnt;
ForEachRegstActConsumerPathDuration(
[&](int64_t regst_desc_id, int64_t consumer_actor_id, double duration) {
std::pair<int64_t, int64_t> regst_desc_id_consumed(regst_desc_id, consumer_actor_id);
regst_desc_id_consumed2duration[regst_desc_id_consumed] += duration;
++regst_desc_id_consumed2cnt[regst_desc_id_consumed];
});
for (const auto& pair : regst_desc_id_consumed2duration) {
Handler(pair.first.first, pair.first.second,
pair.second / regst_desc_id_consumed2cnt.at(pair.first));
}
}
void ChainActGraph::ForEachRegstDescConsumerPathIIScale(
const std::function<void(int64_t, int64_t, double)>& Handler) const {
std::map<std::pair<int64_t, int64_t>, uint64_t> regst_desc_id_consumed2used_cnt;
std::map<int64_t, uint64_t> regst_desc_id2produced_cnt;
uint64_t max_cnt = 0;
ForEachNode([&](const ChainActNode* node) {
node->ForEachLastConsumedRegstAct([&](const RegstAct* regst_act) {
int64_t produced_cnt = ++regst_desc_id2produced_cnt[regst_act->regst_desc_id];
if (max_cnt < produced_cnt) { max_cnt = produced_cnt; }
for (const ActEvent* act_event : regst_act->consumer_act_events) {
std::pair<int64_t, int64_t> regst_desc_id_consumed(regst_act->regst_desc_id,
act_event->actor_id());
int64_t used_cnt = ++regst_desc_id_consumed2used_cnt[regst_desc_id_consumed];
if (max_cnt < used_cnt) { max_cnt = used_cnt; }
}
});
});
for (const auto& pair : regst_desc_id_consumed2used_cnt) {
uint64_t produced_cnt = regst_desc_id2produced_cnt.at(pair.first.first);
Handler(pair.first.first, pair.first.second,
1.0 * max_cnt / std::min(produced_cnt, pair.second));
}
}
double ChainActGraph::CalcBaseII() const {
int64_t max_act_cnt = 0;
HashMap<int64_t, int64_t> actor_id2outputed_act_cnt;
ForEachActEvent([&](const ActEvent* act_event) {
int64_t actor_id = act_event->actor_id();
if (IsActEventWithConsumer(act_event)) {
++actor_id2outputed_act_cnt[actor_id];
max_act_cnt = std::max(max_act_cnt, actor_id2outputed_act_cnt[actor_id]);
}
});
HashMap<int64_t, double> stream_id2total_calc_time;
ForEachActEvent([&](const ActEvent* act_event) {
int64_t actor_id = act_event->actor_id();
auto frequence_it = actor_id2outputed_act_cnt.find(actor_id);
if (frequence_it == actor_id2outputed_act_cnt.end()) { return; }
int64_t stream_id = act_event->work_stream_id();
stream_id2total_calc_time[stream_id] += Duration4ActEvent(*act_event);
});
double base_ii = 0;
for (const auto& pair : stream_id2total_calc_time) {
base_ii = std::max(base_ii, pair.second / max_act_cnt);
}
return base_ii;
}
void ChainActGraph::ForEachRegstActConsumerPathDuration(
void ChainActSubGraph::ForEachRegstActConsumerPathDuration(
const std::function<void(int64_t, int64_t, double)>& Handler) const {
HashSet<std::shared_ptr<RegstActGroupCtx>> ctx_window;
HashMap<const RegstAct*, std::shared_ptr<RegstActGroupCtx>> regst_act2ctx;
......@@ -157,8 +94,8 @@ void ChainActGraph::ForEachRegstActConsumerPathDuration(
});
}
void ChainActGraph::CalcRegstActNodePathDuration(RegstActGroupCtx* regst_act_group_ctx,
const ChainActNode* node) const {
void ChainActSubGraph::CalcRegstActNodePathDuration(RegstActGroupCtx* regst_act_group_ctx,
const ChainActNode* node) const {
double duration = 0;
node->ForEachInEdge([&](const ChainActEdge* in_edge) {
const ChainActNode* in_node = in_edge->src_node();
......@@ -172,13 +109,7 @@ void ChainActGraph::CalcRegstActNodePathDuration(RegstActGroupCtx* regst_act_gro
if (duration > 0) { regst_act_group_ctx->node2duration_to_producer[node] = duration; }
}
void ChainActGraph::InitTaskId2TaskProto() {
for (const auto& task_proto : plan_->task()) {
CHECK(task_id2task_proto_.emplace(task_proto.task_id(), &task_proto).second);
}
}
void ChainActGraph::InitNodes(
void ChainActSubGraph::InitNodes(
std::list<std::unique_ptr<ActEvent>>&& act_events,
HashMap<std::pair<int64_t, int64_t>, const ActEvent*>* regst_uid2producer_act_event) {
HashMap<std::pair<int64_t, int64_t>, std::list<std::unique_ptr<ActEvent>>>
......@@ -208,7 +139,7 @@ void ChainActGraph::InitNodes(
}
}
void ChainActGraph::InitEdges(
void ChainActSubGraph::InitEdges(
const HashMap<std::pair<int64_t, int64_t>, const ActEvent*>& regst_uid2producer_act_event,
HashMap<std::pair<int64_t, int64_t>, std::list<const ActEvent*>>*
regst_uid2consumer_act_events) {
......@@ -238,7 +169,7 @@ void ChainActGraph::InitEdges(
});
}
void ChainActGraph::InitNodeProducedRegstAct(
void ChainActSubGraph::InitNodeProducedRegstAct(
const HashMap<std::pair<int64_t, int64_t>, const ActEvent*>& regst_uid2producer_act_event,
const HashMap<std::pair<int64_t, int64_t>, std::list<const ActEvent*>>&
regst_uid2consumer_act_events) const {
......@@ -257,7 +188,7 @@ void ChainActGraph::InitNodeProducedRegstAct(
}
}
void ChainActGraph::InitNodeLastConsumedRegstActGroup() const {
void ChainActSubGraph::InitNodeLastConsumedRegstActGroup() const {
auto TopoOrderValue4Node = MakeGetterTopoOrderValue4Node();
auto ForEachConsumer = [](const std::list<const RegstAct*>& regst_act_group,
const std::function<void(const ActEvent*)>& Handler) {
......@@ -281,7 +212,8 @@ void ChainActGraph::InitNodeLastConsumedRegstActGroup() const {
});
}
std::function<int64_t(const ChainActNode*)> ChainActGraph::MakeGetterTopoOrderValue4Node() const {
std::function<int64_t(const ChainActNode*)> ChainActSubGraph::MakeGetterTopoOrderValue4Node()
const {
auto node2topo_order_value = std::make_shared<HashMap<const ChainActNode*, int64_t>>();
int64_t topo_order_value = -1;
TopoForEachChainActNode([&](const ChainActNode* chain_act_node) {
......@@ -291,11 +223,11 @@ std::function<int64_t(const ChainActNode*)> ChainActGraph::MakeGetterTopoOrderVa
[node2topo_order_value](const ChainActNode* node) { return node2topo_order_value->at(node); };
}
const ChainActNode* ChainActGraph::Node4ActEvent(const ActEvent* act_event) const {
const ChainActNode* ChainActSubGraph::Node4ActEvent(const ActEvent* act_event) const {
return act_event2chain_node_.at(act_event);
}
void ChainActGraph::TopoForEachChainActNode(
void ChainActSubGraph::TopoForEachChainActNode(
const std::function<void(const ChainActNode*)>& Handler) const {
std::list<const ChainActNode*> starts;
ForEachNode([&](const ChainActNode* node) {
......@@ -308,19 +240,19 @@ void ChainActGraph::TopoForEachChainActNode(
&ChainActNode::ForEachNodeOnOutEdge, Handler);
}
void ChainActGraph::ForEachActEvent(const std::function<void(const ActEvent*)>& Handler) const {
void ChainActSubGraph::ForEachActEvent(const std::function<void(const ActEvent*)>& Handler) const {
ForEachNode([&](const ChainActNode* node) { node->ForEachActEvent(Handler); });
}
bool ChainActGraph::IsActEventWithConsumer(const ActEvent* act_event) const {
bool ChainActSubGraph::IsActEventWithConsumer(const ActEvent* act_event) const {
return act_event_with_consumer_.find(act_event) != act_event_with_consumer_.end();
}
ChainActGraph::ChainActGraph(const Plan& plan, std::list<std::unique_ptr<ActEvent>>&& act_events)
: plan_(&plan) {
ChainActSubGraph::ChainActSubGraph(const HashMap<int64_t, const TaskProto&>& task_id2task_proto,
std::list<std::unique_ptr<ActEvent>>&& act_events)
: task_id2task_proto_(task_id2task_proto) {
HashMap<std::pair<int64_t, int64_t>, const ActEvent*> regst_uid2producer_act_event;
HashMap<std::pair<int64_t, int64_t>, std::list<const ActEvent*>> regst_uid2consumer_act_events;
InitTaskId2TaskProto();
InitNodes(std::move(act_events), &regst_uid2producer_act_event);
InitEdges(regst_uid2producer_act_event, &regst_uid2consumer_act_events);
InitNodeProducedRegstAct(regst_uid2producer_act_event, regst_uid2consumer_act_events);
......@@ -328,4 +260,127 @@ ChainActGraph::ChainActGraph(const Plan& plan, std::list<std::unique_ptr<ActEven
// ToDotWithAutoFilePath();
}
void ChainActGraph::ForEachChainActSubGraph(
const std::function<void(const ChainActSubGraph*)>& Handler) const {
for (auto& sub_graph : sub_graphs_) { Handler(sub_graph.get()); }
}
void ChainActGraph::ForEachRegstDescConsumerPathIIScale(
const std::function<void(int64_t, int64_t, double)>& Handler) const {
std::map<std::pair<int64_t, int64_t>, uint64_t> regst_desc_id_consumed2used_cnt;
std::map<int64_t, uint64_t> regst_desc_id2produced_cnt;
uint64_t max_cnt = 0;
ForEachChainActSubGraph([&](const ChainActSubGraph* sub_graph) {
sub_graph->ForEachNode([&](const ChainActNode* node) {
node->ForEachLastConsumedRegstAct([&](const RegstAct* regst_act) {
int64_t produced_cnt = ++regst_desc_id2produced_cnt[regst_act->regst_desc_id];
if (max_cnt < produced_cnt) { max_cnt = produced_cnt; }
for (const ActEvent* act_event : regst_act->consumer_act_events) {
std::pair<int64_t, int64_t> regst_desc_id_consumed(regst_act->regst_desc_id,
act_event->actor_id());
int64_t used_cnt = ++regst_desc_id_consumed2used_cnt[regst_desc_id_consumed];
if (max_cnt < used_cnt) { max_cnt = used_cnt; }
}
});
});
});
for (const auto& pair : regst_desc_id_consumed2used_cnt) {
uint64_t produced_cnt = regst_desc_id2produced_cnt.at(pair.first.first);
Handler(pair.first.first, pair.first.second,
1.0 * max_cnt / std::min(produced_cnt, pair.second));
}
}
void ChainActGraph::ForEachRegstDescConsumerPathMeanDuration(
const std::function<void(int64_t, int64_t, double)>& Handler) const {
std::map<std::pair<int64_t, int64_t>, double> regst_desc_id_consumed2duration;
std::map<std::pair<int64_t, int64_t>, int> regst_desc_id_consumed2cnt;
ForEachChainActSubGraph([&](const ChainActSubGraph* sub_graph) {
sub_graph->ForEachRegstActConsumerPathDuration(
[&](int64_t regst_desc_id, int64_t consumer_actor_id, double duration) {
std::pair<int64_t, int64_t> regst_desc_id_consumed(regst_desc_id, consumer_actor_id);
regst_desc_id_consumed2duration[regst_desc_id_consumed] += duration;
++regst_desc_id_consumed2cnt[regst_desc_id_consumed];
});
});
for (const auto& pair : regst_desc_id_consumed2duration) {
Handler(pair.first.first, pair.first.second,
pair.second / regst_desc_id_consumed2cnt.at(pair.first));
}
}
double ChainActGraph::CalcBaseII() const {
int64_t max_act_cnt = 0;
HashMap<int64_t, int64_t> actor_id2outputed_act_cnt;
ForEachChainActSubGraph([&](const ChainActSubGraph* sub_graph) {
sub_graph->ForEachActEvent([&](const ActEvent* act_event) {
int64_t actor_id = act_event->actor_id();
if (sub_graph->IsActEventWithConsumer(act_event)) {
++actor_id2outputed_act_cnt[actor_id];
max_act_cnt = std::max(max_act_cnt, actor_id2outputed_act_cnt[actor_id]);
}
});
});
HashMap<int64_t, double> stream_id2total_calc_time;
ForEachChainActSubGraph([&](const ChainActSubGraph* sub_graph) {
sub_graph->ForEachActEvent([&](const ActEvent* act_event) {
int64_t actor_id = act_event->actor_id();
auto frequence_it = actor_id2outputed_act_cnt.find(actor_id);
if (frequence_it == actor_id2outputed_act_cnt.end()) { return; }
int64_t stream_id = act_event->work_stream_id();
stream_id2total_calc_time[stream_id] += Duration4ActEvent(*act_event);
});
});
double base_ii = 0;
for (const auto& pair : stream_id2total_calc_time) {
base_ii = std::max(base_ii, pair.second / max_act_cnt);
}
return base_ii;
}
void ChainActGraph::InitTaskId2TaskProto() {
for (const TaskProto& task_proto : plan_->task()) {
CHECK(task_id2task_proto_.emplace(task_proto.task_id(), task_proto).second);
}
}
void ChainActGraph::GroupActEventByActId(
std::list<std::unique_ptr<ActEvent>>&& act_events,
HashMap<int64_t, std::list<std::unique_ptr<ActEvent>>>* act_id2act_event_group) const {
// not considering RNN yet
for (auto& act_event : act_events) {
(*act_id2act_event_group)[act_event->act_id()].push_back(std::move(act_event));
}
}
void ChainActGraph::MultiThreadBuildChainActSubGraph(
HashMap<int64_t, std::list<std::unique_ptr<ActEvent>>>* act_id2act_event_group) {
int64_t sub_graph_num = act_id2act_event_group->size();
int64_t cpu_num = std::thread::hardware_concurrency();
int64_t thread_pool_size = std::min(sub_graph_num, cpu_num);
BlockingCounter counter(sub_graph_num);
std::mutex sub_graph_mtx;
ThreadPool thread_pool(thread_pool_size);
for (auto& pair : *act_id2act_event_group) {
thread_pool.AddWork([&]() {
auto sub_graph =
std::make_unique<ChainActSubGraph>(task_id2task_proto_, std::move(pair.second));
{
std::unique_lock<std::mutex> guard(sub_graph_mtx);
sub_graphs_.push_back(std::move(sub_graph));
}
counter.Decrease();
});
}
counter.WaitUntilCntEqualZero();
}
ChainActGraph::ChainActGraph(const Plan& plan, std::list<std::unique_ptr<ActEvent>>&& act_events)
: plan_(&plan) {
HashMap<int64_t, std::list<std::unique_ptr<ActEvent>>> act_id2act_event_group;
InitTaskId2TaskProto();
GroupActEventByActId(std::move(act_events), &act_id2act_event_group);
MultiThreadBuildChainActSubGraph(&act_id2act_event_group);
}
} // namespace oneflow
......@@ -89,33 +89,27 @@ class ChainActNode final : public Node<ChainActNode, ChainActEdge> {
std::list<std::list<const RegstAct*>> last_consumed_regst_act_groups_;
};
class ChainActGraph final : public Graph<const ChainActNode, const ChainActEdge> {
class ChainActSubGraph final : public Graph<const ChainActNode, const ChainActEdge> {
public:
OF_DISALLOW_COPY_AND_MOVE(ChainActGraph);
ChainActGraph() = delete;
~ChainActGraph() = default;
ChainActGraph(const Plan& plan, std::list<std::unique_ptr<ActEvent>>&& act_events);
const char* TypeName() const override { return "ChainActGraph"; }
OF_DISALLOW_COPY_AND_MOVE(ChainActSubGraph);
ChainActSubGraph() = delete;
~ChainActSubGraph() = default;
ChainActSubGraph(const HashMap<int64_t, const TaskProto&>& task_id2task_proto,
std::list<std::unique_ptr<ActEvent>>&& act_events);
const char* TypeName() const override { return "ChainActSubGraph"; }
// Getters
const TaskProto& GetTaskProto(int64_t actor_id) const {
return *(task_id2task_proto_.at(actor_id));
}
const TaskProto& GetTaskProto(int64_t actor_id) const { return task_id2task_proto_.at(actor_id); }
bool IsActEventWithConsumer(const ActEvent* act_event) const;
// ForEach
void ForEachRegstDescConsumerPathMeanDuration(
const std::function<void(int64_t, int64_t, double)>& Handler) const;
void ForEachRegstDescConsumerPathIIScale(
void ForEachActEvent(const std::function<void(const ActEvent*)>& Handler) const;
void ForEachRegstActConsumerPathDuration(
const std::function<void(int64_t, int64_t, double)>& Handler) const;
double CalcBaseII() const;
private:
bool IsActEventWithConsumer(const ActEvent* act_event) const;
const ChainActNode* Node4ActEvent(const ActEvent* act_event) const;
void ForEachActEvent(const std::function<void(const ActEvent*)>& Handler) const;
std::function<int64_t(const ChainActNode*)> MakeGetterTopoOrderValue4Node() const;
const ChainActNode* Node4ActEvent(const ActEvent* act_event) const;
void InitNodes(
std::list<std::unique_ptr<ActEvent>>&& act_events,
HashMap<std::pair<int64_t, int64_t>, const ActEvent*>* regst_uid2producer_act_event);
......@@ -127,20 +121,48 @@ class ChainActGraph final : public Graph<const ChainActNode, const ChainActEdge>
const HashMap<std::pair<int64_t, int64_t>, const ActEvent*>& regst_uid2producer_act_event,
const HashMap<std::pair<int64_t, int64_t>, std::list<const ActEvent*>>&
regst_uid2consumer_act_events) const;
void InitTaskId2TaskProto();
void InitNodeLastConsumedRegstActGroup() const;
void TopoForEachChainActNode(const std::function<void(const ChainActNode*)>& Handler) const;
void ForEachRegstActConsumerPathDuration(
const std::function<void(int64_t, int64_t, double)>& Handler) const;
void CalcRegstActNodePathDuration(RegstActGroupCtx* regst_act_group_ctx,
const ChainActNode* node) const;
const Plan* plan_;
HashMap<int64_t, const TaskProto*> task_id2task_proto_;
HashSet<const ActEvent*> act_event_with_consumer_;
const HashMap<int64_t, const TaskProto&>& task_id2task_proto_;
HashMap<const ActEvent*, ChainActNode*> act_event2chain_node_;
};
class ChainActGraph final {
public:
OF_DISALLOW_COPY_AND_MOVE(ChainActGraph);
ChainActGraph() = delete;
~ChainActGraph() = default;
ChainActGraph(const Plan& plan, std::list<std::unique_ptr<ActEvent>>&& act_events);
// Getter
const TaskProto& GetTaskProto(int64_t actor_id) const { return task_id2task_proto_.at(actor_id); }
// ForEach
void ForEachRegstDescConsumerPathMeanDuration(
const std::function<void(int64_t, int64_t, double)>& Handler) const;
void ForEachRegstDescConsumerPathIIScale(
const std::function<void(int64_t, int64_t, double)>& Handler) const;
double CalcBaseII() const;
private:
void InitTaskId2TaskProto();
void ForEachChainActSubGraph(const std::function<void(const ChainActSubGraph*)>& Handler) const;
void GroupActEventByActId(
std::list<std::unique_ptr<ActEvent>>&& act_events,
HashMap<int64_t, std::list<std::unique_ptr<ActEvent>>>* act_id2act_event_group) const;
void MultiThreadBuildChainActSubGraph(
HashMap<int64_t, std::list<std::unique_ptr<ActEvent>>>* act_id2act_event_group);
const Plan* plan_;
HashMap<int64_t, const TaskProto&> task_id2task_proto_;
std::list<std::unique_ptr<ChainActSubGraph>> sub_graphs_;
};
} // namespace oneflow
#endif // ONEFLOW_CORE_GRAPH_CHAIN_ACT_GRAPH_H_
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册