提交 d96c9759 编写于 作者: Y Yelrose

add models

上级 90ff1f7f
# Easy Paper Reproduction for Citation Network (Cora/Pubmed/Citeseer)
import pgl
import model
from pgl import data_loader
import paddle.fluid as fluid
import numpy as np
import time
from optimization import AdamW
def build_model(dataset, config, phase, main_prog):
gw = pgl.graph_wrapper.GraphWrapper(
name="graph",
node_feat=dataset.graph.node_feat_info())
GraphModel = getattr(model, config.model_name)
m = GraphModel(config=config, num_class=dataset.num_classes)
logits = m.forward(gw, gw.node_feat["words"])
node_index = fluid.layers.data(
"node_index",
shape=[None, 1],
dtype="int64",
append_batch_size=False)
node_label = fluid.layers.data(
"node_label",
shape=[None, 1],
dtype="int64",
append_batch_size=False)
pred = fluid.layers.gather(logits, node_index)
loss, pred = fluid.layers.softmax_with_cross_entropy(
logits=pred, label=node_label, return_softmax=True)
acc = fluid.layers.accuracy(input=pred, label=node_label, k=1)
loss = fluid.layers.mean(loss)
if phase == "train":
#adam = fluid.optimizer.Adam(
# learning_rate=config.learning_rate,
# regularization=fluid.regularizer.L2DecayRegularizer(
# regularization_coeff=config.weight_decay))
#adam.minimize(loss)
AdamW(loss=loss,
learning_rate=config.learning_rate,
weight_decay=config.weight_decay,
train_program=main_prog)
return gw, loss, acc
model_name: APPNP
k_hop: 10
alpha: 0.1
num_layer2: 1
learning_rate: 0.01
dropout: 0.5
hidden_size: 64
weight_decay: 0.0005
model_name: GAT
learning_rate: 0.005
weight_decay: 0.0005
num_layers: 1
feat_drop: 0.6
attn_drop: 0.6
num_heads: 8
hidden_size: 8
model_name: GCN
num_layers: 1
dropout: 0.5
hidden_size: 64
learning_rate: 0.01
weight_decay: 0.0005
import pgl
import paddle.fluid.layers as L
import pgl.layers.conv as conv
class GCN(object):
"""Implement of GCN
"""
def __init__(self, config, num_class):
self.num_class = num_class
self.num_layers = config.get("num_layers", 1)
self.hidden_size = config.get("hidden_size", 64)
self.dropout = config.get("dropout", 0.5)
def forward(self, graph_wrapper, feature):
for i in range(self.num_layers):
feature = pgl.layers.gcn(graph_wrapper,
feature,
self.hidden_size,
activation="relu",
norm=graph_wrapper.node_feat["norm"],
name="layer_%s" % i)
feature = L.dropout(
feature,
self.dropout,
dropout_implementation='upscale_in_train')
feature = conv.gcn(graph_wrapper,
feature,
self.num_class,
activation=None,
norm=graph_wrapper.node_feat["norm"],
name="output")
return feature
class GAT(object):
"""Implement of GAT"""
def __init__(self, config, num_class):
self.num_class = num_class
self.num_layers = config.get("num_layers", 1)
self.num_heads = config.get("num_heads", 8)
self.hidden_size = config.get("hidden_size", 8)
self.feat_dropout = config.get("feat_drop", 0.6)
self.attn_dropout = config.get("attn_drop", 0.6)
def forward(self, graph_wrapper, feature):
for i in range(self.num_layers):
feature = conv.gat(graph_wrapper,
feature,
self.hidden_size,
activation="elu",
name="gat_layer_%s" % i,
num_heads=self.num_heads,
feat_drop=self.feat_dropout,
attn_drop=self.attn_dropout)
feature = conv.gat(graph_wrapper,
feature,
self.num_class,
num_heads=1,
activation=None,
feat_drop=self.feat_dropout,
attn_drop=self.attn_dropout,
name="output")
return feature
class APPNP(object):
"""Implement of APPNP"""
def __init__(self, config, num_class):
self.num_class = num_class
self.num_layers = config.get("num_layers", 1)
self.hidden_size = config.get("hidden_size", 64)
self.dropout = config.get("dropout", 0.5)
self.alpha = config.get("alpha", 0.1)
self.k_hop = config.get("k_hop", 10)
def forward(self, graph_wrapper, feature):
for i in range(self.num_layers):
feature = L.dropout(
feature,
self.dropout,
dropout_implementation='upscale_in_train')
feature = L.fc(feature, self.hidden_size, act="relu", name="lin%s" % i)
feature = L.dropout(
feature,
self.dropout,
dropout_implementation='upscale_in_train')
feature = L.fc(feature, self.num_class, act=None, name="output")
feature = conv.appnp(graph_wrapper,
feature=feature,
norm=graph_wrapper.node_feat["norm"],
alpha=self.alpha,
k_hop=self.k_hop)
return feature
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Optimization and learning rate scheduling."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import paddle.fluid as fluid
def AdamW(loss,
learning_rate,
train_program,
weight_decay):
optimizer = fluid.optimizer.Adam(learning_rate=learning_rate)
def exclude_from_weight_decay(name):
if name.find("layer_norm") > -1:
return True
bias_suffix = ["_bias", "_b", ".b_0"]
for suffix in bias_suffix:
if name.endswith(suffix):
return True
return False
param_list = dict()
for param in train_program.global_block().all_parameters():
param_list[param.name] = param * 1.0
param_list[param.name].stop_gradient = True
_, param_grads = optimizer.minimize(loss)
if weight_decay > 0:
for param, grad in param_grads:
if exclude_from_weight_decay(param.name):
continue
with param.block.program._optimized_guard(
[param, grad]), fluid.framework.name_scope("weight_decay"):
updated_param = param - param_list[
param.name] * weight_decay * learning_rate
fluid.layers.assign(output=param, input=updated_param)
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pgl
import model# import LabelGraphGCN
from pgl import data_loader
from pgl.utils.logger import log
import paddle.fluid as fluid
import numpy as np
import time
import argparse
from build_model import build_model
import yaml
from easydict import EasyDict as edict
def load(name):
if name == 'cora':
dataset = data_loader.CoraDataset()
elif name == "pubmed":
dataset = data_loader.CitationDataset("pubmed", symmetry_edges=False)
elif name == "citeseer":
dataset = data_loader.CitationDataset("citeseer", symmetry_edges=False)
else:
raise ValueError(name + " dataset doesn't exists")
return dataset
def main(args, config):
dataset = load(args.dataset)
indegree = dataset.graph.indegree()
norm = np.zeros_like(indegree, dtype="float32")
norm[indegree > 0] = np.power(indegree[indegree > 0], -0.5)
dataset.graph.node_feat["norm"] = np.expand_dims(norm, -1)
place = fluid.CUDAPlace(0) if args.use_cuda else fluid.CPUPlace()
train_program = fluid.default_main_program()
startup_program = fluid.default_startup_program()
with fluid.program_guard(train_program, startup_program):
with fluid.unique_name.guard():
gw, loss, acc = build_model(dataset,
config=config,
phase="train",
main_prog=train_program)
test_program = fluid.Program()
with fluid.program_guard(test_program, startup_program):
with fluid.unique_name.guard():
_gw, v_loss, v_acc = build_model(dataset,
config=config,
phase="test",
main_prog=test_program)
test_program = test_program.clone(for_test=True)
exe = fluid.Executor(place)
exe.run(startup_program)
train_index = dataset.train_index
train_label = np.expand_dims(dataset.y[train_index], -1)
train_index = np.expand_dims(train_index, -1)
log.info("Number of Train %s" % len(train_index))
val_index = dataset.val_index
val_label = np.expand_dims(dataset.y[val_index], -1)
val_index = np.expand_dims(val_index, -1)
test_index = dataset.test_index
test_label = np.expand_dims(dataset.y[test_index], -1)
test_index = np.expand_dims(test_index, -1)
dur = []
cal_val_acc = []
cal_test_acc = []
for epoch in range(300):
if epoch >= 3:
t0 = time.time()
feed_dict = gw.to_feed(dataset.graph)
feed_dict["node_index"] = np.array(train_index, dtype="int64")
feed_dict["node_label"] = np.array(train_label, dtype="int64")
train_loss, train_acc = exe.run(train_program,
feed=feed_dict,
fetch_list=[loss, acc],
return_numpy=True)
if epoch >= 3:
time_per_epoch = 1.0 * (time.time() - t0)
dur.append(time_per_epoch)
feed_dict = gw.to_feed(dataset.graph)
feed_dict["node_index"] = np.array(val_index, dtype="int64")
feed_dict["node_label"] = np.array(val_label, dtype="int64")
val_loss, val_acc = exe.run(test_program,
feed=feed_dict,
fetch_list=[v_loss, v_acc],
return_numpy=True)
val_loss = val_loss[0]
val_acc = val_acc[0]
cal_val_acc.append(val_acc)
feed_dict["node_index"] = np.array(test_index, dtype="int64")
feed_dict["node_label"] = np.array(test_label, dtype="int64")
test_loss, test_acc = exe.run(test_program,
feed=feed_dict,
fetch_list=[v_loss, v_acc],
return_numpy=True)
test_loss = test_loss[0]
test_acc = test_acc[0]
cal_test_acc.append(test_acc)
if epoch % 10 == 0:
log.info("Epoch %d " % epoch +
"Train Loss: %f " % train_loss + "Train Acc: %f " % train_acc
+ "Val Loss: %f " % val_loss + "Val Acc: %f " % val_acc
+" Test Loss: %f " % test_loss + " Test Acc: %f " % test_acc)
cal_val_acc = np.array(cal_val_acc)
log.info("Model: %s Best Test Accuracy: %f" % (config.model_name,
cal_test_acc[np.argmax(cal_val_acc)]))
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Benchmarking Citation Network')
parser.add_argument(
"--dataset", type=str, default="cora", help="dataset (cora, pubmed)")
parser.add_argument("--use_cuda", action='store_true', help="use_cuda")
parser.add_argument("--conf", type=str, help="config file for models")
args = parser.parse_args()
config = edict(yaml.load(open(args.conf), Loader=yaml.FullLoader))
log.info(args)
main(args, config)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册