/** * Copyright 2019-2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef INC_GRAPH_NODE_H_ #define INC_GRAPH_NODE_H_ #include #include #include #include #include #include #include "graph/ge_attr_value.h" #include "graph/op_desc.h" #include "graph/range_vistor.h" #include "utils/attr_utils.h" namespace ge { class ComputeGraph; using ComputeGraphPtr = std::shared_ptr; class Node; using NodePtr = std::shared_ptr; using ConstNodePtr = std::shared_ptr; using NodeRef = std::weak_ptr; class Anchor; using AnchorPtr = std::shared_ptr; class InDataAnchor; using InDataAnchorPtr = std::shared_ptr; class OutDataAnchor; using OutDataAnchorPtr = std::shared_ptr; class ControlAnchor; using ControlAnchorPtr = std::shared_ptr; class InControlAnchor; using InControlAnchorPtr = std::shared_ptr; class OutControlAnchor; using OutControlAnchorPtr = std::shared_ptr; using OpDescPtr = std::shared_ptr; using ConstNode = const Node; typedef std::vector> kFusionDataFlowVec_t; // Node is a component of ComputeGraph class Node : public std::enable_shared_from_this { friend class ComputeGraph; friend class ModelSerializeImp; public: template using Vistor = RangeVistor>; ~Node(); Node(const Node &) = delete; Node &operator=(const Node &) = delete; bool operator==(const Node &r_node) const; protected: Node() = default; Node(const OpDescPtr &op, const ComputeGraphPtr &ownerGraph); public: graphStatus Init(); std::string GetName() const; std::string GetType() const; ComputeGraphPtr GetOwnerComputeGraph() const; graphStatus SetOwnerComputeGraph(const ComputeGraphPtr &graph); Vistor GetAllInDataAnchors() const; Vistor GetAllOutDataAnchors() const; uint32_t GetAllInDataAnchorsSize() const; uint32_t GetAllOutDataAnchorsSize() const; Vistor GetAllOutAnchors() const; Vistor GetAllInAnchors() const; InDataAnchorPtr GetInDataAnchor(int idx) const; OutDataAnchorPtr GetOutDataAnchor(int idx) const; InControlAnchorPtr GetInControlAnchor() const; OutControlAnchorPtr GetOutControlAnchor() const; Vistor GetInNodes() const; Vistor GetOutNodes() const; AnchorPtr GetInAnchor(int idx) const; AnchorPtr GetOutAnchor(int idx) const; bool IsAllInNodesSeen(std::unordered_set &nodes_seen) const; // All inData nodes Vistor GetInDataNodes() const; // All inControl nodes Vistor GetInControlNodes() const; // GetInAllNodes = InDataNodes + InControlNodes Vistor GetInAllNodes() const; // All outData nodes Vistor GetOutDataNodes() const; uint32_t GetOutDataNodesSize() const; // All outControl nodes Vistor GetOutControlNodes() const; // GetOutAllNodes = OutDataNodes + InControlNodes Vistor GetOutAllNodes() const; // Get all indata nodes and its outanchor Vistor> GetInDataNodesAndAnchors() const; // Get all outdata nodes and its inanchor Vistor> GetOutDataNodesAndAnchors() const; graphStatus InferShapeAndType() const; graphStatus Verify() const; graphStatus InferOriginFormat() const; OpDescPtr GetOpDesc() const; graphStatus UpdateOpDesc(const OpDescPtr &op); graphStatus AddLinkFrom(const NodePtr &input_node); graphStatus AddLinkFrom(const uint32_t &index, NodePtr input_node); graphStatus AddLinkFrom(const string &name, NodePtr input_node); graphStatus AddLinkFromForParse(const NodePtr &input_node); void AddSendEventId(uint32_t event_id) { send_event_id_list_.push_back(event_id); } void AddRecvEventId(uint32_t event_id) { recv_event_id_list_.push_back(event_id); } const std::vector &GetSendEventIdList() const { return send_event_id_list_; } const std::vector &GetRecvEventIdList() const { return recv_event_id_list_; } void GetFusionInputFlowList(kFusionDataFlowVec_t &fusion_input_list) { fusion_input_list = fusion_input_dataflow_list_; } void GetFusionOutputFlowList(kFusionDataFlowVec_t &fusion_output_list) { fusion_output_list = fusion_output_dataflow_list_; } void SetFusionInputFlowList(kFusionDataFlowVec_t &fusion_input_list) { fusion_input_dataflow_list_ = fusion_input_list; } void SetFusionOutputFlowList(kFusionDataFlowVec_t &fusion_output_list) { fusion_output_dataflow_list_ = fusion_output_list; } void SetOrigNode(const NodePtr &orignode) { orig_node_ = orignode; } NodePtr GetOrigNode(void) { return orig_node_; } private: bool NodeMembersAreEqual(const Node &r_node) const; bool NodeAttrsAreEqual(const Node &r_node) const; bool NodeInConnectsAreEqual(const Node &r_node) const; bool NodeOutConnectsAreEqual(const Node &r_node) const; bool NodeAnchorIsEqual(const AnchorPtr &l_anchor, const AnchorPtr &r_anchor, size_t i) const; OpDescPtr op_; std::weak_ptr owner_graph_; vector in_data_anchors_; vector out_data_anchors_; InControlAnchorPtr in_control_anchor_; OutControlAnchorPtr out_control_anchor_; map attrs_; bool has_init_{false}; bool anchor_status_updated_{false}; std::vector send_event_id_list_; std::vector recv_event_id_list_; kFusionDataFlowVec_t fusion_input_dataflow_list_; kFusionDataFlowVec_t fusion_output_dataflow_list_; NodePtr orig_node_; friend class NodeUtils; friend class OnnxUtils; }; } // namespace ge #endif // INC_GRAPH_NODE_H_