diff --git a/pgl/layers/graph_pool.py b/pgl/layers/graph_pool.py index a88468f7b6ef12131c4554db9339c35472a66f11..a1b6249c21d9022a30e8f51f2a6950a187ee3b14 100644 --- a/pgl/layers/graph_pool.py +++ b/pgl/layers/graph_pool.py @@ -19,7 +19,7 @@ from pgl import graph_wrapper from pgl.utils import paddle_helper from pgl.utils import op -__all__ = ['graph_pooling'] +__all__ = ['graph_pooling', 'graph_norm'] def graph_pooling(gw, node_feat, pool_type): @@ -40,3 +40,29 @@ def graph_pooling(gw, node_feat, pool_type): graph_feat = op.nested_lod_reset(node_feat, gw.graph_lod) graph_feat = fluid.layers.sequence_pool(graph_feat, pool_type) return graph_feat + + +def graph_norm(gw, feature): + """Implementation of graph normalization + + Reference Paper: BENCHMARKING GRAPH NEURAL NETWORKS + + Each node features is divied by sqrt(num_nodes) per graphs. + + Args: + gw: Graph wrapper object (:code:`StaticGraphWrapper` or :code:`GraphWrapper`) + + graph_level (default: False): If :code:`graph_level=True` return shape (num_graphs, 1) + elif :code:`graph_level=False return shape (num_nodes, 1)` + + Return: + A tensor with shape (num_graphs, 1) or (num_node, 1) + """ + nodes = fluid.layers.fill_constant( + [gw.num_nodes, 1], dtype="float32", value=1.0) + norm = graph_pooling(gw, nodes, pool_type="sum") + norm = fluid.layers.sqrt(norm) + feature_lod = op.nested_lod_reset(feature, gw.graph_lod) + norm = fluid.layers.sequence_expand_as(norm, feature_lod) + norm.stop_gradient = True + return feature_lod / norm