/** * Copyright 2019-2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "graph/passes/permute_pass.h" #include #include #include "common/debug/log.h" #include "common/types.h" #include "graph/utils/attr_utils.h" #include "graph/utils/op_desc_utils.h" #include "inc/kernel.h" #include "inc/kernel_factory.h" #include "framework/omg/omg_inner_types.h" using domi::DOMI_TENSOR_ND; using domi::DOMI_TENSOR_NHWC; using domi::GetContext; using domi::SUCCESS; namespace ge { Status PermutePass::Run(ComputeGraphPtr graph) { GE_CHECK_NOTNULL(graph); std::vector isolate_nodes; for (NodePtr &node : graph->GetAllNodes()) { OpDescPtr op_desc_ptr = node->GetOpDesc(); GE_CHECK_NOTNULL(op_desc_ptr); GE_IF_BOOL_EXEC( op_desc_ptr->GetType() == PERMUTE && GetContext().type == domi::FMK_TYPE_T, /// Input format 5D means NHWC in 4D way. So if input origin foramt is NCHW and /// permute paramter list is [0,3,1,2], this permute can be optimised. GE_IF_BOOL_EXEC( GetContext().format != DOMI_TENSOR_ND, // Get input origin foramt for (NodePtr &n : graph->GetAllNodes()) { GE_IF_BOOL_EXEC( n->GetOpDesc()->GetType() == PERMUTE, std::queue < NodePtr > q_node; q_node.push(n); bool jump_out = false; while (!q_node.empty()) { NodePtr n_temp = q_node.back(); q_node.pop(); for (auto &inNode : n_temp->GetInDataNodes()) { int64_t cur_format = 0; GE_IF_BOOL_EXEC(AttrUtils::GetInt(inNode->GetOpDesc(), ATTR_NAME_FORMAT, cur_format), GE_IF_BOOL_EXEC(!AttrUtils::SetInt(n->GetOpDesc(), "permute_src_format", cur_format), GELOGW("set permute_src_format failed"); continue); jump_out = true; break); q_node.push(inNode); } GE_IF_BOOL_EXEC(jump_out, break); }); } int64_t permute_src_format = 0; GE_IF_BOOL_EXEC(!AttrUtils::GetInt(op_desc_ptr, "permute_src_format", permute_src_format), continue); // Get dim_index_ std::vector index_list; GE_CHK_BOOL_RET_STATUS(AttrUtils::GetListInt(op_desc_ptr, PERMUTE_ATTR_ORDER, index_list), INTERNAL_ERROR, "get index list failed"); size_t index_size = index_list.size(); GE_IF_BOOL_EXEC(index_size == 0, continue); GE_IF_BOOL_EXEC( index_size == 4 && (permute_src_format == DOMI_TENSOR_NHWC && index_list.at(0) == 0 && index_list.at(1) == 3 && index_list.at(2) == 1 && index_list.at(3) == 2), isolate_nodes.push_back(node); continue); int64_t conv_format = 0; GE_IF_BOOL_EXEC( index_size == 4 && (index_list.at(0) == 0 && index_list.at(1) == 2 && index_list.at(2) == 3 && index_list.at(3) == 1), GE_IF_BOOL_EXEC( (node->GetOutDataNodesSize() > 0 && node->GetOutDataNodes().at(0) != nullptr && node->GetOutDataNodes().at(0)->GetOpDesc() != nullptr) && ((node->GetOutDataNodesSize() != 0 && CONVOLUTION == node->GetOutDataNodes().at(0)->GetOpDesc()->GetType() && AttrUtils::GetInt(node->GetOutDataNodes().at(0)->GetOpDesc(), ATTR_NAME_FORMAT, conv_format) && conv_format == DOMI_TENSOR_NHWC) || (node->GetOutDataNodesSize() != 0 && node->GetOutDataNodes().at(0)->GetOpDesc()->GetType() == DEPCONVOLUTION) || (node->GetOutDataNodesSize() != 0 && node->GetOutDataNodes().at(0)->GetOpDesc()->GetType() == DECONVOLUTION) || (node->GetOutDataNodesSize() != 0 && node->GetOutDataNodes().at(0)->GetOpDesc()->GetType() == PAD && node->GetOutDataNodes().at(0)->GetOutDataNodesSize() != 0 && node->GetOutDataNodes().at(0)->GetOutDataNodes().at(0) != nullptr && node->GetOutDataNodes().at(0)->GetOutDataNodes().at(0)->GetOpDesc() != nullptr && node->GetOutDataNodes().at(0)->GetOutDataNodes().at(0)->GetOpDesc()->GetType() == CONVOLUTION)), isolate_nodes.push_back(node); continue);););); } GE_IF_BOOL_EXEC(isolate_nodes.size() != 0, for (auto &node : isolate_nodes) { // Adding an attribute indicates that the predecessor Permute has been deleted for the Builder to process. for (auto &outNode : node->GetOutDataNodes()) { OpDescPtr op_desc_ptr = outNode->GetOpDesc(); GE_CHECK_NOTNULL(op_desc_ptr); if (!AttrUtils::SetBool(op_desc_ptr, ATTR_NAME_PRED_PERMUTE_DELETED, true)) { GELOGE(INTERNAL_ERROR, "set ATTR_NAME_PRED_PERMUTE_DELETED failed"); return INTERNAL_ERROR; } } GE_RETURN_WITH_LOG_IF_ERROR(graph->RemoveNode(node), "[%s]:remove permute node failed", node->GetOpDesc()->GetName().c_str()); }); return SUCCESS; } } // namespace ge