diff --git a/ogb_examples/nodeproppred/unimp/main_arxiv_large.py b/ogb_examples/nodeproppred/unimp/main_arxiv_large.py new file mode 100644 index 0000000000000000000000000000000000000000..940694ab0d6d498a535c0f66391f4f1f387e5a0a --- /dev/null +++ b/ogb_examples/nodeproppred/unimp/main_arxiv_large.py @@ -0,0 +1,196 @@ +import math +import torch +import paddle +import pgl +import numpy as np +import paddle.fluid as F +import paddle.fluid.layers as L +from pgl.contrib.ogb.nodeproppred.dataset_pgl import PglNodePropPredDataset +from ogb.nodeproppred import Evaluator +from utils import to_undirected, add_self_loop, linear_warmup_decay +from model_large import Arxiv_baseline_model, Arxiv_label_embedding_model +from optimization import optimization +import argparse +from tqdm import tqdm +evaluator = Evaluator(name='ogbn-arxiv') + + +def get_config(): + parser = argparse.ArgumentParser() + + ## model_base_arg + model_group=parser.add_argument_group('model_base_arg') + model_group.add_argument('--num_layers', default=3, type=int) + model_group.add_argument('--hidden_size', default=80, type=int) + model_group.add_argument('--num_heads', default=5, type=int) + model_group.add_argument('--dropout', default=0.3, type=float) + model_group.add_argument('--attn_dropout', default=0.1, type=float) + + ## embed_arg + embed_group=parser.add_argument_group('embed_arg') + embed_group.add_argument('--use_label_e', action='store_true') + embed_group.add_argument('--label_rate', default=0.65, type=float) + + ## train_arg + train_group=parser.add_argument_group('train_arg') + train_group.add_argument('--runs', default=10, type=int ) + train_group.add_argument('--epochs', default=2000, type=int ) + train_group.add_argument('--lr', default=0.001, type=float) + train_group.add_argument('--place', default=-1, type=int) + train_group.add_argument('--log_file', default='result_arxiv.txt', type=str) + return parser.parse_args() + + +def optimizer_func(lr=0.01): + return F.optimizer.AdamOptimizer(learning_rate=lr, regularization=F.regularizer.L2Decay( + regularization_coeff=0.0005)) + +def eval_test(parser, program, model, test_exe, graph, y_true, split_idx): + feed_dict=model.gw.to_feed(graph) + if parser.use_label_e: + feed_dict['label']=y_true + feed_dict['label_idx']=split_idx['train'] + feed_dict['attn_drop']=-1 + + avg_cost_np = test_exe.run( + program=program, + feed=feed_dict, + fetch_list=[model.out_feat]) + + y_pred=avg_cost_np[0].argmax(axis=-1) + y_pred=np.expand_dims(y_pred, 1) + + train_acc = evaluator.eval({ + 'y_true': y_true[split_idx['train']], + 'y_pred': y_pred[split_idx['train']], + })['acc'] + val_acc = evaluator.eval({ + 'y_true': y_true[split_idx['valid']], + 'y_pred': y_pred[split_idx['valid']], + })['acc'] + test_acc = evaluator.eval({ + 'y_true': y_true[split_idx['test']], + 'y_pred': y_pred[split_idx['test']], + })['acc'] + + return train_acc, val_acc, test_acc + +def train_loop(parser, start_program, main_program, test_program, + model, graph, label, split_idx, exe, run_id, wf=None): + + exe.run(start_program) + max_acc=0 + max_step=0 + max_val_acc=0 + max_cor_acc=0 + max_cor_step=0 + + for epoch_id in tqdm(range(parser.epochs)): + + if parser.use_label_e: + feed_dict=model.gw.to_feed(graph) + train_idx_temp = split_idx['train'] + np.random.shuffle(train_idx_temp) + label_idx=train_idx_temp[ :int(parser.label_rate*len(train_idx_temp))] + unlabel_idx=train_idx_temp[int(parser.label_rate*len(train_idx_temp)): ] + feed_dict['label']=label + feed_dict['label_idx']= label_idx + feed_dict['train_idx']= unlabel_idx + feed_dict['attn_drop']=parser.attn_dropout + else: + feed_dict=model.gw.to_feed(graph) + feed_dict['label']=label + feed_dict['train_idx']= split_idx['train'] + + loss = exe.run(main_program, + feed=feed_dict, + fetch_list=[model.avg_cost]) + loss = loss[0] + + result = eval_test(parser, test_program, model, exe, graph, label, split_idx) + train_acc, valid_acc, test_acc = result + + max_val_acc=max(valid_acc, max_val_acc) + if max_val_acc==valid_acc: + max_cor_acc=test_acc + max_cor_step=epoch_id + + + if max_acc==result[2]: + max_step=epoch_id + result_t=(f'Run: {run_id:02d}, ' + f'Epoch: {epoch_id:02d}, ' + f'Loss: {loss[0]:.4f}, ' + f'Train: {100 * train_acc:.2f}%, ' + f'Valid: {100 * valid_acc:.2f}%, ' + f'Test: {100 * test_acc:.2f}% \n' + f'max_val: {100 * max_val_acc:.2f}%, ' + f'max_val_Test: {100 * max_cor_acc:.2f}%, ' + f'max_val_step: {max_cor_step}\n' + ) + if (epoch_id+1)%100==0: + print(result_t) + wf.write(result_t) + wf.write('\n') + wf.flush() + return max_cor_acc + + +if __name__ == '__main__': + parser = get_config() + print('===========args==============') + print(parser) + print('=============================') + + startup_prog = F.default_startup_program() + train_prog = F.default_main_program() + + + place=F.CPUPlace() if parser.place <0 else F.CUDAPlace(parser.place) + + dataset = PglNodePropPredDataset(name="ogbn-arxiv") + split_idx=dataset.get_idx_split() + + graph, label = dataset[0] + print(label.shape) + + graph=to_undirected(graph) + graph=add_self_loop(graph) + + with F.unique_name.guard(): + with F.program_guard(train_prog, startup_prog): + gw = pgl.graph_wrapper.GraphWrapper( + name="arxiv", node_feat=graph.node_feat_info(), place=place) + + if parser.use_label_e: + model=Arxiv_label_embedding_model(gw, parser.hidden_size, parser.num_heads, + parser.dropout, parser.num_layers) + else: + model=Arxiv_baseline_model(gw, parser.hidden_size, parser.num_heads, + parser.dropout, parser.num_layers) + + test_prog=train_prog.clone(for_test=True) + model.train_program() + + adam_optimizer = optimizer_func(parser.lr) + adam_optimizer = F.optimizer.RecomputeOptimizer(adam_optimizer) + adam_optimizer._set_checkpoints(model.checkpoints) + adam_optimizer.minimize(model.avg_cost) + + + exe = F.Executor(place) + + wf = open(parser.log_file, 'w', encoding='utf-8') + total_test_acc=0.0 + for run_i in range(parser.runs): + total_test_acc+=train_loop(parser, startup_prog, train_prog, test_prog, model, + graph, label, split_idx, exe, run_i, wf) + wf.write(f'average: {100 * (total_test_acc/parser.runs):.2f}%') + wf.close() + +# Runned 10 times +# Val Accs: [74.64, 74.74, 74.71, 74.83, 74.82, 74.77, 74.75, 74.86, 74.6, 74.76] +# Test Accs: [73.79, 73.82, 74.0, 73.85, 74.02, 73.67, 73.65, 73.87, 73.66, 73.6] +# Average val accuracy: 74.74799999999999 ± 0.0775628777186617 +# Average test accuracy: 73.793 ± 0.13957435294494433 +# params: 1162515 \ No newline at end of file diff --git a/ogb_examples/nodeproppred/unimp/model_large.py b/ogb_examples/nodeproppred/unimp/model_large.py new file mode 100644 index 0000000000000000000000000000000000000000..c5e030b51d9203bab476dd601d3a6ded69791c99 --- /dev/null +++ b/ogb_examples/nodeproppred/unimp/model_large.py @@ -0,0 +1,147 @@ +'''build label embedding model +''' +import math +import pgl +import paddle.fluid as F +import paddle.fluid.layers as L +from pgl.utils import paddle_helper +from module.transformer_gat_pgl import transformer_gat_pgl +from module.model_unimp_large import graph_transformer, linear, attn_appnp + +class Arxiv_baseline_model(): + def __init__(self, gw, hidden_size, num_heads, dropout, num_layers): + '''Arxiv_baseline_model + ''' + self.gw=gw + self.hidden_size=hidden_size + self.num_heads= num_heads + self.dropout= dropout + self.num_layers=num_layers + self.out_size=40 + self.embed_size=128 + self.checkpoints=[] + self.build_model() + + def embed_input(self, feature): + + lay_norm_attr = F.ParamAttr(initializer=F.initializer.ConstantInitializer(value=1)) + lay_norm_bias = F.ParamAttr(initializer=F.initializer.ConstantInitializer(value=0)) + feature = L.layer_norm(feature, name='layer_norm_feature_input', + param_attr=lay_norm_attr, + bias_attr=lay_norm_bias) + + return feature + + + def build_model(self): + + feature_batch = self.embed_input(self.gw.node_feat['feat']) + feature_batch = L.dropout(feature_batch, dropout_prob=self.dropout, + dropout_implementation='upscale_in_train') + for i in range(self.num_layers - 1): + feature_batch = graph_transformer(str(i), self.gw, feature_batch, + hidden_size=self.hidden_size, + num_heads=self.num_heads, + concat=True, skip_feat=True, + layer_norm=True, relu=True, gate=True) + if self.dropout > 0: + feature_batch = L.dropout(feature_batch, dropout_prob=self.dropout, + dropout_implementation='upscale_in_train') + self.checkpoints.append(feature_batch) + + feature_batch = graph_transformer(str(self.num_layers - 1), self.gw, feature_batch, + hidden_size=self.out_size, + num_heads=self.num_heads, + concat=False, skip_feat=True, + layer_norm=False, relu=False, gate=True) + self.checkpoints.append(feature_batch) + self.out_feat = feature_batch + + def train_program(self,): + label = F.data(name="label", shape=[None, 1], dtype="int64") + train_idx = F.data(name='train_idx', shape=[None], dtype="int64") + prediction = L.gather(self.out_feat, train_idx, overwrite=False) + label = L.gather(label, train_idx, overwrite=False) + cost = L.softmax_with_cross_entropy(logits=prediction, label=label) + avg_cost = L.mean(cost) + self.avg_cost = avg_cost + +class Arxiv_label_embedding_model(): + def __init__(self, gw, hidden_size, num_heads, dropout, num_layers): + '''Arxiv_label_embedding_model + ''' + self.gw = gw + self.hidden_size = hidden_size + self.num_heads = num_heads + self.dropout = dropout + self.num_layers = num_layers + self.out_size = 40 + self.embed_size = 128 + self.checkpoints = [] + self.build_model() + + def label_embed_input(self, feature): + label = F.data(name="label", shape=[None, 1], dtype="int64") + label_idx = F.data(name='label_idx', shape=[None], dtype="int64") + label = L.reshape(label, shape=[-1]) + label = L.gather(label, label_idx, overwrite=False) + + lay_norm_attr = F.ParamAttr(initializer=F.initializer.ConstantInitializer(value=1)) + lay_norm_bias = F.ParamAttr(initializer=F.initializer.ConstantInitializer(value=0)) + feature = L.layer_norm(feature, name='layer_norm_feature_input1', + param_attr=lay_norm_attr, + bias_attr=lay_norm_bias) + + + embed_attr = F.ParamAttr(initializer=F.initializer.NormalInitializer(loc=0.0, scale=1.0)) + embed = F.embedding(input=label, size=(self.out_size, self.embed_size), param_attr=embed_attr ) + lay_norm_attr = F.ParamAttr(initializer=F.initializer.ConstantInitializer(value=1)) + lay_norm_bias = F.ParamAttr(initializer=F.initializer.ConstantInitializer(value=0)) + embed = L.layer_norm(embed, name='layer_norm_feature_input2', + param_attr=lay_norm_attr, + bias_attr=lay_norm_bias) + embed = L.relu(embed) + + feature_label = L.gather(feature, label_idx, overwrite=False) + feature_label = feature_label + embed + feature = L.scatter(feature, label_idx, feature_label, overwrite=True) + + return feature + + def build_model(self): + label_feature = self.label_embed_input(self.gw.node_feat['feat']) + feature_batch = L.dropout(label_feature, dropout_prob=self.dropout, + dropout_implementation='upscale_in_train') + + for i in range(self.num_layers - 1): + feature_batch, _, cks = graph_transformer(str(i), self.gw, feature_batch, + hidden_size=self.hidden_size, + num_heads=self.num_heads, + attn_drop=True, + concat=True, skip_feat=True, + layer_norm=True, relu=True, gate=True) + if self.dropout > 0: + feature_batch = L.dropout(feature_batch, dropout_prob=self.dropout, + dropout_implementation='upscale_in_train') + self.checkpoints = self.checkpoints + cks + + feature_batch, attn, cks = graph_transformer(str(self.num_layers - 1), self.gw, feature_batch, + hidden_size=self.out_size, + num_heads=self.num_heads+1, + concat=False, skip_feat=True, + layer_norm=False, relu=False, gate=True) + self.checkpoints.append(feature_batch) + feature_batch = attn_appnp(self.gw, feature_batch, attn, alpha=0.2, k_hop=10) + + self.checkpoints.append(feature_batch) + self.out_feat = feature_batch + + def train_program(self,): + label = F.data(name="label", shape=[None, 1], dtype="int64") + train_idx = F.data(name='train_idx', shape=[None], dtype="int64") + prediction = L.gather(self.out_feat, train_idx, overwrite=False) + label = L.gather(label, train_idx, overwrite=False) + cost = L.softmax_with_cross_entropy(logits=prediction, label=label) + avg_cost = L.mean(cost) + self.avg_cost = avg_cost + diff --git a/ogb_examples/nodeproppred/unimp/module/model_unimp_large.py b/ogb_examples/nodeproppred/unimp/module/model_unimp_large.py new file mode 100644 index 0000000000000000000000000000000000000000..c5f338f24b17dc6504bde0e477ae067d0d042805 --- /dev/null +++ b/ogb_examples/nodeproppred/unimp/module/model_unimp_large.py @@ -0,0 +1,245 @@ +import pgl +import paddle.fluid as F +import paddle.fluid.layers as L +from pgl.utils import paddle_helper +from pgl import message_passing +import math + +def graph_transformer(name, gw, + feature, + hidden_size, + num_heads=4, + attn_drop=False, + edge_feature=None, + concat=True, + skip_feat=True, + gate=False, + layer_norm=True, + relu=True, + is_test=False): + """Implementation of graph Transformer from UniMP + + This is an implementation of the paper Unified Massage Passing Model for Semi-Supervised Classification + (https://arxiv.org/abs/2009.03509). + + Args: + name: Granph Transformer layer names. + + gw: Graph wrapper object (:code:`StaticGraphWrapper` or :code:`GraphWrapper`) + + feature: A tensor with shape (num_nodes, feature_size). + + hidden_size: The hidden size for graph transformer. + + num_heads: The head number in graph transformer. + + attn_drop: Dropout rate for attention. + + edge_feature: A tensor with shape (num_edges, feature_size). + + concat: Reshape the output (num_nodes, num_heads, hidden_size) by concat (num_nodes, hidden_size * num_heads) or mean (num_nodes, hidden_size) + + skip_feat: Whether use skip connect + + gate: Whether add skip_feat and output up with gate weight + + layer_norm: Whether use layer_norm for output + + relu: Whether use relu activation for output + + is_test: Whether in test phrase. + + Return: + A tensor with shape (num_nodes, hidden_size * num_heads) or (num_nodes, hidden_size) + """ + def send_attention(src_feat, dst_feat, edge_feat): + if edge_feat is None or not edge_feat: + output = src_feat["k_h"] * dst_feat["q_h"] + output = L.reduce_sum(output, -1) + output = output / (hidden_size ** 0.5) +# alpha = paddle_helper.sequence_softmax(output) + return {"alpha": output, "v": src_feat["v_h"]} # batch x h batch x h x feat + else: + edge_feat = edge_feat["edge"] + edge_feat = L.reshape(edge_feat, [-1, num_heads, hidden_size]) + output = (src_feat["k_h"] + edge_feat) * dst_feat["q_h"] + output = L.reduce_sum(output, -1) + output = output / (hidden_size ** 0.5) +# alpha = paddle_helper.sequence_softmax(output) + return {"alpha": output, "v": (src_feat["v_h"] + edge_feat)} # batch x h batch x h x feat + + class Reduce_attention(): + def __init__(self,): + self.alpha = None + def __call__(self, msg): + alpha = msg["alpha"] # lod-tensor (batch_size, num_heads) + if attn_drop: + old_h = alpha + dropout = F.data(name='attn_drop', shape=[1], dtype="int64") + u = L.uniform_random(shape=L.cast(L.shape(alpha)[:1], 'int64'), min=0., max=1.) + keeped = L.cast(u > dropout, dtype="float32") + self_attn_mask = L.scale(x=keeped, scale=10000.0, bias=-1.0, bias_after_scale=False) + n_head_self_attn_mask = L.stack( x=[self_attn_mask] * num_heads, axis=1) + n_head_self_attn_mask.stop_gradient = True + alpha = n_head_self_attn_mask+ alpha + alpha = L.lod_reset(alpha, old_h) + + h = msg["v"] + alpha = paddle_helper.sequence_softmax(alpha) + + self.alpha = alpha + old_h = h + h = h * alpha + h = L.lod_reset(h, old_h) + h = L.sequence_pool(h, "sum") + + if concat: + h = L.reshape(h, [-1, num_heads * hidden_size]) + else: + h = L.reduce_mean(h, dim=1) + return h + reduce_attention = Reduce_attention() + + q = linear(feature, hidden_size * num_heads, name=name + '_q_weight', init_type='gcn') + k = linear(feature, hidden_size * num_heads, name=name + '_k_weight', init_type='gcn') + v = linear(feature, hidden_size * num_heads, name=name + '_v_weight', init_type='gcn') + + + reshape_q = L.reshape(q, [-1, num_heads, hidden_size]) + reshape_k = L.reshape(k, [-1, num_heads, hidden_size]) + reshape_v = L.reshape(v, [-1, num_heads, hidden_size]) + + msg = gw.send( + send_attention, + nfeat_list=[("q_h", reshape_q), ("k_h", reshape_k), + ("v_h", reshape_v)], + efeat_list=edge_feature) + out_feat = gw.recv(msg, reduce_attention) + checkpoints=[out_feat] + + if skip_feat: + if concat: + + out_feat, cks = appnp(gw, out_feat, k_hop=1) +# out_feat, cks = appnp(gw, out_feat, k_hop=3) + checkpoints.append(out_feat) + +# The UniMP-xxlarge will come soon. +# out_feat, cks = appnp(gw, out_feat, k_hop=6) +# out_feat, cks = appnp(gw, out_feat, k_hop=9) +# checkpoints = checkpoints + cks + + + skip_feature = linear(feature, hidden_size * num_heads, name=name + '_skip_weight', init_type='lin') + else: + + skip_feature = linear(feature, hidden_size, name=name + '_skip_weight', init_type='lin') + + if gate: + temp_output = L.concat([skip_feature, out_feat, out_feat - skip_feature], axis=-1) + gate_f = L.sigmoid(linear(temp_output, 1, name=name + '_gate_weight', init_type='lin')) + out_feat = skip_feature * gate_f + out_feat * (1 - gate_f) + else: + out_feat = skip_feature + out_feat + + if layer_norm: + lay_norm_attr = F.ParamAttr(initializer=F.initializer.ConstantInitializer(value=1)) + lay_norm_bias = F.ParamAttr(initializer=F.initializer.ConstantInitializer(value=0)) + out_feat = L.layer_norm(out_feat, name=name + '_layer_norm', + param_attr=lay_norm_attr, + bias_attr=lay_norm_bias) + if relu: + out_feat = L.relu(out_feat) + + return out_feat, reduce_attention.alpha, checkpoints + + +def appnp(gw, feature, alpha=0.2, k_hop=10): + """Implementation of APPNP of "Predict then Propagate: Graph Neural Networks + meet Personalized PageRank" (ICLR 2019). + Args: + gw: Graph wrapper object (:code:`StaticGraphWrapper` or :code:`GraphWrapper`) + feature: A tensor with shape (num_nodes, feature_size). + edge_dropout: Edge dropout rate. + k_hop: K Steps for Propagation + Return: + A tensor with shape (num_nodes, hidden_size) + """ + + def send_src_copy(src_feat, dst_feat, edge_feat): + feature = src_feat["h"] + return feature + + def get_norm(indegree): + float_degree = L.cast(indegree, dtype="float32") + float_degree = L.clamp(float_degree, min=1.0) + norm = L.pow(float_degree, factor=-0.5) + return norm + + cks = [] + h0 = feature + ngw = gw + norm = get_norm(ngw.indegree()) + + for i in range(k_hop): + + feature = feature * norm + msg = gw.send(send_src_copy, nfeat_list=[("h", feature)]) + feature = gw.recv(msg, "sum") + feature = feature * norm + feature = feature * (1 - alpha) + h0 * alpha + + if (i+1) % 3 == 0: + cks.append(feature) + return feature, cks + +def attn_appnp(gw, feature, attn, alpha=0.2, k_hop=10): + """Attention based APPNP to Make model output deeper + Args: + gw: Graph wrapper object (:code:`StaticGraphWrapper` or :code:`GraphWrapper`) + attn: Using the attntion as transition matrix for APPNP + feature: A tensor with shape (num_nodes, feature_size). + k_hop: K Steps for Propagation + Return: + A tensor with shape (num_nodes, hidden_size) + """ + def send_src_copy(src_feat, dst_feat, edge_feat): + feature = src_feat["h"] + return feature + + h0 = feature + attn = L.reduce_mean(attn, 1) + for i in range(k_hop): + msg = gw.send(send_src_copy, nfeat_list=[("h", feature)]) + msg = msg * attn + feature = gw.recv(msg, "sum") + feature = feature * (1 - alpha) + h0 * alpha + return feature + +def linear(input, hidden_size, name, with_bias=True, init_type='gcn'): + """fluid.layers.fc with different init_type + """ + + if init_type == 'gcn': + fc_w_attr = F.ParamAttr(initializer=F.initializer.XavierInitializer()) + fc_bias_attr = F.ParamAttr(initializer=F.initializer.ConstantInitializer(0.0)) + else: + fan_in = input.shape[-1] + bias_bound = 1.0 / math.sqrt(fan_in) + fc_bias_attr = F.ParamAttr(initializer=F.initializer.UniformInitializer(low=-bias_bound, high=bias_bound)) + + negative_slope = math.sqrt(5) + gain = math.sqrt(2.0 / (1 + negative_slope ** 2)) + std = gain / math.sqrt(fan_in) + weight_bound = math.sqrt(3.0) * std + fc_w_attr = F.ParamAttr(initializer=F.initializer.UniformInitializer(low=-weight_bound, high=weight_bound)) + + if not with_bias: + fc_bias_attr = False + + output = L.fc(input, + hidden_size, + param_attr=fc_w_attr, + name=name, + bias_attr=fc_bias_attr) + return output \ No newline at end of file