提交 f8208c7c 编写于 作者: G gukecai

Support GetNext Parallel

上级 27a88a6b
...@@ -283,18 +283,19 @@ bool AscendKernelRuntime::GenTask(const session::KernelGraph *graph) { ...@@ -283,18 +283,19 @@ bool AscendKernelRuntime::GenTask(const session::KernelGraph *graph) {
AscendStreamAssign &assign_instance = AscendStreamAssign::GetInstance(); AscendStreamAssign &assign_instance = AscendStreamAssign::GetInstance();
// the streams' flag not HEAD_STREAM // the streams' flag not HEAD_STREAM
std::vector<uint32_t> wait_active_stream_list = assign_instance.GetWaitStreams(); std::vector<uint32_t> wait_active_stream_list;
std::vector<uint32_t> force_copy_stream_list = assign_instance.GetHcomStreams(); assign_instance.GetWaitStreams(&wait_active_stream_list);
auto force_copy_stream_list = assign_instance.hcom_streams();
MS_LOG(INFO) << "call DavinciModel total stream num:" << assign_instance.GetTotalStreamNum() MS_LOG(INFO) << "call DavinciModel total stream num:" << assign_instance.GetTotalStreamNum()
<< ", total event num:" << assign_instance.GetTotalEventNum() << ", total event num:" << assign_instance.total_event_num()
<< ", wait_active_stream_list size:" << wait_active_stream_list.size() << ", wait_active_stream_list size:" << wait_active_stream_list.size()
<< ", force_copy_stream_list size:" << force_copy_stream_list.size(); << ", force_copy_stream_list size:" << force_copy_stream_list.size();
std::vector<std::shared_ptr<ge::model_runner::OpInfo>> empty_list; std::vector<std::shared_ptr<ge::model_runner::OpInfo>> empty_list;
std::shared_ptr<ge::model_runner::DavinciModel> model = std::make_shared<ge::model_runner::DavinciModel>( std::shared_ptr<ge::model_runner::DavinciModel> model = std::make_shared<ge::model_runner::DavinciModel>(
task_info_list, empty_list, empty_list, empty_list, empty_list, wait_active_stream_list, force_copy_stream_list, 0, task_info_list, empty_list, empty_list, empty_list, empty_list, wait_active_stream_list, force_copy_stream_list, 0,
0, 0, 0, 0, 0, assign_instance.GetTotalStreamNum(), 1, assign_instance.GetTotalEventNum(), 0); 0, 0, 0, 0, 0, assign_instance.GetTotalStreamNum(), 1, assign_instance.total_event_num(), 0);
auto ret = graph_model_map_.insert(std::make_pair(graph->graph_id(), model)); auto ret = graph_model_map_.insert(std::make_pair(graph->graph_id(), model));
if (!ret.second) { if (!ret.second) {
......
...@@ -25,8 +25,8 @@ ...@@ -25,8 +25,8 @@
#include "session/anf_runtime_algorithm.h" #include "session/anf_runtime_algorithm.h"
#include "device/kernel_adjust.h" #include "device/kernel_adjust.h"
#include "predict/generator/utils/ir_model_util.h" #include "predict/generator/utils/ir_model_util.h"
#include "device/kernel_info.h"
#include "pre_activate/common/helper.h" #include "pre_activate/common/helper.h"
#include "utils/utils.h"
namespace mindspore { namespace mindspore {
namespace device { namespace device {
...@@ -54,6 +54,7 @@ void AscendStreamAssign::ResetNew() { ...@@ -54,6 +54,7 @@ void AscendStreamAssign::ResetNew() {
inner_parallel_streams_.clear(); inner_parallel_streams_.clear();
processed_parallel_streams_.clear(); processed_parallel_streams_.clear();
hcom_stream_list_.clear(); hcom_stream_list_.clear();
need_first_active_streams_.clear();
} }
void AscendStreamAssign::AssignIndependentStreamId(const CNodePtr &cur_cnode_ptr, uint32_t processing_logic_id) { void AscendStreamAssign::AssignIndependentStreamId(const CNodePtr &cur_cnode_ptr, uint32_t processing_logic_id) {
...@@ -200,13 +201,12 @@ void AscendStreamAssign::AssignAllNodesStream(const shared_ptr<session::KernelGr ...@@ -200,13 +201,12 @@ void AscendStreamAssign::AssignAllNodesStream(const shared_ptr<session::KernelGr
MS_LOG(INFO) << "stream nums:common:" << total_common_stream_num_ << ",independ:" << total_independ_stream_num_; MS_LOG(INFO) << "stream nums:common:" << total_common_stream_num_ << ",independ:" << total_independ_stream_num_;
} }
vector<uint32_t> AscendStreamAssign::TransLogicToPhysic(const vector<uint32_t> &logic_ids) { void AscendStreamAssign::TransLogicToPhysic(const vector<uint32_t> &logic_ids, vector<uint32_t> *physic_ids) {
vector<uint32_t> physic_ids;
for (auto &id : logic_ids) { for (auto &id : logic_ids) {
auto it = logic_to_physic_map_.find(id); auto it = logic_to_physic_map_.find(id);
if (it != logic_to_physic_map_.end()) { if (it != logic_to_physic_map_.end()) {
MS_LOG(INFO) << "logic id[" << id << "] to physic id[" << it->second << "]"; MS_LOG(INFO) << "logic id[" << id << "] to physic id[" << it->second << "]";
physic_ids.push_back(it->second); (*physic_ids).push_back(it->second);
} else { } else {
MS_LOG(EXCEPTION) << "logic id[" << id << "] has no correspond physic id"; MS_LOG(EXCEPTION) << "logic id[" << id << "] has no correspond physic id";
} }
...@@ -214,10 +214,9 @@ vector<uint32_t> AscendStreamAssign::TransLogicToPhysic(const vector<uint32_t> & ...@@ -214,10 +214,9 @@ vector<uint32_t> AscendStreamAssign::TransLogicToPhysic(const vector<uint32_t> &
auto it_independ = logic_to_independent_map_.find(id); auto it_independ = logic_to_independent_map_.find(id);
if (it_independ != logic_to_independent_map_.end()) { if (it_independ != logic_to_independent_map_.end()) {
MS_LOG(INFO) << "logic id[" << id << "] to independent id[" << it_independ->second << "]"; MS_LOG(INFO) << "logic id[" << id << "] to independent id[" << it_independ->second << "]";
physic_ids.push_back(it_independ->second); (*physic_ids).push_back(it_independ->second);
} }
} }
return physic_ids;
} }
void AscendStreamAssign::UpdateStreamActive(const CNodePtr &active_ptr) { void AscendStreamAssign::UpdateStreamActive(const CNodePtr &active_ptr) {
...@@ -227,7 +226,8 @@ void AscendStreamAssign::UpdateStreamActive(const CNodePtr &active_ptr) { ...@@ -227,7 +226,8 @@ void AscendStreamAssign::UpdateStreamActive(const CNodePtr &active_ptr) {
MS_EXCEPTION_IF_NULL(primitive); MS_EXCEPTION_IF_NULL(primitive);
vector<uint32_t> active_logic_ids = GetValue<std::vector<uint32_t>>(primitive->GetAttr(kAttrActiveStreamList)); vector<uint32_t> active_logic_ids = GetValue<std::vector<uint32_t>>(primitive->GetAttr(kAttrActiveStreamList));
// out StreamAcitve active physic stream is not parallel now, if parallel, should deal here. // out StreamAcitve active physic stream is not parallel now, if parallel, should deal here.
vector<uint32_t> active_physic_ids = TransLogicToPhysic(active_logic_ids); vector<uint32_t> active_physic_ids;
TransLogicToPhysic(active_logic_ids, &active_physic_ids);
ValuePtr active_physic_value = MakeValue<std::vector<uint32_t>>(active_physic_ids); ValuePtr active_physic_value = MakeValue<std::vector<uint32_t>>(active_physic_ids);
AnfAlgo::SetNodeAttr(kAttrActiveStreamList, active_physic_value, active_ptr); AnfAlgo::SetNodeAttr(kAttrActiveStreamList, active_physic_value, active_ptr);
} }
...@@ -242,7 +242,8 @@ void AscendStreamAssign::UpdateStreamSwitch(const CNodePtr &switch_ptr, const CN ...@@ -242,7 +242,8 @@ void AscendStreamAssign::UpdateStreamSwitch(const CNodePtr &switch_ptr, const CN
MS_LOG(INFO) << "streamswtich stream id[" << AnfAlgo::GetStreamId(switch_ptr) << "], true_logic_id[" << true_logic_id MS_LOG(INFO) << "streamswtich stream id[" << AnfAlgo::GetStreamId(switch_ptr) << "], true_logic_id[" << true_logic_id
<< "]"; << "]";
vector<uint32_t> logic_ids{true_logic_id}; vector<uint32_t> logic_ids{true_logic_id};
vector<uint32_t> physic_ids = TransLogicToPhysic(logic_ids); vector<uint32_t> physic_ids;
TransLogicToPhysic(logic_ids, &physic_ids);
if (physic_ids.empty()) { if (physic_ids.empty()) {
MS_LOG(EXCEPTION) << "stream switch true logic id[" << true_logic_id << "] has no physical id"; MS_LOG(EXCEPTION) << "stream switch true logic id[" << true_logic_id << "] has no physical id";
} }
...@@ -334,8 +335,8 @@ bool AscendStreamAssign::IsProcessedParallelStream(uint32_t stream_id) { ...@@ -334,8 +335,8 @@ bool AscendStreamAssign::IsProcessedParallelStream(uint32_t stream_id) {
return false; return false;
} }
vector<uint32_t> AscendStreamAssign::GetParallelStream(uint32_t cur_stream_id, uint32_t stream_acitve_id) { void AscendStreamAssign::GetParallelStream(uint32_t cur_stream_id, uint32_t stream_acitve_id,
vector<uint32_t> parallel_streams; vector<uint32_t> *parallel_streams) {
for (size_t i = 0; i < inner_parallel_streams_.size(); i++) { for (size_t i = 0; i < inner_parallel_streams_.size(); i++) {
auto cur_parallel_streams = inner_parallel_streams_[i]; auto cur_parallel_streams = inner_parallel_streams_[i];
auto it = std::find(cur_parallel_streams.begin(), cur_parallel_streams.end(), cur_stream_id); auto it = std::find(cur_parallel_streams.begin(), cur_parallel_streams.end(), cur_stream_id);
...@@ -347,17 +348,17 @@ vector<uint32_t> AscendStreamAssign::GetParallelStream(uint32_t cur_stream_id, u ...@@ -347,17 +348,17 @@ vector<uint32_t> AscendStreamAssign::GetParallelStream(uint32_t cur_stream_id, u
<< "is same with streamacvite stream id" << stream_acitve_id; << "is same with streamacvite stream id" << stream_acitve_id;
continue; continue;
} }
parallel_streams.emplace_back(cur_parallel_streams[j]); (*parallel_streams).emplace_back(cur_parallel_streams[j]);
} }
// record processed parallel streams // record processed parallel streams
(void)std::copy(parallel_streams.begin(), parallel_streams.end(), (void)std::copy((*parallel_streams).begin(), (*parallel_streams).end(),
std::back_inserter(processed_parallel_streams_)); std::back_inserter(processed_parallel_streams_));
return parallel_streams; return;
} }
} }
return vector<uint32_t>{cur_stream_id}; (*parallel_streams).push_back(cur_stream_id);
} }
void AscendStreamAssign::InsertActiveNew(const std::shared_ptr<session::KernelGraph> &graph_ptr) { void AscendStreamAssign::InsertActiveNew(const std::shared_ptr<session::KernelGraph> &graph_ptr) {
...@@ -379,30 +380,32 @@ void AscendStreamAssign::InsertActiveNew(const std::shared_ptr<session::KernelGr ...@@ -379,30 +380,32 @@ void AscendStreamAssign::InsertActiveNew(const std::shared_ptr<session::KernelGr
} }
bool inner_active = pre_stream_id != cur_stream_id && pre_stream_id < cur_stream_id && bool inner_active = pre_stream_id != cur_stream_id && pre_stream_id < cur_stream_id &&
AnfAlgo::GetCNodeName(pre_cnode_ptr) != "StreamSwitch" && AnfAlgo::GetCNodeName(pre_cnode_ptr) != kStreamSwitchOpName &&
AnfAlgo::GetCNodeName(pre_cnode_ptr) != "StreamActive"; AnfAlgo::GetCNodeName(pre_cnode_ptr) != kStreamActiveOpName &&
AnfAlgo::GetCNodeName(pre_cnode_ptr) != kSendOpName;
bool processed = IsProcessedParallelStream(cur_stream_id); bool processed = IsProcessedParallelStream(cur_stream_id);
// 1)inner stream assign, need insert active op // 1)inner stream assign, need insert active op
if (inner_active && !processed) { if (inner_active && !processed) {
MS_LOG(INFO) << "Inner insert active op, self stream id[" << pre_stream_id << "]"; MS_LOG(INFO) << "Inner insert active op, self stream id[" << pre_stream_id << "]";
CNodePtr active_ptr = KernelAdjust::GetInstance().CreateSteamActiveOp(graph_ptr); CNodePtr active_ptr = KernelAdjust::GetInstance().CreateStreamActiveOp(graph_ptr);
update_cnode_list.emplace_back(active_ptr); update_cnode_list.emplace_back(active_ptr);
update_cnode_list.emplace_back(cur_cnode_ptr);
// 1.set stream id // 1.set stream id
AnfAlgo::SetStreamId(pre_stream_id, active_ptr.get()); AnfAlgo::SetStreamId(pre_stream_id, active_ptr.get());
// 2.set active stream ids // 2.set active stream ids
vector<uint32_t> active_index_list = GetParallelStream(cur_stream_id, pre_stream_id); std::vector<uint32_t> active_index_list;
GetParallelStream(cur_stream_id, pre_stream_id, &active_index_list);
AnfAlgo::SetNodeAttr(kAttrActiveStreamList, MakeValue<std::vector<uint32_t>>(active_index_list), active_ptr); AnfAlgo::SetNodeAttr(kAttrActiveStreamList, MakeValue<std::vector<uint32_t>>(active_index_list), active_ptr);
} else if (AnfAlgo::GetCNodeName(cur_cnode_ptr) == "StreamActive" && }
// inner_active is not a if/else relationship with the next if/else. such as:StreamActive(S7)-->StreamActive(S8)
if (AnfAlgo::GetCNodeName(cur_cnode_ptr) == kStreamActiveOpName &&
AnfAlgo::GetStreamDistinctionLabel(cur_cnode_ptr.get()) != UINT32_MAX) { AnfAlgo::GetStreamDistinctionLabel(cur_cnode_ptr.get()) != UINT32_MAX) {
// 2)outter stream assign, update active op // 2)outter stream assign, update active op
update_cnode_list.emplace_back(cur_cnode_ptr); update_cnode_list.emplace_back(cur_cnode_ptr);
UpdateStreamActive(cur_cnode_ptr); UpdateStreamActive(cur_cnode_ptr);
} else if (AnfAlgo::GetCNodeName(cur_cnode_ptr) == "StreamSwitch") { } else if (AnfAlgo::GetCNodeName(cur_cnode_ptr) == kStreamSwitchOpName) {
// 3)update switch op // 3)update switch op
MS_LOG(INFO) << "Insert active op after switch"; MS_LOG(INFO) << "Insert active op after switch";
CNodePtr active_ptr = KernelAdjust::GetInstance().CreateSteamActiveOp(graph_ptr); CNodePtr active_ptr = KernelAdjust::GetInstance().CreateStreamActiveOp(graph_ptr);
update_cnode_list.emplace_back(cur_cnode_ptr); update_cnode_list.emplace_back(cur_cnode_ptr);
update_cnode_list.emplace_back(active_ptr); update_cnode_list.emplace_back(active_ptr);
UpdateStreamSwitch(cur_cnode_ptr, active_ptr); UpdateStreamSwitch(cur_cnode_ptr, active_ptr);
...@@ -417,6 +420,37 @@ void AscendStreamAssign::InsertActiveNew(const std::shared_ptr<session::KernelGr ...@@ -417,6 +420,37 @@ void AscendStreamAssign::InsertActiveNew(const std::shared_ptr<session::KernelGr
MS_LOG(INFO) << "end"; MS_LOG(INFO) << "end";
} }
void AscendStreamAssign::UpdateEventId(const shared_ptr<session::KernelGraph> &graph_ptr) {
MS_LOG(INFO) << "start";
MS_EXCEPTION_IF_NULL(graph_ptr);
CNodePtr cur_cnode_ptr = nullptr;
// key:virutal event id, value:real event id
std::unordered_map<uint32_t, uint32_t> event_id_map;
uint32_t event_id;
auto cnode_ptr_list = graph_ptr->execution_order();
for (size_t i = 0; i < cnode_ptr_list.size(); ++i) {
cur_cnode_ptr = cnode_ptr_list[i];
MS_EXCEPTION_IF_NULL(cur_cnode_ptr);
if (AnfAlgo::GetCNodeName(cur_cnode_ptr) == kSendOpName || AnfAlgo::GetCNodeName(cur_cnode_ptr) == kRecvOpName) {
auto primitive = AnfAlgo::GetCNodePrimitive(cur_cnode_ptr);
MS_EXCEPTION_IF_NULL(primitive);
event_id = GetValue<uint32_t>(primitive->GetAttr(kAttrEventId));
// before stream assign, send/recv event_id assign from kFirstEventId
if (event_id < kFirstEventId) {
continue;
}
auto it = event_id_map.find(event_id);
if (it == event_id_map.end()) {
event_id_map.insert(std::make_pair(event_id, total_event_num_));
AnfAlgo::SetNodeAttr(kAttrEventId, MakeValue<uint32_t>(total_event_num_), cur_cnode_ptr);
total_event_num_++;
} else {
AnfAlgo::SetNodeAttr(kAttrEventId, MakeValue<uint32_t>(it->second), cur_cnode_ptr);
}
}
}
}
void AscendStreamAssign::UpdateStreamId(const shared_ptr<session::KernelGraph> &graph_ptr) { void AscendStreamAssign::UpdateStreamId(const shared_ptr<session::KernelGraph> &graph_ptr) {
MS_LOG(INFO) << "start"; MS_LOG(INFO) << "start";
MS_EXCEPTION_IF_NULL(graph_ptr); MS_EXCEPTION_IF_NULL(graph_ptr);
...@@ -427,7 +461,7 @@ void AscendStreamAssign::UpdateStreamId(const shared_ptr<session::KernelGraph> & ...@@ -427,7 +461,7 @@ void AscendStreamAssign::UpdateStreamId(const shared_ptr<session::KernelGraph> &
MS_EXCEPTION_IF_NULL(cur_cnode_ptr); MS_EXCEPTION_IF_NULL(cur_cnode_ptr);
uint32_t cur_stream_id = AnfAlgo::GetStreamId(cur_cnode_ptr); uint32_t cur_stream_id = AnfAlgo::GetStreamId(cur_cnode_ptr);
if (cur_stream_id < kIndependFirstStreamId) { if (cur_stream_id < kIndependFirstStreamId) {
if (AnfAlgo::GetCNodeName(cur_cnode_ptr) == "StreamActive") { if (AnfAlgo::GetCNodeName(cur_cnode_ptr) == kStreamActiveOpName) {
auto primitive = AnfAlgo::GetCNodePrimitive(cur_cnode_ptr); auto primitive = AnfAlgo::GetCNodePrimitive(cur_cnode_ptr);
MS_EXCEPTION_IF_NULL(primitive); MS_EXCEPTION_IF_NULL(primitive);
vector<uint32_t> active_ids = GetValue<std::vector<uint32_t>>(primitive->GetAttr(kAttrActiveStreamList)); vector<uint32_t> active_ids = GetValue<std::vector<uint32_t>>(primitive->GetAttr(kAttrActiveStreamList));
...@@ -471,6 +505,29 @@ void AscendStreamAssign::UpdateStreamId(const shared_ptr<session::KernelGraph> & ...@@ -471,6 +505,29 @@ void AscendStreamAssign::UpdateStreamId(const shared_ptr<session::KernelGraph> &
MS_LOG(INFO) << "end"; MS_LOG(INFO) << "end";
} }
void AscendStreamAssign::GetNeedActiveStreams(const shared_ptr<session::KernelGraph> &graph_ptr) {
MS_EXCEPTION_IF_NULL(graph_ptr);
CNodePtr cur_cnode_ptr = nullptr;
auto cnode_ptr_list = graph_ptr->execution_order();
for (size_t i = 0; i < cnode_ptr_list.size(); ++i) {
cur_cnode_ptr = cnode_ptr_list[i];
MS_EXCEPTION_IF_NULL(cur_cnode_ptr);
auto primitive = AnfAlgo::GetCNodePrimitive(cur_cnode_ptr);
MS_EXCEPTION_IF_NULL(primitive);
auto value_ptr = primitive->GetAttr(kStreamNeedActivedFirst);
if (value_ptr == nullptr) {
continue;
}
auto need_active = GetValue<bool>(value_ptr);
if (need_active) {
auto stream_id = AnfAlgo::GetStreamId(cur_cnode_ptr);
MS_LOG(INFO) << "stream id:" << stream_id << " is need actived at first";
need_first_active_streams_.push_back(stream_id);
}
}
}
void AscendStreamAssign::AssignStreamNew(const shared_ptr<session::KernelGraph> &graph_ptr) { void AscendStreamAssign::AssignStreamNew(const shared_ptr<session::KernelGraph> &graph_ptr) {
if (IsTaskSink()) { if (IsTaskSink()) {
ResetNew(); ResetNew();
...@@ -480,13 +537,15 @@ void AscendStreamAssign::AssignStreamNew(const shared_ptr<session::KernelGraph> ...@@ -480,13 +537,15 @@ void AscendStreamAssign::AssignStreamNew(const shared_ptr<session::KernelGraph>
InsertSendRecvForHcomParallel(graph_ptr); InsertSendRecvForHcomParallel(graph_ptr);
InsertSendRecvForIndependent(graph_ptr); InsertSendRecvForIndependent(graph_ptr);
UpdateStreamId(graph_ptr); UpdateStreamId(graph_ptr);
UpdateEventId(graph_ptr);
GetNeedActiveStreams(graph_ptr);
MS_LOG(INFO) << "after finish stream assign"; MS_LOG(INFO) << "after finish stream assign";
PrintGraphExeOrders(graph_ptr); PrintGraphExeOrders(graph_ptr);
// Get info for D Model // Get info for D Model
generator::IRModelUtil::GetInstance().set_event_num(GetTotalEventNum()); generator::IRModelUtil::GetInstance().set_event_num(total_event_num());
generator::IRModelUtil::GetInstance().set_stream_num(GetTotalCommonStreamNum() + GetTotalIndependStreamNum()); generator::IRModelUtil::GetInstance().set_stream_num(total_common_stream_num() + total_independ_stream_num());
// Init to 1,temporarily // Init to 1,temporarily
generator::IRModelUtil::GetInstance().set_batch_num(1); generator::IRModelUtil::GetInstance().set_batch_num(1);
} }
...@@ -495,7 +554,7 @@ void AscendStreamAssign::AssignStreamNew(const shared_ptr<session::KernelGraph> ...@@ -495,7 +554,7 @@ void AscendStreamAssign::AssignStreamNew(const shared_ptr<session::KernelGraph>
CNodePtr AscendStreamAssign::CreateSendApplyKernel(const std::shared_ptr<session::KernelGraph> &graph_ptr, CNodePtr AscendStreamAssign::CreateSendApplyKernel(const std::shared_ptr<session::KernelGraph> &graph_ptr,
uint32_t event_id, uint32_t stream_id) { uint32_t event_id, uint32_t stream_id) {
MS_EXCEPTION_IF_NULL(graph_ptr); MS_EXCEPTION_IF_NULL(graph_ptr);
auto send_op = std::make_shared<Primitive>("Send"); auto send_op = std::make_shared<Primitive>(kSendOpName);
MS_EXCEPTION_IF_NULL(send_op); MS_EXCEPTION_IF_NULL(send_op);
auto send_apply = std::make_shared<ValueNode>(send_op); auto send_apply = std::make_shared<ValueNode>(send_op);
MS_EXCEPTION_IF_NULL(send_apply); MS_EXCEPTION_IF_NULL(send_apply);
...@@ -505,7 +564,7 @@ CNodePtr AscendStreamAssign::CreateSendApplyKernel(const std::shared_ptr<session ...@@ -505,7 +564,7 @@ CNodePtr AscendStreamAssign::CreateSendApplyKernel(const std::shared_ptr<session
kernel::KernelBuildInfo::KernelBuildInfoBuilder selected_kernel_builder; kernel::KernelBuildInfo::KernelBuildInfoBuilder selected_kernel_builder;
selected_kernel_builder.SetKernelType(KernelType::RT_KERNEL); selected_kernel_builder.SetKernelType(KernelType::RT_KERNEL);
AnfAlgo::SetSelectKernelBuildInfo(selected_kernel_builder.Build(), send_node_ptr.get()); AnfAlgo::SetSelectKernelBuildInfo(selected_kernel_builder.Build(), send_node_ptr.get());
AnfAlgo::SetNodeAttr("event_id", MakeValue(event_id), send_node_ptr); AnfAlgo::SetNodeAttr(kAttrEventId, MakeValue(event_id), send_node_ptr);
auto abstract_none = std::make_shared<abstract::AbstractNone>(); auto abstract_none = std::make_shared<abstract::AbstractNone>();
MS_EXCEPTION_IF_NULL(abstract_none); MS_EXCEPTION_IF_NULL(abstract_none);
send_node_ptr->set_abstract(abstract_none); send_node_ptr->set_abstract(abstract_none);
...@@ -516,7 +575,7 @@ CNodePtr AscendStreamAssign::CreateSendApplyKernel(const std::shared_ptr<session ...@@ -516,7 +575,7 @@ CNodePtr AscendStreamAssign::CreateSendApplyKernel(const std::shared_ptr<session
CNodePtr AscendStreamAssign::CreateRecvApplyKernel(const std::shared_ptr<session::KernelGraph> &graph_ptr, CNodePtr AscendStreamAssign::CreateRecvApplyKernel(const std::shared_ptr<session::KernelGraph> &graph_ptr,
uint32_t event_id, uint32_t stream_id) { uint32_t event_id, uint32_t stream_id) {
MS_EXCEPTION_IF_NULL(graph_ptr); MS_EXCEPTION_IF_NULL(graph_ptr);
auto recv_op = std::make_shared<Primitive>("Recv"); auto recv_op = std::make_shared<Primitive>(kRecvOpName);
MS_EXCEPTION_IF_NULL(recv_op); MS_EXCEPTION_IF_NULL(recv_op);
auto recv_apply = std::make_shared<ValueNode>(recv_op); auto recv_apply = std::make_shared<ValueNode>(recv_op);
MS_EXCEPTION_IF_NULL(recv_apply); MS_EXCEPTION_IF_NULL(recv_apply);
...@@ -526,7 +585,7 @@ CNodePtr AscendStreamAssign::CreateRecvApplyKernel(const std::shared_ptr<session ...@@ -526,7 +585,7 @@ CNodePtr AscendStreamAssign::CreateRecvApplyKernel(const std::shared_ptr<session
kernel::KernelBuildInfo::KernelBuildInfoBuilder selected_kernel_builder; kernel::KernelBuildInfo::KernelBuildInfoBuilder selected_kernel_builder;
selected_kernel_builder.SetKernelType(KernelType::RT_KERNEL); selected_kernel_builder.SetKernelType(KernelType::RT_KERNEL);
AnfAlgo::SetSelectKernelBuildInfo(selected_kernel_builder.Build(), recv_node_ptr.get()); AnfAlgo::SetSelectKernelBuildInfo(selected_kernel_builder.Build(), recv_node_ptr.get());
AnfAlgo::SetNodeAttr("event_id", MakeValue(event_id), recv_node_ptr); AnfAlgo::SetNodeAttr(kAttrEventId, MakeValue(event_id), recv_node_ptr);
AnfAlgo::SetStreamId(stream_id, recv_node_ptr.get()); AnfAlgo::SetStreamId(stream_id, recv_node_ptr.get());
auto abstract_none = std::make_shared<abstract::AbstractNone>(); auto abstract_none = std::make_shared<abstract::AbstractNone>();
MS_EXCEPTION_IF_NULL(abstract_none); MS_EXCEPTION_IF_NULL(abstract_none);
...@@ -605,7 +664,7 @@ bool AscendStreamAssign::IsIndependentNode(const CNodePtr &node_ptr) { ...@@ -605,7 +664,7 @@ bool AscendStreamAssign::IsIndependentNode(const CNodePtr &node_ptr) {
return false; return false;
} }
if (AnfAlgo::GetCNodeName(node_ptr) == "GetNext") { if (AnfAlgo::GetCNodeName(node_ptr) == kGetNextOpName) {
MS_LOG(INFO) << "GetNext should not be independent node"; MS_LOG(INFO) << "GetNext should not be independent node";
return false; return false;
} }
...@@ -638,20 +697,23 @@ bool AscendStreamAssign::IsTaskSink() { ...@@ -638,20 +697,23 @@ bool AscendStreamAssign::IsTaskSink() {
} }
} }
std::vector<uint32_t> AscendStreamAssign::GetWaitStreams() { void AscendStreamAssign::GetWaitStreams(vector<uint32_t> *wait_active_stream_list) {
vector<uint32_t> wait_active_stream_list;
if (total_common_stream_num_ == 0) { if (total_common_stream_num_ == 0) {
MS_LOG(INFO) << "total_common_stream_num is zero"; MS_LOG(INFO) << "total_common_stream_num is zero";
return wait_active_stream_list; return;
} }
// common stream:active first common stream // common stream:active first common stream
MS_LOG(INFO) << "active physic id[" << first_physic_id_ << "]"; MS_LOG(INFO) << "active physic id[" << first_physic_id_ << "]";
for (uint32_t i = first_physic_id_ + 1; i < total_common_stream_num_; i++) { for (uint32_t i = first_physic_id_ + 1; i < total_common_stream_num_; i++) {
auto it = std::find(need_first_active_streams_.begin(), need_first_active_streams_.end(), i);
if (it == need_first_active_streams_.end()) {
MS_LOG(INFO) << "wait common stream id = " << i; MS_LOG(INFO) << "wait common stream id = " << i;
wait_active_stream_list.push_back(i); (*wait_active_stream_list).push_back(i);
}
} }
// all independ stream id before first physical stream id should be actived
auto it = logic_to_independent_map_.find(first_logic_id_); auto it = logic_to_independent_map_.find(first_logic_id_);
if (it != logic_to_independent_map_.end()) { if (it != logic_to_independent_map_.end()) {
uint32_t independent_id = it->second; uint32_t independent_id = it->second;
...@@ -675,16 +737,14 @@ std::vector<uint32_t> AscendStreamAssign::GetWaitStreams() { ...@@ -675,16 +737,14 @@ std::vector<uint32_t> AscendStreamAssign::GetWaitStreams() {
if (i + total_common_stream_num_ <= max_before_physic) { if (i + total_common_stream_num_ <= max_before_physic) {
continue; continue;
} }
// all wait streams should not in need_first_active_streams_
auto iter =
std::find(need_first_active_streams_.begin(), need_first_active_streams_.end(), i + total_common_stream_num_);
if (iter == need_first_active_streams_.end()) {
MS_LOG(INFO) << "wait independent stream id:" << i + total_common_stream_num_; MS_LOG(INFO) << "wait independent stream id:" << i + total_common_stream_num_;
wait_active_stream_list.push_back(i + total_common_stream_num_); (*wait_active_stream_list).push_back(i + total_common_stream_num_);
}
} }
return wait_active_stream_list;
}
std::vector<uint32_t> AscendStreamAssign::GetHcomStreams() {
MS_LOG(INFO) << "hcom total stream nums:" << hcom_stream_list_.size();
return hcom_stream_list_;
} }
uint32_t AscendStreamAssign::GetTotalStreamNum() const { return total_common_stream_num_ + total_independ_stream_num_; } uint32_t AscendStreamAssign::GetTotalStreamNum() const { return total_common_stream_num_ + total_independ_stream_num_; }
...@@ -695,7 +755,7 @@ void AscendStreamAssign::PrintGraphExeOrders(const shared_ptr<mindspore::session ...@@ -695,7 +755,7 @@ void AscendStreamAssign::PrintGraphExeOrders(const shared_ptr<mindspore::session
for (size_t i = 0; i < cnode_ptr_list.size(); ++i) { for (size_t i = 0; i < cnode_ptr_list.size(); ++i) {
CNodePtr cur_cnode_ptr = cnode_ptr_list[i]; CNodePtr cur_cnode_ptr = cnode_ptr_list[i];
MS_EXCEPTION_IF_NULL(cur_cnode_ptr); MS_EXCEPTION_IF_NULL(cur_cnode_ptr);
if (AnfAlgo::GetCNodeName(cur_cnode_ptr) == "Send" || AnfAlgo::GetCNodeName(cur_cnode_ptr) == "Recv") { if (AnfAlgo::GetCNodeName(cur_cnode_ptr) == kSendOpName || AnfAlgo::GetCNodeName(cur_cnode_ptr) == kRecvOpName) {
auto primitive = AnfAlgo::GetCNodePrimitive(cur_cnode_ptr); auto primitive = AnfAlgo::GetCNodePrimitive(cur_cnode_ptr);
MS_LOG(INFO) << "node name[" << AnfAlgo::GetCNodeName(cur_cnode_ptr) << "], logic id[" MS_LOG(INFO) << "node name[" << AnfAlgo::GetCNodeName(cur_cnode_ptr) << "], logic id["
<< AnfAlgo::GetStreamDistinctionLabel(cur_cnode_ptr.get()) << "], stream id[" << AnfAlgo::GetStreamDistinctionLabel(cur_cnode_ptr.get()) << "], stream id["
......
...@@ -49,37 +49,35 @@ class AscendStreamAssign { ...@@ -49,37 +49,35 @@ class AscendStreamAssign {
uint32_t GetTotalStreamNum() const; uint32_t GetTotalStreamNum() const;
// new stream policy // new stream policy
uint32_t GetTotalCommonStreamNum() const { return total_common_stream_num_; } uint32_t total_common_stream_num() const { return total_common_stream_num_; }
uint32_t GetTotalIndependStreamNum() const { return total_independ_stream_num_; } uint32_t total_independ_stream_num() const { return total_independ_stream_num_; }
uint32_t GetTotalEventNum() const { return total_event_num_; } uint32_t total_event_num() const { return total_event_num_; }
const uint32_t GetFisrtPhysicId() const { return first_physic_id_; }
const uint32_t GetFirstLogicId() const { return first_logic_id_; }
void InsertActiveNew(const std::shared_ptr<session::KernelGraph>& graph_ptr); void InsertActiveNew(const std::shared_ptr<session::KernelGraph>& graph_ptr);
void AssignAllNodesStream(const std::shared_ptr<session::KernelGraph>& graph_ptr); void AssignAllNodesStream(const std::shared_ptr<session::KernelGraph>& graph_ptr);
void ResetNew(); void ResetNew();
void AssignStreamNew(const std::shared_ptr<session::KernelGraph>& graph_ptr); void AssignStreamNew(const std::shared_ptr<session::KernelGraph>& graph_ptr);
bool IsIndependentNode(const CNodePtr& node_ptr); bool IsIndependentNode(const CNodePtr& node_ptr);
const std::unordered_map<uint32_t, uint32_t> GetIndependentMap() { return logic_to_independent_map_; } const std::unordered_map<uint32_t, uint32_t>& logic_to_independent_map() { return logic_to_independent_map_; }
const std::unordered_map<uint32_t, uint32_t> GetPhysicMap() { return logic_to_physic_map_; } const std::unordered_map<uint32_t, uint32_t>& logic_to_physic_map() { return logic_to_physic_map_; }
std::vector<uint32_t> GetWaitStreams(); const std::vector<std::vector<uint32_t>>& inner_parallel_streams() { return inner_parallel_streams_; }
std::vector<uint32_t> GetHcomStreams(); void GetWaitStreams(vector<uint32_t>* wait_active_stream_list);
const std::vector<uint32_t>& hcom_streams() { return hcom_stream_list_; }
private:
AscendStreamAssign() = default;
~AscendStreamAssign() = default;
CNodePtr CreateSendApplyKernel(const std::shared_ptr<session::KernelGraph>& graph_ptr, uint32_t event_id, CNodePtr CreateSendApplyKernel(const std::shared_ptr<session::KernelGraph>& graph_ptr, uint32_t event_id,
uint32_t stream_id); uint32_t stream_id);
CNodePtr CreateRecvApplyKernel(const std::shared_ptr<session::KernelGraph>& graph_ptr, uint32_t event_id, CNodePtr CreateRecvApplyKernel(const std::shared_ptr<session::KernelGraph>& graph_ptr, uint32_t event_id,
uint32_t stream_id); uint32_t stream_id);
private:
AscendStreamAssign() = default;
~AscendStreamAssign() = default;
vector<CNodePtr>::iterator FindTargetOp(vector<CNodePtr>::iterator begin, vector<CNodePtr>::iterator end, vector<CNodePtr>::iterator FindTargetOp(vector<CNodePtr>::iterator begin, vector<CNodePtr>::iterator end,
const CNodePtr& node); const CNodePtr& node);
bool IsHcom(const CNodePtr& apply_kernel); bool IsHcom(const CNodePtr& apply_kernel);
bool IsProcessed(uint32_t logic_id); bool IsProcessed(uint32_t logic_id);
vector<uint32_t> TransLogicToPhysic(const vector<uint32_t>& logic_ids); void TransLogicToPhysic(const vector<uint32_t>& logic_ids, vector<uint32_t>* physic_ids);
void AssignCommonStreamId(const CNodePtr& cur_cnode_ptr, CNodePtr* pre_cnode_ptr, uint32_t* cur_index, void AssignCommonStreamId(const CNodePtr& cur_cnode_ptr, CNodePtr* pre_cnode_ptr, uint32_t* cur_index,
uint32_t* cur_stream_id); uint32_t* cur_stream_id);
void RecordIdMap(uint32_t logic_id, uint32_t physic_id); void RecordIdMap(uint32_t logic_id, uint32_t physic_id);
...@@ -88,15 +86,17 @@ class AscendStreamAssign { ...@@ -88,15 +86,17 @@ class AscendStreamAssign {
bool IsTaskSink(); bool IsTaskSink();
void AssignIndependentStreamId(const CNodePtr& cur_cnode_ptr, uint32_t deal_logic_id); void AssignIndependentStreamId(const CNodePtr& cur_cnode_ptr, uint32_t deal_logic_id);
void UpdateStreamId(const std::shared_ptr<session::KernelGraph>& graph_ptr); void UpdateStreamId(const std::shared_ptr<session::KernelGraph>& graph_ptr);
void UpdateEventId(const std::shared_ptr<session::KernelGraph>& graph_ptr);
void PrintGraphExeOrders(const std::shared_ptr<session::KernelGraph>& graph_ptr); void PrintGraphExeOrders(const std::shared_ptr<session::KernelGraph>& graph_ptr);
void RecordFirstCommonOp(const CNodePtr& cur_cnode_ptr, uint32_t cur_node_logic_id, uint32_t cur_stream_id); void RecordFirstCommonOp(const CNodePtr& cur_cnode_ptr, uint32_t cur_node_logic_id, uint32_t cur_stream_id);
uint32_t GetLogicId(const CNodePtr& cur_cnode_ptr); uint32_t GetLogicId(const CNodePtr& cur_cnode_ptr);
void SetCommonStreamNum(uint32_t cur_stream_id); void SetCommonStreamNum(uint32_t cur_stream_id);
void FindAllReduceParallel(const std::shared_ptr<session::KernelGraph>& graph_ptr); void FindAllReduceParallel(const std::shared_ptr<session::KernelGraph>& graph_ptr);
bool IsProcessedParallelStream(uint32_t stream_id); bool IsProcessedParallelStream(uint32_t stream_id);
vector<uint32_t> GetParallelStream(uint32_t cur_stream_id, uint32_t stream_acitve_id); void GetParallelStream(uint32_t cur_stream_id, uint32_t stream_acitve_id, std::vector<uint32_t>* parallel_streams);
void InsertSendRecvForIndependent(const std::shared_ptr<session::KernelGraph>& graph_ptr); void InsertSendRecvForIndependent(const std::shared_ptr<session::KernelGraph>& graph_ptr);
void InsertSendRecvForHcomParallel(const std::shared_ptr<session::KernelGraph>& graph_ptr); void InsertSendRecvForHcomParallel(const std::shared_ptr<session::KernelGraph>& graph_ptr);
void GetNeedActiveStreams(const std::shared_ptr<session::KernelGraph>& graph_ptr);
uint32_t total_common_stream_num_{0}; uint32_t total_common_stream_num_{0};
uint32_t total_independ_stream_num_{0}; uint32_t total_independ_stream_num_{0};
...@@ -112,6 +112,7 @@ class AscendStreamAssign { ...@@ -112,6 +112,7 @@ class AscendStreamAssign {
std::vector<std::vector<uint32_t>> inner_parallel_streams_{}; std::vector<std::vector<uint32_t>> inner_parallel_streams_{};
std::vector<uint32_t> processed_parallel_streams_{}; std::vector<uint32_t> processed_parallel_streams_{};
std::vector<uint32_t> hcom_stream_list_{}; std::vector<uint32_t> hcom_stream_list_{};
std::vector<uint32_t> need_first_active_streams_{};
// new policy end // new policy end
}; };
} // namespace ascend } // namespace ascend
......
...@@ -32,16 +32,8 @@ ...@@ -32,16 +32,8 @@
#include "utils/utils.h" #include "utils/utils.h"
#include "device/ascend/profiling/profiling_manager.h" #include "device/ascend/profiling/profiling_manager.h"
#include "device/ascend/kernel_select_ascend.h" #include "device/ascend/kernel_select_ascend.h"
#include "device/kernel_info.h"
#include "runtime/base.h" #include "runtime/base.h"
#include "device/ascend/ascend_stream_assign.h"
constexpr auto kLoopCountParamName = "loop_count";
constexpr auto kIterLoopParamName = "iter_loop";
constexpr auto kZeroParamName = "zero";
constexpr auto kOneParamName = "one";
constexpr auto kStreamSwitch = "StreamSwitch";
constexpr auto kStreamActive = "StreamActive";
constexpr auto kAssignAdd = "AssignAdd";
namespace mindspore { namespace mindspore {
namespace device { namespace device {
using device::ascend::ProfilingUtils; using device::ascend::ProfilingUtils;
...@@ -70,6 +62,63 @@ bool KernelAdjust::NeedInsertSwitch() { ...@@ -70,6 +62,63 @@ bool KernelAdjust::NeedInsertSwitch() {
ConfigManager::GetInstance().iter_num() > 1); ConfigManager::GetInstance().iter_num() > 1);
} }
uint32_t KernelAdjust::FindFirstStreamSwitchLabel(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr) {
MS_EXCEPTION_IF_NULL(kernel_graph_ptr);
auto cnode_ptr_list = kernel_graph_ptr->execution_order();
CNodePtr cur_cnode_ptr = nullptr;
uint32_t label = kInvalidDistincLabel;
for (uint32_t i = 0; i < cnode_ptr_list.size(); ++i) {
cur_cnode_ptr = cnode_ptr_list[i];
MS_EXCEPTION_IF_NULL(cur_cnode_ptr);
if (AnfAlgo::GetCNodeName(cur_cnode_ptr) == kStreamSwitchOpName) {
label = AnfAlgo::GetStreamDistinctionLabel(cur_cnode_ptr.get());
break;
}
}
return label;
}
CNodePtr KernelAdjust::CreateSendApplyKernel(const std::shared_ptr<session::KernelGraph> &graph_ptr,
uint32_t event_id) {
MS_EXCEPTION_IF_NULL(graph_ptr);
auto send_op = std::make_shared<Primitive>(kSendOpName);
MS_EXCEPTION_IF_NULL(send_op);
auto send_apply = std::make_shared<ValueNode>(send_op);
MS_EXCEPTION_IF_NULL(send_apply);
std::vector<AnfNodePtr> send_input_list = {send_apply};
CNodePtr send_node_ptr = graph_ptr->NewCNode(send_input_list);
MS_EXCEPTION_IF_NULL(send_node_ptr);
kernel::KernelBuildInfo::KernelBuildInfoBuilder selected_kernel_builder;
selected_kernel_builder.SetKernelType(KernelType::RT_KERNEL);
AnfAlgo::SetSelectKernelBuildInfo(selected_kernel_builder.Build(), send_node_ptr.get());
AnfAlgo::SetNodeAttr(kAttrEventId, MakeValue(event_id), send_node_ptr);
auto abstract_none = std::make_shared<abstract::AbstractNone>();
MS_EXCEPTION_IF_NULL(abstract_none);
send_node_ptr->set_abstract(abstract_none);
return send_node_ptr;
}
CNodePtr KernelAdjust::CreateRecvApplyKernel(const std::shared_ptr<session::KernelGraph> &graph_ptr,
uint32_t event_id) {
MS_EXCEPTION_IF_NULL(graph_ptr);
auto recv_op = std::make_shared<Primitive>(kRecvOpName);
MS_EXCEPTION_IF_NULL(recv_op);
auto recv_apply = std::make_shared<ValueNode>(recv_op);
MS_EXCEPTION_IF_NULL(recv_apply);
std::vector<AnfNodePtr> recv_input_list = {recv_apply};
CNodePtr recv_node_ptr = graph_ptr->NewCNode(recv_input_list);
MS_EXCEPTION_IF_NULL(recv_node_ptr);
kernel::KernelBuildInfo::KernelBuildInfoBuilder selected_kernel_builder;
selected_kernel_builder.SetKernelType(KernelType::RT_KERNEL);
AnfAlgo::SetSelectKernelBuildInfo(selected_kernel_builder.Build(), recv_node_ptr.get());
AnfAlgo::SetNodeAttr(kAttrEventId, MakeValue(event_id), recv_node_ptr);
auto abstract_none = std::make_shared<abstract::AbstractNone>();
MS_EXCEPTION_IF_NULL(abstract_none);
recv_node_ptr->set_abstract(abstract_none);
return recv_node_ptr;
}
void KernelAdjust::InsertSwitchLoop(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr) { void KernelAdjust::InsertSwitchLoop(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr) {
if (!NeedInsertSwitch()) { if (!NeedInsertSwitch()) {
return; return;
...@@ -93,21 +142,95 @@ void KernelAdjust::InsertSwitchLoop(const std::shared_ptr<session::KernelGraph> ...@@ -93,21 +142,95 @@ void KernelAdjust::InsertSwitchLoop(const std::shared_ptr<session::KernelGraph>
} }
} }
} }
auto orders = kernel_graph_ptr->execution_order();
if (orders.empty()) {
MS_LOG(EXCEPTION) << "graph execution order is empty";
}
uint32_t first_cnode_stream_label = AnfAlgo::GetStreamDistinctionLabel(orders[0].get());
std::vector<CNodePtr> exec_order; std::vector<CNodePtr> exec_order;
CNodePtr stream_switch_app = CreateStreamSwitchOp(kernel_graph_ptr, switch_loop_input); CNodePtr first_stream_switch_app = CreateStreamSwitchOp(kernel_graph_ptr, switch_loop_input);
MS_EXCEPTION_IF_NULL(stream_switch_app); MS_EXCEPTION_IF_NULL(first_stream_switch_app);
exec_order.push_back(stream_switch_app); AnfAlgo::SetStreamDistinctionLabel(kFirstStreamSwitchLabel, first_stream_switch_app.get());
AnfAlgo::SetNodeAttr(kAttrTrueBranchStream, MakeValue<uint32_t>(kGetNextLabel), first_stream_switch_app);
CNodePtr second_stream_switch_app = CreateStreamSwitchOp(kernel_graph_ptr, switch_loop_input);
MS_EXCEPTION_IF_NULL(second_stream_switch_app);
AnfAlgo::SetStreamDistinctionLabel(kSecondStreamSwitchLabel, second_stream_switch_app.get());
AnfAlgo::SetNodeAttr(kAttrTrueBranchStream, MakeValue<uint32_t>(first_cnode_stream_label), second_stream_switch_app);
// add attr "stream_need_active"
AnfAlgo::SetNodeAttr(kStreamNeedActivedFirst, MakeValue<bool>(true), second_stream_switch_app);
CNodePtr first_stream_active_app = CreateStreamActiveOp(kernel_graph_ptr);
MS_EXCEPTION_IF_NULL(first_stream_active_app);
AnfAlgo::SetStreamDistinctionLabel(first_cnode_stream_label, first_stream_active_app.get());
std::vector<uint32_t> first_active_streams = {kFirstStreamSwitchLabel};
AnfAlgo::SetNodeAttr(kAttrActiveStreamList, MakeValue<std::vector<uint32_t>>(first_active_streams),
first_stream_active_app);
CNodePtr second_stream_active_app = CreateStreamActiveOp(kernel_graph_ptr);
MS_EXCEPTION_IF_NULL(second_stream_active_app);
// specific deal for common ctrl stream policy
uint32_t first_common_stream_switch_label = FindFirstStreamSwitchLabel(kernel_graph_ptr);
if (first_common_stream_switch_label == kInvalidDistincLabel) {
AnfAlgo::SetStreamDistinctionLabel(first_cnode_stream_label, second_stream_active_app.get());
} else {
AnfAlgo::SetStreamDistinctionLabel(first_common_stream_switch_label, second_stream_active_app.get());
}
CNodePtr stream_active_switch_app = CreateStreamActiveSwitchOp(kernel_graph_ptr); std::vector<uint32_t> second_active_streams = {kSecondStreamSwitchLabel};
MS_EXCEPTION_IF_NULL(stream_active_switch_app); AnfAlgo::SetNodeAttr(kAttrActiveStreamList, MakeValue<std::vector<uint32_t>>(second_active_streams),
second_stream_active_app);
CNodePtr assign_add_one = CreateStreamAssignAddnOP(kernel_graph_ptr, switch_loop_input); CNodePtr assign_add_one = CreateStreamAssignAddnOP(kernel_graph_ptr, switch_loop_input);
MS_EXCEPTION_IF_NULL(assign_add_one); MS_EXCEPTION_IF_NULL(assign_add_one);
AnfAlgo::SetStreamDistinctionLabel(first_cnode_stream_label, assign_add_one.get());
CNodePtr send = CreateSendApplyKernel(kernel_graph_ptr, kFirstEventId);
AnfAlgo::SetStreamDistinctionLabel(kGetNextLabel, send.get());
CNodePtr recv = CreateRecvApplyKernel(kernel_graph_ptr, kFirstEventId);
AnfAlgo::SetStreamDistinctionLabel(first_cnode_stream_label, recv.get());
// reorder graph orders
exec_order.push_back(first_stream_switch_app);
size_t i = 0;
for (; i < orders.size(); i++) {
auto node = orders[i];
exec_order.push_back(node);
AnfAlgo::SetStreamDistinctionLabel(kGetNextLabel, exec_order[exec_order.size() - 1].get());
if (AnfAlgo::GetCNodeName(node) == kGetNextOpName) {
break;
}
}
exec_order.push_back(send);
exec_order.push_back(second_stream_switch_app);
exec_order.push_back(recv);
exec_order.push_back(assign_add_one); exec_order.push_back(assign_add_one);
auto original_exec_order = kernel_graph_ptr->execution_order(); std::vector<CNodePtr> memcpy_list;
(void)std::copy(original_exec_order.begin(), original_exec_order.end(), std::back_inserter(exec_order)); std::vector<CNodePtr> before_list;
exec_order.push_back(stream_active_switch_app); std::vector<CNodePtr> after_list;
bool first_memcpy_found = false;
CNodePtr cur_cnode = nullptr;
for (size_t idx = i + 1; idx < orders.size(); idx++) {
cur_cnode = orders[idx];
if (AnfAlgo::HasNodeAttr(kAttrLabelForInsertStreamActive, cur_cnode)) {
memcpy_list.emplace_back(cur_cnode);
first_memcpy_found = true;
} else if (first_memcpy_found) {
after_list.emplace_back(cur_cnode);
} else {
before_list.emplace_back(cur_cnode);
}
}
(void)std::copy(before_list.begin(), before_list.end(), std::back_inserter(exec_order));
(void)std::copy(memcpy_list.begin(), memcpy_list.end(), std::back_inserter(exec_order));
exec_order.push_back(first_stream_active_app);
(void)std::copy(after_list.begin(), after_list.end(), std::back_inserter(exec_order));
exec_order.push_back(second_stream_active_app);
kernel_graph_ptr->set_execution_order(exec_order); kernel_graph_ptr->set_execution_order(exec_order);
} }
...@@ -167,7 +290,7 @@ CNodePtr KernelAdjust::CreateStreamSwitchOp(const std::shared_ptr<session::Kerne ...@@ -167,7 +290,7 @@ CNodePtr KernelAdjust::CreateStreamSwitchOp(const std::shared_ptr<session::Kerne
kernel::KernelBuildInfo::KernelBuildInfoBuilder selected_kernel_builder = CreateMngKernelBuilder( kernel::KernelBuildInfo::KernelBuildInfoBuilder selected_kernel_builder = CreateMngKernelBuilder(
{kOpFormat_DEFAULT, kOpFormat_DEFAULT}, {TypeId::kNumberTypeInt32, TypeId::kNumberTypeInt32}); {kOpFormat_DEFAULT, kOpFormat_DEFAULT}, {TypeId::kNumberTypeInt32, TypeId::kNumberTypeInt32});
auto typeNone_abstract = std::make_shared<abstract::AbstractNone>(); auto typeNone_abstract = std::make_shared<abstract::AbstractNone>();
auto stream_switch = std::make_shared<Primitive>(kStreamSwitch); auto stream_switch = std::make_shared<Primitive>(kStreamSwitchOpName);
std::vector<AnfNodePtr> inputs; std::vector<AnfNodePtr> inputs;
inputs.push_back(NewValueNode(stream_switch)); inputs.push_back(NewValueNode(stream_switch));
inputs.push_back(switch_loop_input.at(kLoopCountParamName)); inputs.push_back(switch_loop_input.at(kLoopCountParamName));
...@@ -181,76 +304,19 @@ CNodePtr KernelAdjust::CreateStreamSwitchOp(const std::shared_ptr<session::Kerne ...@@ -181,76 +304,19 @@ CNodePtr KernelAdjust::CreateStreamSwitchOp(const std::shared_ptr<session::Kerne
int condition = static_cast<int>(RT_LESS); int condition = static_cast<int>(RT_LESS);
ValuePtr cond = MakeValue(condition); ValuePtr cond = MakeValue(condition);
AnfAlgo::SetNodeAttr(kAttrSwitchCondition, cond, stream_switch_app); AnfAlgo::SetNodeAttr(kAttrSwitchCondition, cond, stream_switch_app);
// set attr:true branch graph id ,which is same to stream distinction label
if (kernel_graph_ptr->execution_order().empty()) {
MS_LOG(EXCEPTION) << "empty execution order";
}
auto first_node = kernel_graph_ptr->execution_order()[0];
auto first_stream = AnfAlgo::GetStreamDistinctionLabel(first_node.get());
AnfAlgo::SetNodeAttr(kAttrTrueBranchStream, MakeValue<uint32_t>(first_stream), stream_switch_app);
// set attr:data_type // set attr:data_type
int data_type = static_cast<int>(RT_SWITCH_INT64); int data_type = static_cast<int>(RT_SWITCH_INT64);
ValuePtr dt = MakeValue(data_type); ValuePtr dt = MakeValue(data_type);
AnfAlgo::SetNodeAttr(kAttrDataType, dt, stream_switch_app); AnfAlgo::SetNodeAttr(kAttrDataType, dt, stream_switch_app);
// set distinction label and graph id // set distinction label and graph id
AnfAlgo::SetGraphId(kInvalidGraphId - 1, stream_switch_app.get());
AnfAlgo::SetStreamDistinctionLabel(kInvalidDistincLabel - 1, stream_switch_app.get());
return stream_switch_app; return stream_switch_app;
} }
CNodePtr KernelAdjust::CreateSteamActiveOp(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr) { CNodePtr KernelAdjust::CreateStreamActiveOp(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr) {
kernel::KernelBuildInfo::KernelBuildInfoBuilder selected_kernel_builder = CreateMngKernelBuilder(
{kOpFormat_DEFAULT, kOpFormat_DEFAULT}, {TypeId::kNumberTypeInt32, TypeId::kNumberTypeInt32});
abstract::AbstractBasePtr typeNone_abstract = std::make_shared<abstract::AbstractNone>();
auto stream_active_others = std::make_shared<Primitive>(kStreamActive);
std::vector<AnfNodePtr> inputs;
inputs.push_back(NewValueNode(stream_active_others));
MS_EXCEPTION_IF_NULL(kernel_graph_ptr);
CNodePtr stream_active_others_app = kernel_graph_ptr->NewCNode(inputs);
MS_EXCEPTION_IF_NULL(stream_active_others_app);
AnfAlgo::SetSelectKernelBuildInfo(selected_kernel_builder.Build(), stream_active_others_app.get());
stream_active_others_app->set_abstract(typeNone_abstract);
return stream_active_others_app;
}
CNodePtr KernelAdjust::CreateStreamActiveSwitchOp(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr) {
kernel::KernelBuildInfo::KernelBuildInfoBuilder selected_kernel_builder = CreateMngKernelBuilder(
{kOpFormat_DEFAULT, kOpFormat_DEFAULT}, {TypeId::kNumberTypeInt32, TypeId::kNumberTypeInt32});
abstract::AbstractBasePtr typeNone_abstract = std::make_shared<abstract::AbstractNone>();
auto stream_active_switch = std::make_shared<Primitive>(kStreamActive);
std::vector<AnfNodePtr> inputs;
inputs.push_back(NewValueNode(stream_active_switch));
MS_EXCEPTION_IF_NULL(kernel_graph_ptr);
CNodePtr stream_active_switch_app = kernel_graph_ptr->NewCNode(inputs);
MS_EXCEPTION_IF_NULL(stream_active_switch_app);
AnfAlgo::SetSelectKernelBuildInfo(selected_kernel_builder.Build(), stream_active_switch_app.get());
stream_active_switch_app->set_abstract(typeNone_abstract);
// set attr,which stream to active
std::vector<uint32_t> active_index_value = {kInvalidDistincLabel - 1};
auto value = MakeValue<std::vector<uint32_t>>(active_index_value);
AnfAlgo::SetNodeAttr(kAttrActiveStreamList, value, stream_active_switch_app);
// set the distinction label of stream active
if (kernel_graph_ptr->execution_order().empty()) {
MS_LOG(EXCEPTION) << "empty execution order";
}
auto first_node = kernel_graph_ptr->execution_order()[0];
auto label = AnfAlgo::GetStreamDistinctionLabel(first_node.get());
// find the first switch's distinction label
for (auto node : kernel_graph_ptr->execution_order()) {
if (AnfAlgo::GetCNodeName(node) == "StreamSwitch") {
label = AnfAlgo::GetStreamDistinctionLabel(node.get());
break;
}
}
AnfAlgo::SetStreamDistinctionLabel(label, stream_active_switch_app.get());
return stream_active_switch_app;
}
CNodePtr KernelAdjust::CreateStreamActiveOtherOp(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr) {
kernel::KernelBuildInfo::KernelBuildInfoBuilder selected_kernel_builder = CreateMngKernelBuilder( kernel::KernelBuildInfo::KernelBuildInfoBuilder selected_kernel_builder = CreateMngKernelBuilder(
{kOpFormat_DEFAULT, kOpFormat_DEFAULT}, {TypeId::kNumberTypeInt32, TypeId::kNumberTypeInt32}); {kOpFormat_DEFAULT, kOpFormat_DEFAULT}, {TypeId::kNumberTypeInt32, TypeId::kNumberTypeInt32});
abstract::AbstractBasePtr typeNone_abstract = std::make_shared<abstract::AbstractNone>(); abstract::AbstractBasePtr typeNone_abstract = std::make_shared<abstract::AbstractNone>();
auto stream_active_others = std::make_shared<Primitive>(kStreamActive); auto stream_active_others = std::make_shared<Primitive>(kStreamActiveOpName);
std::vector<AnfNodePtr> inputs; std::vector<AnfNodePtr> inputs;
inputs.push_back(NewValueNode(stream_active_others)); inputs.push_back(NewValueNode(stream_active_others));
MS_EXCEPTION_IF_NULL(kernel_graph_ptr); MS_EXCEPTION_IF_NULL(kernel_graph_ptr);
...@@ -258,9 +324,6 @@ CNodePtr KernelAdjust::CreateStreamActiveOtherOp(const std::shared_ptr<session:: ...@@ -258,9 +324,6 @@ CNodePtr KernelAdjust::CreateStreamActiveOtherOp(const std::shared_ptr<session::
MS_EXCEPTION_IF_NULL(stream_active_others_app); MS_EXCEPTION_IF_NULL(stream_active_others_app);
AnfAlgo::SetSelectKernelBuildInfo(selected_kernel_builder.Build(), stream_active_others_app.get()); AnfAlgo::SetSelectKernelBuildInfo(selected_kernel_builder.Build(), stream_active_others_app.get());
stream_active_others_app->set_abstract(typeNone_abstract); stream_active_others_app->set_abstract(typeNone_abstract);
// set attr
ValuePtr active_target = MakeValue(kValueTargetOther);
AnfAlgo::SetNodeAttr(kAttrActiveTarget, active_target, stream_active_others_app);
return stream_active_others_app; return stream_active_others_app;
} }
...@@ -273,7 +336,7 @@ CNodePtr KernelAdjust::CreateStreamAssignAddnOP( ...@@ -273,7 +336,7 @@ CNodePtr KernelAdjust::CreateStreamAssignAddnOP(
selected_kernel_builder.SetOutputsFormat({kOpFormat_DEFAULT}); selected_kernel_builder.SetOutputsFormat({kOpFormat_DEFAULT});
selected_kernel_builder.SetOutputsDeviceType({kNumberTypeInt32}); selected_kernel_builder.SetOutputsDeviceType({kNumberTypeInt32});
// AssignAdd // AssignAdd
auto assign_add = std::make_shared<Primitive>(kAssignAdd); auto assign_add = std::make_shared<Primitive>(kAssignAddOpName);
std::vector<AnfNodePtr> inputs; std::vector<AnfNodePtr> inputs;
inputs.push_back(NewValueNode(assign_add)); inputs.push_back(NewValueNode(assign_add));
inputs.push_back(switch_loop_input.at(kLoopCountParamName)); inputs.push_back(switch_loop_input.at(kLoopCountParamName));
...@@ -290,70 +353,9 @@ CNodePtr KernelAdjust::CreateStreamAssignAddnOP( ...@@ -290,70 +353,9 @@ CNodePtr KernelAdjust::CreateStreamAssignAddnOP(
selected_kernel_builder.SetKernelType(KernelType::TBE_KERNEL); selected_kernel_builder.SetKernelType(KernelType::TBE_KERNEL);
MS_EXCEPTION_IF_NULL(switch_loop_input.at(kLoopCountParamName)); MS_EXCEPTION_IF_NULL(switch_loop_input.at(kLoopCountParamName));
assign_add_one->set_abstract(switch_loop_input.at(kLoopCountParamName)->abstract()); assign_add_one->set_abstract(switch_loop_input.at(kLoopCountParamName)->abstract());
// set the distinction label of assign add
if (kernel_graph_ptr->execution_order().empty()) {
MS_LOG(EXCEPTION) << "empty execution order";
}
auto first_node = kernel_graph_ptr->execution_order()[0];
auto label = AnfAlgo::GetStreamDistinctionLabel(first_node.get());
AnfAlgo::SetStreamDistinctionLabel(label, assign_add_one.get());
return assign_add_one; return assign_add_one;
} }
void KernelAdjust::SetStreamActiveOPs(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr,
const std::unordered_set<uint32_t> &ctrl_stream_list,
const std::unordered_set<uint32_t> &comm_stream_list,
const std::unordered_set<uint32_t> &momentum_stream_list) {
MS_EXCEPTION_IF_NULL(kernel_graph_ptr);
for (const auto &cnode_ptr : kernel_graph_ptr->execution_order()) {
MS_EXCEPTION_IF_NULL(cnode_ptr);
if (AnfAlgo::GetCNodeName(cnode_ptr) == kStreamActive) {
auto primitive = AnfAlgo::GetCNodePrimitive(cnode_ptr);
ValuePtr active_target = primitive->GetAttr(kAttrActiveTarget);
std::vector<uint32_t> index_list;
index_list.clear();
if (GetValue<string>(active_target) == kValueTargetSwitch) {
index_list.insert(index_list.end(), ctrl_stream_list.begin(), ctrl_stream_list.end());
} else if (GetValue<string>(active_target) == kValueTargetOther) {
for (uint32_t index : comm_stream_list) {
if (AnfAlgo::GetStreamId(cnode_ptr) == index) {
continue;
}
index_list.emplace_back(index);
}
index_list.insert(index_list.end(), momentum_stream_list.begin(), momentum_stream_list.end());
}
ValuePtr index_list_value = MakeValue(index_list);
AnfAlgo::SetNodeAttr(kAttrActiveStreamList, index_list_value, cnode_ptr);
}
}
}
void KernelAdjust::SetStreamSwitchOps(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr) {
MS_EXCEPTION_IF_NULL(kernel_graph_ptr);
CNodePtr switch_cnode_ptr = nullptr;
uint32_t target_stream_id = 0;
for (const auto &cnode_ptr : kernel_graph_ptr->execution_order()) {
MS_EXCEPTION_IF_NULL(cnode_ptr);
if (AnfAlgo::GetCNodeName(cnode_ptr) == kStreamSwitch) {
switch_cnode_ptr = cnode_ptr;
}
if (AnfAlgo::GetCNodeName(cnode_ptr) == kStreamActive) {
auto primitive = AnfAlgo::GetCNodePrimitive(cnode_ptr);
ValuePtr active_target = primitive->GetAttr(kAttrActiveTarget);
if (GetValue<string>(active_target) == kValueTargetOther) {
target_stream_id = AnfAlgo::GetStreamId(cnode_ptr);
}
}
}
if (switch_cnode_ptr != nullptr) {
// set attr:true stream
ValuePtr true_index = MakeValue(target_stream_id);
AnfAlgo::SetNodeAttr(kAttrTrueBranchStream, true_index, switch_cnode_ptr);
MS_LOG(INFO) << "switch to true_index:" << target_stream_id;
}
}
bool KernelAdjust::StepLoadCtrlInputs(const std::shared_ptr<session::Context> &context, bool KernelAdjust::StepLoadCtrlInputs(const std::shared_ptr<session::Context> &context,
const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr) { const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr) {
if (!NeedInsertSwitch()) { if (!NeedInsertSwitch()) {
......
...@@ -28,10 +28,22 @@ ...@@ -28,10 +28,22 @@
#include "session/session_context.h" #include "session/session_context.h"
#include "ir/meta_tensor.h" #include "ir/meta_tensor.h"
#include "device/ascend/profiling/profiling_utils.h" #include "device/ascend/profiling/profiling_utils.h"
#include "device/kernel_info.h"
using mindspore::device::ascend::ProfilingTraceInfo; using mindspore::device::ascend::ProfilingTraceInfo;
using mindspore::device::ascend::ProfilingUtils; using mindspore::device::ascend::ProfilingUtils;
namespace mindspore { namespace mindspore {
constexpr auto kLoopCountParamName = "loop_count";
constexpr auto kIterLoopParamName = "iter_loop";
constexpr auto kZeroParamName = "zero";
constexpr auto kOneParamName = "one";
constexpr auto kStreamNeedActivedFirst = "stream_need_active_first";
const uint32_t kFirstStreamSwitchLabel = kInvalidDistincLabel - 1;
const uint32_t kGetNextLabel = kInvalidDistincLabel - 2;
const uint32_t kSecondStreamSwitchLabel = kInvalidDistincLabel - 3;
const uint32_t kInvalidEventId = UINT32_MAX;
const uint32_t kFirstEventId = kInvalidEventId / 2;
namespace device { namespace device {
class KernelAdjust { class KernelAdjust {
public: public:
...@@ -41,26 +53,23 @@ class KernelAdjust { ...@@ -41,26 +53,23 @@ class KernelAdjust {
} }
void Reorder(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr); void Reorder(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr);
void InsertSwitchLoop(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr); void InsertSwitchLoop(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr);
void SetStreamActiveOPs(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr,
const std::unordered_set<uint32_t> &ctrl_stream_list,
const std::unordered_set<uint32_t> &comm_stream_list,
const std::unordered_set<uint32_t> &momentum_stream_list);
void SetStreamSwitchOps(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr);
bool StepLoadCtrlInputs(const std::shared_ptr<session::Context> &context, bool StepLoadCtrlInputs(const std::shared_ptr<session::Context> &context,
const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr); const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr);
void Profiling(NotNull<session::KernelGraph *> kernel_graph_ptr); void Profiling(NotNull<session::KernelGraph *> kernel_graph_ptr);
static bool NeedInsertSwitch(); static bool NeedInsertSwitch();
CNodePtr CreateSteamActiveOp(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr); CNodePtr CreateStreamActiveOp(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr);
private: private:
KernelAdjust() = default; KernelAdjust() = default;
~KernelAdjust() = default; ~KernelAdjust() = default;
CNodePtr CreateRecvApplyKernel(const std::shared_ptr<session::KernelGraph> &graph_ptr, uint32_t event_id);
CNodePtr CreateSendApplyKernel(const std::shared_ptr<session::KernelGraph> &graph_ptr, uint32_t event_id);
uint32_t FindFirstStreamSwitchLabel(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr);
void CreateSwitchOpParameters(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr, void CreateSwitchOpParameters(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr,
std::map<std::string, mindspore::ParameterPtr> *switch_loop_input); std::map<std::string, mindspore::ParameterPtr> *switch_loop_input);
CNodePtr CreateStreamSwitchOp(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr, CNodePtr CreateStreamSwitchOp(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr,
const std::map<std::string, mindspore::ParameterPtr> &switch_loop_input); const std::map<std::string, mindspore::ParameterPtr> &switch_loop_input);
CNodePtr CreateStreamActiveSwitchOp(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr);
CNodePtr CreateStreamActiveOtherOp(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr);
CNodePtr CreateStreamAssignAddnOP(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr, CNodePtr CreateStreamAssignAddnOP(const std::shared_ptr<session::KernelGraph> &kernel_graph_ptr,
const std::map<std::string, mindspore::ParameterPtr> &switch_loop_input); const std::map<std::string, mindspore::ParameterPtr> &switch_loop_input);
kernel::KernelBuildInfo::KernelBuildInfoBuilder CreateMngKernelBuilder(const std::vector<std::string> &formats, kernel::KernelBuildInfo::KernelBuildInfoBuilder CreateMngKernelBuilder(const std::vector<std::string> &formats,
......
...@@ -62,6 +62,7 @@ ...@@ -62,6 +62,7 @@
#include "pre_activate/ascend/format_type/insert_transdata_for_runop.h" #include "pre_activate/ascend/format_type/insert_transdata_for_runop.h"
#include "pre_activate/ascend/enhancer/getnext_memcpy_elimination.h" #include "pre_activate/ascend/enhancer/getnext_memcpy_elimination.h"
#include "pre_activate/ascend/ir_fission/addn_fission.h" #include "pre_activate/ascend/ir_fission/addn_fission.h"
#include "pre_activate/ascend/enhancer/insert_memcpy_async_for_getnext.h"
#include "utils/context/ms_context.h" #include "utils/context/ms_context.h"
#include "utils/config_manager.h" #include "utils/config_manager.h"
#include "debug/anf_ir_dump.h" #include "debug/anf_ir_dump.h"
...@@ -187,6 +188,12 @@ void AscendBackendIRFusionOptimization(const std::shared_ptr<session::KernelGrap ...@@ -187,6 +188,12 @@ void AscendBackendIRFusionOptimization(const std::shared_ptr<session::KernelGrap
ir_fusion_pm->AddPass(std::make_shared<GetitemTuple>()); ir_fusion_pm->AddPass(std::make_shared<GetitemTuple>());
ir_fusion_pm->AddPass(std::make_shared<TransposeTransDataFusion>()); ir_fusion_pm->AddPass(std::make_shared<TransposeTransDataFusion>());
} }
if (context_ptr->enable_task_sink() && context_ptr->loop_sink_flag() && ConfigManager::GetInstance().iter_num() > 1) {
ir_fusion_pm->AddPass(std::make_shared<InsertMemcpyAsyncForGetNext>());
ir_fusion_pm->AddPass(std::make_shared<GetitemTuple>());
ir_fusion_pm->AddPass(std::make_shared<EraseVisitAttr>());
}
optimizer->AddPassManager(ir_fusion_pm); optimizer->AddPassManager(ir_fusion_pm);
(void)optimizer->Optimize(kernel_graph); (void)optimizer->Optimize(kernel_graph);
kernel_graph->SetExecOrderByDefault(); kernel_graph->SetExecOrderByDefault();
......
...@@ -20,8 +20,8 @@ namespace mindspore { ...@@ -20,8 +20,8 @@ namespace mindspore {
namespace memreuse { namespace memreuse {
void StreamReuse::SetStreamReuseResource() { void StreamReuse::SetStreamReuseResource() {
#ifdef ENABLE_D #ifdef ENABLE_D
auto logic_physic_map = device::ascend::AscendStreamAssign::GetInstance().GetPhysicMap(); auto logic_physic_map = device::ascend::AscendStreamAssign::GetInstance().logic_to_physic_map();
auto logic_independent_map = device::ascend::AscendStreamAssign::GetInstance().GetIndependentMap(); auto logic_independent_map = device::ascend::AscendStreamAssign::GetInstance().logic_to_independent_map();
MS_LOG(INFO) << "stream mem reuse for Davici"; MS_LOG(INFO) << "stream mem reuse for Davici";
if (!logic_independent_map.empty() && !logic_physic_map.empty()) { if (!logic_independent_map.empty() && !logic_physic_map.empty()) {
set_logic_physic_map(logic_physic_map); set_logic_physic_map(logic_physic_map);
......
...@@ -610,7 +610,7 @@ void AscendSession::CopyOutputOfIf(GraphId false_graph_id) { ...@@ -610,7 +610,7 @@ void AscendSession::CopyOutputOfIf(GraphId false_graph_id) {
if (context_ptr->enable_task_sink() && context_ptr->loop_sink_flag() && if (context_ptr->enable_task_sink() && context_ptr->loop_sink_flag() &&
ConfigManager::GetInstance().iter_num() > 1) { ConfigManager::GetInstance().iter_num() > 1) {
// insert active in true graph, another active will be inserted in kernel adjust // insert active in true graph, another active will be inserted in kernel adjust
InsertStreamActiveToGraph(true_last_id, kInvalidDistincLabel - 1); InsertStreamActiveToGraph(true_last_id, kSecondStreamSwitchLabel);
} }
break; break;
} }
......
...@@ -114,6 +114,9 @@ constexpr auto kFusedMulAddNOpName = "FusedMulAddN"; ...@@ -114,6 +114,9 @@ constexpr auto kFusedMulAddNOpName = "FusedMulAddN";
constexpr auto kFusedMulApplyMomentumOpName = "FusedMulApplyMomentum"; constexpr auto kFusedMulApplyMomentumOpName = "FusedMulApplyMomentum";
constexpr auto kBiasAddOpName = "BiasAdd"; constexpr auto kBiasAddOpName = "BiasAdd";
constexpr auto kConfusionMulGradOpName = "ConfusionMulGrad"; constexpr auto kConfusionMulGradOpName = "ConfusionMulGrad";
constexpr auto kStreamSwitchOpName = "StreamSwitch";
constexpr auto kStreamActiveOpName = "StreamActive";
constexpr auto kAssignAddOpName = "AssignAdd";
constexpr auto kSendOpName = "Send"; constexpr auto kSendOpName = "Send";
constexpr auto kRecvOpName = "Recv"; constexpr auto kRecvOpName = "Recv";
constexpr auto kReluV2OpName = "ReluV2"; constexpr auto kReluV2OpName = "ReluV2";
......
...@@ -24,9 +24,7 @@ void AscendStreamAssign::AssignStreamNew(const KernelGraphPtr &graph) { return; ...@@ -24,9 +24,7 @@ void AscendStreamAssign::AssignStreamNew(const KernelGraphPtr &graph) { return;
uint32_t AscendStreamAssign::GetTotalStreamNum() const { return 1; } uint32_t AscendStreamAssign::GetTotalStreamNum() const { return 1; }
std::vector<uint32_t> AscendStreamAssign::GetWaitStreams() { return vector<uint32_t>(); } void AscendStreamAssign::GetWaitStreams(vector<uint32_t> *wait_active_stream_list) { return; }
std::vector<uint32_t> AscendStreamAssign::GetHcomStreams() { return vector<uint32_t>(); }
namespace tasksink { namespace tasksink {
bool TaskGenerator::GenTasks(const std::vector<CNodePtr> &anf_node_list, std::vector<TaskInfoPtr> *const task_info_list, bool TaskGenerator::GenTasks(const std::vector<CNodePtr> &anf_node_list, std::vector<TaskInfoPtr> *const task_info_list,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册