# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """This package implements common layers to help building graph neural networks. """ import paddle.fluid as fluid from pgl import graph_wrapper from pgl.utils import paddle_helper from pgl.utils import op __all__ = ['graph_pooling', 'graph_norm'] def graph_pooling(gw, node_feat, pool_type): """Implementation of graph pooling This is an implementation of graph pooling Args: gw: Graph wrapper object (:code:`StaticGraphWrapper` or :code:`GraphWrapper`) node_feat: A tensor with shape (num_nodes, feature_size). pool_type: The type of pooling ("sum", "average" , "min") Return: A tensor with shape (num_graph, hidden_size) """ graph_feat = op.nested_lod_reset(node_feat, gw.graph_lod) graph_feat = fluid.layers.sequence_pool(graph_feat, pool_type) return graph_feat def graph_norm(gw, feature): """Implementation of graph normalization Reference Paper: BENCHMARKING GRAPH NEURAL NETWORKS Each node features is divied by sqrt(num_nodes) per graphs. Args: gw: Graph wrapper object (:code:`StaticGraphWrapper` or :code:`GraphWrapper`) feature: A tensor with shape (num_nodes, hidden_size) Return: A tensor with shape (num_nodes, hidden_size) """ nodes = fluid.layers.fill_constant( [gw.num_nodes, 1], dtype="float32", value=1.0) norm = graph_pooling(gw, nodes, pool_type="sum") norm = fluid.layers.sqrt(norm) feature_lod = op.nested_lod_reset(feature, gw.graph_lod) norm = fluid.layers.sequence_expand_as(norm, feature_lod) norm.stop_gradient = True return feature_lod / norm