diff --git a/examples/citation_benchmark/build_model.py b/examples/citation_benchmark/build_model.py index 9813f54a82af5f62287dd6f79388ff88a56a2957..3b92abc9b5bf5573544d388fcad78b6608fde99c 100644 --- a/examples/citation_benchmark/build_model.py +++ b/examples/citation_benchmark/build_model.py @@ -13,7 +13,7 @@ def build_model(dataset, config, phase, main_prog): GraphModel = getattr(model, config.model_name) m = GraphModel(config=config, num_class=dataset.num_classes) - logits = m.forward(gw, gw.node_feat["words"]) + logits = m.forward(gw, gw.node_feat["words"], phase) node_index = fluid.layers.data( "node_index", @@ -33,11 +33,6 @@ def build_model(dataset, config, phase, main_prog): loss = fluid.layers.mean(loss) if phase == "train": - #adam = fluid.optimizer.Adam( - # learning_rate=config.learning_rate, - # regularization=fluid.regularizer.L2DecayRegularizer( - # regularization_coeff=config.weight_decay)) - #adam.minimize(loss) AdamW(loss=loss, learning_rate=config.learning_rate, weight_decay=config.weight_decay, diff --git a/examples/citation_benchmark/config/appnp.yaml b/examples/citation_benchmark/config/appnp.yaml index a9fc393b5c34312d20a285ca5b5decbaebfdaf17..c4637be717cf78d5428a3b6c246b9d8e16b5a96a 100644 --- a/examples/citation_benchmark/config/appnp.yaml +++ b/examples/citation_benchmark/config/appnp.yaml @@ -6,3 +6,4 @@ learning_rate: 0.01 dropout: 0.5 hidden_size: 64 weight_decay: 0.0005 +edge_dropout: 0.00 diff --git a/examples/citation_benchmark/config/gat.yaml b/examples/citation_benchmark/config/gat.yaml index 4be1d7ccf8dfc1b456d8392a9e12b30ab0ea9616..4c08e7fb1c1bf99c2e13140eecf85957935c9ad8 100644 --- a/examples/citation_benchmark/config/gat.yaml +++ b/examples/citation_benchmark/config/gat.yaml @@ -6,3 +6,4 @@ feat_drop: 0.6 attn_drop: 0.6 num_heads: 8 hidden_size: 8 +edge_dropout: 0.1 diff --git a/examples/citation_benchmark/config/gcn.yaml b/examples/citation_benchmark/config/gcn.yaml index 533c33c0d4a80d076660eeecb7969b3a133ed147..beb9129068e12ad077156d4f030f1c06f1cdbb01 100644 --- a/examples/citation_benchmark/config/gcn.yaml +++ b/examples/citation_benchmark/config/gcn.yaml @@ -1,6 +1,7 @@ model_name: GCN num_layers: 1 dropout: 0.5 -hidden_size: 64 +hidden_size: 16 learning_rate: 0.01 weight_decay: 0.0005 +edge_dropout: 0.0 diff --git a/examples/citation_benchmark/config/sgc.yaml b/examples/citation_benchmark/config/sgc.yaml new file mode 100644 index 0000000000000000000000000000000000000000..5db1a64238145f5e882a0f889a1a84f4317bdd74 --- /dev/null +++ b/examples/citation_benchmark/config/sgc.yaml @@ -0,0 +1,4 @@ +model_name: SGC +num_layers: 2 +learning_rate: 0.2 +weight_decay: 0.000005 diff --git a/examples/citation_benchmark/model.py b/examples/citation_benchmark/model.py index 9fa6adcf97d93b8d9c59de1e7a0cde89c9d67168..e325f375da5005cf8d0fdf26226f7463b720cc7c 100644 --- a/examples/citation_benchmark/model.py +++ b/examples/citation_benchmark/model.py @@ -2,6 +2,12 @@ import pgl import paddle.fluid.layers as L import pgl.layers.conv as conv +def get_norm(indegree): + norm = L.pow(L.cast(indegree, dtype="float32") + 1e-6, factor=-0.5) + norm = norm * L.cast(indegree > 0, dtype="float32") + return norm + + class GCN(object): """Implement of GCN """ @@ -10,26 +16,48 @@ class GCN(object): self.num_layers = config.get("num_layers", 1) self.hidden_size = config.get("hidden_size", 64) self.dropout = config.get("dropout", 0.5) + self.edge_dropout = config.get("edge_dropout", 0.0) - def forward(self, graph_wrapper, feature): + def forward(self, graph_wrapper, feature, phase): + for i in range(self.num_layers): - feature = pgl.layers.gcn(graph_wrapper, + + if phase == "train": + ngw = pgl.sample.edge_drop(graph_wrapper, self.edge_dropout) + norm = get_norm(ngw.indegree()) + else: + ngw = graph_wrapper + norm = graph_wrapper.node_feat["norm"] + + feature = L.dropout( + feature, + self.dropout, + dropout_implementation='upscale_in_train') + + feature = pgl.layers.gcn(ngw, feature, self.hidden_size, activation="relu", - norm=graph_wrapper.node_feat["norm"], + norm=norm, name="layer_%s" % i) - feature = L.dropout( + feature = L.dropout( feature, self.dropout, dropout_implementation='upscale_in_train') - feature = conv.gcn(graph_wrapper, + if phase == "train": + ngw = pgl.sample.edge_drop(graph_wrapper, self.edge_dropout) + norm = get_norm(ngw.indegree()) + else: + ngw = graph_wrapper + norm = graph_wrapper.node_feat["norm"] + + feature = conv.gcn(ngw, feature, self.num_class, activation=None, - norm=graph_wrapper.node_feat["norm"], + norm=norm, name="output") return feature @@ -43,10 +71,18 @@ class GAT(object): self.hidden_size = config.get("hidden_size", 8) self.feat_dropout = config.get("feat_drop", 0.6) self.attn_dropout = config.get("attn_drop", 0.6) + self.edge_dropout = config.get("edge_dropout", 0.0) + + def forward(self, graph_wrapper, feature, phase): + if phase == "train": + edge_dropout = 0 + else: + edge_dropout = self.edge_dropout - def forward(self, graph_wrapper, feature): for i in range(self.num_layers): - feature = conv.gat(graph_wrapper, + ngw = pgl.sample.edge_drop(graph_wrapper, edge_dropout) + + feature = conv.gat(ngw, feature, self.hidden_size, activation="elu", @@ -55,7 +91,8 @@ class GAT(object): feat_drop=self.feat_dropout, attn_drop=self.attn_dropout) - feature = conv.gat(graph_wrapper, + ngw = pgl.sample.edge_drop(graph_wrapper, edge_dropout) + feature = conv.gat(ngw, feature, self.num_class, num_heads=1, @@ -75,8 +112,14 @@ class APPNP(object): self.dropout = config.get("dropout", 0.5) self.alpha = config.get("alpha", 0.1) self.k_hop = config.get("k_hop", 10) + self.edge_dropout = config.get("edge_dropout", 0.0) + + def forward(self, graph_wrapper, feature, phase): + if phase == "train": + edge_dropout = 0 + else: + edge_dropout = self.edge_dropout - def forward(self, graph_wrapper, feature): for i in range(self.num_layers): feature = L.dropout( feature, @@ -93,8 +136,24 @@ class APPNP(object): feature = conv.appnp(graph_wrapper, feature=feature, - norm=graph_wrapper.node_feat["norm"], + edge_dropout=edge_dropout, alpha=self.alpha, k_hop=self.k_hop) return feature +class SGC(object): + """Implement of SGC""" + def __init__(self, config, num_class): + self.num_class = num_class + self.num_layers = config.get("num_layers", 1) + + def forward(self, graph_wrapper, feature, phase): + feature = conv.appnp(graph_wrapper, + feature=feature, + norm=graph_wrapper.node_feat["norm"], + alpha=0, + k_hop=self.num_layers) + feature.stop_gradient=True + feature = L.fc(feature, self.num_class, act=None, name="output") + return feature + diff --git a/examples/citation_benchmark/train.py b/examples/citation_benchmark/train.py index 32e78475754f238c5876e64267b4fe47aac2f0c1..5c4e0a7e4cf7d6ad58511eb22a65c2eae5292d24 100644 --- a/examples/citation_benchmark/train.py +++ b/examples/citation_benchmark/train.py @@ -63,6 +63,7 @@ def main(args, config): config=config, phase="test", main_prog=test_program) + test_program = test_program.clone(for_test=True) exe = fluid.Executor(place) @@ -86,7 +87,7 @@ def main(args, config): cal_val_acc = [] cal_test_acc = [] - for epoch in range(300): + for epoch in range(args.epoch): if epoch >= 3: t0 = time.time() feed_dict = gw.to_feed(dataset.graph) @@ -123,11 +124,10 @@ def main(args, config): test_loss = test_loss[0] test_acc = test_acc[0] cal_test_acc.append(test_acc) - if epoch % 10 == 0: - log.info("Epoch %d " % epoch + + + log.info("Epoch %d " % epoch + "Train Loss: %f " % train_loss + "Train Acc: %f " % train_acc - + "Val Loss: %f " % val_loss + "Val Acc: %f " % val_acc - +" Test Loss: %f " % test_loss + " Test Acc: %f " % test_acc) + + "Val Loss: %f " % val_loss + "Val Acc: %f " % val_acc) cal_val_acc = np.array(cal_val_acc) log.info("Model: %s Best Test Accuracy: %f" % (config.model_name, @@ -140,6 +140,7 @@ if __name__ == '__main__': "--dataset", type=str, default="cora", help="dataset (cora, pubmed)") parser.add_argument("--use_cuda", action='store_true', help="use_cuda") parser.add_argument("--conf", type=str, help="config file for models") + parser.add_argument("--epoch", type=int, default=200, help="Epoch") args = parser.parse_args() config = edict(yaml.load(open(args.conf), Loader=yaml.FullLoader)) log.info(args) diff --git a/pgl/__init__.py b/pgl/__init__.py index 7543265d30492f8d1fe7a898f948166ae89001ea..e364540e8192d6dbebab72e1af552f00a4919c72 100644 --- a/pgl/__init__.py +++ b/pgl/__init__.py @@ -22,3 +22,4 @@ from pgl import heter_graph from pgl import heter_graph_wrapper from pgl import contrib from pgl import message_passing +from pgl import sample diff --git a/pgl/graph_wrapper.py b/pgl/graph_wrapper.py index 91dda8f78796aedd493b37e85a92ad9ecb1c6664..12bc05186e82c00b2fbecd8d0f9dfd1b5a9a0a68 100644 --- a/pgl/graph_wrapper.py +++ b/pgl/graph_wrapper.py @@ -19,6 +19,7 @@ for PaddlePaddle. import warnings import numpy as np import paddle.fluid as fluid +import paddle.fluid.layers as L from pgl.utils import op from pgl.utils import paddle_helper @@ -47,10 +48,10 @@ def recv(dst, uniq_dst, bucketing_index, msg, reduce_function, num_nodes, try: out_dim = msg.shape[-1] - init_output = fluid.layers.fill_constant( + init_output = L.fill_constant( shape=[num_nodes, out_dim], value=0, dtype=msg.dtype) init_output.stop_gradient = False - empty_msg_flag = fluid.layers.cast(num_edges > 0, dtype=msg.dtype) + empty_msg_flag = L.cast(num_edges > 0, dtype=msg.dtype) msg = msg * empty_msg_flag output = paddle_helper.scatter_add(init_output, dst, msg) return output @@ -59,7 +60,7 @@ def recv(dst, uniq_dst, bucketing_index, msg, reduce_function, num_nodes, "scatter_add is not supported with paddle version <= 1.5") def sum_func(message): - return fluid.layers.sequence_pool(message, "sum") + return L.sequence_pool(message, "sum") reduce_function = sum_func @@ -67,13 +68,13 @@ def recv(dst, uniq_dst, bucketing_index, msg, reduce_function, num_nodes, output = reduce_function(bucketed_msg) output_dim = output.shape[-1] - empty_msg_flag = fluid.layers.cast(num_edges > 0, dtype=output.dtype) + empty_msg_flag = L.cast(num_edges > 0, dtype=output.dtype) output = output * empty_msg_flag - init_output = fluid.layers.fill_constant( + init_output = L.fill_constant( shape=[num_nodes, output_dim], value=0, dtype=output.dtype) init_output.stop_gradient = True - final_output = fluid.layers.scatter(init_output, uniq_dst, output) + final_output = L.scatter(init_output, uniq_dst, output) return final_output @@ -104,6 +105,7 @@ class BaseGraphWrapper(object): self._node_ids = None self._graph_lod = None self._num_graph = None + self._num_edges = None self._data_name_prefix = "" def __repr__(self): @@ -470,7 +472,7 @@ class StaticGraphWrapper(BaseGraphWrapper): class GraphWrapper(BaseGraphWrapper): """Implement a graph wrapper that creates a graph data holders - that attributes and features in the graph are :code:`fluid.layers.data`. + that attributes and features in the graph are :code:`L.data`. And we provide interface :code:`to_feed` to help converting :code:`Graph` data into :code:`feed_dict`. @@ -546,65 +548,65 @@ class GraphWrapper(BaseGraphWrapper): def __create_graph_attr_holders(self): """Create data holders for graph attributes. """ - self._num_edges = fluid.layers.data( + self._num_edges = L.data( self._data_name_prefix + '/num_edges', shape=[1], append_batch_size=False, dtype="int64", stop_gradient=True) - self._num_graph = fluid.layers.data( + self._num_graph = L.data( self._data_name_prefix + '/num_graph', shape=[1], append_batch_size=False, dtype="int64", stop_gradient=True) - self._edges_src = fluid.layers.data( + self._edges_src = L.data( self._data_name_prefix + '/edges_src', shape=[None], append_batch_size=False, dtype="int64", stop_gradient=True) - self._edges_dst = fluid.layers.data( + self._edges_dst = L.data( self._data_name_prefix + '/edges_dst', shape=[None], append_batch_size=False, dtype="int64", stop_gradient=True) - self._num_nodes = fluid.layers.data( + self._num_nodes = L.data( self._data_name_prefix + '/num_nodes', shape=[1], append_batch_size=False, dtype='int64', stop_gradient=True) - self._edge_uniq_dst = fluid.layers.data( + self._edge_uniq_dst = L.data( self._data_name_prefix + "/uniq_dst", shape=[None], append_batch_size=False, dtype="int64", stop_gradient=True) - self._graph_lod = fluid.layers.data( + self._graph_lod = L.data( self._data_name_prefix + "/graph_lod", shape=[None], append_batch_size=False, dtype="int32", stop_gradient=True) - self._edge_uniq_dst_count = fluid.layers.data( + self._edge_uniq_dst_count = L.data( self._data_name_prefix + "/uniq_dst_count", shape=[None], append_batch_size=False, dtype="int32", stop_gradient=True) - self._node_ids = fluid.layers.data( + self._node_ids = L.data( self._data_name_prefix + "/node_ids", shape=[None], append_batch_size=False, dtype="int64", stop_gradient=True) - self._indegree = fluid.layers.data( + self._indegree = L.data( self._data_name_prefix + "/indegree", shape=[None], append_batch_size=False, @@ -627,7 +629,7 @@ class GraphWrapper(BaseGraphWrapper): node_feat_dtype): """Create data holders for node features. """ - feat_holder = fluid.layers.data( + feat_holder = L.data( self._data_name_prefix + '/node_feat/' + node_feat_name, shape=node_feat_shape, append_batch_size=False, @@ -640,7 +642,7 @@ class GraphWrapper(BaseGraphWrapper): edge_feat_dtype): """Create edge holders for edge features. """ - feat_holder = fluid.layers.data( + feat_holder = L.data( self._data_name_prefix + '/edge_feat/' + edge_feat_name, shape=edge_feat_shape, append_batch_size=False, @@ -719,3 +721,56 @@ class GraphWrapper(BaseGraphWrapper): """Return the holder list. """ return self._holder_list + + +def get_degree(edge, num_nodes): + init_output = L.fill_constant( + shape=[num_nodes], value=0, dtype="float32") + init_output.stop_gradient = True + final_output = L.scatter(init_output, + edge, + L.full_like(edge, 1, dtype="float32"), + overwrite=False) + return final_output + +class DropEdgeWrapper(BaseGraphWrapper): + """Implement of Edge Drop """ + def __init__(self, graph_wrapper, dropout): + super(DropEdgeWrapper, self).__init__() + + # Copy Node's information + for key, value in graph_wrapper.node_feat.items(): + self.node_feat_tensor_dict[key] = value + + self._num_nodes = graph_wrapper.num_nodes + self._graph_lod = graph_wrapper.graph_lod + self._num_graph = graph_wrapper.num_graph + self._node_ids = L.range(0, self._num_nodes, step=1, dtype="int32") + + # Dropout Edges + src, dst = graph_wrapper.edges + u = L.uniform_random(shape=L.cast(L.shape(src), 'int64'), min=0., max=1.) + + # Avoid Empty Edges + keeped = L.cast(u > dropout, dtype="float32") + self._num_edges = L.reduce_sum(L.cast(keeped, "int32")) + keeped = keeped + L.cast(self._num_edges == 0, dtype="float32") + + keeped = (keeped > 0.5) + src = paddle_helper.masked_select(src, keeped) + dst = paddle_helper.masked_select(dst, keeped) + src.stop_gradient=True + dst.stop_gradient=True + self._edges_src = src + self._edges_dst = dst + + for key, value in graph_wrapper.edge_feat.items(): + self.edge_feat_tensor_dict[key] = paddle_helper.masked_select(value, keeped) + + self._edge_uniq_dst, _, uniq_count = L.unique_with_counts(dst, dtype="int32") + self._edge_uniq_dst.stop_gradient=True + last = L.reduce_sum(uniq_count, keep_dim=True) + uniq_count = L.cumsum(uniq_count, exclusive=True) + self._edge_uniq_dst_count = L.concat([uniq_count, last]) + self._edge_uniq_dst_count.stop_gradient=True + self._indegree = get_degree(self._edges_dst, self._num_nodes) diff --git a/pgl/layers/conv.py b/pgl/layers/conv.py index 9cf5608b3a477509cc6f387741a5e6f6b41a3377..ff8ebca42f5bbd0c746e0c68cd4de3a926adc38f 100644 --- a/pgl/layers/conv.py +++ b/pgl/layers/conv.py @@ -14,6 +14,7 @@ """This package implements common layers to help building graph neural networks. """ +import pgl import paddle.fluid as fluid from pgl.utils import paddle_helper from pgl import message_passing @@ -404,7 +405,14 @@ def gen_conv(gw, return output -def appnp(gw, feature, norm=None, alpha=0.2, k_hop=10): +def get_norm(indegree): + """Get Laplacian Normalization""" + norm = fluid.layers.pow(fluid.layers.cast(indegree, dtype="float32") + 1e-6, + factor=-0.5) + norm = norm * fluid.layers.cast(indegree > 0, dtype="float32") + return norm + +def appnp(gw, feature, edge_dropout=0, alpha=0.2, k_hop=10): """Implementation of APPNP of "Predict then Propagate: Graph Neural Networks meet Personalized PageRank" (ICLR 2019). @@ -413,8 +421,7 @@ def appnp(gw, feature, norm=None, alpha=0.2, k_hop=10): feature: A tensor with shape (num_nodes, feature_size). - norm: If :code:`norm` is not None, then the feature will be normalized. Norm must - be tensor with shape (num_nodes,) and dtype float32. + edge_dropout: Edge dropout rate. k_hop: K Steps for Propagation @@ -427,17 +434,21 @@ def appnp(gw, feature, norm=None, alpha=0.2, k_hop=10): return feature h0 = feature + ngw = gw + norm = get_norm(ngw.indegree()) for i in range(k_hop): - if norm is not None: - feature = feature * norm + if edge_dropout > 1e-5: + ngw = pgl.sample.edge_drop(gw, edge_dropout) + norm = get_norm(ngw.indegree()) + + feature = feature * norm msg = gw.send(send_src_copy, nfeat_list=[("h", feature)]) feature = gw.recv(msg, "sum") - if norm is not None: - feature = feature * norm + feature = feature * norm feature = feature * (1 - alpha) + h0 * alpha return feature diff --git a/pgl/sample.py b/pgl/sample.py index 81241d5dc6f8224283abebeaa35da69644e9d1a1..e1c19cb987d04f66956b874212c03d6f523d9d03 100644 --- a/pgl/sample.py +++ b/pgl/sample.py @@ -516,3 +516,10 @@ def graph_saint_random_walk_sample(graph, nodes=sample_nodes, eid=eids, with_node_feat=True, with_edge_feat=True) subgraph.node_feat["index"] = np.array(sample_nodes, dtype="int64") return subgraph + + +def edge_drop(graph_wrapper, dropout_rate): + if dropout_rate < 1e-5: + return graph_wrapper + else: + return pgl.graph_wrapper.DropEdgeWrapper(graph_wrapper, dropout_rate) diff --git a/pgl/utils/paddle_helper.py b/pgl/utils/paddle_helper.py index 3570fac2c9da6b668108d4216cac9d415ce68dcd..2dd6ea248966a734ed2cbfefd342cd90655407f4 100644 --- a/pgl/utils/paddle_helper.py +++ b/pgl/utils/paddle_helper.py @@ -250,3 +250,20 @@ def scatter_max(input, index, updates): output = fluid.layers.scatter(input, index, updates, mode='max') return output + +def masked_select(input, mask): + """masked_select + + Slice the value from given Mask + + Args: + input: Input tensor to be selected + + mask: A bool tensor for sliced. + + Return: + Part of inputs where mask is True. + """ + index = fluid.layers.where(mask) + return fluid.layers.gather(input, index) +