# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License class Node: def __init__(self, id, **attrs): # Each node must has a unique id self._id = id # Attributes for Node self._attrs = {} self._attrs.update(attrs) @property def id(self): return self._id @property def attrs(self): return self._attrs def __getitem__(self, attr_name): return self._attrs[attr_name] def __setitem__(self, attr_name, attr_value): self._attrs[attr_name] = attr_value def __contains__(self, attr_name): try: return attr_name in self._attrs except TypeError: return False def __str__(self): str = "(id: {}, attrs: {})".format(self.id, self.attrs) return str class Edge: def __init__(self, src_id, tgt_id, **attrs): # The id of source node in an Edge self._src_id = src_id # The id of target node in an Edge self._tgt_id = tgt_id # Attributes for Edge self._attrs = {} self._attrs.update(attrs) @property def src_id(self): return self._src_id @property def tgt_id(self): return self._tgt_id @property def attrs(self): return self._attrs def __getitem__(self, attr_name): return self._attrs[attr_name] def __setitem__(self, attr_name, attr_value): self._attrs[attr_name] = attr_value def __contains__(self, attr_name): try: return attr_name in self._attrs except TypeError: return False def __str__(self): str = "" str += "(src_id: {}, tgt_id: {}, attrs: {})".format( self.src_id, self.tgt_id, self._attrs) return str class Graph: def __init__(self, **attrs): # _nodes is dict for storing the nodes of the graph. # The key of this dict is the node id. self._nodes = {} # _adjs is a dict of dict for storing the adjacency of the graph. # The key of the outer dict is the node id of the source node and # the key of the inner dict is the node id of the target node. self._adjs = {} # Attributes for Graph self._attrs = {} self._attrs.update(attrs) @property def nodes(self): return self._nodes @property def attrs(self): return self._attrs @property def adjs(self): return self._adjs def add_node(self, node_id, **attrs): if node_id is None: raise ValueError("None cannot be a node") if node_id not in self._nodes: node = Node(node_id, **attrs) self._nodes[node_id] = node self._adjs[node_id] = {} else: self._nodes[node_id].attrs.update(attrs) def add_edge(self, src_id, tgt_id, **attrs): # add nodes if src_id is None: raise ValueError("None cannot be a node") if tgt_id is None: raise ValueError("None cannot be a node") if src_id not in self._nodes: src_node = Node(src_id) self._nodes[src_id] = src_node self._adjs[src_id] = {} if tgt_id not in self._nodes: tgt_node = Node(tgt_id) self._nodes[tgt_id] = tgt_node self._adjs[tgt_id] = {} # add the edge edge = Edge(src_id, tgt_id, **attrs) self._adjs[src_id][tgt_id] = edge def __len__(self): return len(self._nodes) def __iter__(self): return iter(self._nodes.values()) def __getitem__(self, node_id): # Return the adjacency of a node return self._adjs[node_id] def __contains__(self, node_id): # Check whether a node in the graph try: return node_id in self._nodes except TypeError: return False def __str__(self): str = "" str += "**************Nodes**************\n" for node_id in self.nodes: str += "{}\n".format(self.nodes[node_id]) str += "**************Edges**************\n" for src_id in self.adjs: str += "--------------{}--------------\n".format(src_id) for idx, tgt_id in enumerate(self.adjs[src_id]): str += "{}\n".format(self.adjs[src_id][tgt_id]) return str