# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from __future__ import absolute_import from __future__ import division from __future__ import print_function from __future__ import unicode_literals from __future__ import absolute_import from dataloader.base_dataloader import BaseDataGenerator from utils.to_undirected import to_undirected import ssl ssl._create_default_https_context = ssl._create_unverified_context from pgl.contrib.ogb.nodeproppred.dataset_pgl import PglNodePropPredDataset #from pgl.sample import graph_saint_random_walk_sample from ogb.nodeproppred import Evaluator import tqdm from collections import namedtuple import pgl import numpy as np import copy """ dict_keys(['edge_index', 'edge_feat', 'node_feat', 'node_year', 'num_nodes']) edge_index shape: (2, 1166243) edge_index type: [[104447 15858 107156 ... 45118 45118 45118] [ 13091 47283 69161 ... 162473 162537 72717]] edge_feat: None node_feat shape: (169343, 128) node_year shape: (169343, 1) num_nodes: 169343 label shape: (169343, 1) """ def traverse(item): """traverse """ if isinstance(item, list) or isinstance(item, np.ndarray): for i in iter(item): for j in traverse(i): yield j else: yield item def flat_node_and_edge(nodes): """flat_node_and_edge """ nodes = list(set(traverse(nodes))) return nodes def k_hop_sampler(graph, samples, batch_nodes): # for batch_train_samples, batch_train_labels in batch_info: start_nodes = copy.deepcopy(batch_nodes) nodes = start_nodes edges = [] for max_deg in samples: pred_nodes = graph.sample_predecessor(start_nodes, max_degree=max_deg) for dst_node, src_nodes in zip(start_nodes, pred_nodes): for src_node in src_nodes: edges.append((src_node, dst_node)) last_nodes = nodes nodes = [nodes, pred_nodes] nodes = flat_node_and_edge(nodes) # Find new nodes start_nodes = list(set(nodes) - set(last_nodes)) if len(start_nodes) == 0: break subgraph = graph.subgraph( nodes=nodes, edges=edges, with_node_feat=True, with_edge_feat=True) sub_node_index = subgraph.reindex_from_parrent_nodes(batch_nodes) return subgraph, sub_node_index #def graph_saint_randomwalk_sampler(graph, batch_nodes, max_depth=3): # subgraph = graph_saint_random_walk_sample(graph, batch_nodes, max_depth) # sub_node_index = subgraph.reindex_from_parrent_nodes(batch_nodes) # return subgraph, sub_node_index class ArxivDataGenerator(BaseDataGenerator): def __init__(self, graph_wrapper=None, buf_size=1000, batch_size=128, num_workers=1, samples=[30, 30], shuffle=True, phase="train"): super(ArxivDataGenerator, self).__init__( buf_size=buf_size, num_workers=num_workers, batch_size=batch_size, shuffle=shuffle) self.samples = samples self.d_name = "ogbn-arxiv" self.graph_wrapper = graph_wrapper dataset = PglNodePropPredDataset(name=self.d_name) splitted_idx = dataset.get_idx_split() self.phase = phase graph, label = dataset[0] graph = to_undirected(graph) self.graph = graph self.num_nodes = graph.num_nodes if self.phase == 'train': nodes_idx = splitted_idx["train"] labels = label[nodes_idx] elif self.phase == "valid": nodes_idx = splitted_idx["valid"] labels = label[nodes_idx] elif self.phase == "test": nodes_idx = splitted_idx["test"] labels = label[nodes_idx] self.nodes_idx = nodes_idx self.labels = labels #self.static_gw_based_line_example(nodes_idx, labels) self.sample_based_line_example(nodes_idx, labels) def sample_based_line_example(self, nodes_idx, labels): self.line_examples = [] Example = namedtuple('Example', ["node", "label"]) for node, label in zip(nodes_idx, labels): self.line_examples.append(Example(node=node, label=label)) print("Phase", self.phase) print("Len Examples", len(self.line_examples)) def batch_fn2(self, batch_ex): feed_dict["batch_nodes"] = np.array(batch_ex[0]['node'], dtype="int64") feed_dict["labels"] = np.array(batch_ex[0]['label'], dtype="int64") return feed_dict def batch_fn(self, batch_ex): batch_nodes = [] cc = 0 batch_node_id = [] batch_labels = [] for ex in batch_ex: batch_nodes.append(ex.node) batch_labels.append(ex.label) _graph_wrapper = copy.copy(self.graph_wrapper) #if self.phase == "train": # subgraph, sub_node_index = graph_saint_randomwalk_sampler(self.graph, batch_nodes) #else: subgraph, sub_node_index = k_hop_sampler(self.graph, self.samples, batch_nodes) feed_dict = _graph_wrapper.to_feed(subgraph) feed_dict["batch_nodes"] = sub_node_index feed_dict["labels"] = np.array(batch_labels, dtype="int64") return feed_dict