dataloader.py 4.0 KB
Newer Older
D
DesmonDay 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import numpy as np
import collections
import paddle
import pgl
from pgl.utils.logger import log
from pgl.graph import Graph, MultiGraph

def batch_iter(data, batch_size):
    """node_batch_iter
    """
    size = len(data)
    perm = np.arange(size)
    np.random.shuffle(perm)
    start = 0
    while start < size:
        index = perm[start:start + batch_size]
        start += batch_size
        yield data[index]


def scan_batch_iter(data, batch_size):
    """scan_batch_iter
    """
    batch = []
    for example in data.scan():
        batch.append(example)
    if len(batch) == batch_size:
        yield batch
        batch = []

    if len(batch) > 0:
        yield batch


def label_to_onehot(labels):
    """Return one-hot representations of labels
    """
    onehot_labels = []
    for label in labels:
        if label == 0:
            onehot_labels.append([1, 0])
        else:
            onehot_labels.append([0, 1])
    onehot_labels = np.array(onehot_labels)
    return onehot_labels


class GraphDataloader(object):
    """Graph Dataloader
    """
    def __init__(self,
                dataset,
                graph_wrapper,
                batch_size,
                seed=0,
                buf_size=1000,
                shuffle=True):

        self.shuffle = shuffle
        self.seed = seed
        self.batch_size = batch_size
        self.dataset = dataset
        self.buf_size = buf_size
        self.graph_wrapper = graph_wrapper

    def batch_fn(self, batch_examples):
        """ batch_fun batch producer """
        graphs = [b[0] for b in batch_examples]
        labels = [b[1] for b in batch_examples]
        join_graph = MultiGraph(graphs)

        # normalize
        indegree = join_graph.indegree()
        norm = np.zeros_like(indegree, dtype="float32")
        norm[indegree > 0] = np.power(indegree[indegree > 0], -0.5)
        join_graph.node_feat["norm"] = np.expand_dims(norm, -1)
        
        feed_dict = self.graph_wrapper.to_feed(join_graph)
        labels = np.array(labels)
        feed_dict["labels_1dim"] = labels
        labels = label_to_onehot(labels)
        feed_dict["labels"] = labels

        graph_lod = join_graph.graph_lod
        graph_id = []
        for i in range(1, len(graph_lod)):
            graph_node_num = graph_lod[i] - graph_lod[i - 1]
            graph_id += [i - 1] * graph_node_num
        graph_id = np.array(graph_id, dtype="int32")
        feed_dict["graph_id"] = graph_id

        return feed_dict

    def batch_iter(self):
        """ batch_iter """
        if self.shuffle:
            for batch in batch_iter(self, self.batch_size):
                yield batch
        else:
            for batch in scan_batch_iter(self, self.batch_size):
                yield batch			

    def __len__(self):
        """__len__"""
        return len(self.dataset) 

    def __getitem__(self, idx):
        """__getitem__"""
        if isinstance(idx, collections.Iterable):
            return [self.dataset[bidx] for bidx in idx]
        else:
            return self.dataset[idx]

    def __iter__(self):
        """__iter__"""
        def func_run():
            for batch_examples in self.batch_iter():
                batch_dict = self.batch_fn(batch_examples)
                yield batch_dict

        r = paddle.reader.buffered(func_run, self.buf_size)

        for batch in r():
            yield batch

    def scan(self):
        """scan"""
        for example in self.dataset:
            yield example