preprocess.py 3.6 KB
Newer Older
W
update  
wangwenjin 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
W
wangwenjin 已提交
14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29
import ssl
ssl._create_default_https_context = ssl._create_unverified_context
from ogb.nodeproppred import NodePropPredDataset, Evaluator

import pgl
import numpy as np
import os
import time


def get_graph_data(d_name="ogbn-proteins", mini_data=False):
    """
        Param:
            d_name: name of dataset
            mini_data: if mini_data==True, only use a small dataset (for test)
    """
W
update  
wangwenjin 已提交
30
    # import ogb data
W
wangwenjin 已提交
31 32 33 34 35 36 37
    dataset = NodePropPredDataset(name = d_name)
    num_tasks = dataset.num_tasks # obtaining the number of prediction tasks in a dataset

    split_idx = dataset.get_idx_split()
    train_idx, valid_idx, test_idx = split_idx["train"], split_idx["valid"], split_idx["test"]
    graph, label = dataset[0]
    
W
update  
wangwenjin 已提交
38
    # reshape
W
wangwenjin 已提交
39 40
    graph["edge_index"] = graph["edge_index"].T
    
W
update  
wangwenjin 已提交
41
    # mini dataset
W
wangwenjin 已提交
42 43 44 45 46 47 48 49 50 51
    if mini_data: 
        graph['num_nodes'] = 500
        mask = (graph['edge_index'][:, 0] < 500)*(graph['edge_index'][:, 1] < 500)
        graph["edge_index"] = graph["edge_index"][mask]
        graph["edge_feat"] = graph["edge_feat"][mask]
        label = label[:500]
        train_idx = np.arange(0,400)
        valid_idx = np.arange(400,450)
        test_idx = np.arange(450,500)
    
W
update  
wangwenjin 已提交
52

W
wangwenjin 已提交
53
    
W
update  
wangwenjin 已提交
54
    # read/compute node feature
W
wangwenjin 已提交
55 56 57 58 59 60 61
    if mini_data:
        node_feat_path = './dataset/ogbn_proteins_node_feat_small.npy'
    else:
        node_feat_path = './dataset/ogbn_proteins_node_feat.npy'

    new_node_feat = None
    if os.path.exists(node_feat_path):
W
update  
wangwenjin 已提交
62
        print("Begin: read node feature".center(50, '='))
W
wangwenjin 已提交
63
        new_node_feat = np.load(node_feat_path)
W
update  
wangwenjin 已提交
64
        print("End: read node feature".center(50, '='))
W
wangwenjin 已提交
65
    else:
W
update  
wangwenjin 已提交
66
        print("Begin: compute node feature".center(50, '='))
W
wangwenjin 已提交
67 68 69 70 71 72 73
        start = time.perf_counter()
        for i in range(graph['num_nodes']):
            if i % 100 == 0:
                dur = time.perf_counter() - start
                print("{}/{}({}%), times: {:.2f}s".format(
                    i, graph['num_nodes'], i/graph['num_nodes']*100, dur
                ))
W
update  
wangwenjin 已提交
74 75
            mask = (graph['edge_index'][:, 0] == i)
            
W
wangwenjin 已提交
76 77 78 79 80 81 82 83
            current_node_feat = np.mean(np.compress(mask, graph['edge_feat'], axis=0),
                                        axis=0, keepdims=True)
            if i == 0:
                new_node_feat = [current_node_feat]
            else:  
                new_node_feat.append(current_node_feat)

        new_node_feat = np.concatenate(new_node_feat, axis=0)
W
update  
wangwenjin 已提交
84
        print("End: compute node feature".center(50,'='))
W
wangwenjin 已提交
85

W
update  
wangwenjin 已提交
86
        print("Saving node feature in "+node_feat_path.center(50, '='))
W
wangwenjin 已提交
87
        np.save(node_feat_path, new_node_feat)
W
update  
wangwenjin 已提交
88
        print("Saving finish".center(50,'='))
W
wangwenjin 已提交
89 90 91 92
    
    print(new_node_feat)
    
    
W
update  
wangwenjin 已提交
93
    # create graph
W
wangwenjin 已提交
94 95 96 97 98 99
    g = pgl.graph.Graph(
        num_nodes=graph["num_nodes"],
        edges = graph["edge_index"],
        node_feat = {'node_feat': new_node_feat},
        edge_feat = None
    )
W
update  
wangwenjin 已提交
100
    print("Create graph")
W
wangwenjin 已提交
101 102 103
    print(g)
    return g, label, train_idx, valid_idx, test_idx, Evaluator(d_name)