提交 236d6c6d 编写于 作者: K kswang

fix cpu reshape bug

上级 f10e2974
......@@ -161,8 +161,12 @@ BaseRef CPUKernelRuntime::CreatTensorForOutput(const AnfNodePtr &input_node, siz
}
tensor::TensorPtr tensor = std::make_shared<tensor::Tensor>(type_id, temp_shape);
MS_EXCEPTION_IF_NULL(tensor);
address->ptr_ = tensor->data_c(true);
address->ref_count_ = INIT_NODE_REF;
if (address->ref_count_ > 0 && address->ptr_ != nullptr) {
tensor->set_device_address(address);
} else {
address->ptr_ = tensor->data_c(true);
address->ref_count_ = INIT_NODE_REF;
}
tensor->set_dirty(false);
return tensor;
} else if (input_node->isa<Parameter>() || input_node->isa<ValueNode>()) {
......@@ -211,6 +215,7 @@ void CPUKernelRuntime::BindInputOutput(const session::KernelGraph *kernel_graph,
}
tensor->set_dirty(true);
}
address->ref_count_ = INIT_NODE_REF;
tensor->set_device_address(address);
}
......@@ -220,7 +225,7 @@ void CPUKernelRuntime::BindInputOutput(const session::KernelGraph *kernel_graph,
// new output and bind ptr
auto output_nodes = kernel_graph->outputs();
for (const auto &item : output_nodes) {
auto item_with_index = AnfAlgo::VisitKernelWithReturnType(item, 0);
auto item_with_index = AnfAlgo::VisitKernelWithReturnType(item, 0, true);
auto out = CreatTensorForOutput(item_with_index.first, item_with_index.second, input_map);
outputs->push_back(std::move(out));
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册