ogbl_ppa_dataloader.py 4.5 KB
Newer Older
Y
yelrose 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
from __future__ import absolute_import

from dataloader.base_dataloader import BaseDataGenerator
import ssl
ssl._create_default_https_context = ssl._create_unverified_context

from ogb.linkproppred import LinkPropPredDataset
from ogb.linkproppred import Evaluator
import tqdm
from collections import namedtuple
import pgl
import numpy as np


class PPADataGenerator(BaseDataGenerator):
    def __init__(self,
                 graph_wrapper=None,
                 buf_size=1000,
                 batch_size=128,
                 num_workers=1,
                 shuffle=True,
                 phase="train"):
        super(PPADataGenerator, self).__init__(
            buf_size=buf_size,
            num_workers=num_workers,
            batch_size=batch_size,
            shuffle=shuffle)

        self.d_name = "ogbl-ppa"
        self.graph_wrapper = graph_wrapper
        dataset = LinkPropPredDataset(name=self.d_name)
        splitted_edge = dataset.get_edge_split()
        self.phase = phase
        graph = dataset[0]
        edges = graph["edge_index"].T
        #self.graph = pgl.graph.Graph(num_nodes=graph["num_nodes"],
        #       edges=edges, 
        #       node_feat={"nfeat": graph["node_feat"],
        #             "node_id": np.arange(0, graph["num_nodes"], dtype="int64").reshape(-1, 1) })

        #self.graph.indegree()
        self.num_nodes = graph["num_nodes"]
        if self.phase == 'train':
            edges = splitted_edge["train"]["edge"]
            labels = np.ones(len(edges))
        elif self.phase == "valid":
            # Compute the embedding for all the nodes
            pos_edges = splitted_edge["valid"]["edge"]
            neg_edges = splitted_edge["valid"]["edge_neg"]
            pos_labels = np.ones(len(pos_edges))
            neg_labels = np.zeros(len(neg_edges))
            edges = np.vstack([pos_edges, neg_edges])
            labels = pos_labels.tolist() + neg_labels.tolist()
        elif self.phase == "test":
            # Compute the embedding for all the nodes
            pos_edges = splitted_edge["test"]["edge"]
            neg_edges = splitted_edge["test"]["edge_neg"]
            pos_labels = np.ones(len(pos_edges))
            neg_labels = np.zeros(len(neg_edges))
            edges = np.vstack([pos_edges, neg_edges])
            labels = pos_labels.tolist() + neg_labels.tolist()

        self.line_examples = []
        Example = namedtuple('Example', ['src', "dst", "label"])
        for edge, label in zip(edges, labels):
            self.line_examples.append(
                Example(
                    src=edge[0], dst=edge[1], label=label))
        print("Phase", self.phase)
        print("Len Examples", len(self.line_examples))

    def batch_fn(self, batch_ex):
        batch_src = []
        batch_dst = []
        join_graph = []
        cc = 0
        batch_node_id = []
        batch_labels = []
        for ex in batch_ex:
            batch_src.append(ex.src)
            batch_dst.append(ex.dst)
            batch_labels.append(ex.label)

        if self.phase == "train":
            for num in range(1):
                rand_src = np.random.randint(
                    low=0, high=self.num_nodes, size=len(batch_ex))
                rand_dst = np.random.randint(
                    low=0, high=self.num_nodes, size=len(batch_ex))
                batch_src = batch_src + rand_src.tolist()
                batch_dst = batch_dst + rand_dst.tolist()
                batch_labels = batch_labels + np.zeros_like(
                    rand_src, dtype="int64").tolist()

        feed_dict = {}

        feed_dict["batch_src"] = np.array(batch_src, dtype="int64")
        feed_dict["batch_dst"] = np.array(batch_dst, dtype="int64")
        feed_dict["labels"] = np.array(batch_labels, dtype="int64")
        return feed_dict