未验证 提交 16752f3e 编写于 作者: H Huang Zhengjie 提交者: GitHub

Merge pull request #108 from Yelrose/master

Add APPNP; GCNII; DropEdge
# Easy Paper Reproduction for Citation Network (Cora/Pubmed/Citeseer)
This page tries to reproduce all the **Graph Neural Network** paper for Citation Network (Cora/Pubmed/Citeseer), which is the **Hello world** dataset (**small** and **fast**) for graph neural networks. But it's very hard to achieve very high performance.
All datasets are runned with public split of **semi-supervised** settings. And we report the averarge accuracy by running 10 times.
# Experiment Results
| Model | Cora | Pubmed | Citeseer | Remarks |
| ------------------------------------------------------------ | ------------ | ------------ | ------------ | --------------------------------------------------------- |
| [Vanilla GCN (Kipf 2017)](https://openreview.net/pdf?id=SJU4ayYgl ) | 0.807(0.010) | 0.794(0.003) | 0.710(0.007) | |
| [GAT (Veličković 2017)](https://arxiv.org/pdf/1710.10903.pdf) | 0.834(0.004) | 0.772(0.004) | 0.700(0.006) | |
| [SGC(Wu 2019)](https://arxiv.org/pdf/1902.07153.pdf) | 0.818(0.000) | 0.782(0.000) | 0.708(0.000) | |
| [APPNP (Johannes 2018)](https://arxiv.org/abs/1810.05997) | 0.846(0.003) | 0.803(0.002) | 0.719(0.003) | Almost the same with the results reported in Appendix E. |
| [GCNII (64 Layers, 1500 Epochs, Chen 2020)](https://arxiv.org/pdf/2007.02133.pdf) | 0.846(0.003) | 0.798(0.003) | 0.724(0.006) | |
How to run the experiments?
```shell
# Device choose
export CUDA_VISIBLE_DEVICES=0
# GCN
python train.py --conf config/gcn.yaml --use_cuda --dataset cora
python train.py --conf config/gcn.yaml --use_cuda --dataset pubmed
python train.py --conf config/gcn.yaml --use_cuda --dataset citeseer
# GAT
python train.py --conf config/gat.yaml --use_cuda --dataset cora
python train.py --conf config/gat.yaml --use_cuda --dataset pubmed
python train.py --conf config/gat.yaml --use_cuda --dataset citeseer
# SGC (Slow version)
python train.py --conf config/sgc.yaml --use_cuda --dataset cora
python train.py --conf config/sgc.yaml --use_cuda --dataset pubmed
python train.py --conf config/sgc.yaml --use_cuda --dataset citeseer
# APPNP
python train.py --conf config/appnp.yaml --use_cuda --dataset cora
python train.py --conf config/appnp.yaml --use_cuda --dataset pubmed
python train.py --conf config/appnp.yaml --use_cuda --dataset citeseer
# GCNII (The original code use 1500 epochs.)
python train.py --conf config/gcnii.yaml --use_cuda --dataset cora --epoch 1500
python train.py --conf config/gcnii.yaml --use_cuda --dataset pubmed --epoch 1500
python train.py --conf config/gcnii.yaml --use_cuda --dataset citeseer --epoch 1500
```
import pgl
import model
from pgl import data_loader
import paddle.fluid as fluid
import numpy as np
import time
def build_model(dataset, config, phase, main_prog):
gw = pgl.graph_wrapper.GraphWrapper(
name="graph",
node_feat=dataset.graph.node_feat_info())
GraphModel = getattr(model, config.model_name)
m = GraphModel(config=config, num_class=dataset.num_classes)
logits = m.forward(gw, gw.node_feat["words"], phase)
# Take the last
node_index = fluid.layers.data(
"node_index",
shape=[None, 1],
dtype="int64",
append_batch_size=False)
node_label = fluid.layers.data(
"node_label",
shape=[None, 1],
dtype="int64",
append_batch_size=False)
pred = fluid.layers.gather(logits, node_index)
loss, pred = fluid.layers.softmax_with_cross_entropy(
logits=pred, label=node_label, return_softmax=True)
acc = fluid.layers.accuracy(input=pred, label=node_label, k=1)
loss = fluid.layers.mean(loss)
if phase == "train":
adam = fluid.optimizer.Adam(
learning_rate=config.learning_rate,
regularization=fluid.regularizer.L2DecayRegularizer(
regularization_coeff=config.weight_decay))
adam.minimize(loss)
return gw, loss, acc
model_name: APPNP
k_hop: 10
alpha: 0.1
num_layer: 1
learning_rate: 0.01
dropout: 0.5
hidden_size: 64
weight_decay: 0.0005
edge_dropout: 0.0
model_name: GAT
learning_rate: 0.005
weight_decay: 0.0005
num_layers: 1
feat_drop: 0.6
attn_drop: 0.6
num_heads: 8
hidden_size: 8
edge_dropout: 0.0
model_name: GCN
num_layers: 1
dropout: 0.5
hidden_size: 16
learning_rate: 0.01
weight_decay: 0.0005
edge_dropout: 0.0
model_name: GCNII
k_hop: 64
alpha: 0.1
num_layer: 1
learning_rate: 0.01
dropout: 0.6
hidden_size: 64
weight_decay: 0.0005
edge_dropout: 0.0
model_name: SGC
num_layers: 2
learning_rate: 0.2
weight_decay: 0.000005
feature_pre_normalize: False
import pgl
import paddle.fluid.layers as L
import pgl.layers.conv as conv
def get_norm(indegree):
float_degree = L.cast(indegree, dtype="float32")
float_degree = L.clamp(float_degree, min=1.0)
norm = L.pow(float_degree, factor=-0.5)
return norm
class GCN(object):
"""Implement of GCN
"""
def __init__(self, config, num_class):
self.num_class = num_class
self.num_layers = config.get("num_layers", 1)
self.hidden_size = config.get("hidden_size", 64)
self.dropout = config.get("dropout", 0.5)
self.edge_dropout = config.get("edge_dropout", 0.0)
def forward(self, graph_wrapper, feature, phase):
for i in range(self.num_layers):
if phase == "train":
ngw = pgl.sample.edge_drop(graph_wrapper, self.edge_dropout)
norm = get_norm(ngw.indegree())
else:
ngw = graph_wrapper
norm = graph_wrapper.node_feat["norm"]
feature = pgl.layers.gcn(ngw,
feature,
self.hidden_size,
activation="relu",
norm=norm,
name="layer_%s" % i)
feature = L.dropout(
feature,
self.dropout,
dropout_implementation='upscale_in_train')
if phase == "train":
ngw = pgl.sample.edge_drop(graph_wrapper, self.edge_dropout)
norm = get_norm(ngw.indegree())
else:
ngw = graph_wrapper
norm = graph_wrapper.node_feat["norm"]
feature = conv.gcn(ngw,
feature,
self.num_class,
activation=None,
norm=norm,
name="output")
return feature
class GAT(object):
"""Implement of GAT"""
def __init__(self, config, num_class):
self.num_class = num_class
self.num_layers = config.get("num_layers", 1)
self.num_heads = config.get("num_heads", 8)
self.hidden_size = config.get("hidden_size", 8)
self.feat_dropout = config.get("feat_drop", 0.6)
self.attn_dropout = config.get("attn_drop", 0.6)
self.edge_dropout = config.get("edge_dropout", 0.0)
def forward(self, graph_wrapper, feature, phase):
if phase == "train":
edge_dropout = 0
else:
edge_dropout = self.edge_dropout
for i in range(self.num_layers):
ngw = pgl.sample.edge_drop(graph_wrapper, edge_dropout)
feature = conv.gat(ngw,
feature,
self.hidden_size,
activation="elu",
name="gat_layer_%s" % i,
num_heads=self.num_heads,
feat_drop=self.feat_dropout,
attn_drop=self.attn_dropout)
ngw = pgl.sample.edge_drop(graph_wrapper, edge_dropout)
feature = conv.gat(ngw,
feature,
self.num_class,
num_heads=1,
activation=None,
feat_drop=self.feat_dropout,
attn_drop=self.attn_dropout,
name="output")
return feature
class APPNP(object):
"""Implement of APPNP"""
def __init__(self, config, num_class):
self.num_class = num_class
self.num_layers = config.get("num_layers", 1)
self.hidden_size = config.get("hidden_size", 64)
self.dropout = config.get("dropout", 0.5)
self.alpha = config.get("alpha", 0.1)
self.k_hop = config.get("k_hop", 10)
self.edge_dropout = config.get("edge_dropout", 0.0)
def forward(self, graph_wrapper, feature, phase):
if phase == "train":
edge_dropout = 0
else:
edge_dropout = self.edge_dropout
for i in range(self.num_layers):
feature = L.dropout(
feature,
self.dropout,
dropout_implementation='upscale_in_train')
feature = L.fc(feature, self.hidden_size, act="relu", name="lin%s" % i)
feature = L.dropout(
feature,
self.dropout,
dropout_implementation='upscale_in_train')
feature = L.fc(feature, self.num_class, act=None, name="output")
feature = conv.appnp(graph_wrapper,
feature=feature,
edge_dropout=edge_dropout,
alpha=self.alpha,
k_hop=self.k_hop)
return feature
class SGC(object):
"""Implement of SGC"""
def __init__(self, config, num_class):
self.num_class = num_class
self.num_layers = config.get("num_layers", 1)
def forward(self, graph_wrapper, feature, phase):
feature = conv.appnp(graph_wrapper,
feature=feature,
edge_dropout=0,
alpha=0,
k_hop=self.num_layers)
feature.stop_gradient=True
feature = L.fc(feature, self.num_class, act=None, bias_attr=False, name="output")
return feature
class GCNII(object):
"""Implement of GCNII"""
def __init__(self, config, num_class):
self.num_class = num_class
self.num_layers = config.get("num_layers", 1)
self.hidden_size = config.get("hidden_size", 64)
self.dropout = config.get("dropout", 0.6)
self.alpha = config.get("alpha", 0.1)
self.lambda_l = config.get("lambda_l", 0.5)
self.k_hop = config.get("k_hop", 64)
self.edge_dropout = config.get("edge_dropout", 0.0)
def forward(self, graph_wrapper, feature, phase):
if phase == "train":
edge_dropout = 0
else:
edge_dropout = self.edge_dropout
for i in range(self.num_layers):
feature = L.fc(feature, self.hidden_size, act="relu", name="lin%s" % i)
feature = L.dropout(
feature,
self.dropout,
dropout_implementation='upscale_in_train')
feature = conv.gcnii(graph_wrapper,
feature=feature,
name="gcnii",
activation="relu",
lambda_l=self.lambda_l,
alpha=self.alpha,
dropout=self.dropout,
k_hop=self.k_hop)
feature = L.fc(feature, self.num_class, act=None, name="output")
return feature
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pgl
import model# import LabelGraphGCN
from pgl import data_loader
from pgl.utils.logger import log
import paddle.fluid as fluid
import numpy as np
import time
import argparse
from build_model import build_model
import yaml
from easydict import EasyDict as edict
import tqdm
def normalize(feat):
return feat / np.maximum(np.sum(feat, -1, keepdims=True), 1)
def load(name, normalized_feature=True):
if name == 'cora':
dataset = data_loader.CoraDataset()
elif name == "pubmed":
dataset = data_loader.CitationDataset("pubmed", symmetry_edges=True)
elif name == "citeseer":
dataset = data_loader.CitationDataset("citeseer", symmetry_edges=True)
else:
raise ValueError(name + " dataset doesn't exists")
indegree = dataset.graph.indegree()
norm = np.maximum(indegree.astype("float32"), 1)
norm = np.power(norm, -0.5)
dataset.graph.node_feat["norm"] = np.expand_dims(norm, -1)
dataset.graph.node_feat["words"] = normalize(dataset.graph.node_feat["words"])
return dataset
def main(args, config):
dataset = load(args.dataset, args.feature_pre_normalize)
place = fluid.CUDAPlace(0) if args.use_cuda else fluid.CPUPlace()
train_program = fluid.default_main_program()
startup_program = fluid.default_startup_program()
with fluid.program_guard(train_program, startup_program):
with fluid.unique_name.guard():
gw, loss, acc = build_model(dataset,
config=config,
phase="train",
main_prog=train_program)
test_program = fluid.Program()
with fluid.program_guard(test_program, startup_program):
with fluid.unique_name.guard():
_gw, v_loss, v_acc = build_model(dataset,
config=config,
phase="test",
main_prog=test_program)
test_program = test_program.clone(for_test=True)
exe = fluid.Executor(place)
train_index = dataset.train_index
train_label = np.expand_dims(dataset.y[train_index], -1)
train_index = np.expand_dims(train_index, -1)
val_index = dataset.val_index
val_label = np.expand_dims(dataset.y[val_index], -1)
val_index = np.expand_dims(val_index, -1)
test_index = dataset.test_index
test_label = np.expand_dims(dataset.y[test_index], -1)
test_index = np.expand_dims(test_index, -1)
dur = []
# Feed data
feed_dict = gw.to_feed(dataset.graph)
best_test = []
for run in range(args.runs):
exe.run(startup_program)
cal_val_acc = []
cal_test_acc = []
cal_val_loss = []
cal_test_loss = []
for epoch in tqdm.tqdm(range(args.epoch)):
feed_dict["node_index"] = np.array(train_index, dtype="int64")
feed_dict["node_label"] = np.array(train_label, dtype="int64")
train_loss, train_acc = exe.run(train_program,
feed=feed_dict,
fetch_list=[loss, acc],
return_numpy=True)
feed_dict["node_index"] = np.array(val_index, dtype="int64")
feed_dict["node_label"] = np.array(val_label, dtype="int64")
val_loss, val_acc = exe.run(test_program,
feed=feed_dict,
fetch_list=[v_loss, v_acc],
return_numpy=True)
cal_val_acc.append(val_acc[0])
cal_val_loss.append(val_loss[0])
feed_dict["node_index"] = np.array(test_index, dtype="int64")
feed_dict["node_label"] = np.array(test_label, dtype="int64")
test_loss, test_acc = exe.run(test_program,
feed=feed_dict,
fetch_list=[v_loss, v_acc],
return_numpy=True)
cal_test_acc.append(test_acc[0])
cal_test_loss.append(test_loss[0])
log.info("Runs %s: Model: %s Best Test Accuracy: %f" % (run, config.model_name,
cal_test_acc[np.argmin(cal_val_loss)]))
best_test.append(cal_test_acc[np.argmin(cal_val_loss)])
log.info("Dataset: %s Best Test Accuracy: %f ( stddev: %f )" % (args.dataset, np.mean(best_test), np.std(best_test)))
print("Dataset: %s Best Test Accuracy: %f ( stddev: %f )" % (args.dataset, np.mean(best_test), np.std(best_test)))
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Benchmarking Citation Network')
parser.add_argument(
"--dataset", type=str, default="cora", help="dataset (cora, pubmed)")
parser.add_argument("--use_cuda", action='store_true', help="use_cuda")
parser.add_argument("--conf", type=str, help="config file for models")
parser.add_argument("--epoch", type=int, default=200, help="Epoch")
parser.add_argument("--runs", type=int, default=10, help="runs")
parser.add_argument("--feature_pre_normalize", type=bool, default=True, help="pre_normalize feature")
args = parser.parse_args()
config = edict(yaml.load(open(args.conf), Loader=yaml.FullLoader))
log.info(args)
main(args, config)
...@@ -22,3 +22,4 @@ from pgl import heter_graph ...@@ -22,3 +22,4 @@ from pgl import heter_graph
from pgl import heter_graph_wrapper from pgl import heter_graph_wrapper
from pgl import contrib from pgl import contrib
from pgl import message_passing from pgl import message_passing
from pgl import sample
...@@ -19,6 +19,7 @@ for PaddlePaddle. ...@@ -19,6 +19,7 @@ for PaddlePaddle.
import warnings import warnings
import numpy as np import numpy as np
import paddle.fluid as fluid import paddle.fluid as fluid
import paddle.fluid.layers as L
from pgl.utils import op from pgl.utils import op
from pgl.utils import paddle_helper from pgl.utils import paddle_helper
...@@ -47,10 +48,10 @@ def recv(dst, uniq_dst, bucketing_index, msg, reduce_function, num_nodes, ...@@ -47,10 +48,10 @@ def recv(dst, uniq_dst, bucketing_index, msg, reduce_function, num_nodes,
try: try:
out_dim = msg.shape[-1] out_dim = msg.shape[-1]
init_output = fluid.layers.fill_constant( init_output = L.fill_constant(
shape=[num_nodes, out_dim], value=0, dtype=msg.dtype) shape=[num_nodes, out_dim], value=0, dtype=msg.dtype)
init_output.stop_gradient = False init_output.stop_gradient = False
empty_msg_flag = fluid.layers.cast(num_edges > 0, dtype=msg.dtype) empty_msg_flag = L.cast(num_edges > 0, dtype=msg.dtype)
msg = msg * empty_msg_flag msg = msg * empty_msg_flag
output = paddle_helper.scatter_add(init_output, dst, msg) output = paddle_helper.scatter_add(init_output, dst, msg)
return output return output
...@@ -59,7 +60,7 @@ def recv(dst, uniq_dst, bucketing_index, msg, reduce_function, num_nodes, ...@@ -59,7 +60,7 @@ def recv(dst, uniq_dst, bucketing_index, msg, reduce_function, num_nodes,
"scatter_add is not supported with paddle version <= 1.5") "scatter_add is not supported with paddle version <= 1.5")
def sum_func(message): def sum_func(message):
return fluid.layers.sequence_pool(message, "sum") return L.sequence_pool(message, "sum")
reduce_function = sum_func reduce_function = sum_func
...@@ -67,13 +68,13 @@ def recv(dst, uniq_dst, bucketing_index, msg, reduce_function, num_nodes, ...@@ -67,13 +68,13 @@ def recv(dst, uniq_dst, bucketing_index, msg, reduce_function, num_nodes,
output = reduce_function(bucketed_msg) output = reduce_function(bucketed_msg)
output_dim = output.shape[-1] output_dim = output.shape[-1]
empty_msg_flag = fluid.layers.cast(num_edges > 0, dtype=output.dtype) empty_msg_flag = L.cast(num_edges > 0, dtype=output.dtype)
output = output * empty_msg_flag output = output * empty_msg_flag
init_output = fluid.layers.fill_constant( init_output = L.fill_constant(
shape=[num_nodes, output_dim], value=0, dtype=output.dtype) shape=[num_nodes, output_dim], value=0, dtype=output.dtype)
init_output.stop_gradient = True init_output.stop_gradient = True
final_output = fluid.layers.scatter(init_output, uniq_dst, output) final_output = L.scatter(init_output, uniq_dst, output)
return final_output return final_output
...@@ -104,6 +105,7 @@ class BaseGraphWrapper(object): ...@@ -104,6 +105,7 @@ class BaseGraphWrapper(object):
self._node_ids = None self._node_ids = None
self._graph_lod = None self._graph_lod = None
self._num_graph = None self._num_graph = None
self._num_edges = None
self._data_name_prefix = "" self._data_name_prefix = ""
def __repr__(self): def __repr__(self):
...@@ -470,7 +472,7 @@ class StaticGraphWrapper(BaseGraphWrapper): ...@@ -470,7 +472,7 @@ class StaticGraphWrapper(BaseGraphWrapper):
class GraphWrapper(BaseGraphWrapper): class GraphWrapper(BaseGraphWrapper):
"""Implement a graph wrapper that creates a graph data holders """Implement a graph wrapper that creates a graph data holders
that attributes and features in the graph are :code:`fluid.layers.data`. that attributes and features in the graph are :code:`L.data`.
And we provide interface :code:`to_feed` to help converting :code:`Graph` And we provide interface :code:`to_feed` to help converting :code:`Graph`
data into :code:`feed_dict`. data into :code:`feed_dict`.
...@@ -546,65 +548,65 @@ class GraphWrapper(BaseGraphWrapper): ...@@ -546,65 +548,65 @@ class GraphWrapper(BaseGraphWrapper):
def __create_graph_attr_holders(self): def __create_graph_attr_holders(self):
"""Create data holders for graph attributes. """Create data holders for graph attributes.
""" """
self._num_edges = fluid.layers.data( self._num_edges = L.data(
self._data_name_prefix + '/num_edges', self._data_name_prefix + '/num_edges',
shape=[1], shape=[1],
append_batch_size=False, append_batch_size=False,
dtype="int64", dtype="int64",
stop_gradient=True) stop_gradient=True)
self._num_graph = fluid.layers.data( self._num_graph = L.data(
self._data_name_prefix + '/num_graph', self._data_name_prefix + '/num_graph',
shape=[1], shape=[1],
append_batch_size=False, append_batch_size=False,
dtype="int64", dtype="int64",
stop_gradient=True) stop_gradient=True)
self._edges_src = fluid.layers.data( self._edges_src = L.data(
self._data_name_prefix + '/edges_src', self._data_name_prefix + '/edges_src',
shape=[None], shape=[None],
append_batch_size=False, append_batch_size=False,
dtype="int64", dtype="int64",
stop_gradient=True) stop_gradient=True)
self._edges_dst = fluid.layers.data( self._edges_dst = L.data(
self._data_name_prefix + '/edges_dst', self._data_name_prefix + '/edges_dst',
shape=[None], shape=[None],
append_batch_size=False, append_batch_size=False,
dtype="int64", dtype="int64",
stop_gradient=True) stop_gradient=True)
self._num_nodes = fluid.layers.data( self._num_nodes = L.data(
self._data_name_prefix + '/num_nodes', self._data_name_prefix + '/num_nodes',
shape=[1], shape=[1],
append_batch_size=False, append_batch_size=False,
dtype='int64', dtype='int64',
stop_gradient=True) stop_gradient=True)
self._edge_uniq_dst = fluid.layers.data( self._edge_uniq_dst = L.data(
self._data_name_prefix + "/uniq_dst", self._data_name_prefix + "/uniq_dst",
shape=[None], shape=[None],
append_batch_size=False, append_batch_size=False,
dtype="int64", dtype="int64",
stop_gradient=True) stop_gradient=True)
self._graph_lod = fluid.layers.data( self._graph_lod = L.data(
self._data_name_prefix + "/graph_lod", self._data_name_prefix + "/graph_lod",
shape=[None], shape=[None],
append_batch_size=False, append_batch_size=False,
dtype="int32", dtype="int32",
stop_gradient=True) stop_gradient=True)
self._edge_uniq_dst_count = fluid.layers.data( self._edge_uniq_dst_count = L.data(
self._data_name_prefix + "/uniq_dst_count", self._data_name_prefix + "/uniq_dst_count",
shape=[None], shape=[None],
append_batch_size=False, append_batch_size=False,
dtype="int32", dtype="int32",
stop_gradient=True) stop_gradient=True)
self._node_ids = fluid.layers.data( self._node_ids = L.data(
self._data_name_prefix + "/node_ids", self._data_name_prefix + "/node_ids",
shape=[None], shape=[None],
append_batch_size=False, append_batch_size=False,
dtype="int64", dtype="int64",
stop_gradient=True) stop_gradient=True)
self._indegree = fluid.layers.data( self._indegree = L.data(
self._data_name_prefix + "/indegree", self._data_name_prefix + "/indegree",
shape=[None], shape=[None],
append_batch_size=False, append_batch_size=False,
...@@ -627,7 +629,7 @@ class GraphWrapper(BaseGraphWrapper): ...@@ -627,7 +629,7 @@ class GraphWrapper(BaseGraphWrapper):
node_feat_dtype): node_feat_dtype):
"""Create data holders for node features. """Create data holders for node features.
""" """
feat_holder = fluid.layers.data( feat_holder = L.data(
self._data_name_prefix + '/node_feat/' + node_feat_name, self._data_name_prefix + '/node_feat/' + node_feat_name,
shape=node_feat_shape, shape=node_feat_shape,
append_batch_size=False, append_batch_size=False,
...@@ -640,7 +642,7 @@ class GraphWrapper(BaseGraphWrapper): ...@@ -640,7 +642,7 @@ class GraphWrapper(BaseGraphWrapper):
edge_feat_dtype): edge_feat_dtype):
"""Create edge holders for edge features. """Create edge holders for edge features.
""" """
feat_holder = fluid.layers.data( feat_holder = L.data(
self._data_name_prefix + '/edge_feat/' + edge_feat_name, self._data_name_prefix + '/edge_feat/' + edge_feat_name,
shape=edge_feat_shape, shape=edge_feat_shape,
append_batch_size=False, append_batch_size=False,
...@@ -719,3 +721,61 @@ class GraphWrapper(BaseGraphWrapper): ...@@ -719,3 +721,61 @@ class GraphWrapper(BaseGraphWrapper):
"""Return the holder list. """Return the holder list.
""" """
return self._holder_list return self._holder_list
def get_degree(edge, num_nodes):
init_output = L.fill_constant(
shape=[num_nodes], value=0, dtype="float32")
init_output.stop_gradient = True
final_output = L.scatter(init_output,
edge,
L.full_like(edge, 1, dtype="float32"),
overwrite=False)
return final_output
class DropEdgeWrapper(BaseGraphWrapper):
"""Implement of Edge Drop """
def __init__(self, graph_wrapper, dropout, keep_self_loop=True):
super(DropEdgeWrapper, self).__init__()
# Copy Node's information
for key, value in graph_wrapper.node_feat.items():
self.node_feat_tensor_dict[key] = value
self._num_nodes = graph_wrapper.num_nodes
self._graph_lod = graph_wrapper.graph_lod
self._num_graph = graph_wrapper.num_graph
self._node_ids = L.range(0, self._num_nodes, step=1, dtype="int32")
# Dropout Edges
src, dst = graph_wrapper.edges
u = L.uniform_random(shape=L.cast(L.shape(src), 'int64'), min=0., max=1.)
# Avoid Empty Edges
keeped = L.cast(u > dropout, dtype="float32")
self._num_edges = L.reduce_sum(L.cast(keeped, "int32"))
keeped = keeped + L.cast(self._num_edges == 0, dtype="float32")
if keep_self_loop:
self_loop = L.cast(src == dst, dtype="float32")
keeped = keeped + self_loop
keeped = (keeped > 0.5)
src = paddle_helper.masked_select(src, keeped)
dst = paddle_helper.masked_select(dst, keeped)
src.stop_gradient=True
dst.stop_gradient=True
self._edges_src = src
self._edges_dst = dst
for key, value in graph_wrapper.edge_feat.items():
self.edge_feat_tensor_dict[key] = paddle_helper.masked_select(value, keeped)
self._edge_uniq_dst, _, uniq_count = L.unique_with_counts(dst, dtype="int32")
self._edge_uniq_dst.stop_gradient=True
last = L.reduce_sum(uniq_count, keep_dim=True)
uniq_count = L.cumsum(uniq_count, exclusive=True)
self._edge_uniq_dst_count = L.concat([uniq_count, last])
self._edge_uniq_dst_count.stop_gradient=True
self._indegree = get_degree(self._edges_dst, self._num_nodes)
...@@ -14,11 +14,14 @@ ...@@ -14,11 +14,14 @@
"""This package implements common layers to help building """This package implements common layers to help building
graph neural networks. graph neural networks.
""" """
import pgl
import paddle.fluid as fluid import paddle.fluid as fluid
import paddle.fluid.layers as L
from pgl.utils import paddle_helper from pgl.utils import paddle_helper
from pgl import message_passing from pgl import message_passing
import numpy as np
__all__ = ['gcn', 'gat', 'gin', 'gaan', 'gen_conv'] __all__ = ['gcn', 'gat', 'gin', 'gaan', 'gen_conv', 'appnp', 'gcnii']
def gcn(gw, feature, hidden_size, activation, name, norm=None): def gcn(gw, feature, hidden_size, activation, name, norm=None):
...@@ -50,7 +53,7 @@ def gcn(gw, feature, hidden_size, activation, name, norm=None): ...@@ -50,7 +53,7 @@ def gcn(gw, feature, hidden_size, activation, name, norm=None):
size = feature.shape[-1] size = feature.shape[-1]
if size > hidden_size: if size > hidden_size:
feature = fluid.layers.fc(feature, feature = L.fc(feature,
size=hidden_size, size=hidden_size,
bias_attr=False, bias_attr=False,
param_attr=fluid.ParamAttr(name=name)) param_attr=fluid.ParamAttr(name=name))
...@@ -64,7 +67,7 @@ def gcn(gw, feature, hidden_size, activation, name, norm=None): ...@@ -64,7 +67,7 @@ def gcn(gw, feature, hidden_size, activation, name, norm=None):
output = gw.recv(msg, "sum") output = gw.recv(msg, "sum")
else: else:
output = gw.recv(msg, "sum") output = gw.recv(msg, "sum")
output = fluid.layers.fc(output, output = L.fc(output,
size=hidden_size, size=hidden_size,
bias_attr=False, bias_attr=False,
param_attr=fluid.ParamAttr(name=name)) param_attr=fluid.ParamAttr(name=name))
...@@ -72,12 +75,12 @@ def gcn(gw, feature, hidden_size, activation, name, norm=None): ...@@ -72,12 +75,12 @@ def gcn(gw, feature, hidden_size, activation, name, norm=None):
if norm is not None: if norm is not None:
output = output * norm output = output * norm
bias = fluid.layers.create_parameter( bias = L.create_parameter(
shape=[hidden_size], shape=[hidden_size],
dtype='float32', dtype='float32',
is_bias=True, is_bias=True,
name=name + '_bias') name=name + '_bias')
output = fluid.layers.elementwise_add(output, bias, act=activation) output = L.elementwise_add(output, bias, act=activation)
return output return output
...@@ -120,7 +123,7 @@ def gat(gw, ...@@ -120,7 +123,7 @@ def gat(gw,
def send_attention(src_feat, dst_feat, edge_feat): def send_attention(src_feat, dst_feat, edge_feat):
output = src_feat["left_a"] + dst_feat["right_a"] output = src_feat["left_a"] + dst_feat["right_a"]
output = fluid.layers.leaky_relu( output = L.leaky_relu(
output, alpha=0.2) # (num_edges, num_heads) output, alpha=0.2) # (num_edges, num_heads)
return {"alpha": output, "h": src_feat["h"]} return {"alpha": output, "h": src_feat["h"]}
...@@ -129,54 +132,54 @@ def gat(gw, ...@@ -129,54 +132,54 @@ def gat(gw,
h = msg["h"] h = msg["h"]
alpha = paddle_helper.sequence_softmax(alpha) alpha = paddle_helper.sequence_softmax(alpha)
old_h = h old_h = h
h = fluid.layers.reshape(h, [-1, num_heads, hidden_size]) h = L.reshape(h, [-1, num_heads, hidden_size])
alpha = fluid.layers.reshape(alpha, [-1, num_heads, 1]) alpha = L.reshape(alpha, [-1, num_heads, 1])
if attn_drop > 1e-15: if attn_drop > 1e-15:
alpha = fluid.layers.dropout( alpha = L.dropout(
alpha, alpha,
dropout_prob=attn_drop, dropout_prob=attn_drop,
is_test=is_test, is_test=is_test,
dropout_implementation="upscale_in_train") dropout_implementation="upscale_in_train")
h = h * alpha h = h * alpha
h = fluid.layers.reshape(h, [-1, num_heads * hidden_size]) h = L.reshape(h, [-1, num_heads * hidden_size])
h = fluid.layers.lod_reset(h, old_h) h = L.lod_reset(h, old_h)
return fluid.layers.sequence_pool(h, "sum") return L.sequence_pool(h, "sum")
if feat_drop > 1e-15: if feat_drop > 1e-15:
feature = fluid.layers.dropout( feature = L.dropout(
feature, feature,
dropout_prob=feat_drop, dropout_prob=feat_drop,
is_test=is_test, is_test=is_test,
dropout_implementation='upscale_in_train') dropout_implementation='upscale_in_train')
ft = fluid.layers.fc(feature, ft = L.fc(feature,
hidden_size * num_heads, hidden_size * num_heads,
bias_attr=False, bias_attr=False,
param_attr=fluid.ParamAttr(name=name + '_weight')) param_attr=fluid.ParamAttr(name=name + '_weight'))
left_a = fluid.layers.create_parameter( left_a = L.create_parameter(
shape=[num_heads, hidden_size], shape=[num_heads, hidden_size],
dtype='float32', dtype='float32',
name=name + '_gat_l_A') name=name + '_gat_l_A')
right_a = fluid.layers.create_parameter( right_a = L.create_parameter(
shape=[num_heads, hidden_size], shape=[num_heads, hidden_size],
dtype='float32', dtype='float32',
name=name + '_gat_r_A') name=name + '_gat_r_A')
reshape_ft = fluid.layers.reshape(ft, [-1, num_heads, hidden_size]) reshape_ft = L.reshape(ft, [-1, num_heads, hidden_size])
left_a_value = fluid.layers.reduce_sum(reshape_ft * left_a, -1) left_a_value = L.reduce_sum(reshape_ft * left_a, -1)
right_a_value = fluid.layers.reduce_sum(reshape_ft * right_a, -1) right_a_value = L.reduce_sum(reshape_ft * right_a, -1)
msg = gw.send( msg = gw.send(
send_attention, send_attention,
nfeat_list=[("h", ft), ("left_a", left_a_value), nfeat_list=[("h", ft), ("left_a", left_a_value),
("right_a", right_a_value)]) ("right_a", right_a_value)])
output = gw.recv(msg, reduce_attention) output = gw.recv(msg, reduce_attention)
bias = fluid.layers.create_parameter( bias = L.create_parameter(
shape=[hidden_size * num_heads], shape=[hidden_size * num_heads],
dtype='float32', dtype='float32',
is_bias=True, is_bias=True,
name=name + '_bias') name=name + '_bias')
bias.stop_gradient = True bias.stop_gradient = True
output = fluid.layers.elementwise_add(output, bias, act=activation) output = L.elementwise_add(output, bias, act=activation)
return output return output
...@@ -219,7 +222,7 @@ def gin(gw, ...@@ -219,7 +222,7 @@ def gin(gw,
def send_src_copy(src_feat, dst_feat, edge_feat): def send_src_copy(src_feat, dst_feat, edge_feat):
return src_feat["h"] return src_feat["h"]
epsilon = fluid.layers.create_parameter( epsilon = L.create_parameter(
shape=[1, 1], shape=[1, 1],
dtype="float32", dtype="float32",
attr=fluid.ParamAttr(name="%s_eps" % name), attr=fluid.ParamAttr(name="%s_eps" % name),
...@@ -232,13 +235,13 @@ def gin(gw, ...@@ -232,13 +235,13 @@ def gin(gw,
msg = gw.send(send_src_copy, nfeat_list=[("h", feature)]) msg = gw.send(send_src_copy, nfeat_list=[("h", feature)])
output = gw.recv(msg, "sum") + feature * (epsilon + 1.0) output = gw.recv(msg, "sum") + feature * (epsilon + 1.0)
output = fluid.layers.fc(output, output = L.fc(output,
size=hidden_size, size=hidden_size,
act=None, act=None,
param_attr=fluid.ParamAttr(name="%s_w_0" % name), param_attr=fluid.ParamAttr(name="%s_w_0" % name),
bias_attr=fluid.ParamAttr(name="%s_b_0" % name)) bias_attr=fluid.ParamAttr(name="%s_b_0" % name))
output = fluid.layers.layer_norm( output = L.layer_norm(
output, output,
begin_norm_axis=1, begin_norm_axis=1,
param_attr=fluid.ParamAttr( param_attr=fluid.ParamAttr(
...@@ -249,9 +252,9 @@ def gin(gw, ...@@ -249,9 +252,9 @@ def gin(gw,
initializer=fluid.initializer.Constant(0.0)), ) initializer=fluid.initializer.Constant(0.0)), )
if activation is not None: if activation is not None:
output = getattr(fluid.layers, activation)(output) output = getattr(L, activation)(output)
output = fluid.layers.fc(output, output = L.fc(output,
size=hidden_size, size=hidden_size,
act=activation, act=activation,
param_attr=fluid.ParamAttr(name="%s_w_1" % name), param_attr=fluid.ParamAttr(name="%s_w_1" % name),
...@@ -269,10 +272,10 @@ def gaan(gw, feature, hidden_size_a, hidden_size_v, hidden_size_m, hidden_size_o ...@@ -269,10 +272,10 @@ def gaan(gw, feature, hidden_size_a, hidden_size_v, hidden_size_m, hidden_size_o
feat_query, feat_key = dst_feat['feat_query'], src_feat['feat_key'] feat_query, feat_key = dst_feat['feat_query'], src_feat['feat_key']
# E * M * D1 # E * M * D1
old = feat_query old = feat_query
feat_query = fluid.layers.reshape(feat_query, [-1, heads, hidden_size_a]) feat_query = L.reshape(feat_query, [-1, heads, hidden_size_a])
feat_key = fluid.layers.reshape(feat_key, [-1, heads, hidden_size_a]) feat_key = L.reshape(feat_key, [-1, heads, hidden_size_a])
# E * M # E * M
alpha = fluid.layers.reduce_sum(feat_key * feat_query, dim=-1) alpha = L.reduce_sum(feat_key * feat_query, dim=-1)
return {'dst_node_feat': dst_feat['node_feat'], return {'dst_node_feat': dst_feat['node_feat'],
'src_node_feat': src_feat['node_feat'], 'src_node_feat': src_feat['node_feat'],
...@@ -286,15 +289,15 @@ def gaan(gw, feature, hidden_size_a, hidden_size_v, hidden_size_m, hidden_size_o ...@@ -286,15 +289,15 @@ def gaan(gw, feature, hidden_size_a, hidden_size_v, hidden_size_m, hidden_size_o
# 每条边的出发点的特征 # 每条边的出发点的特征
src_feat = message['src_node_feat'] src_feat = message['src_node_feat']
# 每个中心点自己的特征 # 每个中心点自己的特征
x = fluid.layers.sequence_pool(dst_feat, 'average') x = L.sequence_pool(dst_feat, 'average')
# 每个中心点的邻居的特征的平均值 # 每个中心点的邻居的特征的平均值
z = fluid.layers.sequence_pool(src_feat, 'average') z = L.sequence_pool(src_feat, 'average')
# 计算 gate # 计算 gate
feat_gate = message['feat_gate'] feat_gate = message['feat_gate']
g_max = fluid.layers.sequence_pool(feat_gate, 'max') g_max = L.sequence_pool(feat_gate, 'max')
g = fluid.layers.concat([x, g_max, z], axis=1) g = L.concat([x, g_max, z], axis=1)
g = fluid.layers.fc(g, heads, bias_attr=False, act="sigmoid") g = L.fc(g, heads, bias_attr=False, act="sigmoid")
# softmax # softmax
alpha = message['alpha'] alpha = message['alpha']
...@@ -302,19 +305,19 @@ def gaan(gw, feature, hidden_size_a, hidden_size_v, hidden_size_m, hidden_size_o ...@@ -302,19 +305,19 @@ def gaan(gw, feature, hidden_size_a, hidden_size_v, hidden_size_m, hidden_size_o
feat_value = message['feat_value'] # E * (M * D2) feat_value = message['feat_value'] # E * (M * D2)
old = feat_value old = feat_value
feat_value = fluid.layers.reshape(feat_value, [-1, heads, hidden_size_v]) # E * M * D2 feat_value = L.reshape(feat_value, [-1, heads, hidden_size_v]) # E * M * D2
feat_value = fluid.layers.elementwise_mul(feat_value, alpha, axis=0) feat_value = L.elementwise_mul(feat_value, alpha, axis=0)
feat_value = fluid.layers.reshape(feat_value, [-1, heads*hidden_size_v]) # E * (M * D2) feat_value = L.reshape(feat_value, [-1, heads*hidden_size_v]) # E * (M * D2)
feat_value = fluid.layers.lod_reset(feat_value, old) feat_value = L.lod_reset(feat_value, old)
feat_value = fluid.layers.sequence_pool(feat_value, 'sum') # N * (M * D2) feat_value = L.sequence_pool(feat_value, 'sum') # N * (M * D2)
feat_value = fluid.layers.reshape(feat_value, [-1, heads, hidden_size_v]) # N * M * D2 feat_value = L.reshape(feat_value, [-1, heads, hidden_size_v]) # N * M * D2
output = fluid.layers.elementwise_mul(feat_value, g, axis=0) output = L.elementwise_mul(feat_value, g, axis=0)
output = fluid.layers.reshape(output, [-1, heads * hidden_size_v]) # N * (M * D2) output = L.reshape(output, [-1, heads * hidden_size_v]) # N * (M * D2)
output = fluid.layers.concat([x, output], axis=1) output = L.concat([x, output], axis=1)
return output return output
...@@ -323,16 +326,16 @@ def gaan(gw, feature, hidden_size_a, hidden_size_v, hidden_size_m, hidden_size_o ...@@ -323,16 +326,16 @@ def gaan(gw, feature, hidden_size_a, hidden_size_v, hidden_size_m, hidden_size_o
# 计算每个点自己需要发送出去的内容 # 计算每个点自己需要发送出去的内容
# 投影后的特征向量 # 投影后的特征向量
# N * (D1 * M) # N * (D1 * M)
feat_key = fluid.layers.fc(feature, hidden_size_a * heads, bias_attr=False, feat_key = L.fc(feature, hidden_size_a * heads, bias_attr=False,
param_attr=fluid.ParamAttr(name=name + '_project_key')) param_attr=fluid.ParamAttr(name=name + '_project_key'))
# N * (D2 * M) # N * (D2 * M)
feat_value = fluid.layers.fc(feature, hidden_size_v * heads, bias_attr=False, feat_value = L.fc(feature, hidden_size_v * heads, bias_attr=False,
param_attr=fluid.ParamAttr(name=name + '_project_value')) param_attr=fluid.ParamAttr(name=name + '_project_value'))
# N * (D1 * M) # N * (D1 * M)
feat_query = fluid.layers.fc(feature, hidden_size_a * heads, bias_attr=False, feat_query = L.fc(feature, hidden_size_a * heads, bias_attr=False,
param_attr=fluid.ParamAttr(name=name + '_project_query')) param_attr=fluid.ParamAttr(name=name + '_project_query'))
# N * Dm # N * Dm
feat_gate = fluid.layers.fc(feature, hidden_size_m, bias_attr=False, feat_gate = L.fc(feature, hidden_size_m, bias_attr=False,
param_attr=fluid.ParamAttr(name=name + '_project_gate')) param_attr=fluid.ParamAttr(name=name + '_project_gate'))
# send 阶段 # send 阶段
...@@ -346,10 +349,10 @@ def gaan(gw, feature, hidden_size_a, hidden_size_v, hidden_size_m, hidden_size_o ...@@ -346,10 +349,10 @@ def gaan(gw, feature, hidden_size_a, hidden_size_v, hidden_size_m, hidden_size_o
# 聚合邻居特征 # 聚合邻居特征
output = gw.recv(message, recv_func) output = gw.recv(message, recv_func)
output = fluid.layers.fc(output, hidden_size_o, bias_attr=False, output = L.fc(output, hidden_size_o, bias_attr=False,
param_attr=fluid.ParamAttr(name=name + '_project_output')) param_attr=fluid.ParamAttr(name=name + '_project_output'))
output = fluid.layers.leaky_relu(output, alpha=0.1) output = L.leaky_relu(output, alpha=0.1)
output = fluid.layers.dropout(output, dropout_prob=0.1) output = L.dropout(output, dropout_prob=0.1)
return output return output
...@@ -376,7 +379,7 @@ def gen_conv(gw, ...@@ -376,7 +379,7 @@ def gen_conv(gw,
""" """
if beta == "dynamic": if beta == "dynamic":
beta = fluid.layers.create_parameter( beta = L.create_parameter(
shape=[1], shape=[1],
dtype='float32', dtype='float32',
default_initializer= default_initializer=
...@@ -391,16 +394,132 @@ def gen_conv(gw, ...@@ -391,16 +394,132 @@ def gen_conv(gw,
output = message_passing.msg_norm(feature, output, name) output = message_passing.msg_norm(feature, output, name)
output = feature + output output = feature + output
output = fluid.layers.fc(output, output = L.fc(output,
feature.shape[-1], feature.shape[-1],
bias_attr=False, bias_attr=False,
act="relu", act="relu",
param_attr=fluid.ParamAttr(name=name + '_weight1')) param_attr=fluid.ParamAttr(name=name + '_weight1'))
output = fluid.layers.fc(output, output = L.fc(output,
feature.shape[-1], feature.shape[-1],
bias_attr=False, bias_attr=False,
param_attr=fluid.ParamAttr(name=name + '_weight2')) param_attr=fluid.ParamAttr(name=name + '_weight2'))
return output return output
def get_norm(indegree):
"""Get Laplacian Normalization"""
float_degree = L.cast(indegree, dtype="float32")
float_degree = L.clamp(float_degree, min=1.0)
norm = L.pow(float_degree, factor=-0.5)
return norm
def appnp(gw, feature, edge_dropout=0, alpha=0.2, k_hop=10):
"""Implementation of APPNP of "Predict then Propagate: Graph Neural Networks
meet Personalized PageRank" (ICLR 2019).
Args:
gw: Graph wrapper object (:code:`StaticGraphWrapper` or :code:`GraphWrapper`)
feature: A tensor with shape (num_nodes, feature_size).
edge_dropout: Edge dropout rate.
k_hop: K Steps for Propagation
Return:
A tensor with shape (num_nodes, hidden_size)
"""
def send_src_copy(src_feat, dst_feat, edge_feat):
feature = src_feat["h"]
return feature
h0 = feature
ngw = gw
norm = get_norm(ngw.indegree())
for i in range(k_hop):
if edge_dropout > 1e-5:
ngw = pgl.sample.edge_drop(gw, edge_dropout)
norm = get_norm(ngw.indegree())
feature = feature * norm
msg = gw.send(send_src_copy, nfeat_list=[("h", feature)])
feature = gw.recv(msg, "sum")
feature = feature * norm
feature = feature * (1 - alpha) + h0 * alpha
return feature
def gcnii(gw,
feature,
name,
activation=None,
alpha=0.5,
lambda_l=0.5,
k_hop=1,
dropout=0.5,
is_test=False):
"""Implementation of GCNII of "Simple and Deep Graph Convolutional Networks"
paper: https://arxiv.org/pdf/2007.02133.pdf
Args:
gw: Graph wrapper object (:code:`StaticGraphWrapper` or :code:`GraphWrapper`)
feature: A tensor with shape (num_nodes, feature_size).
activation: The activation for the output.
k_hop: Number of layers for gcnii.
lambda_l: The hyperparameter of lambda in the paper.
alpha: The hyperparameter of alpha in the paper.
dropout: Feature dropout rate.
is_test: train / test phase.
Return:
A tensor with shape (num_nodes, hidden_size)
"""
def send_src_copy(src_feat, dst_feat, edge_feat):
feature = src_feat["h"]
return feature
h0 = feature
ngw = gw
norm = get_norm(ngw.indegree())
hidden_size = feature.shape[-1]
for i in range(k_hop):
beta_i = np.log(1.0 * lambda_l / (i + 1) + 1)
feature = L.dropout(
feature,
dropout_prob=dropout,
is_test=is_test,
dropout_implementation='upscale_in_train')
feature = feature * norm
msg = gw.send(send_src_copy, nfeat_list=[("h", feature)])
feature = gw.recv(msg, "sum")
feature = feature * norm
# appnp
feature = feature * (1 - alpha) + h0 * alpha
feature_transed = L.fc(feature, hidden_size,
act=None, bias_attr=False,
name=name+"_%s_w1" % i)
feature = feature_transed * beta_i + feature * (1 - beta_i)
if activation is not None:
feature = getattr(L, activation)(feature)
return feature
...@@ -516,3 +516,12 @@ def graph_saint_random_walk_sample(graph, ...@@ -516,3 +516,12 @@ def graph_saint_random_walk_sample(graph,
nodes=sample_nodes, eid=eids, with_node_feat=True, with_edge_feat=True) nodes=sample_nodes, eid=eids, with_node_feat=True, with_edge_feat=True)
subgraph.node_feat["index"] = np.array(sample_nodes, dtype="int64") subgraph.node_feat["index"] = np.array(sample_nodes, dtype="int64")
return subgraph return subgraph
def edge_drop(graph_wrapper, dropout_rate, keep_self_loop=True):
if dropout_rate < 1e-5:
return graph_wrapper
else:
return pgl.graph_wrapper.DropEdgeWrapper(graph_wrapper,
dropout_rate,
keep_self_loop)
...@@ -250,3 +250,20 @@ def scatter_max(input, index, updates): ...@@ -250,3 +250,20 @@ def scatter_max(input, index, updates):
output = fluid.layers.scatter(input, index, updates, mode='max') output = fluid.layers.scatter(input, index, updates, mode='max')
return output return output
def masked_select(input, mask):
"""masked_select
Slice the value from given Mask
Args:
input: Input tensor to be selected
mask: A bool tensor for sliced.
Return:
Part of inputs where mask is True.
"""
index = fluid.layers.where(mask)
return fluid.layers.gather(input, index)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册