conv.py 16.4 KB
Newer Older
Y
yelrose 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""This package implements common layers to help building
graph neural networks.
"""
Y
Yelrose 已提交
17
import pgl
Y
yelrose 已提交
18
import paddle.fluid as fluid
Y
Yelrose 已提交
19
import paddle.fluid.layers as L
Y
yelrose 已提交
20
from pgl.utils import paddle_helper
F
fengshikun01 已提交
21
from pgl import message_passing
Y
Yelrose 已提交
22
import numpy as np
Y
yelrose 已提交
23

Y
Yelrose 已提交
24
__all__ = ['gcn', 'gat', 'gin', 'gaan', 'gen_conv', 'appnp', 'gcnii']
Y
yelrose 已提交
25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55


def gcn(gw, feature, hidden_size, activation, name, norm=None):
    """Implementation of graph convolutional neural networks (GCN)

    This is an implementation of the paper SEMI-SUPERVISED CLASSIFICATION
    WITH GRAPH CONVOLUTIONAL NETWORKS (https://arxiv.org/pdf/1609.02907.pdf).

    Args:
        gw: Graph wrapper object (:code:`StaticGraphWrapper` or :code:`GraphWrapper`)

        feature: A tensor with shape (num_nodes, feature_size).

        hidden_size: The hidden size for gcn.

        activation: The activation for the output.

        name: Gcn layer names.

        norm: If :code:`norm` is not None, then the feature will be normalized. Norm must
              be tensor with shape (num_nodes,) and dtype float32.

    Return:
        A tensor with shape (num_nodes, hidden_size)
    """

    def send_src_copy(src_feat, dst_feat, edge_feat):
        return src_feat["h"]

    size = feature.shape[-1]
    if size > hidden_size:
Y
Yelrose 已提交
56
        feature = L.fc(feature,
Y
yelrose 已提交
57 58
                                  size=hidden_size,
                                  bias_attr=False,
Y
Yelrose 已提交
59
                                  param_attr=fluid.ParamAttr(name=name))
Y
yelrose 已提交
60 61 62 63 64 65 66 67 68 69

    if norm is not None:
        feature = feature * norm

    msg = gw.send(send_src_copy, nfeat_list=[("h", feature)])

    if size > hidden_size:
        output = gw.recv(msg, "sum")
    else:
        output = gw.recv(msg, "sum")
Y
Yelrose 已提交
70
        output = L.fc(output,
Y
yelrose 已提交
71 72
                                 size=hidden_size,
                                 bias_attr=False,
Y
Yelrose 已提交
73
                                 param_attr=fluid.ParamAttr(name=name))
Y
yelrose 已提交
74 75 76 77

    if norm is not None:
        output = output * norm

Y
Yelrose 已提交
78
    bias = L.create_parameter(
Y
yelrose 已提交
79 80 81 82
        shape=[hidden_size],
        dtype='float32',
        is_bias=True,
        name=name + '_bias')
Y
Yelrose 已提交
83
    output = L.elementwise_add(output, bias, act=activation)
Y
yelrose 已提交
84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125
    return output


def gat(gw,
        feature,
        hidden_size,
        activation,
        name,
        num_heads=8,
        feat_drop=0.6,
        attn_drop=0.6,
        is_test=False):
    """Implementation of graph attention networks (GAT)

    This is an implementation of the paper GRAPH ATTENTION NETWORKS
    (https://arxiv.org/abs/1710.10903).

    Args:
        gw: Graph wrapper object (:code:`StaticGraphWrapper` or :code:`GraphWrapper`)

        feature: A tensor with shape (num_nodes, feature_size).

        hidden_size: The hidden size for gat.

        activation: The activation for the output.

        name: Gat layer names.

        num_heads: The head number in gat.

        feat_drop: Dropout rate for feature.

        attn_drop: Dropout rate for attention.

        is_test: Whether in test phrase.

    Return:
        A tensor with shape (num_nodes, hidden_size * num_heads)
    """

    def send_attention(src_feat, dst_feat, edge_feat):
        output = src_feat["left_a"] + dst_feat["right_a"]
Y
Yelrose 已提交
126
        output = L.leaky_relu(
Y
yelrose 已提交
127 128 129 130 131 132 133 134
            output, alpha=0.2)  # (num_edges, num_heads)
        return {"alpha": output, "h": src_feat["h"]}

    def reduce_attention(msg):
        alpha = msg["alpha"]  # lod-tensor (batch_size, seq_len, num_heads)
        h = msg["h"]
        alpha = paddle_helper.sequence_softmax(alpha)
        old_h = h
Y
Yelrose 已提交
135 136
        h = L.reshape(h, [-1, num_heads, hidden_size])
        alpha = L.reshape(alpha, [-1, num_heads, 1])
Y
yelrose 已提交
137
        if attn_drop > 1e-15:
Y
Yelrose 已提交
138
            alpha = L.dropout(
Y
yelrose 已提交
139 140 141 142 143
                alpha,
                dropout_prob=attn_drop,
                is_test=is_test,
                dropout_implementation="upscale_in_train")
        h = h * alpha
Y
Yelrose 已提交
144 145 146
        h = L.reshape(h, [-1, num_heads * hidden_size])
        h = L.lod_reset(h, old_h)
        return L.sequence_pool(h, "sum")
Y
yelrose 已提交
147 148

    if feat_drop > 1e-15:
Y
Yelrose 已提交
149
        feature = L.dropout(
Y
yelrose 已提交
150 151 152 153 154
            feature,
            dropout_prob=feat_drop,
            is_test=is_test,
            dropout_implementation='upscale_in_train')

Y
Yelrose 已提交
155
    ft = L.fc(feature,
Y
yelrose 已提交
156 157
                         hidden_size * num_heads,
                         bias_attr=False,
Y
Yelrose 已提交
158
                         param_attr=fluid.ParamAttr(name=name + '_weight'))
Y
Yelrose 已提交
159
    left_a = L.create_parameter(
Y
yelrose 已提交
160 161 162
        shape=[num_heads, hidden_size],
        dtype='float32',
        name=name + '_gat_l_A')
Y
Yelrose 已提交
163
    right_a = L.create_parameter(
Y
yelrose 已提交
164 165 166
        shape=[num_heads, hidden_size],
        dtype='float32',
        name=name + '_gat_r_A')
Y
Yelrose 已提交
167 168 169
    reshape_ft = L.reshape(ft, [-1, num_heads, hidden_size])
    left_a_value = L.reduce_sum(reshape_ft * left_a, -1)
    right_a_value = L.reduce_sum(reshape_ft * right_a, -1)
Y
yelrose 已提交
170 171 172 173 174 175

    msg = gw.send(
        send_attention,
        nfeat_list=[("h", ft), ("left_a", left_a_value),
                    ("right_a", right_a_value)])
    output = gw.recv(msg, reduce_attention)
Y
Yelrose 已提交
176
    bias = L.create_parameter(
Y
yelrose 已提交
177 178 179 180 181
        shape=[hidden_size * num_heads],
        dtype='float32',
        is_bias=True,
        name=name + '_bias')
    bias.stop_gradient = True
Y
Yelrose 已提交
182
    output = L.elementwise_add(output, bias, act=activation)
Y
yelrose 已提交
183
    return output
W
Webbley 已提交
184 185


W
Webbley 已提交
186 187 188 189 190 191 192
def gin(gw,
        feature,
        hidden_size,
        activation,
        name,
        init_eps=0.0,
        train_eps=False):
W
Webbley 已提交
193 194 195 196 197
    """Implementation of Graph Isomorphism Network (GIN) layer.

    This is an implementation of the paper How Powerful are Graph Neural Networks?
    (https://arxiv.org/pdf/1810.00826.pdf).

198 199 200
    In their implementation, all MLPs have 2 layers. Batch normalization is applied
    on every hidden layer.

W
Webbley 已提交
201 202 203 204 205 206 207
    Args:
        gw: Graph wrapper object (:code:`StaticGraphWrapper` or :code:`GraphWrapper`)

        feature: A tensor with shape (num_nodes, feature_size).

        name: GIN layer names.

W
Webbley 已提交
208 209 210 211
        hidden_size: The hidden size for gin.

        activation: The activation for the output.

W
Webbley 已提交
212 213 214 215 216 217 218
        init_eps: float, optional
            Initial :math:`\epsilon` value, default is 0.

        train_eps: bool, optional
            if True, :math:`\epsilon` will be a learnable parameter.

    Return:
W
Webbley 已提交
219
        A tensor with shape (num_nodes, hidden_size).
W
Webbley 已提交
220 221 222 223 224
    """

    def send_src_copy(src_feat, dst_feat, edge_feat):
        return src_feat["h"]

Y
Yelrose 已提交
225
    epsilon = L.create_parameter(
W
Webbley 已提交
226 227
        shape=[1, 1],
        dtype="float32",
W
Webbley 已提交
228 229 230
        attr=fluid.ParamAttr(name="%s_eps" % name),
        default_initializer=fluid.initializer.ConstantInitializer(
            value=init_eps))
W
Webbley 已提交
231 232 233 234 235

    if not train_eps:
        epsilon.stop_gradient = True

    msg = gw.send(send_src_copy, nfeat_list=[("h", feature)])
W
Webbley 已提交
236
    output = gw.recv(msg, "sum") + feature * (epsilon + 1.0)
W
Webbley 已提交
237

Y
Yelrose 已提交
238
    output = L.fc(output,
W
Webbley 已提交
239
                             size=hidden_size,
240 241 242 243
                             act=None,
                             param_attr=fluid.ParamAttr(name="%s_w_0" % name),
                             bias_attr=fluid.ParamAttr(name="%s_b_0" % name))

Y
Yelrose 已提交
244
    output = L.layer_norm(
W
Webbley 已提交
245 246 247 248 249 250 251 252 253
        output,
        begin_norm_axis=1,
        param_attr=fluid.ParamAttr(
            name="norm_scale_%s" % (name),
            initializer=fluid.initializer.Constant(1.0)),
        bias_attr=fluid.ParamAttr(
            name="norm_bias_%s" % (name),
            initializer=fluid.initializer.Constant(0.0)), )

254
    if activation is not None:
Y
Yelrose 已提交
255
        output = getattr(L, activation)(output)
256

Y
Yelrose 已提交
257
    output = L.fc(output,
258 259 260 261
                             size=hidden_size,
                             act=activation,
                             param_attr=fluid.ParamAttr(name="%s_w_1" % name),
                             bias_attr=fluid.ParamAttr(name="%s_b_1" % name))
W
Webbley 已提交
262

W
Webbley 已提交
263
    return output
W
wangwenjin 已提交
264

W
wangwenjin 已提交
265 266 267

def gaan(gw, feature, hidden_size_a, hidden_size_v, hidden_size_m, hidden_size_o, heads, name):
    """Implementation of GaAN"""
W
wangwenjin 已提交
268

W
wangwenjin 已提交
269
    def send_func(src_feat, dst_feat, edge_feat):
W
wangwenjin 已提交
270 271
        # 计算每条边上的注意力分数
        # E * (M * D1), 每个 dst 点都查询它的全部邻边的 src 点
W
wangwenjin 已提交
272
        feat_query, feat_key = dst_feat['feat_query'], src_feat['feat_key']
W
wangwenjin 已提交
273 274
        # E * M * D1
        old = feat_query
Y
Yelrose 已提交
275 276
        feat_query = L.reshape(feat_query, [-1, heads, hidden_size_a])
        feat_key = L.reshape(feat_key, [-1, heads, hidden_size_a])
W
wangwenjin 已提交
277
        # E * M
Y
Yelrose 已提交
278
        alpha = L.reduce_sum(feat_key * feat_query, dim=-1)
W
wangwenjin 已提交
279 280 281 282 283 284 285 286

        return {'dst_node_feat': dst_feat['node_feat'],
                'src_node_feat': src_feat['node_feat'],
                'feat_value': src_feat['feat_value'],
                'alpha': alpha,
                'feat_gate': src_feat['feat_gate']}

    def recv_func(message):
W
wangwenjin 已提交
287 288 289 290 291
        # 每条边的终点的特征
        dst_feat = message['dst_node_feat']
        # 每条边的出发点的特征
        src_feat = message['src_node_feat']
        # 每个中心点自己的特征
Y
Yelrose 已提交
292
        x = L.sequence_pool(dst_feat, 'average')
W
wangwenjin 已提交
293
        # 每个中心点的邻居的特征的平均值
Y
Yelrose 已提交
294
        z = L.sequence_pool(src_feat, 'average')
W
wangwenjin 已提交
295 296

        # 计算 gate
W
wangwenjin 已提交
297
        feat_gate = message['feat_gate']
Y
Yelrose 已提交
298 299 300
        g_max = L.sequence_pool(feat_gate, 'max')
        g = L.concat([x, g_max, z], axis=1)
        g = L.fc(g, heads, bias_attr=False, act="sigmoid")
W
wangwenjin 已提交
301

W
wangwenjin 已提交
302
        # softmax
W
wangwenjin 已提交
303
        alpha = message['alpha']
W
wangwenjin 已提交
304
        alpha = paddle_helper.sequence_softmax(alpha) # E * M
W
wangwenjin 已提交
305

W
wangwenjin 已提交
306
        feat_value = message['feat_value'] # E * (M * D2)
W
wangwenjin 已提交
307
        old = feat_value
Y
Yelrose 已提交
308 309 310 311
        feat_value = L.reshape(feat_value, [-1, heads, hidden_size_v]) # E * M * D2
        feat_value = L.elementwise_mul(feat_value, alpha, axis=0)
        feat_value = L.reshape(feat_value, [-1, heads*hidden_size_v]) # E * (M * D2)
        feat_value = L.lod_reset(feat_value, old)
W
wangwenjin 已提交
312

Y
Yelrose 已提交
313
        feat_value = L.sequence_pool(feat_value, 'sum') # N * (M * D2)
W
wangwenjin 已提交
314

Y
Yelrose 已提交
315
        feat_value = L.reshape(feat_value, [-1, heads, hidden_size_v]) # N * M * D2
W
wangwenjin 已提交
316

Y
Yelrose 已提交
317 318
        output = L.elementwise_mul(feat_value, g, axis=0)
        output = L.reshape(output, [-1, heads * hidden_size_v]) # N * (M * D2)
W
wangwenjin 已提交
319

Y
Yelrose 已提交
320
        output = L.concat([x, output], axis=1)
W
wangwenjin 已提交
321 322 323

        return output

W
wangwenjin 已提交
324 325 326 327 328
    # feature N * D

    # 计算每个点自己需要发送出去的内容
    # 投影后的特征向量
    # N * (D1 * M)
Y
Yelrose 已提交
329
    feat_key = L.fc(feature, hidden_size_a * heads, bias_attr=False,
W
wangwenjin 已提交
330 331
                     param_attr=fluid.ParamAttr(name=name + '_project_key'))
    # N * (D2 * M)
Y
Yelrose 已提交
332
    feat_value = L.fc(feature, hidden_size_v * heads, bias_attr=False,
W
wangwenjin 已提交
333 334
                     param_attr=fluid.ParamAttr(name=name + '_project_value'))
    # N * (D1 * M)
Y
Yelrose 已提交
335
    feat_query = L.fc(feature, hidden_size_a * heads, bias_attr=False,
W
wangwenjin 已提交
336 337
                     param_attr=fluid.ParamAttr(name=name + '_project_query'))
    # N * Dm
Y
Yelrose 已提交
338
    feat_gate = L.fc(feature, hidden_size_m, bias_attr=False, 
W
wangwenjin 已提交
339 340 341 342 343 344 345 346 347 348 349 350
                                param_attr=fluid.ParamAttr(name=name + '_project_gate'))

    # send 阶段

    message = gw.send(
        send_func,
        nfeat_list=[('node_feat', feature), ('feat_key', feat_key), ('feat_value', feat_value),
                    ('feat_query', feat_query), ('feat_gate', feat_gate)],
        efeat_list=None,
    )

    # 聚合邻居特征
W
wangwenjin 已提交
351
    output = gw.recv(message, recv_func)
Y
Yelrose 已提交
352
    output = L.fc(output, hidden_size_o, bias_attr=False,
W
wangwenjin 已提交
353
                            param_attr=fluid.ParamAttr(name=name + '_project_output'))
Y
Yelrose 已提交
354 355
    output = L.leaky_relu(output, alpha=0.1)
    output = L.dropout(output, dropout_prob=0.1)
W
wangwenjin 已提交
356 357

    return output
F
fengshikun01 已提交
358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381


def gen_conv(gw,
        feature,
        name,
        beta=None):
    """Implementation of GENeralized Graph Convolution (GENConv), see the paper
    "DeeperGCN: All You Need to Train Deeper GCNs" in
    https://arxiv.org/pdf/2006.07739.pdf

    Args:
        gw: Graph wrapper object (:code:`StaticGraphWrapper` or :code:`GraphWrapper`)

        feature: A tensor with shape (num_nodes, feature_size).

        beta: [0, +infinity] or "dynamic" or None

        name: deeper gcn layer names.

    Return:
        A tensor with shape (num_nodes, feature_size)
    """
   
    if beta == "dynamic":
Y
Yelrose 已提交
382
        beta = L.create_parameter(
F
fengshikun01 已提交
383 384 385 386 387 388 389 390 391 392 393 394 395 396
                shape=[1],
                dtype='float32',
                default_initializer=
                    fluid.initializer.ConstantInitializer(value=1.0),
                name=name + '_beta')
    
    # message passing
    msg = gw.send(message_passing.copy_send, nfeat_list=[("h", feature)])
    output = gw.recv(msg, message_passing.softmax_agg(beta))
    
    # msg norm
    output = message_passing.msg_norm(feature, output, name)
    output = feature + output
    
Y
Yelrose 已提交
397
    output = L.fc(output,
F
fengshikun01 已提交
398 399 400 401 402
                     feature.shape[-1],
                     bias_attr=False,
                     act="relu",
                     param_attr=fluid.ParamAttr(name=name + '_weight1'))
    
Y
Yelrose 已提交
403
    output = L.fc(output,
F
fengshikun01 已提交
404 405 406 407 408 409
                     feature.shape[-1],
                     bias_attr=False,
                     param_attr=fluid.ParamAttr(name=name + '_weight2'))

    return output

Y
Yelrose 已提交
410 411
def get_norm(indegree):
    """Get Laplacian Normalization"""
Y
Yelrose 已提交
412 413 414
    float_degree = L.cast(indegree, dtype="float32")
    float_degree = L.clamp(float_degree, min=1.0)
    norm = L.pow(float_degree, factor=-0.5) 
Y
Yelrose 已提交
415 416
    return norm

Y
Yelrose 已提交
417

Y
Yelrose 已提交
418
def appnp(gw, feature, edge_dropout=0, alpha=0.2, k_hop=10):
Y
Yelrose 已提交
419 420 421 422 423 424 425 426
    """Implementation of APPNP of "Predict then Propagate: Graph Neural Networks
    meet Personalized PageRank"  (ICLR 2019). 

    Args:
        gw: Graph wrapper object (:code:`StaticGraphWrapper` or :code:`GraphWrapper`)

        feature: A tensor with shape (num_nodes, feature_size).

Y
Yelrose 已提交
427
        edge_dropout: Edge dropout rate.
Y
Yelrose 已提交
428 429 430 431 432 433 434 435 436 437 438 439

        k_hop: K Steps for Propagation

    Return:
        A tensor with shape (num_nodes, hidden_size)
    """

    def send_src_copy(src_feat, dst_feat, edge_feat):
       feature = src_feat["h"]
       return feature

    h0 = feature
Y
Yelrose 已提交
440 441
    ngw = gw 
    norm = get_norm(ngw.indegree())
Y
Yelrose 已提交
442 443
    
    for i in range(k_hop):
Y
Yelrose 已提交
444 445 446 447 448
        if edge_dropout > 1e-5:     
            ngw = pgl.sample.edge_drop(gw, edge_dropout) 
            norm = get_norm(ngw.indegree())
            
        feature = feature * norm
Y
Yelrose 已提交
449 450 451 452 453

        msg = gw.send(send_src_copy, nfeat_list=[("h", feature)])

        feature = gw.recv(msg, "sum")

Y
Yelrose 已提交
454
        feature = feature * norm
Y
Yelrose 已提交
455 456 457

        feature = feature * (1 - alpha) + h0 * alpha
    return feature 
Y
Yelrose 已提交
458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525


def gcnii(gw,
    feature,
    name,
    activation=None,
    alpha=0.5,
    lambda_l=0.5,
    k_hop=1,
    dropout=0.5,
    is_test=False):
    """Implementation of GCNII of "Simple and Deep Graph Convolutional Networks"  

    paper: https://arxiv.org/pdf/2007.02133.pdf

    Args:
        gw: Graph wrapper object (:code:`StaticGraphWrapper` or :code:`GraphWrapper`)

        feature: A tensor with shape (num_nodes, feature_size).

        activation: The activation for the output.

        k_hop: Number of layers for gcnii.
   
        lambda_l: The hyperparameter of lambda in the paper.
       
        alpha: The hyperparameter of alpha in the paper.

        dropout: Feature dropout rate.

        is_test: train / test phase.

    Return:
        A tensor with shape (num_nodes, hidden_size)
    """

    def send_src_copy(src_feat, dst_feat, edge_feat):
       feature = src_feat["h"]
       return feature

    h0 = feature
    ngw = gw 
    norm = get_norm(ngw.indegree())
    hidden_size = feature.shape[-1]
    
    for i in range(k_hop):
        beta_i = np.log(1.0 * lambda_l / (i + 1) + 1)
        feature = L.dropout(
            feature,
            dropout_prob=dropout,
            is_test=is_test,
            dropout_implementation='upscale_in_train')

        feature = feature * norm
        msg = gw.send(send_src_copy, nfeat_list=[("h", feature)])
        feature = gw.recv(msg, "sum")
        feature = feature * norm

        # appnp
        feature = feature * (1 - alpha) + h0 * alpha

        feature_transed = L.fc(feature, hidden_size,
                    act=None, bias_attr=False,
                    name=name+"_%s_w1" % i) 
        feature = feature_transed * beta_i + feature * (1 - beta_i)
        if activation is not None:
            feature = getattr(L, activation)(feature)
    return feature