提交 da518775 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!4813 parallel control stream

Merge pull request !4813 from gukecai/parallel-ctrl
......@@ -211,8 +211,11 @@ bool CommunicationOpFusion::DoFusion(const FuncGraphPtr &func_graph, const Commu
start_index = end_index + 1;
continue;
}
auto kernel_graph = func_graph->cast<KernelGraphPtr>();
auto graph_id = kernel_graph->graph_id();
AnfNodePtr new_communication_op =
CreateFusedCommunicationOp(func_graph, communication_op_info, start_index, end_index);
AnfAlgo::SetGraphId(graph_id, new_communication_op.get());
// replace old communication op with new communication op
for (auto idx = start_index; idx <= end_index; ++idx) {
std::vector<AnfNodePtr> tuple_getitem_input;
......
......@@ -123,11 +123,18 @@ class AscendStreamAssign {
void CheckEventAssign(const NotNull<KernelGraphPtr> &graph_ptr);
void AssignAllNodesStream(const NotNull<KernelGraphPtr> &graph_ptr);
void AssignCommonStreamId(const CNodePtr &cur_cnode_ptr);
void AssignHcomStreamId(const CNodePtr &cur_cnode_ptr);
void AssignIndependentStreamId(const CNodePtr &cur_cnode_ptr);
uint32_t AssignHcomStreamId(const CNodePtr &cur_cnode_ptr, bool new_graph);
uint32_t AssignIndependentStreamId(const CNodePtr &cur_cnode_ptr, bool new_graph);
void UpdateAtomicAddrCleanStreamId(const NotNull<KernelGraphPtr> &graph_ptr);
void FindHcomParallelStreams(const NotNull<KernelGraphPtr> &graph_ptr);
void InsertStreamActive(const NotNull<KernelGraphPtr> &graph_ptr);
void InsertStreamActiveForCommon(const NotNull<KernelGraphPtr> &graph_ptr);
void InsertStreamActiveForIndependent(const NotNull<KernelGraphPtr> &graph_ptr);
void InsertStreamActiveForParallel(const NotNull<KernelGraphPtr> &graph_ptr);
void ActiveRootGraphHcom(const NotNull<KernelGraphPtr> &graph_ptr, const std::set<uint32_t> &hcom_streams);
void ActiveRootGraphIndependent(const NotNull<KernelGraphPtr> &graph_ptr, std::set<uint32_t> independent_streams);
void ActiveOtherGraphParallel(const NotNull<KernelGraphPtr> &graph_ptr,
std::map<uint32_t, std::set<uint32_t>> other_graph);
void UpdateStreamSwitch(const NotNull<KernelGraphPtr> &graph_ptr, const CNodePtr &switch_ptr,
vector<CNodePtr> *orders);
void InsertEventForIndependentParallel(const NotNull<KernelGraphPtr> &graph_ptr);
......@@ -135,9 +142,11 @@ class AscendStreamAssign {
void InsertEventForHcomParallel(const NotNull<KernelGraphPtr> &graph_ptr);
void InsertEventCommonDependHcom(const NotNull<KernelGraphPtr> &graph_ptr);
void InsertEventHcomDependCommon(const NotNull<KernelGraphPtr> &graph_ptr);
void InsertEventHcomDependCommonBak(const NotNull<KernelGraphPtr> &graph_ptr);
void InsertEventHcomDependHcom(const NotNull<KernelGraphPtr> &graph_ptr);
void InsertEventBetweenHcom(const NotNull<KernelGraphPtr> &graph_ptr, const map<uint32_t, vector<size_t>> &hcom_index,
uint32_t first_hcom_stream, uint32_t last_hcom_stream);
CNodePtr GetLastInputCnode(const NotNull<KernelGraphPtr> &graph_ptr, const CNodePtr &cur_cnode_ptr);
bool IsSatisfiedHcom(const std::map<uint32_t, vector<size_t>> &hcom_index, const CNodePtr &node_ptr, size_t index);
void GetProcessedStream(const NotNull<KernelGraphPtr> &graph_ptr);
......@@ -155,6 +164,7 @@ class AscendStreamAssign {
vector<CNodePtr>::iterator FindTargetOp(vector<CNodePtr>::iterator begin, vector<CNodePtr>::iterator end,
const CNodePtr &node);
void GetParallelStream(uint32_t cur_stream_id, uint32_t stream_acitve_id, std::vector<uint32_t> *parallel_streams);
void SetLoopSink();
// function for memory resue
void GetStreamRelations();
......@@ -172,17 +182,23 @@ class AscendStreamAssign {
bool independent_stream_activated_{false};
bool hcom_stream_activated_{false};
bool loop_sink_{false};
// key:stream id, value:task nums;
std::map<uint32_t, uint32_t> independent_stream_map_{};
std::map<uint32_t, uint32_t> hcom_stream_map_{};
std::map<uint32_t, uint32_t> common_stream_map_{};
std::set<uint32_t> processed_streams_{};
std::vector<uint32_t> need_first_active_streams_{};
std::set<CNodeKey> independent_targets_;
// key:graph id, value:stream set
std::map<uint32_t, std::set<uint32_t>> hcom_graph_map_;
std::map<uint32_t, std::set<uint32_t>> independent_graph_map_;
// attr for memory copy reuse
std::map<uint32_t, std::vector<uint32_t>> stream_relations_{};
std::vector<std::vector<uint32_t>> stream_groups_{};
std::map<CNodePtr, CNodePtr> event_map_;
std::map<CNodePtr, CNodePtr> event_map_{};
std::set<uint32_t> middle_active_streams_{};
// new policy end
};
} // namespace ascend
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册