提交 69f4c45b 编写于 作者: C chenfei

get real parameters if graph input is a virtual cnode

上级 1f1a07e6
...@@ -885,11 +885,6 @@ void AscendSession::CreateMultiBranchOutput(NotNull<KernelGraphPtr> graph, NotNu ...@@ -885,11 +885,6 @@ void AscendSession::CreateMultiBranchOutput(NotNull<KernelGraphPtr> graph, NotNu
for (auto &child_graph : graph->child_graph_order()) { for (auto &child_graph : graph->child_graph_order()) {
CreateMultiBranchOutput(NOT_NULL(child_graph), memo); CreateMultiBranchOutput(NOT_NULL(child_graph), memo);
} }
// If graph has no output, the graph is the true graph of while and will call condition graph, no need insert assign
// from condition to true graph
if (graph->get_output_null()) {
return;
}
std::map<AnfNodePtr, AnfNodePtr> need_replace_list; std::map<AnfNodePtr, AnfNodePtr> need_replace_list;
auto node_list = GetCNodes(TopoSort(graph->get_return())); auto node_list = GetCNodes(TopoSort(graph->get_return()));
for (auto &node : node_list) { for (auto &node : node_list) {
...@@ -909,6 +904,11 @@ void AscendSession::CreateMultiBranchOutput(NotNull<KernelGraphPtr> graph, NotNu ...@@ -909,6 +904,11 @@ void AscendSession::CreateMultiBranchOutput(NotNull<KernelGraphPtr> graph, NotNu
auto child_graphs = AnfAlgo::GetCallNodeKernelGraph(node); auto child_graphs = AnfAlgo::GetCallNodeKernelGraph(node);
for (auto &child_graph : child_graphs) { for (auto &child_graph : child_graphs) {
MS_EXCEPTION_IF_NULL(child_graph); MS_EXCEPTION_IF_NULL(child_graph);
// If graph has no output, the graph is the true graph of while and will call condition graph, no need insert
// assign from condition to true graph
if (memo->find(child_graph) != memo->end()) {
continue;
}
if (child_graph->get_output_null()) { if (child_graph->get_output_null()) {
continue; continue;
} }
...@@ -927,6 +927,7 @@ void AscendSession::CreateMultiBranchOutput(NotNull<KernelGraphPtr> graph, NotNu ...@@ -927,6 +927,7 @@ void AscendSession::CreateMultiBranchOutput(NotNull<KernelGraphPtr> graph, NotNu
} }
} }
} }
memo->erase(graph.get());
} }
void AscendSession::IrFusionPass(const NotNull<KernelGraphPtr> graph, NotNull<std::set<KernelGraphPtr> *> memo) { void AscendSession::IrFusionPass(const NotNull<KernelGraphPtr> graph, NotNull<std::set<KernelGraphPtr> *> memo) {
......
...@@ -475,7 +475,7 @@ CNodePtr SessionBasic::CreateNewCNode(const CNodePtr &cnode, bool valid_input, K ...@@ -475,7 +475,7 @@ CNodePtr SessionBasic::CreateNewCNode(const CNodePtr &cnode, bool valid_input, K
cnode_inputs.emplace_back(new_value_node); cnode_inputs.emplace_back(new_value_node);
} }
continue; continue;
} else if (anf->isa<Parameter>() && AnfAlgo::GetOutputTensorNum(anf) == 1) { } else if (anf->isa<Parameter>()) {
auto new_parameter = CreateNewParameterFromParameter(anf, valid_input, graph); auto new_parameter = CreateNewParameterFromParameter(anf, valid_input, graph);
cnode_inputs.push_back(new_parameter); cnode_inputs.push_back(new_parameter);
if (GetGraphIdByNode(anf) == kInvalidGraphId) { if (GetGraphIdByNode(anf) == kInvalidGraphId) {
...@@ -818,6 +818,25 @@ void SessionBasic::AddParameterToGraphInputs(const std::vector<AnfNodePtr> &para ...@@ -818,6 +818,25 @@ void SessionBasic::AddParameterToGraphInputs(const std::vector<AnfNodePtr> &para
} }
} }
namespace {
bool TensorNeedSync(const AnfNodePtr &parameter, const tensor::TensorPtr &tensor) {
auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context);
auto device_address = AnfAlgo::GetMutableOutputAddr(parameter, 0);
if (ms_context->enable_pynative_infer()) {
return tensor->device_address().get() == nullptr || tensor->device_address() != device_address;
}
if (tensor->is_dirty()) {
return true;
}
if (tensor->device_address() != device_address) {
(void)tensor->data_sync();
return true;
}
return false;
}
} // namespace
// run graph steps // run graph steps
void SessionBasic::LoadInputData(const std::shared_ptr<KernelGraph> &kernel_graph, void SessionBasic::LoadInputData(const std::shared_ptr<KernelGraph> &kernel_graph,
const std::vector<tensor::TensorPtr> &inputs_const) const { const std::vector<tensor::TensorPtr> &inputs_const) const {
...@@ -827,7 +846,11 @@ void SessionBasic::LoadInputData(const std::shared_ptr<KernelGraph> &kernel_grap ...@@ -827,7 +846,11 @@ void SessionBasic::LoadInputData(const std::shared_ptr<KernelGraph> &kernel_grap
if (kernel_graph->input_ctrl_tensors()) { if (kernel_graph->input_ctrl_tensors()) {
input_ctrl_size = LoadCtrlInputTensor(kernel_graph, &inputs); input_ctrl_size = LoadCtrlInputTensor(kernel_graph, &inputs);
} }
auto input_nodes = kernel_graph->inputs(); std::vector<AnfNodePtr> input_nodes;
for (const auto &input_node : kernel_graph->inputs()) {
auto params = AnfAlgo::GetAllOutput(input_node);
std::copy(params.begin(), params.end(), std::back_inserter(input_nodes));
}
if ((inputs.size() + input_ctrl_size) - 2 != input_nodes.size()) { if ((inputs.size() + input_ctrl_size) - 2 != input_nodes.size()) {
MS_LOG(EXCEPTION) << "Tensor input:" << inputs.size() << " is not equal graph inputs:" << input_nodes.size() MS_LOG(EXCEPTION) << "Tensor input:" << inputs.size() << " is not equal graph inputs:" << input_nodes.size()
<< ", input_ctrl_size:" << input_ctrl_size; << ", input_ctrl_size:" << input_ctrl_size;
...@@ -838,33 +861,17 @@ void SessionBasic::LoadInputData(const std::shared_ptr<KernelGraph> &kernel_grap ...@@ -838,33 +861,17 @@ void SessionBasic::LoadInputData(const std::shared_ptr<KernelGraph> &kernel_grap
auto tensor = inputs[i]; auto tensor = inputs[i];
MS_EXCEPTION_IF_NULL(tensor); MS_EXCEPTION_IF_NULL(tensor);
auto input_node = input_nodes[i]; auto input_node = input_nodes[i];
MS_EXCEPTION_IF_NULL(input_node); if (TensorNeedSync(input_node, tensor) && input_node->isa<Parameter>() && AnfAlgo::OutputAddrExist(input_node, 0)) {
if (input_node->isa<Parameter>() && AnfAlgo::OutputAddrExist(input_node, 0)) { auto device_address = AnfAlgo::GetMutableOutputAddr(input_node, 0);
auto pk_node = input_node->cast<ParameterPtr>(); if (ms_context->execution_mode() == kPynativeMode ||
auto device_address = AnfAlgo::GetMutableOutputAddr(pk_node, 0); AnfAlgo::IsParameterWeight(input_node->cast<ParameterPtr>())) {
bool need_sync = false; tensor->set_device_address(device_address);
if (ms_context->enable_pynative_infer()) {
if (tensor->device_address().get() == nullptr || tensor->device_address() != device_address) {
need_sync = true;
}
} else {
if (tensor->is_dirty()) {
need_sync = true;
} else if (tensor->device_address() != device_address) {
(void)tensor->data_sync();
need_sync = true;
}
} }
if (need_sync) { MS_EXCEPTION_IF_NULL(device_address);
if (ms_context->execution_mode() == kPynativeMode || AnfAlgo::IsParameterWeight(pk_node)) { if (!device_address->SyncHostToDevice(trans::GetRuntimePaddingShape(input_node, 0),
tensor->set_device_address(device_address); LongToSize(tensor->data().nbytes()), tensor->data_type(),
} tensor->data_c())) {
MS_EXCEPTION_IF_NULL(device_address); MS_LOG(EXCEPTION) << "SyncHostToDevice failed.";
if (!device_address->SyncHostToDevice(trans::GetRuntimePaddingShape(pk_node, 0),
LongToSize(tensor->data().nbytes()), tensor->data_type(),
tensor->data_c())) {
MS_LOG(EXCEPTION) << "SyncHostToDevice failed.";
}
} }
} }
tensor->set_dirty(false); tensor->set_dirty(false);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册