reader.py 5.5 KB
Newer Older
Y
yelrose 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import pickle as pkl
import paddle
import paddle.fluid as fluid
import pgl
import time
Y
Yelrose 已提交
20
from pgl.utils import mp_reader
Y
yelrose 已提交
21 22 23 24 25 26
from pgl.utils.logger import log
import train
import time


def node_batch_iter(nodes, node_label, batch_size):
Y
Yelrose 已提交
27 28
    """node_batch_iter
    """
Y
yelrose 已提交
29 30 31 32 33 34 35 36 37 38
    perm = np.arange(len(nodes))
    np.random.shuffle(perm)
    start = 0
    while start < len(nodes):
        index = perm[start:start + batch_size]
        start += batch_size
        yield nodes[index], node_label[index]


def traverse(item):
Y
Yelrose 已提交
39 40
    """traverse
    """
Y
yelrose 已提交
41 42 43 44 45 46 47 48 49
    if isinstance(item, list) or isinstance(item, np.ndarray):
        for i in iter(item):
            for j in traverse(i):
                yield j
    else:
        yield item


def flat_node_and_edge(nodes, eids):
Y
Yelrose 已提交
50 51
    """flat_node_and_edge
    """
Y
yelrose 已提交
52 53 54 55 56
    nodes = list(set(traverse(nodes)))
    eids = list(set(traverse(eids)))
    return nodes, eids


Y
Yelrose 已提交
57 58 59 60
def worker(batch_info, graph, graph_wrapper, samples):
    """Worker
    """

Y
yelrose 已提交
61
    def work():
Y
Yelrose 已提交
62 63 64
        """work
        """
        first = True
Y
yelrose 已提交
65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80
        for batch_train_samples, batch_train_labels in batch_info:
            start_nodes = batch_train_samples
            nodes = start_nodes
            eids = []
            for max_deg in samples:
                pred, pred_eid = graph.sample_predecessor(
                    start_nodes, max_degree=max_deg, return_eids=True)
                last_nodes = nodes
                nodes = [nodes, pred]
                eids = [eids, pred_eid]
                nodes, eids = flat_node_and_edge(nodes, eids)
                # Find new nodes
                start_nodes = list(set(nodes) - set(last_nodes))
                if len(start_nodes) == 0:
                    break

Y
Yelrose 已提交
81 82 83 84 85 86 87 88
            subgraph = graph.subgraph(nodes=nodes, eid=eids)
            sub_node_index = subgraph.reindex_from_parrent_nodes(
                batch_train_samples)
            feed_dict = graph_wrapper.to_feed(subgraph)
            feed_dict["node_label"] = np.expand_dims(
                np.array(
                    batch_train_labels, dtype="int64"), -1)
            feed_dict["node_index"] = sub_node_index
Y
yelrose 已提交
89 90 91 92 93 94 95 96 97 98 99 100
            yield feed_dict

    return work


def multiprocess_graph_reader(graph,
                              graph_wrapper,
                              samples,
                              node_index,
                              batch_size,
                              node_label,
                              num_workers=4):
Y
Yelrose 已提交
101 102 103
    """multiprocess_graph_reader
    """

Y
yelrose 已提交
104
    def parse_to_subgraph(rd):
Y
Yelrose 已提交
105 106 107
        """parse_to_subgraph
        """

Y
yelrose 已提交
108
        def work():
Y
Yelrose 已提交
109 110 111
            """work
            """
            last = time.time()
Y
yelrose 已提交
112
            for data in rd():
Y
Yelrose 已提交
113 114 115 116
                this = time.time()
                feed_dict = data
                now = time.time()
                last = now
Y
yelrose 已提交
117 118 119 120 121
                yield feed_dict

        return work

    def reader():
Y
Yelrose 已提交
122
        """reader"""
Y
yelrose 已提交
123 124 125 126 127 128 129 130
        batch_info = list(
            node_batch_iter(
                node_index, node_label, batch_size=batch_size))
        block_size = int(len(batch_info) / num_workers + 1)
        reader_pool = []
        for i in range(num_workers):
            reader_pool.append(
                worker(batch_info[block_size * i:block_size * (i + 1)], graph,
Y
Yelrose 已提交
131 132 133
                       graph_wrapper, samples))
        multi_process_sample = mp_reader.multiprocess_reader(
            reader_pool, use_pipe=True, queue_size=1000)
Y
yelrose 已提交
134 135 136 137 138 139 140 141
        r = parse_to_subgraph(multi_process_sample)
        return paddle.reader.buffered(r, 1000)

    return reader()


def graph_reader(graph, graph_wrapper, samples, node_index, batch_size,
                 node_label):
Y
Yelrose 已提交
142 143
    """graph_reader"""

Y
yelrose 已提交
144
    def reader():
Y
Yelrose 已提交
145
        """reader"""
Y
yelrose 已提交
146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174
        for batch_train_samples, batch_train_labels in node_batch_iter(
                node_index, node_label, batch_size=batch_size):
            start_nodes = batch_train_samples
            nodes = start_nodes
            eids = []
            for max_deg in samples:
                pred, pred_eid = graph.sample_predecessor(
                    start_nodes, max_degree=max_deg, return_eids=True)
                last_nodes = nodes
                nodes = [nodes, pred]
                eids = [eids, pred_eid]
                nodes, eids = flat_node_and_edge(nodes, eids)
                # Find new nodes
                start_nodes = list(set(nodes) - set(last_nodes))
                if len(start_nodes) == 0:
                    break

            subgraph = graph.subgraph(nodes=nodes, eid=eids)
            feed_dict = graph_wrapper.to_feed(subgraph)
            sub_node_index = subgraph.reindex_from_parrent_nodes(
                batch_train_samples)

            feed_dict["node_label"] = np.expand_dims(
                np.array(
                    batch_train_labels, dtype="int64"), -1)
            feed_dict["node_index"] = np.array(sub_node_index, dtype="int32")
            yield feed_dict

    return paddle.reader.buffered(reader, 1000)