提交 c594e3d4 编写于 作者: C chujinjin

fix load input data error when input is a tuple

上级 52a7db81
......@@ -133,7 +133,11 @@ void GPUSession::LoadInputData(const std::shared_ptr<KernelGraph> &kernel_graph,
const std::vector<tensor::TensorPtr> &inputs_const) const {
std::vector<tensor::TensorPtr> inputs(inputs_const);
MS_EXCEPTION_IF_NULL(kernel_graph);
auto input_nodes = kernel_graph->inputs();
std::vector<AnfNodePtr> input_nodes;
for (const auto &input_node : kernel_graph->inputs()) {
auto params = AnfAlgo::GetAllOutput(input_node);
std::copy(params.begin(), params.end(), std::back_inserter(input_nodes));
}
auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context);
if (inputs.size() != input_nodes.size()) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册
新手
引导
客服 返回
顶部