From 69f4c45bb86c00d0e9eb0ffa513fbf3948d181ee Mon Sep 17 00:00:00 2001 From: chenfei Date: Mon, 27 Jul 2020 15:53:42 +0800 Subject: [PATCH] get real parameters if graph input is a virtual cnode --- .../ccsrc/backend/session/ascend_session.cc | 11 ++-- .../ccsrc/backend/session/session_basic.cc | 63 ++++++++++--------- 2 files changed, 41 insertions(+), 33 deletions(-) diff --git a/mindspore/ccsrc/backend/session/ascend_session.cc b/mindspore/ccsrc/backend/session/ascend_session.cc index e066b62b1..cf2f85ed9 100644 --- a/mindspore/ccsrc/backend/session/ascend_session.cc +++ b/mindspore/ccsrc/backend/session/ascend_session.cc @@ -885,11 +885,6 @@ void AscendSession::CreateMultiBranchOutput(NotNull graph, NotNu for (auto &child_graph : graph->child_graph_order()) { CreateMultiBranchOutput(NOT_NULL(child_graph), memo); } - // If graph has no output, the graph is the true graph of while and will call condition graph, no need insert assign - // from condition to true graph - if (graph->get_output_null()) { - return; - } std::map need_replace_list; auto node_list = GetCNodes(TopoSort(graph->get_return())); for (auto &node : node_list) { @@ -909,6 +904,11 @@ void AscendSession::CreateMultiBranchOutput(NotNull graph, NotNu auto child_graphs = AnfAlgo::GetCallNodeKernelGraph(node); for (auto &child_graph : child_graphs) { MS_EXCEPTION_IF_NULL(child_graph); + // If graph has no output, the graph is the true graph of while and will call condition graph, no need insert + // assign from condition to true graph + if (memo->find(child_graph) != memo->end()) { + continue; + } if (child_graph->get_output_null()) { continue; } @@ -927,6 +927,7 @@ void AscendSession::CreateMultiBranchOutput(NotNull graph, NotNu } } } + memo->erase(graph.get()); } void AscendSession::IrFusionPass(const NotNull graph, NotNull *> memo) { diff --git a/mindspore/ccsrc/backend/session/session_basic.cc b/mindspore/ccsrc/backend/session/session_basic.cc index 593944969..2d2da4fde 100644 --- a/mindspore/ccsrc/backend/session/session_basic.cc +++ b/mindspore/ccsrc/backend/session/session_basic.cc @@ -475,7 +475,7 @@ CNodePtr SessionBasic::CreateNewCNode(const CNodePtr &cnode, bool valid_input, K cnode_inputs.emplace_back(new_value_node); } continue; - } else if (anf->isa() && AnfAlgo::GetOutputTensorNum(anf) == 1) { + } else if (anf->isa()) { auto new_parameter = CreateNewParameterFromParameter(anf, valid_input, graph); cnode_inputs.push_back(new_parameter); if (GetGraphIdByNode(anf) == kInvalidGraphId) { @@ -818,6 +818,25 @@ void SessionBasic::AddParameterToGraphInputs(const std::vector ¶ } } +namespace { +bool TensorNeedSync(const AnfNodePtr ¶meter, const tensor::TensorPtr &tensor) { + auto ms_context = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(ms_context); + auto device_address = AnfAlgo::GetMutableOutputAddr(parameter, 0); + if (ms_context->enable_pynative_infer()) { + return tensor->device_address().get() == nullptr || tensor->device_address() != device_address; + } + if (tensor->is_dirty()) { + return true; + } + if (tensor->device_address() != device_address) { + (void)tensor->data_sync(); + return true; + } + return false; +} +} // namespace + // run graph steps void SessionBasic::LoadInputData(const std::shared_ptr &kernel_graph, const std::vector &inputs_const) const { @@ -827,7 +846,11 @@ void SessionBasic::LoadInputData(const std::shared_ptr &kernel_grap if (kernel_graph->input_ctrl_tensors()) { input_ctrl_size = LoadCtrlInputTensor(kernel_graph, &inputs); } - auto input_nodes = kernel_graph->inputs(); + std::vector input_nodes; + for (const auto &input_node : kernel_graph->inputs()) { + auto params = AnfAlgo::GetAllOutput(input_node); + std::copy(params.begin(), params.end(), std::back_inserter(input_nodes)); + } if ((inputs.size() + input_ctrl_size) - 2 != input_nodes.size()) { MS_LOG(EXCEPTION) << "Tensor input:" << inputs.size() << " is not equal graph inputs:" << input_nodes.size() << ", input_ctrl_size:" << input_ctrl_size; @@ -838,33 +861,17 @@ void SessionBasic::LoadInputData(const std::shared_ptr &kernel_grap auto tensor = inputs[i]; MS_EXCEPTION_IF_NULL(tensor); auto input_node = input_nodes[i]; - MS_EXCEPTION_IF_NULL(input_node); - if (input_node->isa() && AnfAlgo::OutputAddrExist(input_node, 0)) { - auto pk_node = input_node->cast(); - auto device_address = AnfAlgo::GetMutableOutputAddr(pk_node, 0); - bool need_sync = false; - if (ms_context->enable_pynative_infer()) { - if (tensor->device_address().get() == nullptr || tensor->device_address() != device_address) { - need_sync = true; - } - } else { - if (tensor->is_dirty()) { - need_sync = true; - } else if (tensor->device_address() != device_address) { - (void)tensor->data_sync(); - need_sync = true; - } + if (TensorNeedSync(input_node, tensor) && input_node->isa() && AnfAlgo::OutputAddrExist(input_node, 0)) { + auto device_address = AnfAlgo::GetMutableOutputAddr(input_node, 0); + if (ms_context->execution_mode() == kPynativeMode || + AnfAlgo::IsParameterWeight(input_node->cast())) { + tensor->set_device_address(device_address); } - if (need_sync) { - if (ms_context->execution_mode() == kPynativeMode || AnfAlgo::IsParameterWeight(pk_node)) { - tensor->set_device_address(device_address); - } - MS_EXCEPTION_IF_NULL(device_address); - if (!device_address->SyncHostToDevice(trans::GetRuntimePaddingShape(pk_node, 0), - LongToSize(tensor->data().nbytes()), tensor->data_type(), - tensor->data_c())) { - MS_LOG(EXCEPTION) << "SyncHostToDevice failed."; - } + MS_EXCEPTION_IF_NULL(device_address); + if (!device_address->SyncHostToDevice(trans::GetRuntimePaddingShape(input_node, 0), + LongToSize(tensor->data().nbytes()), tensor->data_type(), + tensor->data_c())) { + MS_LOG(EXCEPTION) << "SyncHostToDevice failed."; } } tensor->set_dirty(false); -- GitLab