graph_op.py 3.2 KB
Newer Older
W
Webbley 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""This package implements common layers to help building
graph neural networks.
"""
17 18
import paddle.fluid as F 
import paddle.fluid.layers as L
W
Webbley 已提交
19 20 21 22
from pgl import graph_wrapper
from pgl.utils import paddle_helper
from pgl.utils import op

23
__all__ = ['graph_pooling', 'graph_norm', 'graph_gather']
W
Webbley 已提交
24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41


def graph_pooling(gw, node_feat, pool_type):
    """Implementation of graph pooling 

    This is an implementation of graph pooling

    Args:
        gw: Graph wrapper object (:code:`StaticGraphWrapper` or :code:`GraphWrapper`)

        node_feat: A tensor with shape (num_nodes, feature_size).

        pool_type: The type of pooling ("sum", "average" , "min")

    Return:
        A tensor with shape (num_graph, hidden_size)
    """
    graph_feat = op.nested_lod_reset(node_feat, gw.graph_lod)
42
    graph_feat = L.sequence_pool(graph_feat, pool_type)
W
Webbley 已提交
43
    return graph_feat
Y
Yelrose 已提交
44 45 46 47 48 49 50 51 52 53 54 55


def graph_norm(gw, feature):
    """Implementation of graph normalization
   
    Reference Paper: BENCHMARKING GRAPH NEURAL NETWORKS
   
    Each node features is divied by sqrt(num_nodes) per graphs.

    Args:
        gw: Graph wrapper object (:code:`StaticGraphWrapper` or :code:`GraphWrapper`)

Y
Yelrose 已提交
56
        feature: A tensor with shape (num_nodes, hidden_size)
Y
Yelrose 已提交
57 58

    Return:
Y
Yelrose 已提交
59
        A tensor with shape (num_nodes, hidden_size)
Y
Yelrose 已提交
60
    """
61
    nodes = L.fill_constant(
Y
Yelrose 已提交
62 63
        [gw.num_nodes, 1], dtype="float32", value=1.0)
    norm = graph_pooling(gw, nodes, pool_type="sum")
64
    norm = L.sqrt(norm)
Y
Yelrose 已提交
65
    feature_lod = op.nested_lod_reset(feature, gw.graph_lod)
66
    norm = L.sequence_expand_as(norm, feature_lod)
Y
Yelrose 已提交
67 68
    norm.stop_gradient = True
    return feature_lod / norm
69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88


def graph_gather(gw, feature, index):
    """Implementation of graph gather 

    Gather the corresponding index for each graph.
   
    Args:
        gw: Graph wrapper object (:code:`StaticGraphWrapper` or :code:`GraphWrapper`)

        feature: A tensor with shape (num_nodes, ). 

        index (int32): A tensor with K-rank where the first dim denotes the graph.
                        Shape (num_graph, ) or (num_graph, k1, k2, k3, ..., kn).
                       WARNING: We dont support negative index.

    Return:
        A tensor with shape (num_graph, k1, k2, k3, ..., kn, hidden_size)
    """
    shape = L.shape(index)
Y
Yelrose 已提交
89
    output_dim = int(feature.shape[-1])
90 91 92 93 94 95
    index = index + gw.graph_lod[:-1]
    index = L.reshape(index, [-1])
    feature = L.gather(feature, index, overwrite=False)
    new_shape = []
    for i in range(shape.shape[0]):
        new_shape.append(shape[i])
Y
Yelrose 已提交
96
    new_shape.append(output_dim)
97 98 99
    feature = L.reshape(feature, new_shape)
    return feature