/** * Copyright 2019-2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef INC_GRAPH_UTILS_GRAPH_UTILS_H_ #define INC_GRAPH_UTILS_GRAPH_UTILS_H_ #include #include #include #include #include #include "graph/anchor.h" #include "graph/node.h" #include "graph/compute_graph.h" #include "graph/utils/anchor_utils.h" #include "graph/graph.h" #include "graph/model.h" #define REFER_ATTR_VALUE(VT_ENUM, DataType, attr, ret) \ do { \ DataType ret; \ attr.GetValue(ret); \ } while (0) #define PRINT_ATTR_VALUE_IF(value_type, VT_ENUM, DataType, attr, stream) \ do { \ if (value_type == VT_ENUM) { \ REFER_ATTR_VALUE(VT_ENUM, DataType, attr, ret) \ stream << ret; \ } \ } while (0) #define PRINT_LIST_ATTR_VALUE_IF(value_type, VT_ENUM, DataType, attr, stream) \ do { \ if (value_type == VT_ENUM) { \ REFER_ATTR_VALUE(VT_ENUM, DataType, attr, ret) \ stream << "["; \ for (int i = 0; i < ret.size(); i++) { \ stream << ret[i]; \ if (i + 1 != ret.size()) stream << ", "; \ } \ stream << "]"; \ } \ } while (0) #define PRINT_ATTR_VALUE_ELIF(value_type, VT_ENUM, DataType, attr, stream) \ else PRINT_ATTR_VALUE_IF(value_type, VT_ENUM, DataType, attr, stream) #define PRINT_LIST_ATTR_VALUE_ELIF(value_type, VT_ENUM, DataType, attr, stream) \ else PRINT_LIST_ATTR_VALUE_IF(value_type, VT_ENUM, DataType, attr, stream) #define PRINT_SHAPE(i_o, n, idx, stream) \ do { \ auto op = n->GetOpDesc(); \ GeTensorDesc td = i_o == "input" ? op->GetInputDesc(idx) : op->GetOutputDesc(idx); \ auto shape = td.GetShape().GetDims(); \ stream << "["; \ for (int i = 0; i < shape.size(); i++) { \ stream << shape[i]; \ if (i + 1 < shape.size()) stream << ", "; \ } \ stream << "]"; \ } while (0) #define PRINT_ATTR_FUNC(stream) \ [&](GeAttrValue attr) { \ auto type = attr.GetValueType(); \ PRINT_ATTR_VALUE_IF(type, GeAttrValue::ValueType::VT_STRING, GeAttrValue::STR, attr, stream) \ PRINT_ATTR_VALUE_ELIF(type, GeAttrValue::ValueType::VT_FLOAT, GeAttrValue::FLOAT, attr, stream) \ PRINT_ATTR_VALUE_ELIF(type, GeAttrValue::ValueType::VT_BOOL, GeAttrValue::BOOL, attr, stream) \ PRINT_ATTR_VALUE_ELIF(type, GeAttrValue::ValueType::VT_INT, GeAttrValue::INT, attr, stream) \ PRINT_LIST_ATTR_VALUE_ELIF(type, GeAttrValue::ValueType::VT_LIST_STRING, GeAttrValue::LIST_STR, attr, stream) \ PRINT_LIST_ATTR_VALUE_ELIF(type, GeAttrValue::ValueType::VT_LIST_FLOAT, GeAttrValue::LIST_FLOAT, attr, stream) \ PRINT_LIST_ATTR_VALUE_ELIF(type, GeAttrValue::ValueType::VT_LIST_BOOL, GeAttrValue::LIST_BOOL, attr, stream) \ PRINT_LIST_ATTR_VALUE_ELIF(type, GeAttrValue::ValueType::VT_LIST_INT, GeAttrValue::LIST_INT, attr, stream) \ else if (type == GeAttrValue::ValueType::VT_TENSOR_DESC) stream << "TENSOR_DESC"; \ else if (type == GeAttrValue::ValueType::VT_TENSOR) stream << "TENSOR"; \ else if (type == GeAttrValue::ValueType::VT_BYTES) stream << "BYTES"; \ else if (type == GeAttrValue::ValueType::VT_LIST_TENSOR_DESC) stream << "LIST_TENSOR_DESC"; \ else if (type == GeAttrValue::ValueType::VT_LIST_TENSOR) stream << "LIST_TENSOR"; \ else if (type == GeAttrValue::ValueType::VT_LIST_BYTES) stream << "LIST_BYTES"; \ }; namespace ge { class GraphUtils { public: static ComputeGraphPtr GetComputeGraph(const Graph &graph); static Graph CreateGraphFromComputeGraph(const ComputeGraphPtr compute_graph); static ComputeGraphPtr CreateGraphFromOperator(const string &name, const std::vector &inputs); static graphStatus AddEdge(const OutDataAnchorPtr &src, const InDataAnchorPtr &dst); static graphStatus AddEdge(const OutDataAnchorPtr &src, const Format &src_format, const InDataAnchorPtr &dst, const Format &dst_format); static graphStatus AddEdge(const AnchorPtr &src, const AnchorPtr &dst); static graphStatus AddEdge(const OutControlAnchorPtr &src, const InControlAnchorPtr &dst); static graphStatus AddEdge(const OutDataAnchorPtr &src, const InControlAnchorPtr &dst); // check whether src is link to dst and then remove static graphStatus RemoveEdge(const OutDataAnchorPtr &src, const InDataAnchorPtr &dst); static graphStatus RemoveEdge(const AnchorPtr &src, const AnchorPtr &dst); static graphStatus RemoveEdge(const OutControlAnchorPtr &src, const InControlAnchorPtr &dst); static graphStatus RemoveEdge(const OutDataAnchorPtr &src, const InControlAnchorPtr &dst); static graphStatus ReplaceEdgeDst(const OutDataAnchorPtr &src, const InDataAnchorPtr &dst, const InDataAnchorPtr &new_dst); static graphStatus ReplaceEdgeDst(const OutControlAnchorPtr &src, const InControlAnchorPtr &dst, const InControlAnchorPtr &new_dst); static graphStatus InsertNodeBetweenDataAnchors(const OutDataAnchorPtr &src, const InDataAnchorPtr &dst, const NodePtr &new_node); static graphStatus RemoveNodeWithoutRelink(const ComputeGraphPtr &compute_graph, const NodePtr &node); static graphStatus InsertTransNode(ComputeGraphPtr compute_graph, const InDataAnchorPtr &in_data_anchor, const std::vector &vec_op_desc); static graphStatus RemoveJustNode(ComputeGraphPtr compute_graph, const NodePtr &node); static graphStatus RemoveJustNode(ComputeGraph &compute_graph, const NodePtr &node); static void RecordOriginalNames(std::vector original_nodes, const ge::NodePtr &node); static void RecordOriginalNames(std::vector names_tmp, const ge::NodePtr &node); static bool CheckIsTrainGraph(const ge::ComputeGraphPtr &compute_graph); static bool MatchDumpStr(const std::string &suffix); static void DumpGEGraph(const ge::ComputeGraphPtr &graph, const std::string &suffix, bool is_always_dump = false); static bool LoadGEGraph(const char *file, ge::ComputeGraph &compute_graph); static bool CheckGlobalStepNode(const ge::NodePtr &node); static void BreakConnect(const std::map &all_nodes_infos); static void DumpGEGraphToOnnx(const ge::ComputeGraph &compute_graph, const std::string &suffix); static bool LoadGEGraphFromOnnx(const char *file, ge::ComputeGraph &compute_graph); static bool ReadProtoFromTextFile(const char *file, google::protobuf::Message *message); static void WriteProtoToTextFile(const google::protobuf::Message &proto, const char *real_path); static graphStatus AppendInputNode(const ComputeGraphPtr &graph, const NodePtr &node); /// /// Isolating `node`, relinking data links from the in-anchor peer nodes to /// the out-anchor peer nodes according to `io_map`, relinking control links /// to ensure that input nodes of `node` are before out nodes /// /// Link the `io_map[i]` input anchor peer node to `i` output anchor peer /// nodes, then unlink all links connecting with `node`. If `io_map[i]` < 0, /// unlink all links from `i` output anchor without any relinking. /// /// @param node /// @param io_map /// @return /// static graphStatus IsolateNode(const NodePtr &node, const std::initializer_list &io_map); static graphStatus IsolateNode(const NodePtr &node, const std::vector &io_map); /// /// Isolate `node` which must be one input one output, equivalent to /// `IsolateNode(node, {0})` /// @param node /// @return /// static graphStatus IsolateNodeOneIO(const NodePtr &node); /// /// The data anchors replacing behavior is the same with /// `ReplaceNodeDataAnchors`. In addition, replace all `old_node` control /// anchors with `new_node`'s. /// @param new_node /// @param old_node /// @param inputs_map /// @param outputs_map /// @return /// static graphStatus ReplaceNodeAnchors(const NodePtr &new_node, const NodePtr &old_node, std::initializer_list inputs_map, std::initializer_list outputs_map); static graphStatus ReplaceNodeAnchors(const NodePtr &new_node, const NodePtr &old_node, const std::vector &inputs_map, const std::vector &outputs_map); /// /// Replace `old_node` data anchors with `new_node`'s according to `inputs_map` and `outputs_map`. /// Replace the `i` in/out data anchor on `old_node` with /// `inputs_map[i]`/`outputs_map[i]` data anchor on `new_node`. /// If `inputs_map[i]`/`outputs_map[i]` < 0 or the index not contained in /// `inputs_map[i]`/`outputs_map[i]`, the `i` data anchor will remain /// on `old_node`. /// @param new_node /// @param old_node /// @param inputs_map /// @param outputs_map /// @return /// static graphStatus ReplaceNodeDataAnchors(const NodePtr &new_node, const NodePtr &old_node, std::initializer_list inputs_map, std::initializer_list outputs_map); static graphStatus ReplaceNodeDataAnchors(const NodePtr &new_node, const NodePtr &old_node, const std::vector &inputs_map, const std::vector &outputs_map); /// /// Copy all in-control edges from `src_node` to `dst_node` /// @param src_node /// @param dst_node /// @return /// static graphStatus CopyInCtrlEdges(const NodePtr &src_node, NodePtr &dst_node); static graphStatus MoveInCtrlEdges(const NodePtr &src_node, NodePtr &dst_node); /// /// Copy all out-control edges from `src_node` to `dst_node` /// @param src_node /// @param dst_node /// @return success: GRAPH_SUCESS /// static graphStatus CopyOutCtrlEdges(const NodePtr &src_node, NodePtr &dst_node); /// /// Move all out-control edges from `src_node` to `dst_node` /// @param src_node /// @param dst_node /// @return success: GRAPH_SUCESS /// static graphStatus MoveOutCtrlEdges(NodePtr &src_node, NodePtr &dst_node); }; } // namespace ge #endif // INC_GRAPH_UTILS_GRAPH_UTILS_H_