Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PGL
提交
d96c9759
P
PGL
项目概览
PaddlePaddle
/
PGL
通知
76
Star
4
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
11
列表
看板
标记
里程碑
合并请求
1
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PGL
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
11
Issue
11
列表
看板
标记
里程碑
合并请求
1
合并请求
1
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
d96c9759
编写于
7月 27, 2020
作者:
Y
Yelrose
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add models
上级
90ff1f7f
变更
8
显示空白变更内容
内联
并排
Showing
8 changed file
with
371 addition
and
0 deletion
+371
-0
examples/citation_benchmark/README.md
examples/citation_benchmark/README.md
+1
-0
examples/citation_benchmark/build_model.py
examples/citation_benchmark/build_model.py
+47
-0
examples/citation_benchmark/config/appnp.yaml
examples/citation_benchmark/config/appnp.yaml
+8
-0
examples/citation_benchmark/config/gat.yaml
examples/citation_benchmark/config/gat.yaml
+8
-0
examples/citation_benchmark/config/gcn.yaml
examples/citation_benchmark/config/gcn.yaml
+6
-0
examples/citation_benchmark/model.py
examples/citation_benchmark/model.py
+100
-0
examples/citation_benchmark/optimization.py
examples/citation_benchmark/optimization.py
+55
-0
examples/citation_benchmark/train.py
examples/citation_benchmark/train.py
+146
-0
未找到文件。
examples/citation_benchmark/README.md
0 → 100644
浏览文件 @
d96c9759
# Easy Paper Reproduction for Citation Network (Cora/Pubmed/Citeseer)
examples/citation_benchmark/build_model.py
0 → 100644
浏览文件 @
d96c9759
import
pgl
import
model
from
pgl
import
data_loader
import
paddle.fluid
as
fluid
import
numpy
as
np
import
time
from
optimization
import
AdamW
def
build_model
(
dataset
,
config
,
phase
,
main_prog
):
gw
=
pgl
.
graph_wrapper
.
GraphWrapper
(
name
=
"graph"
,
node_feat
=
dataset
.
graph
.
node_feat_info
())
GraphModel
=
getattr
(
model
,
config
.
model_name
)
m
=
GraphModel
(
config
=
config
,
num_class
=
dataset
.
num_classes
)
logits
=
m
.
forward
(
gw
,
gw
.
node_feat
[
"words"
])
node_index
=
fluid
.
layers
.
data
(
"node_index"
,
shape
=
[
None
,
1
],
dtype
=
"int64"
,
append_batch_size
=
False
)
node_label
=
fluid
.
layers
.
data
(
"node_label"
,
shape
=
[
None
,
1
],
dtype
=
"int64"
,
append_batch_size
=
False
)
pred
=
fluid
.
layers
.
gather
(
logits
,
node_index
)
loss
,
pred
=
fluid
.
layers
.
softmax_with_cross_entropy
(
logits
=
pred
,
label
=
node_label
,
return_softmax
=
True
)
acc
=
fluid
.
layers
.
accuracy
(
input
=
pred
,
label
=
node_label
,
k
=
1
)
loss
=
fluid
.
layers
.
mean
(
loss
)
if
phase
==
"train"
:
#adam = fluid.optimizer.Adam(
# learning_rate=config.learning_rate,
# regularization=fluid.regularizer.L2DecayRegularizer(
# regularization_coeff=config.weight_decay))
#adam.minimize(loss)
AdamW
(
loss
=
loss
,
learning_rate
=
config
.
learning_rate
,
weight_decay
=
config
.
weight_decay
,
train_program
=
main_prog
)
return
gw
,
loss
,
acc
examples/citation_benchmark/config/appnp.yaml
0 → 100644
浏览文件 @
d96c9759
model_name
:
APPNP
k_hop
:
10
alpha
:
0.1
num_layer2
:
1
learning_rate
:
0.01
dropout
:
0.5
hidden_size
:
64
weight_decay
:
0.0005
examples/citation_benchmark/config/gat.yaml
0 → 100644
浏览文件 @
d96c9759
model_name
:
GAT
learning_rate
:
0.005
weight_decay
:
0.0005
num_layers
:
1
feat_drop
:
0.6
attn_drop
:
0.6
num_heads
:
8
hidden_size
:
8
examples/citation_benchmark/config/gcn.yaml
0 → 100644
浏览文件 @
d96c9759
model_name
:
GCN
num_layers
:
1
dropout
:
0.5
hidden_size
:
64
learning_rate
:
0.01
weight_decay
:
0.0005
examples/citation_benchmark/model.py
0 → 100644
浏览文件 @
d96c9759
import
pgl
import
paddle.fluid.layers
as
L
import
pgl.layers.conv
as
conv
class
GCN
(
object
):
"""Implement of GCN
"""
def
__init__
(
self
,
config
,
num_class
):
self
.
num_class
=
num_class
self
.
num_layers
=
config
.
get
(
"num_layers"
,
1
)
self
.
hidden_size
=
config
.
get
(
"hidden_size"
,
64
)
self
.
dropout
=
config
.
get
(
"dropout"
,
0.5
)
def
forward
(
self
,
graph_wrapper
,
feature
):
for
i
in
range
(
self
.
num_layers
):
feature
=
pgl
.
layers
.
gcn
(
graph_wrapper
,
feature
,
self
.
hidden_size
,
activation
=
"relu"
,
norm
=
graph_wrapper
.
node_feat
[
"norm"
],
name
=
"layer_%s"
%
i
)
feature
=
L
.
dropout
(
feature
,
self
.
dropout
,
dropout_implementation
=
'upscale_in_train'
)
feature
=
conv
.
gcn
(
graph_wrapper
,
feature
,
self
.
num_class
,
activation
=
None
,
norm
=
graph_wrapper
.
node_feat
[
"norm"
],
name
=
"output"
)
return
feature
class
GAT
(
object
):
"""Implement of GAT"""
def
__init__
(
self
,
config
,
num_class
):
self
.
num_class
=
num_class
self
.
num_layers
=
config
.
get
(
"num_layers"
,
1
)
self
.
num_heads
=
config
.
get
(
"num_heads"
,
8
)
self
.
hidden_size
=
config
.
get
(
"hidden_size"
,
8
)
self
.
feat_dropout
=
config
.
get
(
"feat_drop"
,
0.6
)
self
.
attn_dropout
=
config
.
get
(
"attn_drop"
,
0.6
)
def
forward
(
self
,
graph_wrapper
,
feature
):
for
i
in
range
(
self
.
num_layers
):
feature
=
conv
.
gat
(
graph_wrapper
,
feature
,
self
.
hidden_size
,
activation
=
"elu"
,
name
=
"gat_layer_%s"
%
i
,
num_heads
=
self
.
num_heads
,
feat_drop
=
self
.
feat_dropout
,
attn_drop
=
self
.
attn_dropout
)
feature
=
conv
.
gat
(
graph_wrapper
,
feature
,
self
.
num_class
,
num_heads
=
1
,
activation
=
None
,
feat_drop
=
self
.
feat_dropout
,
attn_drop
=
self
.
attn_dropout
,
name
=
"output"
)
return
feature
class
APPNP
(
object
):
"""Implement of APPNP"""
def
__init__
(
self
,
config
,
num_class
):
self
.
num_class
=
num_class
self
.
num_layers
=
config
.
get
(
"num_layers"
,
1
)
self
.
hidden_size
=
config
.
get
(
"hidden_size"
,
64
)
self
.
dropout
=
config
.
get
(
"dropout"
,
0.5
)
self
.
alpha
=
config
.
get
(
"alpha"
,
0.1
)
self
.
k_hop
=
config
.
get
(
"k_hop"
,
10
)
def
forward
(
self
,
graph_wrapper
,
feature
):
for
i
in
range
(
self
.
num_layers
):
feature
=
L
.
dropout
(
feature
,
self
.
dropout
,
dropout_implementation
=
'upscale_in_train'
)
feature
=
L
.
fc
(
feature
,
self
.
hidden_size
,
act
=
"relu"
,
name
=
"lin%s"
%
i
)
feature
=
L
.
dropout
(
feature
,
self
.
dropout
,
dropout_implementation
=
'upscale_in_train'
)
feature
=
L
.
fc
(
feature
,
self
.
num_class
,
act
=
None
,
name
=
"output"
)
feature
=
conv
.
appnp
(
graph_wrapper
,
feature
=
feature
,
norm
=
graph_wrapper
.
node_feat
[
"norm"
],
alpha
=
self
.
alpha
,
k_hop
=
self
.
k_hop
)
return
feature
examples/citation_benchmark/optimization.py
0 → 100644
浏览文件 @
d96c9759
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Optimization and learning rate scheduling."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
numpy
as
np
import
paddle.fluid
as
fluid
def
AdamW
(
loss
,
learning_rate
,
train_program
,
weight_decay
):
optimizer
=
fluid
.
optimizer
.
Adam
(
learning_rate
=
learning_rate
)
def
exclude_from_weight_decay
(
name
):
if
name
.
find
(
"layer_norm"
)
>
-
1
:
return
True
bias_suffix
=
[
"_bias"
,
"_b"
,
".b_0"
]
for
suffix
in
bias_suffix
:
if
name
.
endswith
(
suffix
):
return
True
return
False
param_list
=
dict
()
for
param
in
train_program
.
global_block
().
all_parameters
():
param_list
[
param
.
name
]
=
param
*
1.0
param_list
[
param
.
name
].
stop_gradient
=
True
_
,
param_grads
=
optimizer
.
minimize
(
loss
)
if
weight_decay
>
0
:
for
param
,
grad
in
param_grads
:
if
exclude_from_weight_decay
(
param
.
name
):
continue
with
param
.
block
.
program
.
_optimized_guard
(
[
param
,
grad
]),
fluid
.
framework
.
name_scope
(
"weight_decay"
):
updated_param
=
param
-
param_list
[
param
.
name
]
*
weight_decay
*
learning_rate
fluid
.
layers
.
assign
(
output
=
param
,
input
=
updated_param
)
examples/citation_benchmark/train.py
0 → 100644
浏览文件 @
d96c9759
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
pgl
import
model
# import LabelGraphGCN
from
pgl
import
data_loader
from
pgl.utils.logger
import
log
import
paddle.fluid
as
fluid
import
numpy
as
np
import
time
import
argparse
from
build_model
import
build_model
import
yaml
from
easydict
import
EasyDict
as
edict
def
load
(
name
):
if
name
==
'cora'
:
dataset
=
data_loader
.
CoraDataset
()
elif
name
==
"pubmed"
:
dataset
=
data_loader
.
CitationDataset
(
"pubmed"
,
symmetry_edges
=
False
)
elif
name
==
"citeseer"
:
dataset
=
data_loader
.
CitationDataset
(
"citeseer"
,
symmetry_edges
=
False
)
else
:
raise
ValueError
(
name
+
" dataset doesn't exists"
)
return
dataset
def
main
(
args
,
config
):
dataset
=
load
(
args
.
dataset
)
indegree
=
dataset
.
graph
.
indegree
()
norm
=
np
.
zeros_like
(
indegree
,
dtype
=
"float32"
)
norm
[
indegree
>
0
]
=
np
.
power
(
indegree
[
indegree
>
0
],
-
0.5
)
dataset
.
graph
.
node_feat
[
"norm"
]
=
np
.
expand_dims
(
norm
,
-
1
)
place
=
fluid
.
CUDAPlace
(
0
)
if
args
.
use_cuda
else
fluid
.
CPUPlace
()
train_program
=
fluid
.
default_main_program
()
startup_program
=
fluid
.
default_startup_program
()
with
fluid
.
program_guard
(
train_program
,
startup_program
):
with
fluid
.
unique_name
.
guard
():
gw
,
loss
,
acc
=
build_model
(
dataset
,
config
=
config
,
phase
=
"train"
,
main_prog
=
train_program
)
test_program
=
fluid
.
Program
()
with
fluid
.
program_guard
(
test_program
,
startup_program
):
with
fluid
.
unique_name
.
guard
():
_gw
,
v_loss
,
v_acc
=
build_model
(
dataset
,
config
=
config
,
phase
=
"test"
,
main_prog
=
test_program
)
test_program
=
test_program
.
clone
(
for_test
=
True
)
exe
=
fluid
.
Executor
(
place
)
exe
.
run
(
startup_program
)
train_index
=
dataset
.
train_index
train_label
=
np
.
expand_dims
(
dataset
.
y
[
train_index
],
-
1
)
train_index
=
np
.
expand_dims
(
train_index
,
-
1
)
log
.
info
(
"Number of Train %s"
%
len
(
train_index
))
val_index
=
dataset
.
val_index
val_label
=
np
.
expand_dims
(
dataset
.
y
[
val_index
],
-
1
)
val_index
=
np
.
expand_dims
(
val_index
,
-
1
)
test_index
=
dataset
.
test_index
test_label
=
np
.
expand_dims
(
dataset
.
y
[
test_index
],
-
1
)
test_index
=
np
.
expand_dims
(
test_index
,
-
1
)
dur
=
[]
cal_val_acc
=
[]
cal_test_acc
=
[]
for
epoch
in
range
(
300
):
if
epoch
>=
3
:
t0
=
time
.
time
()
feed_dict
=
gw
.
to_feed
(
dataset
.
graph
)
feed_dict
[
"node_index"
]
=
np
.
array
(
train_index
,
dtype
=
"int64"
)
feed_dict
[
"node_label"
]
=
np
.
array
(
train_label
,
dtype
=
"int64"
)
train_loss
,
train_acc
=
exe
.
run
(
train_program
,
feed
=
feed_dict
,
fetch_list
=
[
loss
,
acc
],
return_numpy
=
True
)
if
epoch
>=
3
:
time_per_epoch
=
1.0
*
(
time
.
time
()
-
t0
)
dur
.
append
(
time_per_epoch
)
feed_dict
=
gw
.
to_feed
(
dataset
.
graph
)
feed_dict
[
"node_index"
]
=
np
.
array
(
val_index
,
dtype
=
"int64"
)
feed_dict
[
"node_label"
]
=
np
.
array
(
val_label
,
dtype
=
"int64"
)
val_loss
,
val_acc
=
exe
.
run
(
test_program
,
feed
=
feed_dict
,
fetch_list
=
[
v_loss
,
v_acc
],
return_numpy
=
True
)
val_loss
=
val_loss
[
0
]
val_acc
=
val_acc
[
0
]
cal_val_acc
.
append
(
val_acc
)
feed_dict
[
"node_index"
]
=
np
.
array
(
test_index
,
dtype
=
"int64"
)
feed_dict
[
"node_label"
]
=
np
.
array
(
test_label
,
dtype
=
"int64"
)
test_loss
,
test_acc
=
exe
.
run
(
test_program
,
feed
=
feed_dict
,
fetch_list
=
[
v_loss
,
v_acc
],
return_numpy
=
True
)
test_loss
=
test_loss
[
0
]
test_acc
=
test_acc
[
0
]
cal_test_acc
.
append
(
test_acc
)
if
epoch
%
10
==
0
:
log
.
info
(
"Epoch %d "
%
epoch
+
"Train Loss: %f "
%
train_loss
+
"Train Acc: %f "
%
train_acc
+
"Val Loss: %f "
%
val_loss
+
"Val Acc: %f "
%
val_acc
+
" Test Loss: %f "
%
test_loss
+
" Test Acc: %f "
%
test_acc
)
cal_val_acc
=
np
.
array
(
cal_val_acc
)
log
.
info
(
"Model: %s Best Test Accuracy: %f"
%
(
config
.
model_name
,
cal_test_acc
[
np
.
argmax
(
cal_val_acc
)]))
if
__name__
==
'__main__'
:
parser
=
argparse
.
ArgumentParser
(
description
=
'Benchmarking Citation Network'
)
parser
.
add_argument
(
"--dataset"
,
type
=
str
,
default
=
"cora"
,
help
=
"dataset (cora, pubmed)"
)
parser
.
add_argument
(
"--use_cuda"
,
action
=
'store_true'
,
help
=
"use_cuda"
)
parser
.
add_argument
(
"--conf"
,
type
=
str
,
help
=
"config file for models"
)
args
=
parser
.
parse_args
()
config
=
edict
(
yaml
.
load
(
open
(
args
.
conf
),
Loader
=
yaml
.
FullLoader
))
log
.
info
(
args
)
main
(
args
,
config
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录