提交 584bb25a 编写于 作者: W Webbley

Merge remote-tracking branch 'upstream/main' into main

# Easy Paper Reproduction for Citation Network (Cora/Pubmed/Citeseer)
This page tries to reproduce all the **Graph Neural Network** paper for Citation Network (Cora/Pubmed/Citeseer), which is the **Hello world** dataset (**small** and **fast**) for graph neural networks. But it's very hard to achieve very high performance.
All datasets are runned with public split of **semi-supervised** settings. And we report the averarge accuracy by running 10 times.
# Experiment Results
| Model | Cora | Pubmed | Citeseer | Remarks |
| ------------------------------------------------------------ | ------------ | ------------ | ------------ | --------------------------------------------------------- |
| [Vanilla GCN (Kipf 2017)](https://openreview.net/pdf?id=SJU4ayYgl ) | 0.807(0.010) | 0.794(0.003) | 0.710(0.007) | |
| [GAT (Veličković 2017)](https://arxiv.org/pdf/1710.10903.pdf) | 0.834(0.004) | 0.772(0.004) | 0.700(0.006) | |
| [SGC(Wu 2019)](https://arxiv.org/pdf/1902.07153.pdf) | 0.818(0.000) | 0.782(0.000) | 0.708(0.000) | |
| [APPNP (Johannes 2018)](https://arxiv.org/abs/1810.05997) | 0.846(0.003) | 0.803(0.002) | 0.719(0.003) | Almost the same with the results reported in Appendix E. |
| [GCNII (64 Layers, 1500 Epochs, Chen 2020)](https://arxiv.org/pdf/2007.02133.pdf) | 0.846(0.003) | 0.798(0.003) | 0.724(0.006) | |
How to run the experiments?
```shell
# Device choose
export CUDA_VISIBLE_DEVICES=0
# GCN
python train.py --conf config/gcn.yaml --use_cuda --dataset cora
python train.py --conf config/gcn.yaml --use_cuda --dataset pubmed
python train.py --conf config/gcn.yaml --use_cuda --dataset citeseer
# GAT
python train.py --conf config/gat.yaml --use_cuda --dataset cora
python train.py --conf config/gat.yaml --use_cuda --dataset pubmed
python train.py --conf config/gat.yaml --use_cuda --dataset citeseer
# SGC (Slow version)
python train.py --conf config/sgc.yaml --use_cuda --dataset cora
python train.py --conf config/sgc.yaml --use_cuda --dataset pubmed
python train.py --conf config/sgc.yaml --use_cuda --dataset citeseer
# APPNP
python train.py --conf config/appnp.yaml --use_cuda --dataset cora
python train.py --conf config/appnp.yaml --use_cuda --dataset pubmed
python train.py --conf config/appnp.yaml --use_cuda --dataset citeseer
# GCNII (The original code use 1500 epochs.)
python train.py --conf config/gcnii.yaml --use_cuda --dataset cora --epoch 1500
python train.py --conf config/gcnii.yaml --use_cuda --dataset pubmed --epoch 1500
python train.py --conf config/gcnii.yaml --use_cuda --dataset citeseer --epoch 1500
```
import pgl
import model
from pgl import data_loader
import paddle.fluid as fluid
import numpy as np
import time
def build_model(dataset, config, phase, main_prog):
gw = pgl.graph_wrapper.GraphWrapper(
name="graph",
node_feat=dataset.graph.node_feat_info())
GraphModel = getattr(model, config.model_name)
m = GraphModel(config=config, num_class=dataset.num_classes)
logits = m.forward(gw, gw.node_feat["words"], phase)
# Take the last
node_index = fluid.layers.data(
"node_index",
shape=[None, 1],
dtype="int64",
append_batch_size=False)
node_label = fluid.layers.data(
"node_label",
shape=[None, 1],
dtype="int64",
append_batch_size=False)
pred = fluid.layers.gather(logits, node_index)
loss, pred = fluid.layers.softmax_with_cross_entropy(
logits=pred, label=node_label, return_softmax=True)
acc = fluid.layers.accuracy(input=pred, label=node_label, k=1)
loss = fluid.layers.mean(loss)
if phase == "train":
adam = fluid.optimizer.Adam(
learning_rate=config.learning_rate,
regularization=fluid.regularizer.L2DecayRegularizer(
regularization_coeff=config.weight_decay))
adam.minimize(loss)
return gw, loss, acc
model_name: APPNP
k_hop: 10
alpha: 0.1
num_layer: 1
learning_rate: 0.01
dropout: 0.5
hidden_size: 64
weight_decay: 0.0005
edge_dropout: 0.0
model_name: GAT
learning_rate: 0.005
weight_decay: 0.0005
num_layers: 1
feat_drop: 0.6
attn_drop: 0.6
num_heads: 8
hidden_size: 8
edge_dropout: 0.0
model_name: GCN
num_layers: 1
dropout: 0.5
hidden_size: 16
learning_rate: 0.01
weight_decay: 0.0005
edge_dropout: 0.0
model_name: GCNII
k_hop: 64
alpha: 0.1
num_layer: 1
learning_rate: 0.01
dropout: 0.6
hidden_size: 64
weight_decay: 0.0005
edge_dropout: 0.0
model_name: SGC
num_layers: 2
learning_rate: 0.2
weight_decay: 0.000005
feature_pre_normalize: False
import pgl
import paddle.fluid.layers as L
import pgl.layers.conv as conv
def get_norm(indegree):
float_degree = L.cast(indegree, dtype="float32")
float_degree = L.clamp(float_degree, min=1.0)
norm = L.pow(float_degree, factor=-0.5)
return norm
class GCN(object):
"""Implement of GCN
"""
def __init__(self, config, num_class):
self.num_class = num_class
self.num_layers = config.get("num_layers", 1)
self.hidden_size = config.get("hidden_size", 64)
self.dropout = config.get("dropout", 0.5)
self.edge_dropout = config.get("edge_dropout", 0.0)
def forward(self, graph_wrapper, feature, phase):
for i in range(self.num_layers):
if phase == "train":
ngw = pgl.sample.edge_drop(graph_wrapper, self.edge_dropout)
norm = get_norm(ngw.indegree())
else:
ngw = graph_wrapper
norm = graph_wrapper.node_feat["norm"]
feature = pgl.layers.gcn(ngw,
feature,
self.hidden_size,
activation="relu",
norm=norm,
name="layer_%s" % i)
feature = L.dropout(
feature,
self.dropout,
dropout_implementation='upscale_in_train')
if phase == "train":
ngw = pgl.sample.edge_drop(graph_wrapper, self.edge_dropout)
norm = get_norm(ngw.indegree())
else:
ngw = graph_wrapper
norm = graph_wrapper.node_feat["norm"]
feature = conv.gcn(ngw,
feature,
self.num_class,
activation=None,
norm=norm,
name="output")
return feature
class GAT(object):
"""Implement of GAT"""
def __init__(self, config, num_class):
self.num_class = num_class
self.num_layers = config.get("num_layers", 1)
self.num_heads = config.get("num_heads", 8)
self.hidden_size = config.get("hidden_size", 8)
self.feat_dropout = config.get("feat_drop", 0.6)
self.attn_dropout = config.get("attn_drop", 0.6)
self.edge_dropout = config.get("edge_dropout", 0.0)
def forward(self, graph_wrapper, feature, phase):
if phase == "train":
edge_dropout = 0
else:
edge_dropout = self.edge_dropout
for i in range(self.num_layers):
ngw = pgl.sample.edge_drop(graph_wrapper, edge_dropout)
feature = conv.gat(ngw,
feature,
self.hidden_size,
activation="elu",
name="gat_layer_%s" % i,
num_heads=self.num_heads,
feat_drop=self.feat_dropout,
attn_drop=self.attn_dropout)
ngw = pgl.sample.edge_drop(graph_wrapper, edge_dropout)
feature = conv.gat(ngw,
feature,
self.num_class,
num_heads=1,
activation=None,
feat_drop=self.feat_dropout,
attn_drop=self.attn_dropout,
name="output")
return feature
class APPNP(object):
"""Implement of APPNP"""
def __init__(self, config, num_class):
self.num_class = num_class
self.num_layers = config.get("num_layers", 1)
self.hidden_size = config.get("hidden_size", 64)
self.dropout = config.get("dropout", 0.5)
self.alpha = config.get("alpha", 0.1)
self.k_hop = config.get("k_hop", 10)
self.edge_dropout = config.get("edge_dropout", 0.0)
def forward(self, graph_wrapper, feature, phase):
if phase == "train":
edge_dropout = 0
else:
edge_dropout = self.edge_dropout
for i in range(self.num_layers):
feature = L.dropout(
feature,
self.dropout,
dropout_implementation='upscale_in_train')
feature = L.fc(feature, self.hidden_size, act="relu", name="lin%s" % i)
feature = L.dropout(
feature,
self.dropout,
dropout_implementation='upscale_in_train')
feature = L.fc(feature, self.num_class, act=None, name="output")
feature = conv.appnp(graph_wrapper,
feature=feature,
edge_dropout=edge_dropout,
alpha=self.alpha,
k_hop=self.k_hop)
return feature
class SGC(object):
"""Implement of SGC"""
def __init__(self, config, num_class):
self.num_class = num_class
self.num_layers = config.get("num_layers", 1)
def forward(self, graph_wrapper, feature, phase):
feature = conv.appnp(graph_wrapper,
feature=feature,
edge_dropout=0,
alpha=0,
k_hop=self.num_layers)
feature.stop_gradient=True
feature = L.fc(feature, self.num_class, act=None, bias_attr=False, name="output")
return feature
class GCNII(object):
"""Implement of GCNII"""
def __init__(self, config, num_class):
self.num_class = num_class
self.num_layers = config.get("num_layers", 1)
self.hidden_size = config.get("hidden_size", 64)
self.dropout = config.get("dropout", 0.6)
self.alpha = config.get("alpha", 0.1)
self.lambda_l = config.get("lambda_l", 0.5)
self.k_hop = config.get("k_hop", 64)
self.edge_dropout = config.get("edge_dropout", 0.0)
def forward(self, graph_wrapper, feature, phase):
if phase == "train":
edge_dropout = 0
else:
edge_dropout = self.edge_dropout
for i in range(self.num_layers):
feature = L.fc(feature, self.hidden_size, act="relu", name="lin%s" % i)
feature = L.dropout(
feature,
self.dropout,
dropout_implementation='upscale_in_train')
feature = conv.gcnii(graph_wrapper,
feature=feature,
name="gcnii",
activation="relu",
lambda_l=self.lambda_l,
alpha=self.alpha,
dropout=self.dropout,
k_hop=self.k_hop)
feature = L.fc(feature, self.num_class, act=None, name="output")
return feature
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pgl
import model# import LabelGraphGCN
from pgl import data_loader
from pgl.utils.logger import log
import paddle.fluid as fluid
import numpy as np
import time
import argparse
from build_model import build_model
import yaml
from easydict import EasyDict as edict
import tqdm
def normalize(feat):
return feat / np.maximum(np.sum(feat, -1, keepdims=True), 1)
def load(name, normalized_feature=True):
if name == 'cora':
dataset = data_loader.CoraDataset()
elif name == "pubmed":
dataset = data_loader.CitationDataset("pubmed", symmetry_edges=True)
elif name == "citeseer":
dataset = data_loader.CitationDataset("citeseer", symmetry_edges=True)
else:
raise ValueError(name + " dataset doesn't exists")
indegree = dataset.graph.indegree()
norm = np.maximum(indegree.astype("float32"), 1)
norm = np.power(norm, -0.5)
dataset.graph.node_feat["norm"] = np.expand_dims(norm, -1)
dataset.graph.node_feat["words"] = normalize(dataset.graph.node_feat["words"])
return dataset
def main(args, config):
dataset = load(args.dataset, args.feature_pre_normalize)
place = fluid.CUDAPlace(0) if args.use_cuda else fluid.CPUPlace()
train_program = fluid.default_main_program()
startup_program = fluid.default_startup_program()
with fluid.program_guard(train_program, startup_program):
with fluid.unique_name.guard():
gw, loss, acc = build_model(dataset,
config=config,
phase="train",
main_prog=train_program)
test_program = fluid.Program()
with fluid.program_guard(test_program, startup_program):
with fluid.unique_name.guard():
_gw, v_loss, v_acc = build_model(dataset,
config=config,
phase="test",
main_prog=test_program)
test_program = test_program.clone(for_test=True)
exe = fluid.Executor(place)
train_index = dataset.train_index
train_label = np.expand_dims(dataset.y[train_index], -1)
train_index = np.expand_dims(train_index, -1)
val_index = dataset.val_index
val_label = np.expand_dims(dataset.y[val_index], -1)
val_index = np.expand_dims(val_index, -1)
test_index = dataset.test_index
test_label = np.expand_dims(dataset.y[test_index], -1)
test_index = np.expand_dims(test_index, -1)
dur = []
# Feed data
feed_dict = gw.to_feed(dataset.graph)
best_test = []
for run in range(args.runs):
exe.run(startup_program)
cal_val_acc = []
cal_test_acc = []
cal_val_loss = []
cal_test_loss = []
for epoch in tqdm.tqdm(range(args.epoch)):
feed_dict["node_index"] = np.array(train_index, dtype="int64")
feed_dict["node_label"] = np.array(train_label, dtype="int64")
train_loss, train_acc = exe.run(train_program,
feed=feed_dict,
fetch_list=[loss, acc],
return_numpy=True)
feed_dict["node_index"] = np.array(val_index, dtype="int64")
feed_dict["node_label"] = np.array(val_label, dtype="int64")
val_loss, val_acc = exe.run(test_program,
feed=feed_dict,
fetch_list=[v_loss, v_acc],
return_numpy=True)
cal_val_acc.append(val_acc[0])
cal_val_loss.append(val_loss[0])
feed_dict["node_index"] = np.array(test_index, dtype="int64")
feed_dict["node_label"] = np.array(test_label, dtype="int64")
test_loss, test_acc = exe.run(test_program,
feed=feed_dict,
fetch_list=[v_loss, v_acc],
return_numpy=True)
cal_test_acc.append(test_acc[0])
cal_test_loss.append(test_loss[0])
log.info("Runs %s: Model: %s Best Test Accuracy: %f" % (run, config.model_name,
cal_test_acc[np.argmin(cal_val_loss)]))
best_test.append(cal_test_acc[np.argmin(cal_val_loss)])
log.info("Dataset: %s Best Test Accuracy: %f ( stddev: %f )" % (args.dataset, np.mean(best_test), np.std(best_test)))
print("Dataset: %s Best Test Accuracy: %f ( stddev: %f )" % (args.dataset, np.mean(best_test), np.std(best_test)))
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Benchmarking Citation Network')
parser.add_argument(
"--dataset", type=str, default="cora", help="dataset (cora, pubmed)")
parser.add_argument("--use_cuda", action='store_true', help="use_cuda")
parser.add_argument("--conf", type=str, help="config file for models")
parser.add_argument("--epoch", type=int, default=200, help="Epoch")
parser.add_argument("--runs", type=int, default=10, help="runs")
parser.add_argument("--feature_pre_normalize", type=bool, default=True, help="pre_normalize feature")
args = parser.parse_args()
config = edict(yaml.load(open(args.conf), Loader=yaml.FullLoader))
log.info(args)
main(args, config)
......@@ -49,6 +49,8 @@ sh local_run.sh config/enriesage_v1_gpu.yaml
sh local_run.sh config/enriesage_v1_cpu.yaml
```
**NOTE**: To help users better understand the ERNIESage Model, we provide a running example in Baidu AIStudio. Please visit here: https://aistudio.baidu.com/aistudio/projectdetail/667443.
## Hyperparamters
- learner_type: `gpu` or `cpu`; gpu use fleet Collective mode, cpu use fleet Transpiler mode.
......
......@@ -50,6 +50,8 @@ sh local_run.sh config/erniesage_v2_gpu.yaml
sh local_run.sh config/erniesage_v2_cpu.yaml
```
**NOTE**:为了方便用户们学习使用ERNIESage,我们在百度AIStudio中提供了可以直接运行的ERNIESage实例,详情可见:https://aistudio.baidu.com/aistudio/projectdetail/667443.
## Hyperparamters
- learner_type: `gpu` or `cpu`; gpu 使用fleet Collective 模式, cpu 使用fleet Transpiler 模式.
......
# X-Transformer
Models based on Transformers are wildly successful for a wide variety of Natural Language Processing (NLP) tasks and consequently are a mainstay of modern NLP research. Transformer is constituted of a self-attention and a feed-forward module. The self-attention mechanism allows each token in the input sequence to attend independently to every other token in the sequence. From the view of graph representation, the generalized attention mechanism can be described by a Undirected Complete Graph whose vertex is the token. So, the attention module can be implemented by a graph library, especially recently the efficient attention implementation, e.g. [BigBird](https://arxiv.org/abs/2007.14062) \ [LongFormer](https://arxiv.org/abs/2004.05150) \ [Sparse Transformer](https://arxiv.org/abs/1904.10509).
We have showcased the [BigBird](https://arxiv.org/abs/2007.14062) implementation and tested the performence as show below, and the [LongFormer](https://arxiv.org/abs/2004.05150) \ [Sparse Transformer](https://arxiv.org/abs/1904.10509) can be easily implemented by revised the correspoding code.
## Dependencies
- [paddlepaddle >= 1.7](https://github.com/PaddlePaddle/paddle)
- [pgl 1.1](https://github.com/PaddlePaddle/PGL)
## Performance
We have evaluate the implemented method on a summarization dataset CNN/DM. The experiment was conducted on two P40 GPU cards.
| CNN/DM | BatchSize | R1 | R2 | R3 | speed(steps/s) |
| ------------------ | --------- | ----------------- | ----------------- | ----------------- | ------ |
| LEAD | - | 40.42 | 17.62 | 36.67 | - |
| Oracle | - | 52.59 | 31.24 | 48.87 | - |
| non-sparse, L=512 | 32 | 42.175 | 19.392 | 38.613 | 0.6359 |
| L=2048 | 10 | 41.334 | 18.369 | 37.752 | 0.8246 |
| L=1024 | 20 | 41.453 | 18.529 | 37.872 | 0.6432 |
| L=768 | 26 | 41.611 | 18.735 | 38.051 | 0.6517 |
| L=512 | 40 | 41.742 | 18.733 | 38.127 | 0.6213 |
**\**** For this task, we warm up from ERNIE 2.0 en directly rather than pretrain the model for the additional position embedding, so the embedding for the position which is larger than 512 is used repeatedly from ERNIE 2.0.
This may cause score degradation. But in the future, we will test the pre-trained model.
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle.fluid as fluid
import paddle.fluid.layers as L
import paddle.fluid.layers as layers
from pgl.utils import paddle_helper
import pgl
def masked_select(input, mask):
"""masked_select
Slice the value from given Mask
Args:
input: Input tensor to be selected
mask: A bool tensor for sliced.
Return:
Part of inputs where mask is True.
"""
index = L.where(mask)
return L.gather(input, index, overwrite=False)
class BigBirdWrapper(pgl.graph_wrapper.BaseGraphWrapper):
"""Implement of Big Bird by PGL graph wrapper """
def __init__(self, input_mask):
super(BigBirdWrapper, self).__init__()
max_seqlen = L.shape(input_mask)[1]
input_mask = L.reshape(input_mask, [-1])
num_nodes = L.shape(input_mask)[0]
src, dst = build_edges(num_nodes, input_mask, max_seqlen)
self._edges_src = src
self._edges_dst = dst
self._edges_src.stop_gradient=True
self._edges_dst.stop_gradient=True
self._num_nodes = num_nodes
self._num_edges = L.shape(self._edges_src)[0]
self._node_ids = L.range(0, self._num_nodes, step=1, dtype="int32")
self._edge_uniq_dst, _, uniq_count = L.unique_with_counts(self._edges_dst, dtype="int32")
self._edge_uniq_dst.stop_gradient=True
last = L.reduce_sum(uniq_count, keep_dim=True)
uniq_count = L.cumsum(uniq_count, exclusive=True)
self._edge_uniq_dst_count = L.concat([uniq_count, last])
self._edge_uniq_dst_count.stop_gradient=True
def select_edges(src, dst, input_mask, num_nodes, max_seqlen):
src = fluid.layers.elementwise_max(src, num_nodes * 0)
dst = fluid.layers.elementwise_max(dst, num_nodes * 0)
src = fluid.layers.elementwise_min(src, num_nodes - 1)
dst = fluid.layers.elementwise_min(dst, num_nodes - 1)
conditions = []
conditions.append(L.gather(input_mask, src) > 0.5)
conditions.append(L.gather(input_mask, dst) > 0.5)
block_src = src / max_seqlen
block_dst = dst / max_seqlen
conditions.append(block_src == block_dst)
mask = None
for cond in conditions:
if mask is None:
mask = cond
else:
mask = L.logical_and(mask, cond)
dst = masked_select(dst, mask)
src = masked_select(src, mask)
return src, dst
def uniq_edges(src, dst, num_nodes):
sorted_dst = L.cast(dst, dtype="int64")
sorted_src = L.cast(src, dtype="int64")
num_nodes = L.cast(num_nodes, dtype="int64")
edge_hash = sorted_dst * num_nodes + sorted_src
edge_hash, _ = L.argsort(edge_hash)
edge_hash, _ = L.unique(edge_hash, dtype="int64")
sorted_src = L.elementwise_mod(edge_hash, num_nodes)
sorted_dst = L.elementwise_div(edge_hash, num_nodes)
sorted_src = L.cast(sorted_src, dtype="int32")
sorted_dst = L.cast(sorted_dst, dtype="int32")
return sorted_src, sorted_dst
def build_edges(num_nodes, input_mask, max_seqlen):
edges = L.range(start=0, end=num_nodes, step=1, dtype="int32")
all_edges = []
# Window
filter_func = lambda x, y: select_edges(x, y, input_mask, num_nodes, max_seqlen)
all_edges.append(filter_func(edges - 1, edges)) # win-1
all_edges.append(filter_func(edges + 1, edges)) # win-2
all_edges.append(filter_func(edges, edges)) #self-loop
# Global Assume [CLS] is the first token.
# vertical cls-window attention
cls_position = edges / max_seqlen * max_seqlen
all_edges.append(filter_func(cls_position, edges))
# horizontal cls attention
all_edges.append(filter_func(edges, cls_position))
# Random
for i in range(2):
rand_edge = L.floor(L.uniform_random(min=0, max=1, shape=[num_nodes]) * L.cast(max_seqlen, dtype="float32"))
rand_edge = L.cast(rand_edge, dtype="int32") + cls_position
all_edges.append(filter_func(rand_edge, edges))
if len(all_edges) > 1:
src = L.concat([ s for s, d in all_edges], 0)
dst = L.concat([ d for s, d in all_edges], 0)
else:
src = all_edges[0][0]
dst = all_edges[0][1]
# sort edges
sorted_src, sorted_dst = uniq_edges(src, dst, num_nodes)
return sorted_src, sorted_dst
def sparse_scaled_dot_product_attention(q, k, v, input_mask, dropout_rate, n_head, d_key, d_value):
def send_q_k_spmm(src_feat, dst_feat, edge_feat):
# q [ num_edges, n_head * dim]
# k [ num_edges, n_head * dim]
# v [ num_edges, n_head * dim]
_q = dst_feat["q"]
_k = src_feat["k"]
_v = src_feat["v"]
_q = L.reshape(_q, [-1, n_head, _q.shape[-1] // n_head])
_k = L.reshape(_k, [-1, n_head, _k.shape[-1] // n_head])
score = L.reduce_sum(_q * _k, -1) # [num_edge, n_head]
return { "score": score, "value": _v}
def recv_score_v_spmm(msg):
score = msg["score"]
score = paddle_helper.sequence_softmax(score)
score = layers.dropout(
score,
dropout_prob=dropout_rate,
dropout_implementation="upscale_in_train",
is_test=False)
score = L.reshape(score, [-1, n_head, 1])
_v = msg["value"]
_new_v = L.reshape(_v, [-1, n_head, _v.shape[-1] // n_head])
_new_v = _new_v * score
_new_v = L.reshape(_new_v, [-1, _v.shape[-1]])
_new_v = L.lod_reset(_new_v, _v)
return L.sequence_pool(_new_v, "sum")
graph_wrapper = BigBirdWrapper(input_mask)
old_v = v
q = L.reshape(q, [-1, d_key * n_head])
k = L.reshape(k, [-1, d_key * n_head])
v = L.reshape(v, [-1, d_value * n_head])
q = L.scale(q, scale=d_key ** -0.5)
msg = graph_wrapper.send(send_q_k_spmm, nfeat_list=[("k", k), ("v", v), ("q", q)])
out = graph_wrapper.recv(msg, recv_score_v_spmm)
out = L.reshape(out, [-1, L.shape(old_v)[1], d_value * n_head])
return out, out
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import partial
import paddle.fluid as fluid
import paddle.fluid.layers as L
import paddle.fluid.layers as layers
from .sparse_scaled_dot_product_attention import sparse_scaled_dot_product_attention
to_3d = lambda a: a # will change later
to_2d = lambda a: a
def multi_head_attention(queries,
keys,
values,
attn_bias,
d_key,
d_value,
d_model,
input_mask,
n_head=1,
dropout_rate=0.,
cache=None,
param_initializer=None,
name='multi_head_att'):
"""
Multi-Head Attention. Note that attn_bias is added to the logit before
computing softmax activiation to mask certain selected positions so that
they will not considered in attention weights.
"""
keys = queries if keys is None else keys
values = keys if values is None else values
def __compute_qkv(queries, keys, values, n_head, d_key, d_value):
"""
Add linear projection to queries, keys, and values.
"""
q = layers.fc(input=queries,
size=d_key * n_head,
num_flatten_dims=len(queries.shape) - 1,
param_attr=fluid.ParamAttr(
name=name + '_query_fc.w_0',
initializer=param_initializer),
bias_attr=name + '_query_fc.b_0')
k = layers.fc(input=keys,
size=d_key * n_head,
num_flatten_dims=len(keys.shape) - 1,
param_attr=fluid.ParamAttr(
name=name + '_key_fc.w_0',
initializer=param_initializer),
bias_attr=name + '_key_fc.b_0')
v = layers.fc(input=values,
size=d_value * n_head,
num_flatten_dims=len(values.shape) - 1,
param_attr=fluid.ParamAttr(
name=name + '_value_fc.w_0',
initializer=param_initializer),
bias_attr=name + '_value_fc.b_0')
return q, k, v
def __split_heads(x, n_head):
"""
Reshape the last dimension of inpunt tensor x so that it becomes two
dimensions and then transpose. Specifically, input a tensor with shape
[bs, max_sequence_length, n_head * hidden_dim] then output a tensor
with shape [bs, n_head, max_sequence_length, hidden_dim].
"""
hidden_size = x.shape[-1]
# The value 0 in shape attr means copying the corresponding dimension
# size of the input as the output dimension size.
reshaped = layers.reshape(
x=x, shape=[0, 0, n_head, hidden_size // n_head], inplace=True)
# permuate the dimensions into:
# [batch_size, n_head, max_sequence_len, hidden_size_per_head]
return layers.transpose(x=reshaped, perm=[0, 2, 1, 3])
def __combine_heads(x):
"""
Transpose and then reshape the last two dimensions of inpunt tensor x
so that it becomes one dimension, which is reverse to __split_heads.
"""
if len(x.shape) == 3: return x
if len(x.shape) != 4:
raise ValueError("Input(x) should be a 4-D Tensor.")
trans_x = layers.transpose(x, perm=[0, 2, 1, 3])
# The value 0 in shape attr means copying the corresponding dimension
# size of the input as the output dimension size.
#trans_x.desc.set_shape((-1, 1, n_head, d_value))
return layers.reshape(x=trans_x, shape=[0, 0, d_model], inplace=True)
q, k, v = __compute_qkv(queries, keys, values, n_head, d_key, d_value)
q = to_3d(q)
k = to_3d(k)
v = to_3d(v)
if cache is not None: # use cache and concat time steps
# Since the inplace reshape in __split_heads changes the shape of k and
# v, which is the cache input for next time step, reshape the cache
# input from the previous time step first.
k = cache["k"] = layers.concat(
[layers.reshape(
cache["k"], shape=[0, 0, d_model]), k], axis=1)
v = cache["v"] = layers.concat(
[layers.reshape(
cache["v"], shape=[0, 0, d_model]), v], axis=1)
out, _ = sparse_scaled_dot_product_attention(q, k, v,
input_mask, dropout_rate, n_head, d_key, d_value)
out = to_2d(out)
# Project back to the model size.
proj_out = layers.fc(input=out,
size=d_model,
num_flatten_dims=len(out.shape) - 1,
param_attr=fluid.ParamAttr(
name=name + '_output_fc.w_0',
initializer=param_initializer),
bias_attr=name + '_output_fc.b_0')
return proj_out, _
def positionwise_feed_forward(x,
d_inner_hid,
d_hid,
dropout_rate,
hidden_act,
param_initializer=None,
name='ffn'):
"""
Position-wise Feed-Forward Networks.
This module consists of two linear transformations with a ReLU activation
in between, which is applied to each position separately and identically.
"""
hidden = layers.fc(input=x,
size=d_inner_hid,
num_flatten_dims=len(x.shape) - 1,
act=hidden_act,
param_attr=fluid.ParamAttr(
name=name + '_fc_0.w_0',
initializer=param_initializer),
bias_attr=name + '_fc_0.b_0')
if dropout_rate:
hidden = layers.dropout(
hidden,
dropout_prob=dropout_rate,
dropout_implementation="upscale_in_train",
is_test=False)
out = layers.fc(input=hidden,
size=d_hid,
num_flatten_dims=len(hidden.shape) - 1,
param_attr=fluid.ParamAttr(
name=name + '_fc_1.w_0',
initializer=param_initializer),
bias_attr=name + '_fc_1.b_0')
return out
def pre_post_process_layer(prev_out,
out,
process_cmd,
dropout_rate=0.,
name=''):
"""
Add residual connection, layer normalization and droput to the out tensor
optionally according to the value of process_cmd.
This will be used before or after multi-head attention and position-wise
feed-forward networks.
"""
for cmd in process_cmd:
if cmd == "a": # add residual connection
out = out + prev_out if prev_out else out
elif cmd == "n": # add layer normalization
out_dtype = out.dtype
if out_dtype == fluid.core.VarDesc.VarType.FP16:
out = layers.cast(x=out, dtype="float32")
out = layers.layer_norm(
out,
begin_norm_axis=len(out.shape) - 1,
param_attr=fluid.ParamAttr(
name=name + '_layer_norm_scale',
initializer=fluid.initializer.Constant(1.)),
bias_attr=fluid.ParamAttr(
name=name + '_layer_norm_bias',
initializer=fluid.initializer.Constant(0.)))
if out_dtype == fluid.core.VarDesc.VarType.FP16:
out = layers.cast(x=out, dtype="float16")
elif cmd == "d": # add dropout
if dropout_rate:
out = layers.dropout(
out,
dropout_prob=dropout_rate,
dropout_implementation="upscale_in_train",
is_test=False)
return out
pre_process_layer = partial(pre_post_process_layer, None)
post_process_layer = pre_post_process_layer
def encoder_layer(enc_input,
input_mask,
attn_bias,
n_head,
d_key,
d_value,
d_model,
d_inner_hid,
prepostprocess_dropout,
attention_dropout,
relu_dropout,
hidden_act,
preprocess_cmd="n",
postprocess_cmd="da",
param_initializer=None,
name=''):
"""The encoder layers that can be stacked to form a deep encoder.
This module consits of a multi-head (self) attention followed by
position-wise feed-forward networks and both the two components companied
with the post_process_layer to add residual connection, layer normalization
and droput.
"""
attn_output, ctx_multiheads_attn = multi_head_attention(
pre_process_layer(
enc_input,
preprocess_cmd,
prepostprocess_dropout,
name=name + '_pre_att'),
None,
None,
attn_bias,
d_key,
d_value,
d_model,
input_mask,
n_head,
attention_dropout,
param_initializer=param_initializer,
name=name + '_multi_head_att')
attn_output = post_process_layer(
enc_input,
attn_output,
postprocess_cmd,
prepostprocess_dropout,
name=name + '_post_att')
ffd_output = positionwise_feed_forward(
pre_process_layer(
attn_output,
preprocess_cmd,
prepostprocess_dropout,
name=name + '_pre_ffn'),
d_inner_hid,
d_model,
relu_dropout,
hidden_act,
param_initializer=param_initializer,
name=name + '_ffn')
ret = post_process_layer(
attn_output,
ffd_output,
postprocess_cmd,
prepostprocess_dropout,
name=name + '_post_ffn')
return ret, ctx_multiheads_attn, ffd_output
def build_pad_idx(input_mask):
pad_idx = L.where(L.cast(L.squeeze(input_mask, [2]), 'bool'))
return pad_idx
def build_attn_bias(input_mask, n_head, dtype):
attn_bias = L.matmul(
input_mask, input_mask, transpose_y=True) # [batch, seq, seq]
attn_bias = (1. - attn_bias) * -10000.
attn_bias = L.stack([attn_bias] * n_head, 1)
if attn_bias.dtype != dtype:
attn_bias = L.cast(attn_bias, dtype)
return attn_bias
def encoder(enc_input,
input_mask,
n_layer,
n_head,
d_key,
d_value,
d_model,
d_inner_hid,
prepostprocess_dropout,
attention_dropout,
relu_dropout,
hidden_act,
preprocess_cmd="n",
postprocess_cmd="da",
param_initializer=None,
name=''):
"""
The encoder is composed of a stack of identical layers returned by calling
encoder_layer.
"""
d_shape = L.shape(input_mask)
pad_idx = build_pad_idx(input_mask)
attn_bias = build_attn_bias(input_mask, n_head, enc_input.dtype)
enc_input = to_2d(enc_input)
all_hidden = []
all_attn = []
all_ffn = []
for i in range(n_layer):
enc_output, ctx_multiheads_attn, ffn_output = encoder_layer(
enc_input,
input_mask,
attn_bias,
n_head,
d_key,
d_value,
d_model,
d_inner_hid,
prepostprocess_dropout,
attention_dropout,
relu_dropout,
hidden_act,
preprocess_cmd,
postprocess_cmd,
param_initializer=param_initializer,
name=name + '_layer_' + str(i))
all_hidden.append(enc_output)
all_attn.append(ctx_multiheads_attn)
all_ffn.append(ffn_output)
enc_input = enc_output
enc_output = pre_process_layer(
enc_output,
preprocess_cmd,
prepostprocess_dropout,
name="post_encoder")
enc_output = to_3d(enc_output)
return enc_output, all_hidden, all_attn, all_ffn
......@@ -22,3 +22,4 @@ from pgl import heter_graph
from pgl import heter_graph_wrapper
from pgl import contrib
from pgl import message_passing
from pgl import sample
......@@ -19,6 +19,7 @@ for PaddlePaddle.
import warnings
import numpy as np
import paddle.fluid as fluid
import paddle.fluid.layers as L
from pgl.utils import op
from pgl.utils import paddle_helper
......@@ -26,12 +27,11 @@ from pgl.utils.logger import log
__all__ = ["BaseGraphWrapper", "GraphWrapper", "StaticGraphWrapper"]
def send(src, dst, nfeat, efeat, message_func):
"""Send message from src to dst.
"""
src_feat = op.read_rows(nfeat, src)
dst_feat = op.read_rows(nfeat, dst)
src_feat = op.RowReader(nfeat, src)
dst_feat = op.RowReader(nfeat, dst)
msg = message_func(src_feat, dst_feat, efeat)
return msg
......@@ -47,10 +47,10 @@ def recv(dst, uniq_dst, bucketing_index, msg, reduce_function, num_nodes,
try:
out_dim = msg.shape[-1]
init_output = fluid.layers.fill_constant(
init_output = L.fill_constant(
shape=[num_nodes, out_dim], value=0, dtype=msg.dtype)
init_output.stop_gradient = False
empty_msg_flag = fluid.layers.cast(num_edges > 0, dtype=msg.dtype)
empty_msg_flag = L.cast(num_edges > 0, dtype=msg.dtype)
msg = msg * empty_msg_flag
output = paddle_helper.scatter_add(init_output, dst, msg)
return output
......@@ -59,7 +59,7 @@ def recv(dst, uniq_dst, bucketing_index, msg, reduce_function, num_nodes,
"scatter_add is not supported with paddle version <= 1.5")
def sum_func(message):
return fluid.layers.sequence_pool(message, "sum")
return L.sequence_pool(message, "sum")
reduce_function = sum_func
......@@ -67,13 +67,13 @@ def recv(dst, uniq_dst, bucketing_index, msg, reduce_function, num_nodes,
output = reduce_function(bucketed_msg)
output_dim = output.shape[-1]
empty_msg_flag = fluid.layers.cast(num_edges > 0, dtype=output.dtype)
empty_msg_flag = L.cast(num_edges > 0, dtype=output.dtype)
output = output * empty_msg_flag
init_output = fluid.layers.fill_constant(
init_output = L.fill_constant(
shape=[num_nodes, output_dim], value=0, dtype=output.dtype)
init_output.stop_gradient = True
final_output = fluid.layers.scatter(init_output, uniq_dst, output)
final_output = L.scatter(init_output, uniq_dst, output)
return final_output
......@@ -104,6 +104,7 @@ class BaseGraphWrapper(object):
self._node_ids = None
self._graph_lod = None
self._num_graph = None
self._num_edges = None
self._data_name_prefix = ""
def __repr__(self):
......@@ -142,11 +143,13 @@ class BaseGraphWrapper(object):
"""
if efeat_list is None:
efeat_list = {}
if nfeat_list is None:
nfeat_list = {}
src, dst = self.edges
nfeat = {}
for feat in nfeat_list:
if isinstance(feat, str):
nfeat[feat] = self.node_feat[feat]
......@@ -470,7 +473,7 @@ class StaticGraphWrapper(BaseGraphWrapper):
class GraphWrapper(BaseGraphWrapper):
"""Implement a graph wrapper that creates a graph data holders
that attributes and features in the graph are :code:`fluid.layers.data`.
that attributes and features in the graph are :code:`L.data`.
And we provide interface :code:`to_feed` to help converting :code:`Graph`
data into :code:`feed_dict`.
......@@ -546,65 +549,65 @@ class GraphWrapper(BaseGraphWrapper):
def __create_graph_attr_holders(self):
"""Create data holders for graph attributes.
"""
self._num_edges = fluid.layers.data(
self._num_edges = L.data(
self._data_name_prefix + '/num_edges',
shape=[1],
append_batch_size=False,
dtype="int64",
stop_gradient=True)
self._num_graph = fluid.layers.data(
self._num_graph = L.data(
self._data_name_prefix + '/num_graph',
shape=[1],
append_batch_size=False,
dtype="int64",
stop_gradient=True)
self._edges_src = fluid.layers.data(
self._edges_src = L.data(
self._data_name_prefix + '/edges_src',
shape=[None],
append_batch_size=False,
dtype="int64",
stop_gradient=True)
self._edges_dst = fluid.layers.data(
self._edges_dst = L.data(
self._data_name_prefix + '/edges_dst',
shape=[None],
append_batch_size=False,
dtype="int64",
stop_gradient=True)
self._num_nodes = fluid.layers.data(
self._num_nodes = L.data(
self._data_name_prefix + '/num_nodes',
shape=[1],
append_batch_size=False,
dtype='int64',
stop_gradient=True)
self._edge_uniq_dst = fluid.layers.data(
self._edge_uniq_dst = L.data(
self._data_name_prefix + "/uniq_dst",
shape=[None],
append_batch_size=False,
dtype="int64",
stop_gradient=True)
self._graph_lod = fluid.layers.data(
self._graph_lod = L.data(
self._data_name_prefix + "/graph_lod",
shape=[None],
append_batch_size=False,
dtype="int32",
stop_gradient=True)
self._edge_uniq_dst_count = fluid.layers.data(
self._edge_uniq_dst_count = L.data(
self._data_name_prefix + "/uniq_dst_count",
shape=[None],
append_batch_size=False,
dtype="int32",
stop_gradient=True)
self._node_ids = fluid.layers.data(
self._node_ids = L.data(
self._data_name_prefix + "/node_ids",
shape=[None],
append_batch_size=False,
dtype="int64",
stop_gradient=True)
self._indegree = fluid.layers.data(
self._indegree = L.data(
self._data_name_prefix + "/indegree",
shape=[None],
append_batch_size=False,
......@@ -627,7 +630,7 @@ class GraphWrapper(BaseGraphWrapper):
node_feat_dtype):
"""Create data holders for node features.
"""
feat_holder = fluid.layers.data(
feat_holder = L.data(
self._data_name_prefix + '/node_feat/' + node_feat_name,
shape=node_feat_shape,
append_batch_size=False,
......@@ -640,7 +643,7 @@ class GraphWrapper(BaseGraphWrapper):
edge_feat_dtype):
"""Create edge holders for edge features.
"""
feat_holder = fluid.layers.data(
feat_holder = L.data(
self._data_name_prefix + '/edge_feat/' + edge_feat_name,
shape=edge_feat_shape,
append_batch_size=False,
......@@ -719,3 +722,61 @@ class GraphWrapper(BaseGraphWrapper):
"""Return the holder list.
"""
return self._holder_list
def get_degree(edge, num_nodes):
init_output = L.fill_constant(
shape=[num_nodes], value=0, dtype="float32")
init_output.stop_gradient = True
final_output = L.scatter(init_output,
edge,
L.full_like(edge, 1, dtype="float32"),
overwrite=False)
return final_output
class DropEdgeWrapper(BaseGraphWrapper):
"""Implement of Edge Drop """
def __init__(self, graph_wrapper, dropout, keep_self_loop=True):
super(DropEdgeWrapper, self).__init__()
# Copy Node's information
for key, value in graph_wrapper.node_feat.items():
self.node_feat_tensor_dict[key] = value
self._num_nodes = graph_wrapper.num_nodes
self._graph_lod = graph_wrapper.graph_lod
self._num_graph = graph_wrapper.num_graph
self._node_ids = L.range(0, self._num_nodes, step=1, dtype="int32")
# Dropout Edges
src, dst = graph_wrapper.edges
u = L.uniform_random(shape=L.cast(L.shape(src), 'int64'), min=0., max=1.)
# Avoid Empty Edges
keeped = L.cast(u > dropout, dtype="float32")
self._num_edges = L.reduce_sum(L.cast(keeped, "int32"))
keeped = keeped + L.cast(self._num_edges == 0, dtype="float32")
if keep_self_loop:
self_loop = L.cast(src == dst, dtype="float32")
keeped = keeped + self_loop
keeped = (keeped > 0.5)
src = paddle_helper.masked_select(src, keeped)
dst = paddle_helper.masked_select(dst, keeped)
src.stop_gradient=True
dst.stop_gradient=True
self._edges_src = src
self._edges_dst = dst
for key, value in graph_wrapper.edge_feat.items():
self.edge_feat_tensor_dict[key] = paddle_helper.masked_select(value, keeped)
self._edge_uniq_dst, _, uniq_count = L.unique_with_counts(dst, dtype="int32")
self._edge_uniq_dst.stop_gradient=True
last = L.reduce_sum(uniq_count, keep_dim=True)
uniq_count = L.cumsum(uniq_count, exclusive=True)
self._edge_uniq_dst_count = L.concat([uniq_count, last])
self._edge_uniq_dst_count.stop_gradient=True
self._indegree = get_degree(self._edges_dst, self._num_nodes)
......@@ -14,11 +14,14 @@
"""This package implements common layers to help building
graph neural networks.
"""
import pgl
import paddle.fluid as fluid
import paddle.fluid.layers as L
from pgl.utils import paddle_helper
from pgl import message_passing
import numpy as np
__all__ = ['gcn', 'gat', 'gin', 'gaan', 'gen_conv']
__all__ = ['gcn', 'gat', 'gin', 'gaan', 'gen_conv', 'appnp', 'gcnii']
def gcn(gw, feature, hidden_size, activation, name, norm=None):
......@@ -50,7 +53,7 @@ def gcn(gw, feature, hidden_size, activation, name, norm=None):
size = feature.shape[-1]
if size > hidden_size:
feature = fluid.layers.fc(feature,
feature = L.fc(feature,
size=hidden_size,
bias_attr=False,
param_attr=fluid.ParamAttr(name=name))
......@@ -64,7 +67,7 @@ def gcn(gw, feature, hidden_size, activation, name, norm=None):
output = gw.recv(msg, "sum")
else:
output = gw.recv(msg, "sum")
output = fluid.layers.fc(output,
output = L.fc(output,
size=hidden_size,
bias_attr=False,
param_attr=fluid.ParamAttr(name=name))
......@@ -72,12 +75,12 @@ def gcn(gw, feature, hidden_size, activation, name, norm=None):
if norm is not None:
output = output * norm
bias = fluid.layers.create_parameter(
bias = L.create_parameter(
shape=[hidden_size],
dtype='float32',
is_bias=True,
name=name + '_bias')
output = fluid.layers.elementwise_add(output, bias, act=activation)
output = L.elementwise_add(output, bias, act=activation)
return output
......@@ -120,7 +123,7 @@ def gat(gw,
def send_attention(src_feat, dst_feat, edge_feat):
output = src_feat["left_a"] + dst_feat["right_a"]
output = fluid.layers.leaky_relu(
output = L.leaky_relu(
output, alpha=0.2) # (num_edges, num_heads)
return {"alpha": output, "h": src_feat["h"]}
......@@ -129,54 +132,54 @@ def gat(gw,
h = msg["h"]
alpha = paddle_helper.sequence_softmax(alpha)
old_h = h
h = fluid.layers.reshape(h, [-1, num_heads, hidden_size])
alpha = fluid.layers.reshape(alpha, [-1, num_heads, 1])
h = L.reshape(h, [-1, num_heads, hidden_size])
alpha = L.reshape(alpha, [-1, num_heads, 1])
if attn_drop > 1e-15:
alpha = fluid.layers.dropout(
alpha = L.dropout(
alpha,
dropout_prob=attn_drop,
is_test=is_test,
dropout_implementation="upscale_in_train")
h = h * alpha
h = fluid.layers.reshape(h, [-1, num_heads * hidden_size])
h = fluid.layers.lod_reset(h, old_h)
return fluid.layers.sequence_pool(h, "sum")
h = L.reshape(h, [-1, num_heads * hidden_size])
h = L.lod_reset(h, old_h)
return L.sequence_pool(h, "sum")
if feat_drop > 1e-15:
feature = fluid.layers.dropout(
feature = L.dropout(
feature,
dropout_prob=feat_drop,
is_test=is_test,
dropout_implementation='upscale_in_train')
ft = fluid.layers.fc(feature,
ft = L.fc(feature,
hidden_size * num_heads,
bias_attr=False,
param_attr=fluid.ParamAttr(name=name + '_weight'))
left_a = fluid.layers.create_parameter(
left_a = L.create_parameter(
shape=[num_heads, hidden_size],
dtype='float32',
name=name + '_gat_l_A')
right_a = fluid.layers.create_parameter(
right_a = L.create_parameter(
shape=[num_heads, hidden_size],
dtype='float32',
name=name + '_gat_r_A')
reshape_ft = fluid.layers.reshape(ft, [-1, num_heads, hidden_size])
left_a_value = fluid.layers.reduce_sum(reshape_ft * left_a, -1)
right_a_value = fluid.layers.reduce_sum(reshape_ft * right_a, -1)
reshape_ft = L.reshape(ft, [-1, num_heads, hidden_size])
left_a_value = L.reduce_sum(reshape_ft * left_a, -1)
right_a_value = L.reduce_sum(reshape_ft * right_a, -1)
msg = gw.send(
send_attention,
nfeat_list=[("h", ft), ("left_a", left_a_value),
("right_a", right_a_value)])
output = gw.recv(msg, reduce_attention)
bias = fluid.layers.create_parameter(
bias = L.create_parameter(
shape=[hidden_size * num_heads],
dtype='float32',
is_bias=True,
name=name + '_bias')
bias.stop_gradient = True
output = fluid.layers.elementwise_add(output, bias, act=activation)
output = L.elementwise_add(output, bias, act=activation)
return output
......@@ -219,7 +222,7 @@ def gin(gw,
def send_src_copy(src_feat, dst_feat, edge_feat):
return src_feat["h"]
epsilon = fluid.layers.create_parameter(
epsilon = L.create_parameter(
shape=[1, 1],
dtype="float32",
attr=fluid.ParamAttr(name="%s_eps" % name),
......@@ -232,13 +235,13 @@ def gin(gw,
msg = gw.send(send_src_copy, nfeat_list=[("h", feature)])
output = gw.recv(msg, "sum") + feature * (epsilon + 1.0)
output = fluid.layers.fc(output,
output = L.fc(output,
size=hidden_size,
act=None,
param_attr=fluid.ParamAttr(name="%s_w_0" % name),
bias_attr=fluid.ParamAttr(name="%s_b_0" % name))
output = fluid.layers.layer_norm(
output = L.layer_norm(
output,
begin_norm_axis=1,
param_attr=fluid.ParamAttr(
......@@ -249,9 +252,9 @@ def gin(gw,
initializer=fluid.initializer.Constant(0.0)), )
if activation is not None:
output = getattr(fluid.layers, activation)(output)
output = getattr(L, activation)(output)
output = fluid.layers.fc(output,
output = L.fc(output,
size=hidden_size,
act=activation,
param_attr=fluid.ParamAttr(name="%s_w_1" % name),
......@@ -269,10 +272,10 @@ def gaan(gw, feature, hidden_size_a, hidden_size_v, hidden_size_m, hidden_size_o
feat_query, feat_key = dst_feat['feat_query'], src_feat['feat_key']
# E * M * D1
old = feat_query
feat_query = fluid.layers.reshape(feat_query, [-1, heads, hidden_size_a])
feat_key = fluid.layers.reshape(feat_key, [-1, heads, hidden_size_a])
feat_query = L.reshape(feat_query, [-1, heads, hidden_size_a])
feat_key = L.reshape(feat_key, [-1, heads, hidden_size_a])
# E * M
alpha = fluid.layers.reduce_sum(feat_key * feat_query, dim=-1)
alpha = L.reduce_sum(feat_key * feat_query, dim=-1)
return {'dst_node_feat': dst_feat['node_feat'],
'src_node_feat': src_feat['node_feat'],
......@@ -286,15 +289,15 @@ def gaan(gw, feature, hidden_size_a, hidden_size_v, hidden_size_m, hidden_size_o
# 每条边的出发点的特征
src_feat = message['src_node_feat']
# 每个中心点自己的特征
x = fluid.layers.sequence_pool(dst_feat, 'average')
x = L.sequence_pool(dst_feat, 'average')
# 每个中心点的邻居的特征的平均值
z = fluid.layers.sequence_pool(src_feat, 'average')
z = L.sequence_pool(src_feat, 'average')
# 计算 gate
feat_gate = message['feat_gate']
g_max = fluid.layers.sequence_pool(feat_gate, 'max')
g = fluid.layers.concat([x, g_max, z], axis=1)
g = fluid.layers.fc(g, heads, bias_attr=False, act="sigmoid")
g_max = L.sequence_pool(feat_gate, 'max')
g = L.concat([x, g_max, z], axis=1)
g = L.fc(g, heads, bias_attr=False, act="sigmoid")
# softmax
alpha = message['alpha']
......@@ -302,19 +305,19 @@ def gaan(gw, feature, hidden_size_a, hidden_size_v, hidden_size_m, hidden_size_o
feat_value = message['feat_value'] # E * (M * D2)
old = feat_value
feat_value = fluid.layers.reshape(feat_value, [-1, heads, hidden_size_v]) # E * M * D2
feat_value = fluid.layers.elementwise_mul(feat_value, alpha, axis=0)
feat_value = fluid.layers.reshape(feat_value, [-1, heads*hidden_size_v]) # E * (M * D2)
feat_value = fluid.layers.lod_reset(feat_value, old)
feat_value = L.reshape(feat_value, [-1, heads, hidden_size_v]) # E * M * D2
feat_value = L.elementwise_mul(feat_value, alpha, axis=0)
feat_value = L.reshape(feat_value, [-1, heads*hidden_size_v]) # E * (M * D2)
feat_value = L.lod_reset(feat_value, old)
feat_value = fluid.layers.sequence_pool(feat_value, 'sum') # N * (M * D2)
feat_value = L.sequence_pool(feat_value, 'sum') # N * (M * D2)
feat_value = fluid.layers.reshape(feat_value, [-1, heads, hidden_size_v]) # N * M * D2
feat_value = L.reshape(feat_value, [-1, heads, hidden_size_v]) # N * M * D2
output = fluid.layers.elementwise_mul(feat_value, g, axis=0)
output = fluid.layers.reshape(output, [-1, heads * hidden_size_v]) # N * (M * D2)
output = L.elementwise_mul(feat_value, g, axis=0)
output = L.reshape(output, [-1, heads * hidden_size_v]) # N * (M * D2)
output = fluid.layers.concat([x, output], axis=1)
output = L.concat([x, output], axis=1)
return output
......@@ -323,16 +326,16 @@ def gaan(gw, feature, hidden_size_a, hidden_size_v, hidden_size_m, hidden_size_o
# 计算每个点自己需要发送出去的内容
# 投影后的特征向量
# N * (D1 * M)
feat_key = fluid.layers.fc(feature, hidden_size_a * heads, bias_attr=False,
feat_key = L.fc(feature, hidden_size_a * heads, bias_attr=False,
param_attr=fluid.ParamAttr(name=name + '_project_key'))
# N * (D2 * M)
feat_value = fluid.layers.fc(feature, hidden_size_v * heads, bias_attr=False,
feat_value = L.fc(feature, hidden_size_v * heads, bias_attr=False,
param_attr=fluid.ParamAttr(name=name + '_project_value'))
# N * (D1 * M)
feat_query = fluid.layers.fc(feature, hidden_size_a * heads, bias_attr=False,
feat_query = L.fc(feature, hidden_size_a * heads, bias_attr=False,
param_attr=fluid.ParamAttr(name=name + '_project_query'))
# N * Dm
feat_gate = fluid.layers.fc(feature, hidden_size_m, bias_attr=False,
feat_gate = L.fc(feature, hidden_size_m, bias_attr=False,
param_attr=fluid.ParamAttr(name=name + '_project_gate'))
# send 阶段
......@@ -346,10 +349,10 @@ def gaan(gw, feature, hidden_size_a, hidden_size_v, hidden_size_m, hidden_size_o
# 聚合邻居特征
output = gw.recv(message, recv_func)
output = fluid.layers.fc(output, hidden_size_o, bias_attr=False,
output = L.fc(output, hidden_size_o, bias_attr=False,
param_attr=fluid.ParamAttr(name=name + '_project_output'))
output = fluid.layers.leaky_relu(output, alpha=0.1)
output = fluid.layers.dropout(output, dropout_prob=0.1)
output = L.leaky_relu(output, alpha=0.1)
output = L.dropout(output, dropout_prob=0.1)
return output
......@@ -376,7 +379,7 @@ def gen_conv(gw,
"""
if beta == "dynamic":
beta = fluid.layers.create_parameter(
beta = L.create_parameter(
shape=[1],
dtype='float32',
default_initializer=
......@@ -391,16 +394,132 @@ def gen_conv(gw,
output = message_passing.msg_norm(feature, output, name)
output = feature + output
output = fluid.layers.fc(output,
output = L.fc(output,
feature.shape[-1],
bias_attr=False,
act="relu",
param_attr=fluid.ParamAttr(name=name + '_weight1'))
output = fluid.layers.fc(output,
output = L.fc(output,
feature.shape[-1],
bias_attr=False,
param_attr=fluid.ParamAttr(name=name + '_weight2'))
return output
def get_norm(indegree):
"""Get Laplacian Normalization"""
float_degree = L.cast(indegree, dtype="float32")
float_degree = L.clamp(float_degree, min=1.0)
norm = L.pow(float_degree, factor=-0.5)
return norm
def appnp(gw, feature, edge_dropout=0, alpha=0.2, k_hop=10):
"""Implementation of APPNP of "Predict then Propagate: Graph Neural Networks
meet Personalized PageRank" (ICLR 2019).
Args:
gw: Graph wrapper object (:code:`StaticGraphWrapper` or :code:`GraphWrapper`)
feature: A tensor with shape (num_nodes, feature_size).
edge_dropout: Edge dropout rate.
k_hop: K Steps for Propagation
Return:
A tensor with shape (num_nodes, hidden_size)
"""
def send_src_copy(src_feat, dst_feat, edge_feat):
feature = src_feat["h"]
return feature
h0 = feature
ngw = gw
norm = get_norm(ngw.indegree())
for i in range(k_hop):
if edge_dropout > 1e-5:
ngw = pgl.sample.edge_drop(gw, edge_dropout)
norm = get_norm(ngw.indegree())
feature = feature * norm
msg = gw.send(send_src_copy, nfeat_list=[("h", feature)])
feature = gw.recv(msg, "sum")
feature = feature * norm
feature = feature * (1 - alpha) + h0 * alpha
return feature
def gcnii(gw,
feature,
name,
activation=None,
alpha=0.5,
lambda_l=0.5,
k_hop=1,
dropout=0.5,
is_test=False):
"""Implementation of GCNII of "Simple and Deep Graph Convolutional Networks"
paper: https://arxiv.org/pdf/2007.02133.pdf
Args:
gw: Graph wrapper object (:code:`StaticGraphWrapper` or :code:`GraphWrapper`)
feature: A tensor with shape (num_nodes, feature_size).
activation: The activation for the output.
k_hop: Number of layers for gcnii.
lambda_l: The hyperparameter of lambda in the paper.
alpha: The hyperparameter of alpha in the paper.
dropout: Feature dropout rate.
is_test: train / test phase.
Return:
A tensor with shape (num_nodes, hidden_size)
"""
def send_src_copy(src_feat, dst_feat, edge_feat):
feature = src_feat["h"]
return feature
h0 = feature
ngw = gw
norm = get_norm(ngw.indegree())
hidden_size = feature.shape[-1]
for i in range(k_hop):
beta_i = np.log(1.0 * lambda_l / (i + 1) + 1)
feature = L.dropout(
feature,
dropout_prob=dropout,
is_test=is_test,
dropout_implementation='upscale_in_train')
feature = feature * norm
msg = gw.send(send_src_copy, nfeat_list=[("h", feature)])
feature = gw.recv(msg, "sum")
feature = feature * norm
# appnp
feature = feature * (1 - alpha) + h0 * alpha
feature_transed = L.fc(feature, hidden_size,
act=None, bias_attr=False,
name=name+"_%s_w1" % i)
feature = feature_transed * beta_i + feature * (1 - beta_i)
if activation is not None:
feature = getattr(L, activation)(feature)
return feature
......@@ -516,3 +516,12 @@ def graph_saint_random_walk_sample(graph,
nodes=sample_nodes, eid=eids, with_node_feat=True, with_edge_feat=True)
subgraph.node_feat["index"] = np.array(sample_nodes, dtype="int64")
return subgraph
def edge_drop(graph_wrapper, dropout_rate, keep_self_loop=True):
if dropout_rate < 1e-5:
return graph_wrapper
else:
return pgl.graph_wrapper.DropEdgeWrapper(graph_wrapper,
dropout_rate,
keep_self_loop)
......@@ -68,3 +68,18 @@ def read_rows(data, index):
return new_data
else:
return paddle_helper.gather(data, index)
class RowReader(object):
"""Memory Efficient RowReader
"""
def __init__(self, nfeat, index):
self.nfeat = nfeat
self.loaded_nfeat = {}
self.index = index
def __getitem__(self, key):
if key not in self.loaded_nfeat:
self.loaded_nfeat[key] = read_rows(self.nfeat[key], self.index)
return self.loaded_nfeat[key]
......@@ -250,3 +250,20 @@ def scatter_max(input, index, updates):
output = fluid.layers.scatter(input, index, updates, mode='max')
return output
def masked_select(input, mask):
"""masked_select
Slice the value from given Mask
Args:
input: Input tensor to be selected
mask: A bool tensor for sliced.
Return:
Part of inputs where mask is True.
"""
index = fluid.layers.where(mask)
return fluid.layers.gather(input, index)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册