提交 7951b318 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!2697 format device ascend code

Merge pull request !2697 from kisnwang/format-device-ascend-code
...@@ -68,9 +68,9 @@ std::string GetRankId() { ...@@ -68,9 +68,9 @@ std::string GetRankId() {
int rank_offset = std::stoi(offset); int rank_offset = std::stoi(offset);
rank_id += rank_offset; rank_id += rank_offset;
} catch (std::invalid_argument) { } catch (std::invalid_argument) {
MS_LOG(EXCEPTION) << "stoi invalid argument:" << offset; MS_LOG(EXCEPTION) << "Call stoi invalid argument:" << offset;
} catch (std::out_of_range) { } catch (std::out_of_range) {
MS_LOG(EXCEPTION) << "stoi out_of_range:" << offset; MS_LOG(EXCEPTION) << "Call stoi out_of_range:" << offset;
} }
} }
rank_id_str = std::to_string(rank_id); rank_id_str = std::to_string(rank_id);
...@@ -81,7 +81,7 @@ std::string GetRankId() { ...@@ -81,7 +81,7 @@ std::string GetRankId() {
rank_id_str = std::getenv("RANK_ID"); rank_id_str = std::getenv("RANK_ID");
#endif #endif
if (rank_id_str.empty()) { if (rank_id_str.empty()) {
MS_LOG(ERROR) << "get hccl rankid failed, please set env RANK_ID"; MS_LOG(ERROR) << "Get hccl rankid failed, please set env RANK_ID";
} }
return rank_id_str; return rank_id_str;
} }
...@@ -100,7 +100,7 @@ void AscendKernelRuntime::ClearGraphModelMap() { ...@@ -100,7 +100,7 @@ void AscendKernelRuntime::ClearGraphModelMap() {
} }
void AscendKernelRuntime::ClearGraphRuntimeResource(uint32_t graph_id) { void AscendKernelRuntime::ClearGraphRuntimeResource(uint32_t graph_id) {
MS_LOG(DEBUG) << "clear graph:" << graph_id << " runtime resource"; MS_LOG(DEBUG) << "Clear graph:" << graph_id << " runtime resource";
auto iter = graph_model_map_.find(graph_id); auto iter = graph_model_map_.find(graph_id);
if (iter == graph_model_map_.end()) { if (iter == graph_model_map_.end()) {
MS_LOG(DEBUG) << "GraphId:" << graph_id << " not found"; MS_LOG(DEBUG) << "GraphId:" << graph_id << " not found";
...@@ -118,7 +118,7 @@ bool AscendKernelRuntime::NeedDestroyHccl() { ...@@ -118,7 +118,7 @@ bool AscendKernelRuntime::NeedDestroyHccl() {
auto context_ptr = MsContext::GetInstance(); auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr); MS_EXCEPTION_IF_NULL(context_ptr);
if (!context_ptr->enable_hccl()) { if (!context_ptr->enable_hccl()) {
MS_LOG(INFO) << "hccl is not enabled"; MS_LOG(INFO) << "Hccl is not enabled";
return false; return false;
} }
// Note: make sure hcom_connectivity_detection api never be used. // Note: make sure hcom_connectivity_detection api never be used.
...@@ -126,7 +126,7 @@ bool AscendKernelRuntime::NeedDestroyHccl() { ...@@ -126,7 +126,7 @@ bool AscendKernelRuntime::NeedDestroyHccl() {
} }
void AscendKernelRuntime::ReleaseDeviceRes() { void AscendKernelRuntime::ReleaseDeviceRes() {
MS_LOG(INFO) << "ascend finalize start"; MS_LOG(INFO) << "Ascend finalize start";
// release ge runtime // release ge runtime
ClearGraphModelMap(); ClearGraphModelMap();
...@@ -134,7 +134,7 @@ void AscendKernelRuntime::ReleaseDeviceRes() { ...@@ -134,7 +134,7 @@ void AscendKernelRuntime::ReleaseDeviceRes() {
MS_EXCEPTION_IF_NULL(context_ptr); MS_EXCEPTION_IF_NULL(context_ptr);
auto ret = rtSetDevice(context_ptr->device_id()); auto ret = rtSetDevice(context_ptr->device_id());
if (ret != RT_ERROR_NONE) { if (ret != RT_ERROR_NONE) {
MS_EXCEPTION(DeviceProcessError) << "rtSetDevice, ret[" << static_cast<int>(ret) << "]"; MS_EXCEPTION(DeviceProcessError) << "Call rtSetDevice, ret[" << static_cast<int>(ret) << "]";
} }
if (mem_manager_ != nullptr) { if (mem_manager_ != nullptr) {
...@@ -144,7 +144,7 @@ void AscendKernelRuntime::ReleaseDeviceRes() { ...@@ -144,7 +144,7 @@ void AscendKernelRuntime::ReleaseDeviceRes() {
(void)DestroyHccl(); (void)DestroyHccl();
(void)ResetDevice(); (void)ResetDevice();
(void)ProfilingManager::GetInstance().StopProfiling(); (void)ProfilingManager::GetInstance().StopProfiling();
MS_LOG(INFO) << "ascend finalize end"; MS_LOG(INFO) << "Ascend finalize end";
} }
bool AscendKernelRuntime::Init() { bool AscendKernelRuntime::Init() {
...@@ -155,7 +155,7 @@ bool AscendKernelRuntime::Init() { ...@@ -155,7 +155,7 @@ bool AscendKernelRuntime::Init() {
#ifdef ENABLE_DUMP_E2E #ifdef ENABLE_DUMP_E2E
ret = SetDumpConf(); ret = SetDumpConf();
if (!ret) { if (!ret) {
MS_LOG(INFO) << "no dump conf to set!"; MS_LOG(INFO) << "No dump conf to set!";
} }
#endif #endif
...@@ -263,13 +263,13 @@ void DumpParameters(mindspore::session::KernelGraph *graph, const string &dump_p ...@@ -263,13 +263,13 @@ void DumpParameters(mindspore::session::KernelGraph *graph, const string &dump_p
bool AscendKernelRuntime::DumpData(mindspore::session::KernelGraph *graph) { bool AscendKernelRuntime::DumpData(mindspore::session::KernelGraph *graph) {
MS_EXCEPTION_IF_NULL(graph); MS_EXCEPTION_IF_NULL(graph);
#ifdef ENABLE_DUMP_E2E #ifdef ENABLE_DUMP_E2E
MS_LOG(INFO) << "start dump step"; MS_LOG(INFO) << "Start dump step";
DumpConfPtr dump_conf = GetDumpConf(); DumpConfPtr dump_conf = GetDumpConf();
MS_EXCEPTION_IF_NULL(dump_conf); MS_EXCEPTION_IF_NULL(dump_conf);
dump_conf->UpdataCurIter(); dump_conf->UpdataCurIter();
bool dump_flag = dump_conf->dump_enable(); bool dump_flag = dump_conf->dump_enable();
if (!dump_flag) { if (!dump_flag) {
MS_LOG(INFO) << "dump flag is disable, pass dump step"; MS_LOG(INFO) << "Dump flag is disable, pass dump step";
return true; return true;
} }
uint32_t cur_iter = dump_conf->cur_iter(); uint32_t cur_iter = dump_conf->cur_iter();
...@@ -278,7 +278,7 @@ bool AscendKernelRuntime::DumpData(mindspore::session::KernelGraph *graph) { ...@@ -278,7 +278,7 @@ bool AscendKernelRuntime::DumpData(mindspore::session::KernelGraph *graph) {
return true; return true;
} }
} }
MS_LOG(INFO) << "cur iter is " << cur_iter; MS_LOG(INFO) << "Cur iter is " << cur_iter;
std::string net_name = dump_conf->dump_net_name(); std::string net_name = dump_conf->dump_net_name();
std::string iterator = to_string(cur_iter); std::string iterator = to_string(cur_iter);
std::string dump_path = dump_conf->dump_path(); std::string dump_path = dump_conf->dump_path();
...@@ -369,9 +369,9 @@ void LoadParameters(mindspore::session::KernelGraph *graph, Debugger *debugger) ...@@ -369,9 +369,9 @@ void LoadParameters(mindspore::session::KernelGraph *graph, Debugger *debugger)
bool AscendKernelRuntime::LoadData(mindspore::session::KernelGraph *graph, Debugger *debugger) { bool AscendKernelRuntime::LoadData(mindspore::session::KernelGraph *graph, Debugger *debugger) {
MS_EXCEPTION_IF_NULL(graph); MS_EXCEPTION_IF_NULL(graph);
#ifdef ENABLE_DEBUGGER #ifdef ENABLE_DEBUGGER
MS_LOG(INFO) << "start load step"; MS_LOG(INFO) << "Start load step";
uint32_t cur_iter = 0; uint32_t cur_iter = 0;
MS_LOG(INFO) << "cur iter is " << cur_iter; MS_LOG(INFO) << "Cur iter is " << cur_iter;
// load output // load output
LoadOutput(graph, debugger); LoadOutput(graph, debugger);
// load parameters // load parameters
...@@ -421,7 +421,7 @@ bool AscendKernelRuntime::GenTask(const session::KernelGraph *graph) { ...@@ -421,7 +421,7 @@ bool AscendKernelRuntime::GenTask(const session::KernelGraph *graph) {
} }
// Graph may have no compute node, such TensorAddGrad. // Graph may have no compute node, such TensorAddGrad.
if (task_info_list.empty()) { if (task_info_list.empty()) {
MS_LOG(WARNING) << "graph " << graph->graph_id() << " have no compute node"; MS_LOG(WARNING) << "Graph " << graph->graph_id() << " have no compute node";
return true; return true;
} }
AscendStreamAssign &assign_instance = AscendStreamAssign::GetInstance(); AscendStreamAssign &assign_instance = AscendStreamAssign::GetInstance();
...@@ -432,7 +432,7 @@ bool AscendKernelRuntime::GenTask(const session::KernelGraph *graph) { ...@@ -432,7 +432,7 @@ bool AscendKernelRuntime::GenTask(const session::KernelGraph *graph) {
assign_instance.GetWaitStreams(&wait_active_stream_list); assign_instance.GetWaitStreams(&wait_active_stream_list);
std::vector<uint32_t> force_copy_stream_list; std::vector<uint32_t> force_copy_stream_list;
assign_instance.GetHcomStreams(&force_copy_stream_list); assign_instance.GetHcomStreams(&force_copy_stream_list);
MS_LOG(INFO) << "call DavinciModel total stream num:" << resource_manager.get_cur_stream_num() MS_LOG(INFO) << "Call DavinciModel total stream num:" << resource_manager.get_cur_stream_num()
<< ", total event num:" << resource_manager.get_cur_event_num() << ", total event num:" << resource_manager.get_cur_event_num()
<< ", total label num:" << label_assign_instance.GetLabelNum(NOT_NULL(graph)) << ", total label num:" << label_assign_instance.GetLabelNum(NOT_NULL(graph))
<< ", wait_active_stream_list size:" << wait_active_stream_list.size() << ", wait_active_stream_list size:" << wait_active_stream_list.size()
...@@ -524,7 +524,7 @@ bool AscendKernelRuntime::RunTask(const session::KernelGraph *graph) { ...@@ -524,7 +524,7 @@ bool AscendKernelRuntime::RunTask(const session::KernelGraph *graph) {
bool status = ge::model_runner::ModelRunner::Instance().RunModel(graph->graph_id(), input_tensors, output_tensors); bool status = ge::model_runner::ModelRunner::Instance().RunModel(graph->graph_id(), input_tensors, output_tensors);
if (!status) { if (!status) {
MS_LOG(ERROR) << "run task failed"; MS_LOG(ERROR) << "Run task failed";
DebugTaskIdName(graph->graph_id()); DebugTaskIdName(graph->graph_id());
return false; return false;
} }
...@@ -543,18 +543,18 @@ bool AscendKernelRuntime::InitDevice() { ...@@ -543,18 +543,18 @@ bool AscendKernelRuntime::InitDevice() {
int device_count = 0; int device_count = 0;
auto ret = rtGetDeviceCount(&device_count); auto ret = rtGetDeviceCount(&device_count);
if (ret != RT_ERROR_NONE) { if (ret != RT_ERROR_NONE) {
MS_EXCEPTION(DeviceProcessError) << "rtGetDeviceCount, ret[" << static_cast<int>(ret) << "]"; MS_EXCEPTION(DeviceProcessError) << "Call rtGetDeviceCount, ret[" << static_cast<int>(ret) << "]";
} }
ret = rtSetDevice(device_id_); ret = rtSetDevice(device_id_);
if (ret != RT_ERROR_NONE) { if (ret != RT_ERROR_NONE) {
MS_EXCEPTION(DeviceProcessError) << "rtSetDevice, ret[" << static_cast<int>(ret) << "]"; MS_EXCEPTION(DeviceProcessError) << "Call rtSetDevice, ret[" << static_cast<int>(ret) << "]";
} }
auto context_ptr = MsContext::GetInstance(); auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr); MS_EXCEPTION_IF_NULL(context_ptr);
if (context_ptr == nullptr) { if (context_ptr == nullptr) {
MS_LOG(ERROR) << "get MsContext instance failed"; MS_LOG(ERROR) << "Get MsContext instance failed";
return false; return false;
} }
if (context_ptr->enable_hccl()) { if (context_ptr->enable_hccl()) {
...@@ -566,17 +566,17 @@ bool AscendKernelRuntime::InitDevice() { ...@@ -566,17 +566,17 @@ bool AscendKernelRuntime::InitDevice() {
ret = rtCtxCreate(&rt_context_, 0, device_id_); ret = rtCtxCreate(&rt_context_, 0, device_id_);
if (ret != RT_ERROR_NONE) { if (ret != RT_ERROR_NONE) {
MS_EXCEPTION(DeviceProcessError) << "rtCtxCreate, ret[" << static_cast<int>(ret) << "]"; MS_EXCEPTION(DeviceProcessError) << "Call rtCtxCreate, ret[" << static_cast<int>(ret) << "]";
} }
ret = rtCtxSetCurrent(rt_context_); ret = rtCtxSetCurrent(rt_context_);
if (ret != RT_ERROR_NONE) { if (ret != RT_ERROR_NONE) {
MS_EXCEPTION(DeviceProcessError) << "rtCtxSetCurrent, ret[" << ret << "]"; MS_EXCEPTION(DeviceProcessError) << "Call rtCtxSetCurrent, ret[" << ret << "]";
} }
ret = rtStreamCreate(&stream_, 0); ret = rtStreamCreate(&stream_, 0);
if (ret != RT_ERROR_NONE) { if (ret != RT_ERROR_NONE) {
MS_LOG(EXCEPTION) << "rtStreamCreate, ret[" << ret << "]"; MS_LOG(EXCEPTION) << "Call rtStreamCreate, ret[" << ret << "]";
} }
return true; return true;
...@@ -585,14 +585,14 @@ bool AscendKernelRuntime::InitDevice() { ...@@ -585,14 +585,14 @@ bool AscendKernelRuntime::InitDevice() {
bool AscendKernelRuntime::ResetDevice() { bool AscendKernelRuntime::ResetDevice() {
auto ret = rtCtxSetCurrent(rt_context_); auto ret = rtCtxSetCurrent(rt_context_);
if (ret != RT_ERROR_NONE) { if (ret != RT_ERROR_NONE) {
MS_LOG(ERROR) << "call rtCtxSetCurrent failed"; MS_LOG(ERROR) << "Call rtCtxSetCurrent failed";
return false; return false;
} }
if (stream_ != nullptr) { if (stream_ != nullptr) {
ret = rtStreamDestroy(stream_); ret = rtStreamDestroy(stream_);
if (ret != RT_ERROR_NONE) { if (ret != RT_ERROR_NONE) {
MS_LOG(EXCEPTION) << "rtStreamDestroy, ret[" << ret << "]"; MS_LOG(EXCEPTION) << "Call rtStreamDestroy, ret[" << ret << "]";
} }
stream_ = nullptr; stream_ = nullptr;
} }
...@@ -600,7 +600,7 @@ bool AscendKernelRuntime::ResetDevice() { ...@@ -600,7 +600,7 @@ bool AscendKernelRuntime::ResetDevice() {
if (rt_context_ != nullptr) { if (rt_context_ != nullptr) {
ret = rtCtxDestroy(rt_context_); ret = rtCtxDestroy(rt_context_);
if (ret != RT_ERROR_NONE) { if (ret != RT_ERROR_NONE) {
MS_EXCEPTION(DeviceProcessError) << "rtCtxDestroy, ret[" << ret << "]"; MS_EXCEPTION(DeviceProcessError) << "Call rtCtxDestroy, ret[" << ret << "]";
} }
rt_context_ = nullptr; rt_context_ = nullptr;
} }
...@@ -613,30 +613,30 @@ bool AscendKernelRuntime::HcclInit() { ...@@ -613,30 +613,30 @@ bool AscendKernelRuntime::HcclInit() {
if (!context_ptr->IsTsdOpened()) { if (!context_ptr->IsTsdOpened()) {
MS_LOG(EXCEPTION) << "Hccl dependent tsd is not open"; MS_LOG(EXCEPTION) << "Hccl dependent tsd is not open";
} }
MS_LOG(INFO) << "do hcom init"; MS_LOG(INFO) << "Do hcom init";
auto config_path_str = std::getenv("MINDSPORE_HCCL_CONFIG_PATH"); auto config_path_str = std::getenv("MINDSPORE_HCCL_CONFIG_PATH");
if (config_path_str == nullptr) { if (config_path_str == nullptr) {
config_path_str = std::getenv("RANK_TABLE_FILE"); config_path_str = std::getenv("RANK_TABLE_FILE");
if (config_path_str == nullptr) { if (config_path_str == nullptr) {
MS_LOG(ERROR) << "get hccl json config failed, please set env MINDSPORE_HCCL_CONFIG_PATH or RANK_TABLE_FILE"; MS_LOG(ERROR) << "Get hccl json config failed, please set env MINDSPORE_HCCL_CONFIG_PATH or RANK_TABLE_FILE";
return false; return false;
} }
} }
if (strlen(config_path_str) > PATH_MAX) { if (strlen(config_path_str) > PATH_MAX) {
MS_LOG(ERROR) << "file path oversize"; MS_LOG(ERROR) << "File path oversize";
return false; return false;
} }
std::string rank_id_str = GetRankId(); std::string rank_id_str = GetRankId();
auto full_path = realpath(config_path_str, nullptr); auto full_path = realpath(config_path_str, nullptr);
if (full_path == nullptr) { if (full_path == nullptr) {
MS_LOG(ERROR) << "file path " << config_path_str << " does not exist"; MS_LOG(ERROR) << "File path " << config_path_str << " does not exist";
return false; return false;
} }
MS_LOG(INFO) << "MINDSPORE_HCCL_CONFIG_PATH : " << full_path << ", RANK_ID: " << rank_id_str; MS_LOG(INFO) << "MINDSPORE_HCCL_CONFIG_PATH : " << full_path << ", RANK_ID: " << rank_id_str;
hcclResult_t res = hcom_init(full_path, rank_id_str.c_str()); hcclResult_t res = hcom_init(full_path, rank_id_str.c_str());
free(full_path); free(full_path);
if (res != HCCL_SUCCESS) { if (res != HCCL_SUCCESS) {
MS_LOG(ERROR) << "hcom init failed, res is " << static_cast<int>(res); MS_LOG(ERROR) << "Hcom init failed, res is " << static_cast<int>(res);
return false; return false;
} }
return true; return true;
...@@ -646,15 +646,15 @@ bool AscendKernelRuntime::DestroyHccl() { ...@@ -646,15 +646,15 @@ bool AscendKernelRuntime::DestroyHccl() {
auto context_ptr = MsContext::GetInstance(); auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr); MS_EXCEPTION_IF_NULL(context_ptr);
if (!NeedDestroyHccl()) { if (!NeedDestroyHccl()) {
MS_LOG(INFO) << "hccl is not enable, no need to close."; MS_LOG(INFO) << "Hccl is not enable, no need to close.";
return true; return true;
} }
hcclResult_t res = hcom_destroy(); hcclResult_t res = hcom_destroy();
if (res != HCCL_SUCCESS) { if (res != HCCL_SUCCESS) {
MS_LOG(ERROR) << "hccl destroy failed"; MS_LOG(ERROR) << "Hccl destroy failed";
return false; return false;
} }
MS_LOG(INFO) << "hccl destroy successful, status = " << res << "."; MS_LOG(INFO) << "Hccl destroy successful, status = " << res << ".";
context_ptr->set_enable_hccl(false); context_ptr->set_enable_hccl(false);
return true; return true;
} }
......
...@@ -46,7 +46,7 @@ void AscendStreamAssign::AssignStream(const NotNull<KernelGraphPtr> &graph_ptr) ...@@ -46,7 +46,7 @@ void AscendStreamAssign::AssignStream(const NotNull<KernelGraphPtr> &graph_ptr)
GetNeedActiveStreams(graph_ptr); GetNeedActiveStreams(graph_ptr);
graph_ptr->PrintGraphExecuteOrder(); graph_ptr->PrintGraphExecuteOrder();
CheckResourceAssign(graph_ptr); CheckResourceAssign(graph_ptr);
MS_LOG(INFO) << "after finish stream assign"; MS_LOG(INFO) << "After finish stream assign";
// Get info for D Model // Get info for D Model
AscendResourceMng &resource_manager = AscendResourceMng::GetInstance(); AscendResourceMng &resource_manager = AscendResourceMng::GetInstance();
...@@ -64,7 +64,7 @@ void AscendStreamAssign::ReorderIndependentOrders(const NotNull<KernelGraphPtr> ...@@ -64,7 +64,7 @@ void AscendStreamAssign::ReorderIndependentOrders(const NotNull<KernelGraphPtr>
std::vector<CNodePtr> others; std::vector<CNodePtr> others;
auto cnode_ptr_list = graph_ptr->execution_order(); auto cnode_ptr_list = graph_ptr->execution_order();
MS_LOG(INFO) << "before reorder, graph orders size:" << cnode_ptr_list.size(); MS_LOG(INFO) << "Before reorder, graph orders size:" << cnode_ptr_list.size();
for (size_t i = 0; i < cnode_ptr_list.size(); ++i) { for (size_t i = 0; i < cnode_ptr_list.size(); ++i) {
auto cur_cnode_ptr = cnode_ptr_list[i]; auto cur_cnode_ptr = cnode_ptr_list[i];
MS_EXCEPTION_IF_NULL(cur_cnode_ptr); MS_EXCEPTION_IF_NULL(cur_cnode_ptr);
...@@ -76,7 +76,7 @@ void AscendStreamAssign::ReorderIndependentOrders(const NotNull<KernelGraphPtr> ...@@ -76,7 +76,7 @@ void AscendStreamAssign::ReorderIndependentOrders(const NotNull<KernelGraphPtr>
} }
if (others.empty() || independents.empty()) { if (others.empty() || independents.empty()) {
MS_LOG(INFO) << "independent or others is empty, no need reorder"; MS_LOG(INFO) << "Independent or others is empty, no need reorder";
return; return;
} }
...@@ -107,9 +107,9 @@ void AscendStreamAssign::ReorderIndependentOrders(const NotNull<KernelGraphPtr> ...@@ -107,9 +107,9 @@ void AscendStreamAssign::ReorderIndependentOrders(const NotNull<KernelGraphPtr>
} }
} }
MS_LOG(INFO) << "after reorder, graph orders size:" << exe_orders.size(); MS_LOG(INFO) << "After reorder, graph orders size:" << exe_orders.size();
if (processed.size() != independents.size()) { if (processed.size() != independents.size()) {
MS_LOG(WARNING) << "processed independent nodes size is not equal to exiting independent nodes size"; MS_LOG(WARNING) << "Processed independent nodes size is not equal to exiting independent nodes size";
return; return;
} }
...@@ -142,7 +142,7 @@ void AscendStreamAssign::AssignAllNodesStream(const NotNull<KernelGraphPtr> &gra ...@@ -142,7 +142,7 @@ void AscendStreamAssign::AssignAllNodesStream(const NotNull<KernelGraphPtr> &gra
AssignCommonStreamId(cur_cnode_ptr); AssignCommonStreamId(cur_cnode_ptr);
} }
MS_LOG(INFO) << "common start from 0, common stream nums:" << resource_manager.get_cur_stream_num(); MS_LOG(INFO) << "Common start from 0, common stream nums:" << resource_manager.get_cur_stream_num();
if (exit_hcom) { if (exit_hcom) {
uint32_t first_hcom_stream_id = resource_manager.ApplyNewStream(); uint32_t first_hcom_stream_id = resource_manager.ApplyNewStream();
...@@ -157,7 +157,7 @@ void AscendStreamAssign::AssignAllNodesStream(const NotNull<KernelGraphPtr> &gra ...@@ -157,7 +157,7 @@ void AscendStreamAssign::AssignAllNodesStream(const NotNull<KernelGraphPtr> &gra
AssignHcomStreamId(cur_cnode_ptr); AssignHcomStreamId(cur_cnode_ptr);
} }
} }
MS_LOG(INFO) << "hcom start from :" << first_hcom_stream_id << ", hcom stream nums:" << hcom_stream_map_.size(); MS_LOG(INFO) << "Hcom start from :" << first_hcom_stream_id << ", hcom stream nums:" << hcom_stream_map_.size();
} }
if (exit_independent) { if (exit_independent) {
...@@ -171,10 +171,10 @@ void AscendStreamAssign::AssignAllNodesStream(const NotNull<KernelGraphPtr> &gra ...@@ -171,10 +171,10 @@ void AscendStreamAssign::AssignAllNodesStream(const NotNull<KernelGraphPtr> &gra
AssignIndependentStreamId(cur_cnode_ptr); AssignIndependentStreamId(cur_cnode_ptr);
} }
} }
MS_LOG(INFO) << "independ start from:" << first_independ << ", stream nums:" << independent_stream_map_.size(); MS_LOG(INFO) << "Independ start from:" << first_independ << ", stream nums:" << independent_stream_map_.size();
} }
MS_LOG(INFO) << "after stream assign, total stream nums:" << resource_manager.get_cur_stream_num(); MS_LOG(INFO) << "After stream assign, total stream nums:" << resource_manager.get_cur_stream_num();
} }
void AscendStreamAssign::AssignCommonStreamId(const CNodePtr &cur_cnode_ptr) { void AscendStreamAssign::AssignCommonStreamId(const CNodePtr &cur_cnode_ptr) {
...@@ -257,7 +257,7 @@ bool AscendStreamAssign::IsIndependentNode(const CNodePtr &node_ptr) { ...@@ -257,7 +257,7 @@ bool AscendStreamAssign::IsIndependentNode(const CNodePtr &node_ptr) {
uint32_t input_nums = AnfAlgo::GetInputTensorNum(node_ptr); uint32_t input_nums = AnfAlgo::GetInputTensorNum(node_ptr);
if (input_nums == 0) { if (input_nums == 0) {
MS_LOG(INFO) << "node " << node_ptr->fullname_with_scope() << " is independent, as inputs nums is zero"; MS_LOG(INFO) << "Node " << node_ptr->fullname_with_scope() << " is independent, as inputs nums is zero";
return true; return true;
} }
...@@ -267,13 +267,13 @@ bool AscendStreamAssign::IsIndependentNode(const CNodePtr &node_ptr) { ...@@ -267,13 +267,13 @@ bool AscendStreamAssign::IsIndependentNode(const CNodePtr &node_ptr) {
return false; return false;
} }
} }
MS_LOG(INFO) << "node " << node_ptr->fullname_with_scope() << " is independent, as inputs is all value node"; MS_LOG(INFO) << "Node " << node_ptr->fullname_with_scope() << " is independent, as inputs is all value node";
return true; return true;
} }
// section 3: // section 3:
void AscendStreamAssign::UpdateAtomicAddrCleanStreamId(const NotNull<KernelGraphPtr> &graph_ptr) { void AscendStreamAssign::UpdateAtomicAddrCleanStreamId(const NotNull<KernelGraphPtr> &graph_ptr) {
MS_LOG(INFO) << "start"; MS_LOG(INFO) << "Start";
auto cnode_ptr_list = graph_ptr->execution_order(); auto cnode_ptr_list = graph_ptr->execution_order();
for (size_t i = 0; i < cnode_ptr_list.size(); ++i) { for (size_t i = 0; i < cnode_ptr_list.size(); ++i) {
CNodePtr cur_cnode_ptr = cnode_ptr_list[i]; CNodePtr cur_cnode_ptr = cnode_ptr_list[i];
...@@ -283,12 +283,12 @@ void AscendStreamAssign::UpdateAtomicAddrCleanStreamId(const NotNull<KernelGraph ...@@ -283,12 +283,12 @@ void AscendStreamAssign::UpdateAtomicAddrCleanStreamId(const NotNull<KernelGraph
AnfAlgo::SetStreamId(AnfAlgo::GetStreamId(cur_cnode_ptr), cnode_ptr_list[i - 1].get()); AnfAlgo::SetStreamId(AnfAlgo::GetStreamId(cur_cnode_ptr), cnode_ptr_list[i - 1].get());
} }
} }
MS_LOG(INFO) << "end"; MS_LOG(INFO) << "End";
} }
// section 4 // section 4
void AscendStreamAssign::InsertStreamActive(const NotNull<KernelGraphPtr> &graph_ptr) { void AscendStreamAssign::InsertStreamActive(const NotNull<KernelGraphPtr> &graph_ptr) {
MS_LOG(INFO) << "start"; MS_LOG(INFO) << "Start";
GetProcessedStream(graph_ptr); GetProcessedStream(graph_ptr);
std::vector<CNodePtr> update_cnode_list; std::vector<CNodePtr> update_cnode_list;
CNodePtr cur_cnode_ptr = nullptr; CNodePtr cur_cnode_ptr = nullptr;
...@@ -314,7 +314,7 @@ void AscendStreamAssign::InsertStreamActive(const NotNull<KernelGraphPtr> &graph ...@@ -314,7 +314,7 @@ void AscendStreamAssign::InsertStreamActive(const NotNull<KernelGraphPtr> &graph
bool processed = IsProcessedStream(cur_stream_id); bool processed = IsProcessedStream(cur_stream_id);
// 1)inner stream assign, need insert active op // 1)inner stream assign, need insert active op
if (!processed) { if (!processed) {
MS_LOG(INFO) << "common stream active info:" << pre_stream_id << "->active" << cur_stream_id; MS_LOG(INFO) << "Common stream active info:" << pre_stream_id << "->active" << cur_stream_id;
CNodePtr active_ptr = KernelAdjust::GetInstance().CreateStreamActiveOp(graph_ptr); CNodePtr active_ptr = KernelAdjust::GetInstance().CreateStreamActiveOp(graph_ptr);
// 1.set stream id // 1.set stream id
AnfAlgo::SetStreamId(pre_stream_id, active_ptr.get()); AnfAlgo::SetStreamId(pre_stream_id, active_ptr.get());
...@@ -336,7 +336,7 @@ void AscendStreamAssign::InsertStreamActive(const NotNull<KernelGraphPtr> &graph ...@@ -336,7 +336,7 @@ void AscendStreamAssign::InsertStreamActive(const NotNull<KernelGraphPtr> &graph
pre_cnode_ptr = cur_cnode_ptr; pre_cnode_ptr = cur_cnode_ptr;
} }
graph_ptr->set_execution_order(update_cnode_list); graph_ptr->set_execution_order(update_cnode_list);
MS_LOG(INFO) << "end"; MS_LOG(INFO) << "End";
} }
void AscendStreamAssign::GetProcessedStream(const NotNull<KernelGraphPtr> &graph_ptr) { void AscendStreamAssign::GetProcessedStream(const NotNull<KernelGraphPtr> &graph_ptr) {
...@@ -364,7 +364,7 @@ void AscendStreamAssign::GetProcessedStream(const NotNull<KernelGraphPtr> &graph ...@@ -364,7 +364,7 @@ void AscendStreamAssign::GetProcessedStream(const NotNull<KernelGraphPtr> &graph
} }
} }
for (const auto &item : processed_streams_) { for (const auto &item : processed_streams_) {
MS_LOG(INFO) << "before active:" << item << " is been processed"; MS_LOG(INFO) << "Before active:" << item << " is been processed";
} }
} }
...@@ -385,7 +385,7 @@ void AscendStreamAssign::UpdateStreamSwitch(const NotNull<KernelGraphPtr> &graph ...@@ -385,7 +385,7 @@ void AscendStreamAssign::UpdateStreamSwitch(const NotNull<KernelGraphPtr> &graph
MS_EXCEPTION_IF_NULL(switch_ptr); MS_EXCEPTION_IF_NULL(switch_ptr);
auto true_stream_id = GetValue<uint32_t>(primitive->GetAttr(kAttrTrueBranchStream)); auto true_stream_id = GetValue<uint32_t>(primitive->GetAttr(kAttrTrueBranchStream));
MS_LOG(INFO) << "streamswtich stream id:" << AnfAlgo::GetStreamId(switch_ptr) MS_LOG(INFO) << "Streamswtich stream id:" << AnfAlgo::GetStreamId(switch_ptr)
<< "; active stream id:" << true_stream_id; << "; active stream id:" << true_stream_id;
CNodePtr active_ptr = KernelAdjust::GetInstance().CreateStreamActiveOp(graph_ptr); CNodePtr active_ptr = KernelAdjust::GetInstance().CreateStreamActiveOp(graph_ptr);
...@@ -425,11 +425,11 @@ bool AscendStreamAssign::IsProcessedStream(uint32_t stream_id) { ...@@ -425,11 +425,11 @@ bool AscendStreamAssign::IsProcessedStream(uint32_t stream_id) {
// section5 // section5
void AscendStreamAssign::InsertEventForHcomParallel(const NotNull<KernelGraphPtr> &graph_ptr) { void AscendStreamAssign::InsertEventForHcomParallel(const NotNull<KernelGraphPtr> &graph_ptr) {
MS_LOG(INFO) << "start"; MS_LOG(INFO) << "Start";
InsertEventCommonDependHcom(graph_ptr); InsertEventCommonDependHcom(graph_ptr);
InsertEventHcomDependCommon(graph_ptr); InsertEventHcomDependCommon(graph_ptr);
InsertEventHcomDependHcom(graph_ptr); InsertEventHcomDependHcom(graph_ptr);
MS_LOG(INFO) << "end"; MS_LOG(INFO) << "End";
} }
void AscendStreamAssign::InsertEventCommonDependHcom(const NotNull<KernelGraphPtr> &graph_ptr) { void AscendStreamAssign::InsertEventCommonDependHcom(const NotNull<KernelGraphPtr> &graph_ptr) {
...@@ -447,7 +447,7 @@ void AscendStreamAssign::InsertEventCommonDependHcom(const NotNull<KernelGraphPt ...@@ -447,7 +447,7 @@ void AscendStreamAssign::InsertEventCommonDependHcom(const NotNull<KernelGraphPt
auto target = FindTargetOp(it, cnodes.end(), *(it - 1)); auto target = FindTargetOp(it, cnodes.end(), *(it - 1));
if (target == cnodes.end()) { if (target == cnodes.end()) {
MS_LOG(WARNING) << "hcom node:" << (*(it - 1))->fullname_with_scope() MS_LOG(WARNING) << "Hcom node:" << (*(it - 1))->fullname_with_scope()
<< ", can't find target for insert recv op, no insert send/recv"; << ", can't find target for insert recv op, no insert send/recv";
it = cnodes.erase(it); it = cnodes.erase(it);
continue; continue;
...@@ -469,7 +469,7 @@ void AscendStreamAssign::InsertEventCommonDependHcom(const NotNull<KernelGraphPt ...@@ -469,7 +469,7 @@ void AscendStreamAssign::InsertEventCommonDependHcom(const NotNull<KernelGraphPt
// one event allocated additional, should delete // one event allocated additional, should delete
resource_manager.DeleteEvent(); resource_manager.DeleteEvent();
graph_ptr->set_execution_order(cnodes); graph_ptr->set_execution_order(cnodes);
MS_LOG(INFO) << "after common depend hcom, total event nums:" << resource_manager.get_cur_event_num(); MS_LOG(INFO) << "After common depend hcom, total event nums:" << resource_manager.get_cur_event_num();
} }
void AscendStreamAssign::InsertEventHcomDependCommon(const NotNull<KernelGraphPtr> &graph_ptr) { void AscendStreamAssign::InsertEventHcomDependCommon(const NotNull<KernelGraphPtr> &graph_ptr) {
...@@ -512,7 +512,7 @@ void AscendStreamAssign::InsertEventHcomDependCommon(const NotNull<KernelGraphPt ...@@ -512,7 +512,7 @@ void AscendStreamAssign::InsertEventHcomDependCommon(const NotNull<KernelGraphPt
} }
graph_ptr->set_execution_order(cnodes); graph_ptr->set_execution_order(cnodes);
MS_LOG(INFO) << "after hcom depend common, total event nums:" << resource_manager.get_cur_event_num(); MS_LOG(INFO) << "After hcom depend common, total event nums:" << resource_manager.get_cur_event_num();
} }
void AscendStreamAssign::InsertEventHcomDependHcom(const NotNull<KernelGraphPtr> &graph_ptr) { void AscendStreamAssign::InsertEventHcomDependHcom(const NotNull<KernelGraphPtr> &graph_ptr) {
...@@ -547,11 +547,11 @@ void AscendStreamAssign::InsertEventHcomDependHcom(const NotNull<KernelGraphPtr> ...@@ -547,11 +547,11 @@ void AscendStreamAssign::InsertEventHcomDependHcom(const NotNull<KernelGraphPtr>
} }
if (hcom_index.size() < 2) { if (hcom_index.size() < 2) {
MS_LOG(INFO) << "different stream hcom size is less than 2, no need insert event between them"; MS_LOG(INFO) << "Different stream hcom size is less than 2, no need insert event between them";
return; return;
} }
InsertEventBetweenHcom(graph_ptr, hcom_index, first_hcom_stream, last_hcom_stream); InsertEventBetweenHcom(graph_ptr, hcom_index, first_hcom_stream, last_hcom_stream);
MS_LOG(INFO) << "after hcom depend hcom, total event nums:" << resource_manager.get_cur_event_num(); MS_LOG(INFO) << "After hcom depend hcom, total event nums:" << resource_manager.get_cur_event_num();
} }
void AscendStreamAssign::InsertEventBetweenHcom(const NotNull<KernelGraphPtr> &graph_ptr, void AscendStreamAssign::InsertEventBetweenHcom(const NotNull<KernelGraphPtr> &graph_ptr,
...@@ -630,7 +630,7 @@ bool AscendStreamAssign::IsSatisfiedHcom(const std::map<uint32_t, vector<size_t> ...@@ -630,7 +630,7 @@ bool AscendStreamAssign::IsSatisfiedHcom(const std::map<uint32_t, vector<size_t>
// section6 // section6
void AscendStreamAssign::InsertEventForIndependentParallel(const NotNull<KernelGraphPtr> &graph_ptr) { void AscendStreamAssign::InsertEventForIndependentParallel(const NotNull<KernelGraphPtr> &graph_ptr) {
MS_LOG(INFO) << "start"; MS_LOG(INFO) << "Start";
AscendResourceMng &resource_manager = AscendResourceMng::GetInstance(); AscendResourceMng &resource_manager = AscendResourceMng::GetInstance();
auto cnode_ptr_list = graph_ptr->execution_order(); auto cnode_ptr_list = graph_ptr->execution_order();
vector<CNodePtr> cnodes = cnode_ptr_list; vector<CNodePtr> cnodes = cnode_ptr_list;
...@@ -639,13 +639,13 @@ void AscendStreamAssign::InsertEventForIndependentParallel(const NotNull<KernelG ...@@ -639,13 +639,13 @@ void AscendStreamAssign::InsertEventForIndependentParallel(const NotNull<KernelG
while (it != cnodes.end()) { while (it != cnodes.end()) {
MS_EXCEPTION_IF_NULL(*it); MS_EXCEPTION_IF_NULL(*it);
if (IsIndependentNode(*it)) { if (IsIndependentNode(*it)) {
MS_LOG(INFO) << "deal independent op[" << (*it)->DebugString() << "]"; MS_LOG(INFO) << "Deal independent op[" << (*it)->DebugString() << "]";
CNodePtr send_cnode_ptr = CreateSendApplyKernel(graph_ptr, cur_event_id, AnfAlgo::GetStreamId(*it)); CNodePtr send_cnode_ptr = CreateSendApplyKernel(graph_ptr, cur_event_id, AnfAlgo::GetStreamId(*it));
it = cnodes.insert(it + 1, send_cnode_ptr); it = cnodes.insert(it + 1, send_cnode_ptr);
auto target = FindTargetOp(it, cnodes.end(), *(it - 1)); auto target = FindTargetOp(it, cnodes.end(), *(it - 1));
if (target == cnodes.end()) { if (target == cnodes.end()) {
MS_LOG(DEBUG) << "independ node[" << (*(it - 1))->fullname_with_scope() MS_LOG(DEBUG) << "Independ node[" << (*(it - 1))->fullname_with_scope()
<< "] can't find target for insert recv op, no insert send/recv"; << "] can't find target for insert recv op, no insert send/recv";
it = cnodes.erase(it); it = cnodes.erase(it);
continue; continue;
...@@ -662,8 +662,8 @@ void AscendStreamAssign::InsertEventForIndependentParallel(const NotNull<KernelG ...@@ -662,8 +662,8 @@ void AscendStreamAssign::InsertEventForIndependentParallel(const NotNull<KernelG
// one event allocated additional, should delete // one event allocated additional, should delete
resource_manager.DeleteEvent(); resource_manager.DeleteEvent();
graph_ptr->set_execution_order(cnodes); graph_ptr->set_execution_order(cnodes);
MS_LOG(INFO) << "after independent parallel, total event nums:" << resource_manager.get_cur_event_num(); MS_LOG(INFO) << "After independent parallel, total event nums:" << resource_manager.get_cur_event_num();
MS_LOG(INFO) << "end"; MS_LOG(INFO) << "End";
} }
// section7 // section7
...@@ -687,7 +687,7 @@ void AscendStreamAssign::GetNeedActiveStreams(const NotNull<KernelGraphPtr> &gra ...@@ -687,7 +687,7 @@ void AscendStreamAssign::GetNeedActiveStreams(const NotNull<KernelGraphPtr> &gra
auto need_active = GetValue<bool>(value_ptr); auto need_active = GetValue<bool>(value_ptr);
if (need_active) { if (need_active) {
auto stream_id = AnfAlgo::GetStreamId(cur_cnode_ptr); auto stream_id = AnfAlgo::GetStreamId(cur_cnode_ptr);
MS_LOG(INFO) << "stream id:" << stream_id << " is need actived at first"; MS_LOG(INFO) << "Stream id:" << stream_id << " is need actived at first";
need_first_active_streams_.push_back(stream_id); need_first_active_streams_.push_back(stream_id);
} }
} }
...@@ -724,7 +724,7 @@ void AscendStreamAssign::CheckStreamAssign(const NotNull<KernelGraphPtr> &graph_ ...@@ -724,7 +724,7 @@ void AscendStreamAssign::CheckStreamAssign(const NotNull<KernelGraphPtr> &graph_
MS_EXCEPTION_IF_NULL(cur_cnode_ptr); MS_EXCEPTION_IF_NULL(cur_cnode_ptr);
uint32_t stream_id = AnfAlgo::GetStreamId(cur_cnode_ptr); uint32_t stream_id = AnfAlgo::GetStreamId(cur_cnode_ptr);
if (stream_id == kInvalidStreamId) { if (stream_id == kInvalidStreamId) {
MS_LOG(EXCEPTION) << "node:" << AnfAlgo::GetCNodeName(cur_cnode_ptr) << "had not been assigned stream"; MS_LOG(EXCEPTION) << "Node:" << AnfAlgo::GetCNodeName(cur_cnode_ptr) << "had not been assigned stream";
} }
(void)streams.emplace(stream_id); (void)streams.emplace(stream_id);
...@@ -739,11 +739,11 @@ void AscendStreamAssign::CheckStreamAssign(const NotNull<KernelGraphPtr> &graph_ ...@@ -739,11 +739,11 @@ void AscendStreamAssign::CheckStreamAssign(const NotNull<KernelGraphPtr> &graph_
// check stream assign // check stream assign
if (!streams.empty()) { if (!streams.empty()) {
if (min_stream != 0) { if (min_stream != 0) {
MS_LOG(EXCEPTION) << "stream should start from 0, now is from " << min_stream; MS_LOG(EXCEPTION) << "Stream should start from 0, now is from " << min_stream;
} }
uint32_t assigned_stream_num = resource_manager.get_cur_stream_num(); uint32_t assigned_stream_num = resource_manager.get_cur_stream_num();
if ((max_stream != assigned_stream_num - 1) || (streams.size() != assigned_stream_num)) { if ((max_stream != assigned_stream_num - 1) || (streams.size() != assigned_stream_num)) {
MS_LOG(EXCEPTION) << "stream should be consecutive, max stream id:" << max_stream MS_LOG(EXCEPTION) << "Stream should be consecutive, max stream id:" << max_stream
<< "; alloc stream nums:" << assigned_stream_num << "; streams size:" << streams.size(); << "; alloc stream nums:" << assigned_stream_num << "; streams size:" << streams.size();
} }
} }
...@@ -779,20 +779,20 @@ void AscendStreamAssign::CheckEventAssign(const NotNull<KernelGraphPtr> &graph_p ...@@ -779,20 +779,20 @@ void AscendStreamAssign::CheckEventAssign(const NotNull<KernelGraphPtr> &graph_p
// check event assign // check event assign
if (!event_map.empty()) { if (!event_map.empty()) {
if (min_event_id != 0) { if (min_event_id != 0) {
MS_LOG(EXCEPTION) << "event should start from 0, now is from " << min_event_id; MS_LOG(EXCEPTION) << "Event should start from 0, now is from " << min_event_id;
} }
uint32_t assigned_event_num = resource_manager.get_cur_event_num(); uint32_t assigned_event_num = resource_manager.get_cur_event_num();
if ((max_event_id != assigned_event_num - 1) || (event_map.size() != assigned_event_num)) { if ((max_event_id != assigned_event_num - 1) || (event_map.size() != assigned_event_num)) {
MS_LOG(EXCEPTION) << "event should be consecutive"; MS_LOG(EXCEPTION) << "Event should be consecutive";
} }
for (const auto &item : event_map) { for (const auto &item : event_map) {
if (item.second.size() != 2) { if (item.second.size() != 2) {
MS_LOG(EXCEPTION) << "send/recv should be in pair and share one event id"; MS_LOG(EXCEPTION) << "Send/recv should be in pair and share one event id";
} }
auto first_name = AnfAlgo::GetCNodeName(item.second[0]); auto first_name = AnfAlgo::GetCNodeName(item.second[0]);
auto second_name = AnfAlgo::GetCNodeName(item.second[1]); auto second_name = AnfAlgo::GetCNodeName(item.second[1]);
if (!(first_name == kSendOpName && second_name == kRecvOpName)) { if (!(first_name == kSendOpName && second_name == kRecvOpName)) {
MS_LOG(EXCEPTION) << "send should be before recv"; MS_LOG(EXCEPTION) << "Send should be before recv";
} }
} }
} }
...@@ -858,7 +858,7 @@ vector<CNodePtr>::iterator AscendStreamAssign::FindTargetOp(vector<CNodePtr>::it ...@@ -858,7 +858,7 @@ vector<CNodePtr>::iterator AscendStreamAssign::FindTargetOp(vector<CNodePtr>::it
} else { } else {
auto real_input = AnfAlgo::VisitKernel(input, 0); auto real_input = AnfAlgo::VisitKernel(input, 0);
if (node == real_input.first) { if (node == real_input.first) {
MS_LOG(INFO) << "find target op[" << (*begin)->DebugString() << "]"; MS_LOG(INFO) << "Find target op[" << (*begin)->DebugString() << "]";
return begin; return begin;
} }
} }
...@@ -872,10 +872,10 @@ bool AscendStreamAssign::IsTaskSink() { ...@@ -872,10 +872,10 @@ bool AscendStreamAssign::IsTaskSink() {
auto ms_context = MsContext::GetInstance(); auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context); MS_EXCEPTION_IF_NULL(ms_context);
if (!ms_context->enable_task_sink()) { if (!ms_context->enable_task_sink()) {
MS_LOG(INFO) << "task sink mode is not enable"; MS_LOG(INFO) << "Task sink mode is not enable";
return false; return false;
} else { } else {
MS_LOG(INFO) << "task sink mode is enable"; MS_LOG(INFO) << "Task sink mode is enable";
return true; return true;
} }
} }
...@@ -885,7 +885,7 @@ void AscendStreamAssign::GetWaitStreams(vector<uint32_t> *wait_active_stream_lis ...@@ -885,7 +885,7 @@ void AscendStreamAssign::GetWaitStreams(vector<uint32_t> *wait_active_stream_lis
AscendResourceMng &resource_manager = AscendResourceMng::GetInstance(); AscendResourceMng &resource_manager = AscendResourceMng::GetInstance();
uint32_t total_stream_num = resource_manager.get_cur_stream_num(); uint32_t total_stream_num = resource_manager.get_cur_stream_num();
if (total_stream_num == 0) { if (total_stream_num == 0) {
MS_LOG(INFO) << "total_common_stream_num is zero"; MS_LOG(INFO) << "The total_common_stream_num is zero";
return; return;
} }
...@@ -893,7 +893,7 @@ void AscendStreamAssign::GetWaitStreams(vector<uint32_t> *wait_active_stream_lis ...@@ -893,7 +893,7 @@ void AscendStreamAssign::GetWaitStreams(vector<uint32_t> *wait_active_stream_lis
for (uint32_t i = 0; i < total_stream_num; i++) { for (uint32_t i = 0; i < total_stream_num; i++) {
auto it = std::find(need_first_active_streams_.begin(), need_first_active_streams_.end(), i); auto it = std::find(need_first_active_streams_.begin(), need_first_active_streams_.end(), i);
if (it == need_first_active_streams_.end()) { if (it == need_first_active_streams_.end()) {
MS_LOG(INFO) << "wait common stream id = " << i; MS_LOG(INFO) << "Wait common stream id = " << i;
wait_active_stream_list->push_back(i); wait_active_stream_list->push_back(i);
} }
} }
......
...@@ -142,6 +142,37 @@ DeviceAddressPtr CPUKernelRuntime::CreateDeviceAddress(void *device_ptr, size_t ...@@ -142,6 +142,37 @@ DeviceAddressPtr CPUKernelRuntime::CreateDeviceAddress(void *device_ptr, size_t
return std::make_shared<CPUDeviceAddress>(device_ptr, device_size, format, type_id); return std::make_shared<CPUDeviceAddress>(device_ptr, device_size, format, type_id);
} }
tensor::TensorPtr CPUKernelRuntime::CreatTensorForOutput(const CNodePtr &node, size_t index,
std::set<DeviceAddressPtr> *bound_addresses,
std::vector<tensor::TensorPtr> *need_sync_outputs) {
MS_EXCEPTION_IF_NULL(node);
MS_EXCEPTION_IF_NULL(bound_addresses);
MS_EXCEPTION_IF_NULL(need_sync_outputs);
size_t output_size = AnfAlgo::GetOutputTensorNum(node);
if (index >= output_size) {
MS_LOG(EXCEPTION) << "Invalid input index " << index;
}
auto address = AnfAlgo::GetMutableOutputAddr(node, index);
MS_EXCEPTION_IF_NULL(address);
auto shape = AnfAlgo::GetOutputInferShape(node, index);
std::vector<int> temp_shape;
(void)temp_shape.insert(temp_shape.end(), shape.begin(), shape.end());
TypeId type_id = AnfAlgo::GetOutputInferDataType(node, index);
type_id = GetCPUSupportOutputTypeId(type_id);
tensor::TensorPtr tensor = std::make_shared<tensor::Tensor>(type_id, temp_shape);
MS_EXCEPTION_IF_NULL(tensor);
if (bound_addresses->find(address) != bound_addresses->end()) {
tensor->set_device_address(address);
need_sync_outputs->emplace_back(tensor);
} else {
address->ptr_ = tensor->data_c();
address->ref_count_ = INIT_NODE_REF;
(void)bound_addresses->insert(address);
}
tensor->set_dirty(false);
return tensor;
}
BaseRef CPUKernelRuntime::CreatTensorForOutput(const session::KernelWithIndex &kernel_with_index, BaseRef CPUKernelRuntime::CreatTensorForOutput(const session::KernelWithIndex &kernel_with_index,
const std::unordered_map<AnfNode *, tensor::TensorPtr> &input_map, const std::unordered_map<AnfNode *, tensor::TensorPtr> &input_map,
std::set<DeviceAddressPtr> *bound_addresses, std::set<DeviceAddressPtr> *bound_addresses,
...@@ -161,29 +192,7 @@ BaseRef CPUKernelRuntime::CreatTensorForOutput(const session::KernelWithIndex &k ...@@ -161,29 +192,7 @@ BaseRef CPUKernelRuntime::CreatTensorForOutput(const session::KernelWithIndex &k
} }
return ret; return ret;
} }
size_t output_size = AnfAlgo::GetOutputTensorNum(node); return CreatTensorForOutput(node, index, bound_addresses, need_sync_outputs);
if (index >= output_size) {
MS_LOG(EXCEPTION) << "Invalid input index " << index;
}
auto address = AnfAlgo::GetMutableOutputAddr(node, index);
MS_EXCEPTION_IF_NULL(address);
auto shape = AnfAlgo::GetOutputInferShape(node, index);
std::vector<int> temp_shape;
(void)temp_shape.insert(temp_shape.end(), shape.begin(), shape.end());
TypeId type_id = AnfAlgo::GetOutputInferDataType(node, index);
type_id = GetCPUSupportOutputTypeId(type_id);
tensor::TensorPtr tensor = std::make_shared<tensor::Tensor>(type_id, temp_shape);
MS_EXCEPTION_IF_NULL(tensor);
if (bound_addresses->find(address) != bound_addresses->end()) {
tensor->set_device_address(address);
need_sync_outputs->emplace_back(tensor);
} else {
address->ptr_ = tensor->data_c();
address->ref_count_ = INIT_NODE_REF;
(void)bound_addresses->insert(address);
}
tensor->set_dirty(false);
return tensor;
} else if (input_node->isa<Parameter>() || input_node->isa<ValueNode>()) { } else if (input_node->isa<Parameter>() || input_node->isa<ValueNode>()) {
auto iter = input_map.find(input_node.get()); auto iter = input_map.find(input_node.get());
if (iter != input_map.end()) { if (iter != input_map.end()) {
...@@ -247,6 +256,7 @@ void CPUKernelRuntime::BindInputOutput(const session::KernelGraph *kernel_graph, ...@@ -247,6 +256,7 @@ void CPUKernelRuntime::BindInputOutput(const session::KernelGraph *kernel_graph,
void CPUKernelRuntime::AddRuntimeAddress(DeviceAddress *address, std::vector<kernel::AddressPtr> *input_list) { void CPUKernelRuntime::AddRuntimeAddress(DeviceAddress *address, std::vector<kernel::AddressPtr> *input_list) {
MS_EXCEPTION_IF_NULL(address); MS_EXCEPTION_IF_NULL(address);
MS_EXCEPTION_IF_NULL(input_list);
kernel::AddressPtr input = std::make_shared<kernel::Address>(); kernel::AddressPtr input = std::make_shared<kernel::Address>();
MS_EXCEPTION_IF_NULL(input); MS_EXCEPTION_IF_NULL(input);
if (address->ptr_ == nullptr) { if (address->ptr_ == nullptr) {
......
...@@ -49,6 +49,10 @@ class CPUKernelRuntime : public KernelRuntime { ...@@ -49,6 +49,10 @@ class CPUKernelRuntime : public KernelRuntime {
TypeId type_id) override; TypeId type_id) override;
private: private:
tensor::TensorPtr CreatTensorForOutput(const CNodePtr &node, size_t index,
std::set<DeviceAddressPtr> *bound_addresses,
std::vector<tensor::TensorPtr> *need_sync_outputs);
BaseRef CreatTensorForOutput(const session::KernelWithIndex &kernel_with_index, BaseRef CreatTensorForOutput(const session::KernelWithIndex &kernel_with_index,
const std::unordered_map<AnfNode *, tensor::TensorPtr> &input_map, const std::unordered_map<AnfNode *, tensor::TensorPtr> &input_map,
std::set<DeviceAddressPtr> *bound_addresses, std::set<DeviceAddressPtr> *bound_addresses,
......
...@@ -56,7 +56,13 @@ void CPUSimpleMemPlan::MemPlan(const session::KernelGraph *graph) { ...@@ -56,7 +56,13 @@ void CPUSimpleMemPlan::MemPlan(const session::KernelGraph *graph) {
graph_mem_size_[graph] = total_mem_size; graph_mem_size_[graph] = total_mem_size;
} }
size_t CPUSimpleMemPlan::GetGraphMemSize(const session::KernelGraph *graph) { return graph_mem_size_[graph]; } size_t CPUSimpleMemPlan::GetGraphMemSize(const session::KernelGraph *graph) const {
auto iter = graph_mem_size_.find(graph);
if (iter != graph_mem_size_.end()) {
return iter->second;
}
return 0;
}
void CPUSimpleMemPlan::MemAssign(const session::KernelGraph *graph, uint8_t *base_ptr) { void CPUSimpleMemPlan::MemAssign(const session::KernelGraph *graph, uint8_t *base_ptr) {
MS_EXCEPTION_IF_NULL(graph); MS_EXCEPTION_IF_NULL(graph);
......
...@@ -31,7 +31,7 @@ class CPUSimpleMemPlan { ...@@ -31,7 +31,7 @@ class CPUSimpleMemPlan {
void MemPlan(const session::KernelGraph *graph); void MemPlan(const session::KernelGraph *graph);
void MemAssign(const session::KernelGraph *graph, uint8_t *base_ptr); void MemAssign(const session::KernelGraph *graph, uint8_t *base_ptr);
size_t GetGraphMemSize(const session::KernelGraph *graph); size_t GetGraphMemSize(const session::KernelGraph *graph) const;
private: private:
std::unordered_map<const session::KernelGraph *, size_t> graph_mem_size_; std::unordered_map<const session::KernelGraph *, size_t> graph_mem_size_;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册