train.py 5.8 KB
Newer Older
Y
Yelrose 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pgl
import model# import LabelGraphGCN
from pgl import data_loader
from pgl.utils.logger import log
import paddle.fluid as fluid
import numpy as np
import time
import argparse
from build_model import build_model
import yaml
from easydict import EasyDict as edict
Y
Yelrose 已提交
25
import tqdm
Y
Yelrose 已提交
26

Y
Yelrose 已提交
27 28
def normalize(feat):
    return feat / np.maximum(np.sum(feat, -1, keepdims=True), 1)
Y
Yelrose 已提交
29 30


Y
Yelrose 已提交
31
def load(name, normalized_feature=True):
Y
Yelrose 已提交
32 33 34
    if name == 'cora':
        dataset = data_loader.CoraDataset()
    elif name == "pubmed":
Y
Yelrose 已提交
35
        dataset = data_loader.CitationDataset("pubmed", symmetry_edges=True)
Y
Yelrose 已提交
36
    elif name == "citeseer":
Y
Yelrose 已提交
37
        dataset = data_loader.CitationDataset("citeseer", symmetry_edges=True)
Y
Yelrose 已提交
38 39 40 41
    else:
        raise ValueError(name + " dataset doesn't exists")

    indegree = dataset.graph.indegree()
Y
Yelrose 已提交
42 43
    norm = np.maximum(indegree.astype("float32"), 1)
    norm = np.power(norm, -0.5)
Y
Yelrose 已提交
44
    dataset.graph.node_feat["norm"] = np.expand_dims(norm, -1)
Y
Yelrose 已提交
45 46 47
    dataset.graph.node_feat["words"] =  normalize(dataset.graph.node_feat["words"])
    return dataset

Y
Yelrose 已提交
48

Y
Yelrose 已提交
49 50
def main(args, config):
    dataset = load(args.dataset, args.feature_pre_normalize)
Y
Yelrose 已提交
51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67
    place = fluid.CUDAPlace(0) if args.use_cuda else fluid.CPUPlace()
    train_program = fluid.default_main_program()
    startup_program = fluid.default_startup_program()
    with fluid.program_guard(train_program, startup_program):
        with fluid.unique_name.guard():
            gw, loss, acc = build_model(dataset,
                                config=config,
                                phase="train",
                                main_prog=train_program)

    test_program = fluid.Program()
    with fluid.program_guard(test_program, startup_program):
        with fluid.unique_name.guard():
            _gw, v_loss, v_acc = build_model(dataset,
                config=config,
                phase="test",
                main_prog=test_program)
Y
Yelrose 已提交
68

Y
Yelrose 已提交
69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86
    test_program = test_program.clone(for_test=True)

    exe = fluid.Executor(place)


    train_index = dataset.train_index
    train_label = np.expand_dims(dataset.y[train_index], -1)
    train_index = np.expand_dims(train_index, -1)

    val_index = dataset.val_index
    val_label = np.expand_dims(dataset.y[val_index], -1)
    val_index = np.expand_dims(val_index, -1)

    test_index = dataset.test_index
    test_label = np.expand_dims(dataset.y[test_index], -1)
    test_index = np.expand_dims(test_index, -1)

    dur = []
Y
Yelrose 已提交
87 88 89 90 91 92

    # Feed data
    feed_dict = gw.to_feed(dataset.graph)


    best_test = []
Y
Yelrose 已提交
93
 
Y
Yelrose 已提交
94 95 96 97 98 99 100 101 102 103
    for run in range(args.runs):
        exe.run(startup_program)
        cal_val_acc = []
        cal_test_acc = []
        cal_val_loss = []
        cal_test_loss = []
        for epoch in tqdm.tqdm(range(args.epoch)):
            feed_dict["node_index"] = np.array(train_index, dtype="int64")
            feed_dict["node_label"] = np.array(train_label, dtype="int64")
            train_loss, train_acc = exe.run(train_program,
Y
Yelrose 已提交
104 105 106 107 108
                                        feed=feed_dict,
                                        fetch_list=[loss, acc],
                                        return_numpy=True)


Y
Yelrose 已提交
109 110 111
            feed_dict["node_index"] = np.array(val_index, dtype="int64")
            feed_dict["node_label"] = np.array(val_label, dtype="int64")
            val_loss, val_acc = exe.run(test_program,
Y
Yelrose 已提交
112 113 114 115
                                    feed=feed_dict,
                                    fetch_list=[v_loss, v_acc],
                                    return_numpy=True)

Y
Yelrose 已提交
116 117
            cal_val_acc.append(val_acc[0])
            cal_val_loss.append(val_loss[0])
Y
Yelrose 已提交
118

Y
Yelrose 已提交
119 120 121
            feed_dict["node_index"] = np.array(test_index, dtype="int64")
            feed_dict["node_label"] = np.array(test_label, dtype="int64")
            test_loss, test_acc = exe.run(test_program,
Y
Yelrose 已提交
122 123 124 125
                                  feed=feed_dict,
                                  fetch_list=[v_loss, v_acc],
                                  return_numpy=True)

Y
Yelrose 已提交
126 127
            cal_test_acc.append(test_acc[0])
            cal_test_loss.append(test_loss[0])
Y
Yelrose 已提交
128

Y
Yelrose 已提交
129
     
Y
Yelrose 已提交
130 131 132 133
        log.info("Runs %s: Model: %s Best Test Accuracy: %f" % (run, config.model_name,
                  cal_test_acc[np.argmin(cal_val_loss)]))

        best_test.append(cal_test_acc[np.argmin(cal_val_loss)])
Y
Yelrose 已提交
134 135 136
    log.info("Dataset: %s Best Test Accuracy: %f ( stddev: %f )" % (args.dataset, np.mean(best_test), np.std(best_test)))
    print("Dataset: %s Best Test Accuracy: %f ( stddev: %f )" % (args.dataset, np.mean(best_test), np.std(best_test)))
    
Y
Yelrose 已提交
137
    
Y
Yelrose 已提交
138 139 140 141 142 143 144 145


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Benchmarking Citation Network')
    parser.add_argument(
        "--dataset", type=str, default="cora", help="dataset (cora, pubmed)")
    parser.add_argument("--use_cuda", action='store_true', help="use_cuda")
    parser.add_argument("--conf", type=str, help="config file for models")
Y
Yelrose 已提交
146
    parser.add_argument("--epoch", type=int, default=200, help="Epoch")
Y
Yelrose 已提交
147
    parser.add_argument("--runs", type=int, default=10, help="runs")
Y
Yelrose 已提交
148
    parser.add_argument("--feature_pre_normalize", type=bool, default=True, help="pre_normalize feature")
Y
Yelrose 已提交
149 150 151 152
    args = parser.parse_args()
    config = edict(yaml.load(open(args.conf), Loader=yaml.FullLoader))
    log.info(args)
    main(args, config)