提交 69f4c45b 编写于 作者: C chenfei

get real parameters if graph input is a virtual cnode

上级 1f1a07e6
......@@ -885,11 +885,6 @@ void AscendSession::CreateMultiBranchOutput(NotNull<KernelGraphPtr> graph, NotNu
for (auto &child_graph : graph->child_graph_order()) {
CreateMultiBranchOutput(NOT_NULL(child_graph), memo);
}
// If graph has no output, the graph is the true graph of while and will call condition graph, no need insert assign
// from condition to true graph
if (graph->get_output_null()) {
return;
}
std::map<AnfNodePtr, AnfNodePtr> need_replace_list;
auto node_list = GetCNodes(TopoSort(graph->get_return()));
for (auto &node : node_list) {
......@@ -909,6 +904,11 @@ void AscendSession::CreateMultiBranchOutput(NotNull<KernelGraphPtr> graph, NotNu
auto child_graphs = AnfAlgo::GetCallNodeKernelGraph(node);
for (auto &child_graph : child_graphs) {
MS_EXCEPTION_IF_NULL(child_graph);
// If graph has no output, the graph is the true graph of while and will call condition graph, no need insert
// assign from condition to true graph
if (memo->find(child_graph) != memo->end()) {
continue;
}
if (child_graph->get_output_null()) {
continue;
}
......@@ -927,6 +927,7 @@ void AscendSession::CreateMultiBranchOutput(NotNull<KernelGraphPtr> graph, NotNu
}
}
}
memo->erase(graph.get());
}
void AscendSession::IrFusionPass(const NotNull<KernelGraphPtr> graph, NotNull<std::set<KernelGraphPtr> *> memo) {
......
......@@ -475,7 +475,7 @@ CNodePtr SessionBasic::CreateNewCNode(const CNodePtr &cnode, bool valid_input, K
cnode_inputs.emplace_back(new_value_node);
}
continue;
} else if (anf->isa<Parameter>() && AnfAlgo::GetOutputTensorNum(anf) == 1) {
} else if (anf->isa<Parameter>()) {
auto new_parameter = CreateNewParameterFromParameter(anf, valid_input, graph);
cnode_inputs.push_back(new_parameter);
if (GetGraphIdByNode(anf) == kInvalidGraphId) {
......@@ -818,6 +818,25 @@ void SessionBasic::AddParameterToGraphInputs(const std::vector<AnfNodePtr> &para
}
}
namespace {
bool TensorNeedSync(const AnfNodePtr &parameter, const tensor::TensorPtr &tensor) {
auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context);
auto device_address = AnfAlgo::GetMutableOutputAddr(parameter, 0);
if (ms_context->enable_pynative_infer()) {
return tensor->device_address().get() == nullptr || tensor->device_address() != device_address;
}
if (tensor->is_dirty()) {
return true;
}
if (tensor->device_address() != device_address) {
(void)tensor->data_sync();
return true;
}
return false;
}
} // namespace
// run graph steps
void SessionBasic::LoadInputData(const std::shared_ptr<KernelGraph> &kernel_graph,
const std::vector<tensor::TensorPtr> &inputs_const) const {
......@@ -827,7 +846,11 @@ void SessionBasic::LoadInputData(const std::shared_ptr<KernelGraph> &kernel_grap
if (kernel_graph->input_ctrl_tensors()) {
input_ctrl_size = LoadCtrlInputTensor(kernel_graph, &inputs);
}
auto input_nodes = kernel_graph->inputs();
std::vector<AnfNodePtr> input_nodes;
for (const auto &input_node : kernel_graph->inputs()) {
auto params = AnfAlgo::GetAllOutput(input_node);
std::copy(params.begin(), params.end(), std::back_inserter(input_nodes));
}
if ((inputs.size() + input_ctrl_size) - 2 != input_nodes.size()) {
MS_LOG(EXCEPTION) << "Tensor input:" << inputs.size() << " is not equal graph inputs:" << input_nodes.size()
<< ", input_ctrl_size:" << input_ctrl_size;
......@@ -838,33 +861,17 @@ void SessionBasic::LoadInputData(const std::shared_ptr<KernelGraph> &kernel_grap
auto tensor = inputs[i];
MS_EXCEPTION_IF_NULL(tensor);
auto input_node = input_nodes[i];
MS_EXCEPTION_IF_NULL(input_node);
if (input_node->isa<Parameter>() && AnfAlgo::OutputAddrExist(input_node, 0)) {
auto pk_node = input_node->cast<ParameterPtr>();
auto device_address = AnfAlgo::GetMutableOutputAddr(pk_node, 0);
bool need_sync = false;
if (ms_context->enable_pynative_infer()) {
if (tensor->device_address().get() == nullptr || tensor->device_address() != device_address) {
need_sync = true;
}
} else {
if (tensor->is_dirty()) {
need_sync = true;
} else if (tensor->device_address() != device_address) {
(void)tensor->data_sync();
need_sync = true;
}
if (TensorNeedSync(input_node, tensor) && input_node->isa<Parameter>() && AnfAlgo::OutputAddrExist(input_node, 0)) {
auto device_address = AnfAlgo::GetMutableOutputAddr(input_node, 0);
if (ms_context->execution_mode() == kPynativeMode ||
AnfAlgo::IsParameterWeight(input_node->cast<ParameterPtr>())) {
tensor->set_device_address(device_address);
}
if (need_sync) {
if (ms_context->execution_mode() == kPynativeMode || AnfAlgo::IsParameterWeight(pk_node)) {
tensor->set_device_address(device_address);
}
MS_EXCEPTION_IF_NULL(device_address);
if (!device_address->SyncHostToDevice(trans::GetRuntimePaddingShape(pk_node, 0),
LongToSize(tensor->data().nbytes()), tensor->data_type(),
tensor->data_c())) {
MS_LOG(EXCEPTION) << "SyncHostToDevice failed.";
}
MS_EXCEPTION_IF_NULL(device_address);
if (!device_address->SyncHostToDevice(trans::GetRuntimePaddingShape(input_node, 0),
LongToSize(tensor->data().nbytes()), tensor->data_type(),
tensor->data_c())) {
MS_LOG(EXCEPTION) << "SyncHostToDevice failed.";
}
}
tensor->set_dirty(false);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册