提交 044b2146 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!2490 assign hcom op to one stream

Merge pull request !2490 from gukecai/new-stream-for-commit
......@@ -426,23 +426,25 @@ bool AscendKernelRuntime::GenTask(const session::KernelGraph *graph) {
return true;
}
AscendStreamAssign &assign_instance = AscendStreamAssign::GetInstance();
AscendStreamMng &stream_manager = AscendStreamMng::GetInstance();
AscendResourceMng &resource_manager = AscendResourceMng::GetInstance();
AscendLabelAssign &label_assign_instance = AscendLabelAssign::GetInstance();
// the streams' flag not HEAD_STREAM
std::vector<uint32_t> wait_active_stream_list;
assign_instance.GetWaitStreams(&wait_active_stream_list);
std::vector<uint32_t> force_copy_stream_list;
assign_instance.GetHcomStreams(&force_copy_stream_list);
MS_LOG(INFO) << "call DavinciModel total stream num:" << stream_manager.GetCurAllocStreamNum()
<< ", total event num:" << assign_instance.total_event_num()
MS_LOG(INFO) << "call DavinciModel total stream num:" << resource_manager.get_cur_stream_num()
<< ", total event num:" << resource_manager.get_cur_event_num()
<< ", total label num:" << label_assign_instance.GetLabelNum(NOT_NULL(graph))
<< ", wait_active_stream_list size:" << wait_active_stream_list.size()
<< ", force_copy_stream_list size:" << force_copy_stream_list.size();
std::vector<std::shared_ptr<ge::model_runner::OpInfo>> empty_list;
std::shared_ptr<ge::model_runner::DavinciModel> model = std::make_shared<ge::model_runner::DavinciModel>(
task_info_list, empty_list, empty_list, empty_list, empty_list, wait_active_stream_list, force_copy_stream_list, 0,
0, 0, 0, 0, 0, stream_manager.GetCurAllocStreamNum(), label_assign_instance.GetLabelNum(NOT_NULL(graph)),
assign_instance.total_event_num(), 0);
0, 0, 0, 0, 0, resource_manager.get_cur_stream_num(), label_assign_instance.GetLabelNum(NOT_NULL(graph)),
resource_manager.get_cur_event_num(), 0);
auto ret = graph_model_map_.insert(std::make_pair(graph->graph_id(), model));
if (!ret.second) {
MS_LOG(EXCEPTION) << "Duplicate GraphId! Please check in ascend_session.";
......
......@@ -29,6 +29,7 @@
#include "runtime/rt_model.h"
#include "runtime/stream.h"
#include "session/kernel_graph.h"
#include "utils/contract.h"
namespace mindspore {
namespace device {
......@@ -38,35 +39,59 @@ using std::shared_ptr;
using std::unordered_map;
using std::unordered_set;
using std::vector;
using CnodeKey = void *;
const uint32_t kInvalidStreamId = UINT32_MAX;
class AscendStreamMng {
const uint32_t kInvalidEventId = UINT32_MAX;
class AscendResourceMng {
public:
static AscendStreamMng &GetInstance() {
static AscendStreamMng instance;
static AscendResourceMng &GetInstance() {
static AscendResourceMng instance;
return instance;
}
void Reset() {
cur_stream_id = 0;
cur_stream_num = 0;
void ResetResource() {
cur_stream_num_ = 0;
cur_event_num_ = 0;
}
uint32_t ApplyNewStream() {
if (!cur_stream_num) {
cur_stream_num++;
if (!cur_stream_num_) {
uint32_t cur_stream_id = cur_stream_num_;
cur_stream_num_++;
return cur_stream_id;
}
cur_stream_num++;
cur_stream_id++;
uint32_t cur_stream_id = cur_stream_num_;
cur_stream_num_++;
return cur_stream_id;
}
uint32_t ApplyNewEvent() {
if (!cur_event_num_) {
uint32_t cur_event_id = cur_event_num_;
cur_event_num_++;
return cur_event_id;
}
uint32_t cur_event_id = cur_event_num_;
cur_event_num_++;
return cur_event_id;
}
uint32_t GetCurAllocStream() { return cur_stream_id; }
uint32_t GetCurAllocStreamNum() { return cur_stream_num; }
void DeleteEvent() {
if (!cur_event_num_) {
MS_LOG(WARNING) << "total event num is 0, no event to delete";
} else {
--cur_event_num_;
}
}
uint32_t get_cur_stream_num() { return cur_stream_num_; }
uint32_t GetCurAllocStreamId() {
if (!cur_stream_num_) {
MS_LOG(EXCEPTION) << "stream nums is 0, no stream id should be get";
}
return cur_stream_num_ - 1;
}
uint32_t get_cur_event_num() { return cur_event_num_; }
private:
uint32_t cur_stream_num{0};
uint32_t cur_stream_id{0};
uint32_t cur_stream_num_{0};
uint32_t cur_event_num_{0};
};
class AscendStreamAssign {
......@@ -79,39 +104,42 @@ class AscendStreamAssign {
AscendStreamAssign(const AscendStreamAssign &) = delete;
AscendStreamAssign &operator=(const AscendStreamAssign &) = delete;
uint32_t total_event_num() const { return total_event_num_; }
void AssignStream(const NotNull<KernelGraphPtr> &graph_ptr);
void GetHcomStreams(std::vector<uint32_t> *streams);
void AssignStream(const std::shared_ptr<session::KernelGraph> &graph_ptr);
void GetWaitStreams(vector<uint32_t> *wait_active_stream_list);
CNodePtr CreateSendApplyKernel(const std::shared_ptr<session::KernelGraph> &graph_ptr, uint32_t event_id,
uint32_t stream_id);
CNodePtr CreateRecvApplyKernel(const std::shared_ptr<session::KernelGraph> &graph_ptr, uint32_t event_id,
uint32_t stream_id);
CNodePtr CreateSendApplyKernel(const NotNull<KernelGraphPtr> &graph_ptr, uint32_t event_id, uint32_t stream_id);
CNodePtr CreateRecvApplyKernel(const NotNull<KernelGraphPtr> &graph_ptr, uint32_t event_id, uint32_t stream_id);
private:
AscendStreamAssign() = default;
~AscendStreamAssign() = default;
void Reset();
void CheckStreamAssign(const std::shared_ptr<session::KernelGraph> &graph_ptr);
void AssignAllNodesStream(const std::shared_ptr<session::KernelGraph> &graph_ptr);
void AssignCommonStreamId(const CNodePtr &cur_cnode_ptr, CNodePtr *pre_cnode_ptr, uint32_t *cur_index,
uint32_t *cur_stream_id);
void CheckResourceAssign(const NotNull<KernelGraphPtr> &graph_ptr);
void CheckStreamAssign(const NotNull<KernelGraphPtr> &graph_ptr);
void CheckEventAssign(const NotNull<KernelGraphPtr> &graph_ptr);
void AssignAllNodesStream(const NotNull<KernelGraphPtr> &graph_ptr);
void AssignCommonStreamId(const CNodePtr &cur_cnode_ptr);
void AssignHcomStreamId(const CNodePtr &cur_cnode_ptr);
void AssignIndependentStreamId(const CNodePtr &cur_cnode_ptr);
void UpdateAtomicAddrCleanStreamId(const std::shared_ptr<session::KernelGraph> &graph_ptr);
void FindHcomParallelStreams(const std::shared_ptr<session::KernelGraph> &graph_ptr);
void InsertStreamActive(const std::shared_ptr<session::KernelGraph> &graph_ptr);
void UpdateStreamSwitch(const std::shared_ptr<session::KernelGraph> &graph_ptr, const CNodePtr &switch_ptr,
const vector<uint32_t> &independent_stream, vector<CNodePtr> *orders);
void InsertSendRecvForIndependent(const std::shared_ptr<session::KernelGraph> &graph_ptr);
void InsertSendRecvForHcomParallel(const std::shared_ptr<session::KernelGraph> &graph_ptr);
void InsertSendRecvForDiffHcom(const shared_ptr<mindspore::session::KernelGraph> &graph_ptr);
void UpdateEventId(const std::shared_ptr<session::KernelGraph> &graph_ptr);
void GetNeedActiveStreams(const std::shared_ptr<session::KernelGraph> &graph_ptr);
void ReorderIndependentOrders(const std::shared_ptr<session::KernelGraph> &graph_ptr);
void UpdateAtomicAddrCleanStreamId(const NotNull<KernelGraphPtr> &graph_ptr);
void FindHcomParallelStreams(const NotNull<KernelGraphPtr> &graph_ptr);
void InsertStreamActive(const NotNull<KernelGraphPtr> &graph_ptr);
void UpdateStreamSwitch(const NotNull<KernelGraphPtr> &graph_ptr, const CNodePtr &switch_ptr,
vector<CNodePtr> *orders);
void InsertEventForIndependentParallel(const NotNull<KernelGraphPtr> &graph_ptr);
void InsertEventForHcomParallel(const NotNull<KernelGraphPtr> &graph_ptr);
void InsertEventCommonDependHcom(const NotNull<KernelGraphPtr> &graph_ptr);
void InsertEventHcomDependCommon(const NotNull<KernelGraphPtr> &graph_ptr);
void InsertEventHcomDependHcom(const NotNull<KernelGraphPtr> &graph_ptr);
void InsertEventBetweenHcom(const NotNull<KernelGraphPtr> &graph_ptr, const map<uint32_t, vector<size_t>> &hcom_index,
uint32_t first_hcom_stream, uint32_t last_hcom_stream);
bool IsSatisfiedHcom(const std::map<uint32_t, vector<size_t>> &hcom_index, const CNodePtr &node_ptr, size_t index);
void GetProcessedStream(const NotNull<KernelGraphPtr> &graph_ptr);
void GetNeedActiveStreams(const NotNull<KernelGraphPtr> &graph_ptr);
void ReorderIndependentOrders(const NotNull<KernelGraphPtr> &graph_ptr);
bool IsTaskSink();
bool IsFusionHcom(const CNodePtr &cur_cnode_ptr);
bool IsHcom(const CNodePtr &cur_cnode_ptr);
bool IsIndependentNode(const CNodePtr &node_ptr);
bool IsProcessedStream(uint32_t stream_id);
......@@ -119,14 +147,13 @@ class AscendStreamAssign {
const CNodePtr &node);
void GetParallelStream(uint32_t cur_stream_id, uint32_t stream_acitve_id, std::vector<uint32_t> *parallel_streams);
uint32_t total_event_num_{0};
bool independent_stream_activated_{false};
bool hcom_stream_activated_{false};
std::map<uint32_t, uint32_t> independent_stream_map_{};
std::map<uint32_t, uint32_t> hcom_stream_map_{};
std::map<uint32_t, uint32_t> common_stream_map_{};
std::set<uint32_t> processed_streams_{};
std::set<uint32_t> hcom_stream_list_{};
std::vector<uint32_t> need_first_active_streams_{};
std::vector<std::vector<uint32_t>> inner_parallel_streams_{};
// new policy end
};
} // namespace ascend
......
......@@ -103,8 +103,8 @@ CNodePtr KernelAdjust::CreateRecvApplyKernel(const std::shared_ptr<session::Kern
}
void KernelAdjust::InsertSwitchLoop(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr) {
device::ascend::AscendStreamMng &stream_manager = device::ascend::AscendStreamMng::GetInstance();
stream_manager.Reset();
device::ascend::AscendResourceMng &resource_manager = device::ascend::AscendResourceMng::GetInstance();
resource_manager.ResetResource();
if (!NeedInsertSwitch()) {
return;
}
......@@ -135,17 +135,16 @@ void KernelAdjust::InsertSwitchLoop(const std::shared_ptr<session::KernelGraph>
}
std::vector<CNodePtr> exec_order;
// getnext loop process
// getnext loop stream switch op
CNodePtr getnext_switch_app = CreateStreamSwitchOp(kernel_graph_ptr, switch_loop_input);
MS_EXCEPTION_IF_NULL(getnext_switch_app);
uint32_t getnext_switch_stream_id = stream_manager.ApplyNewStream();
uint32_t getnext_switch_stream_id = resource_manager.ApplyNewStream();
AnfAlgo::SetStreamId(getnext_switch_stream_id, getnext_switch_app.get());
exec_order.push_back(getnext_switch_app);
// getnext op
uint32_t getnext_stream_id = stream_manager.ApplyNewStream();
uint32_t getnext_stream_id = resource_manager.ApplyNewStream();
size_t i = 0;
for (; i < orders.size(); i++) {
auto node = orders[i];
......@@ -160,7 +159,8 @@ void KernelAdjust::InsertSwitchLoop(const std::shared_ptr<session::KernelGraph>
AnfAlgo::SetNodeAttr(kAttrTrueBranchStream, MakeValue<uint32_t>(getnext_stream_id), getnext_switch_app);
// getnext loop send
CNodePtr send = CreateSendApplyKernel(kernel_graph_ptr, kFirstEventId);
uint32_t getnext_event_id = resource_manager.ApplyNewEvent();
CNodePtr send = CreateSendApplyKernel(kernel_graph_ptr, getnext_event_id);
AnfAlgo::SetStreamId(getnext_stream_id, send.get());
exec_order.push_back(send);
......@@ -168,14 +168,14 @@ void KernelAdjust::InsertSwitchLoop(const std::shared_ptr<session::KernelGraph>
// fpbp loop stream switch
CNodePtr fpbp_switch_app = CreateStreamSwitchOp(kernel_graph_ptr, switch_loop_input);
MS_EXCEPTION_IF_NULL(fpbp_switch_app);
uint32_t fpbp_switch_stream_id = stream_manager.ApplyNewStream();
uint32_t fpbp_switch_stream_id = resource_manager.ApplyNewStream();
AnfAlgo::SetStreamId(fpbp_switch_stream_id, fpbp_switch_app.get());
AnfAlgo::SetNodeAttr(kStreamNeedActivedFirst, MakeValue<bool>(true), fpbp_switch_app);
exec_order.push_back(fpbp_switch_app);
// fpbp loop recv
CNodePtr recv = CreateRecvApplyKernel(kernel_graph_ptr, kFirstEventId);
uint32_t fpbp_stream_id = stream_manager.ApplyNewStream();
CNodePtr recv = CreateRecvApplyKernel(kernel_graph_ptr, getnext_event_id);
uint32_t fpbp_stream_id = resource_manager.ApplyNewStream();
AnfAlgo::SetStreamId(fpbp_stream_id, recv.get());
exec_order.push_back(recv);
......
......@@ -38,12 +38,8 @@ constexpr auto kIterLoopParamName = "iter_loop";
constexpr auto kZeroParamName = "zero";
constexpr auto kOneParamName = "one";
constexpr auto kStreamNeedActivedFirst = "stream_need_active_first";
constexpr uint32_t kSecondStreamSwitchLabel = 2;
const uint32_t kFirstStreamSwitchLabel = 0;
const uint32_t kGetNextLabel = 1;
const uint32_t kSecondStreamSwitchLabel = 2;
const uint32_t kInvalidEventId = UINT32_MAX;
const uint32_t kFirstEventId = kInvalidEventId / 2;
namespace device {
class KernelAdjust {
public:
......
......@@ -305,7 +305,7 @@ GraphId AscendSession::CompileGraph(NotNull<FuncGraphPtr> func_graph) {
// adjust kernel
AdjustKernel(root_graph);
// assign stream
AssignStream(root_graph);
AssignStream(NOT_NULL(root_graph));
// insert profiling point
device::KernelAdjust::GetInstance().Profiling(NOT_NULL(root_graph.get()));
// build kernel
......@@ -377,7 +377,7 @@ void AscendSession::BuildGraph(GraphId graph_id) {
// adjust execution order because merge child graph and other special operations
AdjustKernel(graph);
// Assign streams for control sink and hccl and so on
AssignStream(graph);
AssignStream(NOT_NULL(graph));
device::KernelAdjust::GetInstance().Profiling(NOT_NULL(graph.get()));
// build kernel if node is cnode
......@@ -647,7 +647,7 @@ void AscendSession::RunOpAdjustKernel(const std::shared_ptr<KernelGraph> &kernel
MS_LOG(INFO) << "Finish!";
}
void AscendSession::AssignStream(const std::shared_ptr<KernelGraph> &kernel_graph) const {
void AscendSession::AssignStream(NotNull<KernelGraphPtr> kernel_graph) const {
MS_LOG(INFO) << "Start!";
device::ascend::AscendStreamAssign::GetInstance().AssignStream(kernel_graph);
MS_LOG(INFO) << "Finish!";
......
......@@ -76,7 +76,7 @@ class AscendSession : public SessionBasic {
void HardwareOptimize(const std::shared_ptr<KernelGraph> &kernel_graph) const;
void AdjustKernel(const std::shared_ptr<KernelGraph> &kernel_graph) const;
void RunOpAdjustKernel(const std::shared_ptr<KernelGraph> &kernel_graph) const;
void AssignStream(const std::shared_ptr<KernelGraph> &kernel_graph) const;
void AssignStream(NotNull<KernelGraphPtr> kernel_graph) const;
void AssignLabel(NotNull<KernelGraphPtr> kernel_graph) const;
void BuildKernel(const std::shared_ptr<KernelGraph> &kernel_graph) const;
void MemoryAlloc(KernelGraph *kernel_graph) const;
......
......@@ -26,7 +26,7 @@ void AscendLabelAssign::AssignLabel(NotNull<std::shared_ptr<session::KernelGraph
uint32_t AscendLabelAssign::GetLabelNum(NotNull<const session::KernelGraph *> graph) { return 1; }
uint32_t AscendLabelAssign::GetLabelNum(NotNull<std::shared_ptr<session::KernelGraph>> graph) { return 1; }
void AscendStreamAssign::AssignStream(const KernelGraphPtr &graph) { return; }
void AscendStreamAssign::AssignStream(const NotNull<KernelGraphPtr> &graph_ptr) { return; }
void AscendStreamAssign::GetWaitStreams(vector<uint32_t> *wait_active_stream_list) { return; }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册