graph_reader.py 5.9 KB
Newer Older
W
weiyue.su 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26
"""Graph Dataset
"""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
from __future__ import absolute_import

import os
import pgl
import sys

import numpy as np

from pgl.utils.logger import log
from dataset.base_dataset import BaseDataGenerator
from pgl.sample import alias_sample
from pgl.sample import pinsage_sample
from pgl.sample import graphsage_sample 
from pgl.sample import edge_hash


class GraphGenerator(BaseDataGenerator):
    def __init__(self, graph_wrappers, data, batch_size, samples,
        num_workers, feed_name_list, use_pyreader,
27
        phase, graph_data_path, shuffle=True, buf_size=1000, neg_type="batch_neg"):
W
weiyue.su 已提交
28 29 30 31 32 33 34 35 36 37 38 39 40 41 42

        super(GraphGenerator, self).__init__(
            buf_size=buf_size,
            num_workers=num_workers,
            batch_size=batch_size, shuffle=shuffle)
        # For iteration
        self.line_examples = data

        self.graph_wrappers = graph_wrappers
        self.samples = samples
        self.feed_name_list = feed_name_list
        self.use_pyreader = use_pyreader
        self.phase = phase
        self.load_graph(graph_data_path)
        self.num_layers = len(graph_wrappers)
43
        self.neg_type= neg_type
W
weiyue.su 已提交
44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75

    def load_graph(self, graph_data_path):
        self.graph = pgl.graph.MemmapGraph(graph_data_path)
        self.alias = np.load(os.path.join(graph_data_path, "alias.npy"), mmap_mode="r")
        self.events = np.load(os.path.join(graph_data_path, "events.npy"), mmap_mode="r")
        self.term_ids = np.load(os.path.join(graph_data_path, "term_ids.npy"), mmap_mode="r")
 
    def batch_fn(self, batch_ex):
        # batch_ex = [
        #     (src, dst, neg),
        #     (src, dst, neg),
        #     (src, dst, neg),
        #     ]
        #
        batch_src = []
        batch_dst = []
        batch_neg = []
        for batch in batch_ex:
            batch_src.append(batch[0])
            batch_dst.append(batch[1])
            if len(batch) == 3: # default neg samples
                batch_neg.append(batch[2])

        if len(batch_src) != self.batch_size:
            if self.phase == "train":
                return None  #Skip

        if len(batch_neg) > 0:
            batch_neg = np.unique(np.concatenate(batch_neg))
        batch_src = np.array(batch_src, dtype="int64")
        batch_dst = np.array(batch_dst, dtype="int64")

S
suweiyue 已提交
76
        if self.neg_type == "batch_neg":
S
suweiyue 已提交
77
            batch_neg = batch_dst
78
        else:
S
suweiyue 已提交
79
            # TODO user define shape of neg_sample
80
            neg_shape = batch_dst.shape
S
suweiyue 已提交
81
            sampled_batch_neg = alias_sample(neg_shape, self.alias, self.events)
W
weiyue.su 已提交
82 83 84
            batch_neg = np.concatenate([batch_neg, sampled_batch_neg], 0)

        if self.phase == "train":
S
suweiyue 已提交
85
            # TODO user define ignore edges or not
86 87
            #ignore_edges = np.concatenate([np.stack([batch_src, batch_dst], 1), np.stack([batch_dst, batch_src], 1)], 0)
            ignore_edges = set()
W
weiyue.su 已提交
88 89 90 91 92
        else:
            ignore_edges = set()

        nodes = np.unique(np.concatenate([batch_src, batch_dst, batch_neg], 0))
        subgraphs = graphsage_sample(self.graph, nodes, self.samples, ignore_edges=ignore_edges)
S
suweiyue 已提交
93 94
        subgraphs[0].node_feat["index"] = subgraphs[0].reindex_to_parrent_nodes(subgraphs[0].nodes).astype(np.int64)
        subgraphs[0].node_feat["term_ids"] = self.term_ids[subgraphs[0].node_feat["index"]].astype(np.int64)
W
weiyue.su 已提交
95 96 97 98 99 100 101 102 103 104
        feed_dict = {}
        for i in range(self.num_layers):
            feed_dict.update(self.graph_wrappers[i].to_feed(subgraphs[i]))

        # only reindex from first subgraph
        sub_src_idx = subgraphs[0].reindex_from_parrent_nodes(batch_src)
        sub_dst_idx = subgraphs[0].reindex_from_parrent_nodes(batch_dst)
        sub_neg_idx = subgraphs[0].reindex_from_parrent_nodes(batch_neg)

        feed_dict["user_index"] = np.array(sub_src_idx, dtype="int64")
S
suweiyue 已提交
105
        feed_dict["pos_item_index"] = np.array(sub_dst_idx, dtype="int64")
S
suweiyue 已提交
106
        feed_dict["neg_item_index"] = np.array(sub_neg_idx, dtype="int64")
S
suweiyue 已提交
107

108 109
        feed_dict["user_real_index"] = np.array(batch_src, dtype="int64")
        feed_dict["pos_item_real_index"] = np.array(batch_dst, dtype="int64")
W
weiyue.su 已提交
110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127
        return feed_dict

    def __call__(self):
        return self.generator()

    def generator(self):
        try:
            for feed_dict in super(GraphGenerator, self).generator():
                if self.use_pyreader:
                    yield [feed_dict[name] for name in self.feed_name_list]
                else:
                    yield feed_dict

        except Exception as e:
            log.exception(e)
 

    
128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161
class NodeClassificationGenerator(GraphGenerator):
    def batch_fn(self, batch_ex):
        # batch_ex = [
        #     (node, label),
        #     (node, label),
        #     ]
        #
        batch_node = []
        batch_label = []
        for batch in batch_ex:
            batch_node.append(batch[0])
            batch_label.append(batch[1])

        if len(batch_node) != self.batch_size:
            if self.phase == "train":
                return None  #Skip

        batch_node = np.array(batch_node, dtype="int64")
        batch_label = np.array(batch_label, dtype="int64")

        subgraphs = graphsage_sample(self.graph, batch_node, self.samples)
        subgraphs[0].node_feat["index"] = subgraphs[0].reindex_to_parrent_nodes(subgraphs[0].nodes).astype(np.int64)
        subgraphs[0].node_feat["term_ids"] = self.term_ids[subgraphs[0].node_feat["index"]].astype(np.int64)
        feed_dict = {}
        for i in range(self.num_layers):
            feed_dict.update(self.graph_wrappers[i].to_feed(subgraphs[i]))

        # only reindex from first subgraph
        sub_node_idx = subgraphs[0].reindex_from_parrent_nodes(batch_node)

        feed_dict["node_index"] = np.array(sub_node_idx, dtype="int64")
        feed_dict["node_real_index"] = np.array(batch_node, dtype="int64")
        feed_dict["label"] = np.array(batch_label, dtype="int64")
        return feed_dict