diff --git a/mindspore/ccsrc/pre_activate/ascend/format_type/deal_ref_trans_and_cast.cc b/mindspore/ccsrc/pre_activate/ascend/format_type/deal_ref_trans_and_cast.cc index 2d44bf8f8ff05cf9f2542449f428aca9d1873b28..83a44029a798a4e891eb7f238ac198de0d5e257c 100644 --- a/mindspore/ccsrc/pre_activate/ascend/format_type/deal_ref_trans_and_cast.cc +++ b/mindspore/ccsrc/pre_activate/ascend/format_type/deal_ref_trans_and_cast.cc @@ -101,9 +101,9 @@ AnfNodePtr AddAdditionalToRefOutput(const FuncGraphPtr &func_graph, const CNodeP auto origin_type = AnfAlgo::GetOutputDeviceDataType(origin_pair.first, origin_pair.second); auto cur_format = AnfAlgo::GetOutputFormat(cnode, output_index); auto cur_type = AnfAlgo::GetOutputDeviceDataType(cnode, output_index); - auto cur_shape = AnfAlgo::GetOutputInferShape(cnode, 0); + auto cur_shape = AnfAlgo::GetOutputInferShape(cnode, output_index); // insert trans - if (origin_format != cur_format) { + if (origin_format != cur_format && cur_shape.size() > 1) { auto kernel_select = std::make_shared(); final_node = AddTransOpNodeToGraph(func_graph, final_node, kernel_select, 0, cur_format, origin_format, kTransDataOpName, false);