Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PGL
提交
d82d02cc
P
PGL
项目概览
PaddlePaddle
/
PGL
通知
76
Star
4
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
11
列表
看板
标记
里程碑
合并请求
1
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PGL
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
11
Issue
11
列表
看板
标记
里程碑
合并请求
1
合并请求
1
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
d82d02cc
编写于
6月 16, 2020
作者:
Z
ZHUI
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add ogbn-arxiv example
上级
dd1cb348
变更
13
隐藏空白更改
内联
并排
Showing
13 changed file
with
1475 addition
and
0 deletion
+1475
-0
ogb_examples/nodeproppred/ogbn-arxiv/args.py
ogb_examples/nodeproppred/ogbn-arxiv/args.py
+50
-0
ogb_examples/nodeproppred/ogbn-arxiv/dataloader/__init__.py
ogb_examples/nodeproppred/ogbn-arxiv/dataloader/__init__.py
+13
-0
ogb_examples/nodeproppred/ogbn-arxiv/dataloader/base_dataloader.py
...les/nodeproppred/ogbn-arxiv/dataloader/base_dataloader.py
+148
-0
ogb_examples/nodeproppred/ogbn-arxiv/dataloader/ogbn_arxiv_dataloader.py
...deproppred/ogbn-arxiv/dataloader/ogbn_arxiv_dataloader.py
+169
-0
ogb_examples/nodeproppred/ogbn-arxiv/model.py
ogb_examples/nodeproppred/ogbn-arxiv/model.py
+416
-0
ogb_examples/nodeproppred/ogbn-arxiv/monitor/__init__.py
ogb_examples/nodeproppred/ogbn-arxiv/monitor/__init__.py
+14
-0
ogb_examples/nodeproppred/ogbn-arxiv/monitor/train_monitor.py
...examples/nodeproppred/ogbn-arxiv/monitor/train_monitor.py
+213
-0
ogb_examples/nodeproppred/ogbn-arxiv/run.sh
ogb_examples/nodeproppred/ogbn-arxiv/run.sh
+20
-0
ogb_examples/nodeproppred/ogbn-arxiv/train.py
ogb_examples/nodeproppred/ogbn-arxiv/train.py
+191
-0
ogb_examples/nodeproppred/ogbn-arxiv/utils/__init__.py
ogb_examples/nodeproppred/ogbn-arxiv/utils/__init__.py
+14
-0
ogb_examples/nodeproppred/ogbn-arxiv/utils/args.py
ogb_examples/nodeproppred/ogbn-arxiv/utils/args.py
+97
-0
ogb_examples/nodeproppred/ogbn-arxiv/utils/init.py
ogb_examples/nodeproppred/ogbn-arxiv/utils/init.py
+97
-0
ogb_examples/nodeproppred/ogbn-arxiv/utils/to_undirected.py
ogb_examples/nodeproppred/ogbn-arxiv/utils/to_undirected.py
+33
-0
未找到文件。
ogb_examples/nodeproppred/ogbn-arxiv/args.py
0 → 100644
浏览文件 @
d82d02cc
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""finetune args"""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
from
__future__
import
unicode_literals
from
__future__
import
absolute_import
import
os
import
time
import
argparse
from
utils.args
import
ArgumentGroup
# yapf: disable
parser
=
argparse
.
ArgumentParser
(
__doc__
)
model_g
=
ArgumentGroup
(
parser
,
"model"
,
"model configuration and paths."
)
model_g
.
add_arg
(
"init_checkpoint"
,
str
,
None
,
"Init checkpoint to resume training from."
)
model_g
.
add_arg
(
"init_pretraining_params"
,
str
,
None
,
"Init pre-training params which preforms fine-tuning from. If the "
"arg 'init_checkpoint' has been set, this argument wouldn't be valid."
)
train_g
=
ArgumentGroup
(
parser
,
"training"
,
"training options."
)
train_g
.
add_arg
(
"epoch"
,
int
,
3
,
"Number of epoches for fine-tuning."
)
train_g
.
add_arg
(
"learning_rate"
,
float
,
5e-5
,
"Learning rate used to train with warmup."
)
run_type_g
=
ArgumentGroup
(
parser
,
"run_type"
,
"running type options."
)
run_type_g
.
add_arg
(
"use_cuda"
,
bool
,
True
,
"If set, use GPU for training."
)
run_type_g
.
add_arg
(
"num_workers"
,
int
,
4
,
"use multiprocess to generate graph"
)
run_type_g
.
add_arg
(
"output_path"
,
str
,
None
,
"path to save model"
)
run_type_g
.
add_arg
(
"model"
,
str
,
None
,
"model to run"
)
run_type_g
.
add_arg
(
"hidden_size"
,
int
,
256
,
"model hidden-size"
)
run_type_g
.
add_arg
(
"drop_rate"
,
float
,
0.5
,
"Dropout rate"
)
run_type_g
.
add_arg
(
"batch_size"
,
int
,
1024
,
"batch_size"
)
run_type_g
.
add_arg
(
"full_batch"
,
bool
,
False
,
"use static graph wrapper, if full_batch is true, batch_size will take no effect."
)
run_type_g
.
add_arg
(
"samples"
,
type
=
int
,
nargs
=
'+'
,
default
=
[
30
,
30
],
help
=
"sample nums of k-hop."
)
run_type_g
.
add_arg
(
"test_batch_size"
,
int
,
512
,
help
=
"sample nums of k-hop of test phase."
)
run_type_g
.
add_arg
(
"test_samples"
,
type
=
int
,
nargs
=
'+'
,
default
=
[
30
,
30
],
help
=
"sample nums of k-hop."
)
ogb_examples/nodeproppred/ogbn-arxiv/dataloader/__init__.py
0 → 100644
浏览文件 @
d82d02cc
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
ogb_examples/nodeproppred/ogbn-arxiv/dataloader/base_dataloader.py
0 → 100644
浏览文件 @
d82d02cc
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Base DataLoader
"""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
from
__future__
import
unicode_literals
from
__future__
import
absolute_import
import
os
import
sys
import
six
from
io
import
open
from
collections
import
namedtuple
import
numpy
as
np
import
tqdm
import
paddle
from
pgl.utils
import
mp_reader
import
collections
import
time
import
pgl
if
six
.
PY3
:
import
io
sys
.
stdout
=
io
.
TextIOWrapper
(
sys
.
stdout
.
buffer
,
encoding
=
'utf-8'
)
sys
.
stderr
=
io
.
TextIOWrapper
(
sys
.
stderr
.
buffer
,
encoding
=
'utf-8'
)
def
batch_iter
(
data
,
perm
,
batch_size
,
fid
,
num_workers
):
"""node_batch_iter
"""
size
=
len
(
data
)
start
=
0
cc
=
0
while
start
<
size
:
index
=
perm
[
start
:
start
+
batch_size
]
start
+=
batch_size
cc
+=
1
if
cc
%
num_workers
!=
fid
:
continue
yield
data
[
index
]
def
scan_batch_iter
(
data
,
batch_size
,
fid
,
num_workers
):
"""node_batch_iter
"""
batch
=
[]
cc
=
0
for
line_example
in
data
.
scan
():
cc
+=
1
if
cc
%
num_workers
!=
fid
:
continue
batch
.
append
(
line_example
)
if
len
(
batch
)
==
batch_size
:
yield
batch
batch
=
[]
if
len
(
batch
)
>
0
:
yield
batch
class
BaseDataGenerator
(
object
):
"""Base Data Geneartor"""
def
__init__
(
self
,
buf_size
,
batch_size
,
num_workers
,
shuffle
=
True
):
self
.
num_workers
=
num_workers
self
.
batch_size
=
batch_size
self
.
line_examples
=
[]
self
.
buf_size
=
buf_size
self
.
shuffle
=
shuffle
def
batch_fn
(
self
,
batch_examples
):
""" batch_fn batch producer"""
raise
NotImplementedError
(
"No defined Batch Fn"
)
def
batch_iter
(
self
,
fid
,
perm
):
""" batch iterator"""
if
self
.
shuffle
:
for
batch
in
batch_iter
(
self
,
perm
,
self
.
batch_size
,
fid
,
self
.
num_workers
):
yield
batch
else
:
for
batch
in
scan_batch_iter
(
self
,
self
.
batch_size
,
fid
,
self
.
num_workers
):
yield
batch
def
__len__
(
self
):
return
len
(
self
.
line_examples
)
def
__getitem__
(
self
,
idx
):
if
isinstance
(
idx
,
collections
.
Iterable
):
return
[
self
[
bidx
]
for
bidx
in
idx
]
else
:
return
self
.
line_examples
[
idx
]
def
generator
(
self
):
"""batch dict generator"""
def
worker
(
filter_id
,
perm
):
""" multiprocess worker"""
def
func_run
():
""" func_run """
pid
=
os
.
getpid
()
np
.
random
.
seed
(
pid
+
int
(
time
.
time
()))
for
batch_examples
in
self
.
batch_iter
(
filter_id
,
perm
):
batch_dict
=
self
.
batch_fn
(
batch_examples
)
yield
batch_dict
return
func_run
# consume a seed
np
.
random
.
rand
()
if
self
.
shuffle
:
perm
=
np
.
arange
(
0
,
len
(
self
))
np
.
random
.
shuffle
(
perm
)
else
:
perm
=
None
if
self
.
num_workers
==
1
:
r
=
paddle
.
reader
.
buffered
(
worker
(
0
,
perm
),
self
.
buf_size
)
else
:
worker_pool
=
[
worker
(
wid
,
perm
)
for
wid
in
range
(
self
.
num_workers
)
]
worker
=
mp_reader
.
multiprocess_reader
(
worker_pool
,
use_pipe
=
True
,
queue_size
=
1000
)
r
=
paddle
.
reader
.
buffered
(
worker
,
self
.
buf_size
)
for
batch
in
r
():
yield
batch
def
scan
(
self
):
for
line_example
in
self
.
line_examples
:
yield
line_example
ogb_examples/nodeproppred/ogbn-arxiv/dataloader/ogbn_arxiv_dataloader.py
0 → 100644
浏览文件 @
d82d02cc
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
from
__future__
import
unicode_literals
from
__future__
import
absolute_import
from
dataloader.base_dataloader
import
BaseDataGenerator
from
utils.to_undirected
import
to_undirected
import
ssl
ssl
.
_create_default_https_context
=
ssl
.
_create_unverified_context
from
pgl.contrib.ogb.nodeproppred.dataset_pgl
import
PglNodePropPredDataset
#from pgl.sample import graph_saint_random_walk_sample
from
ogb.nodeproppred
import
Evaluator
import
tqdm
from
collections
import
namedtuple
import
pgl
import
numpy
as
np
import
copy
"""
dict_keys(['edge_index', 'edge_feat', 'node_feat', 'node_year', 'num_nodes'])
edge_index shape: (2, 1166243)
edge_index type: <class 'numpy.ndarray'>
[[104447 15858 107156 ... 45118 45118 45118]
[ 13091 47283 69161 ... 162473 162537 72717]]
edge_feat: None
node_feat shape: (169343, 128)
node_year shape: (169343, 1)
num_nodes: 169343
label shape: (169343, 1)
"""
def
traverse
(
item
):
"""traverse
"""
if
isinstance
(
item
,
list
)
or
isinstance
(
item
,
np
.
ndarray
):
for
i
in
iter
(
item
):
for
j
in
traverse
(
i
):
yield
j
else
:
yield
item
def
flat_node_and_edge
(
nodes
):
"""flat_node_and_edge
"""
nodes
=
list
(
set
(
traverse
(
nodes
)))
return
nodes
def
k_hop_sampler
(
graph
,
samples
,
batch_nodes
):
# for batch_train_samples, batch_train_labels in batch_info:
start_nodes
=
copy
.
deepcopy
(
batch_nodes
)
nodes
=
start_nodes
edges
=
[]
for
max_deg
in
samples
:
pred_nodes
=
graph
.
sample_predecessor
(
start_nodes
,
max_degree
=
max_deg
)
for
dst_node
,
src_nodes
in
zip
(
start_nodes
,
pred_nodes
):
for
src_node
in
src_nodes
:
edges
.
append
((
src_node
,
dst_node
))
last_nodes
=
nodes
nodes
=
[
nodes
,
pred_nodes
]
nodes
=
flat_node_and_edge
(
nodes
)
# Find new nodes
start_nodes
=
list
(
set
(
nodes
)
-
set
(
last_nodes
))
if
len
(
start_nodes
)
==
0
:
break
subgraph
=
graph
.
subgraph
(
nodes
=
nodes
,
edges
=
edges
,
with_node_feat
=
True
,
with_edge_feat
=
True
)
sub_node_index
=
subgraph
.
reindex_from_parrent_nodes
(
batch_nodes
)
return
subgraph
,
sub_node_index
#def graph_saint_randomwalk_sampler(graph, batch_nodes, max_depth=3):
# subgraph = graph_saint_random_walk_sample(graph, batch_nodes, max_depth)
# sub_node_index = subgraph.reindex_from_parrent_nodes(batch_nodes)
# return subgraph, sub_node_index
class
ArxivDataGenerator
(
BaseDataGenerator
):
def
__init__
(
self
,
graph_wrapper
=
None
,
buf_size
=
1000
,
batch_size
=
128
,
num_workers
=
1
,
samples
=
[
30
,
30
],
shuffle
=
True
,
phase
=
"train"
):
super
(
ArxivDataGenerator
,
self
).
__init__
(
buf_size
=
buf_size
,
num_workers
=
num_workers
,
batch_size
=
batch_size
,
shuffle
=
shuffle
)
self
.
samples
=
samples
self
.
d_name
=
"ogbn-arxiv"
self
.
graph_wrapper
=
graph_wrapper
dataset
=
PglNodePropPredDataset
(
name
=
self
.
d_name
)
splitted_idx
=
dataset
.
get_idx_split
()
self
.
phase
=
phase
graph
,
label
=
dataset
[
0
]
graph
=
to_undirected
(
graph
)
self
.
graph
=
graph
self
.
num_nodes
=
graph
.
num_nodes
if
self
.
phase
==
'train'
:
nodes_idx
=
splitted_idx
[
"train"
]
labels
=
label
[
nodes_idx
]
elif
self
.
phase
==
"valid"
:
nodes_idx
=
splitted_idx
[
"valid"
]
labels
=
label
[
nodes_idx
]
elif
self
.
phase
==
"test"
:
nodes_idx
=
splitted_idx
[
"test"
]
labels
=
label
[
nodes_idx
]
self
.
nodes_idx
=
nodes_idx
self
.
labels
=
labels
#self.static_gw_based_line_example(nodes_idx, labels)
self
.
sample_based_line_example
(
nodes_idx
,
labels
)
def
sample_based_line_example
(
self
,
nodes_idx
,
labels
):
self
.
line_examples
=
[]
Example
=
namedtuple
(
'Example'
,
[
"node"
,
"label"
])
for
node
,
label
in
zip
(
nodes_idx
,
labels
):
self
.
line_examples
.
append
(
Example
(
node
=
node
,
label
=
label
))
print
(
"Phase"
,
self
.
phase
)
print
(
"Len Examples"
,
len
(
self
.
line_examples
))
def
batch_fn2
(
self
,
batch_ex
):
feed_dict
[
"batch_nodes"
]
=
np
.
array
(
batch_ex
[
0
][
'node'
],
dtype
=
"int64"
)
feed_dict
[
"labels"
]
=
np
.
array
(
batch_ex
[
0
][
'label'
],
dtype
=
"int64"
)
return
feed_dict
def
batch_fn
(
self
,
batch_ex
):
batch_nodes
=
[]
cc
=
0
batch_node_id
=
[]
batch_labels
=
[]
for
ex
in
batch_ex
:
batch_nodes
.
append
(
ex
.
node
)
batch_labels
.
append
(
ex
.
label
)
_graph_wrapper
=
copy
.
copy
(
self
.
graph_wrapper
)
#if self.phase == "train":
# subgraph, sub_node_index = graph_saint_randomwalk_sampler(self.graph, batch_nodes)
#else:
subgraph
,
sub_node_index
=
k_hop_sampler
(
self
.
graph
,
self
.
samples
,
batch_nodes
)
feed_dict
=
_graph_wrapper
.
to_feed
(
subgraph
)
feed_dict
[
"batch_nodes"
]
=
sub_node_index
feed_dict
[
"labels"
]
=
np
.
array
(
batch_labels
,
dtype
=
"int64"
)
return
feed_dict
ogb_examples/nodeproppred/ogbn-arxiv/model.py
0 → 100644
浏览文件 @
d82d02cc
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# encoding=utf-8
"""lbs_model"""
import
os
import
re
import
time
from
random
import
random
from
functools
import
reduce
,
partial
import
numpy
as
np
import
multiprocessing
import
paddle
import
paddle.fluid
as
F
import
paddle.fluid
as
fluid
import
paddle.fluid.layers
as
L
from
pgl.graph_wrapper
import
GraphWrapper
from
pgl.layers.conv
import
gcn
,
gat
from
pgl.utils
import
paddle_helper
class
BaseGraph
(
object
):
"""Base Graph Model"""
def
__init__
(
self
,
args
,
graph_wrapper
=
None
):
self
.
hidden_size
=
args
.
hidden_size
self
.
num_nodes
=
args
.
num_nodes
self
.
drop_rate
=
args
.
drop_rate
node_feature
=
[(
'feat'
,
[
None
,
128
],
"float32"
)]
if
graph_wrapper
is
None
:
self
.
graph_wrapper
=
GraphWrapper
(
name
=
"graph"
,
place
=
F
.
CPUPlace
(),
node_feat
=
node_feature
)
else
:
self
.
graph_wrapper
=
graph_wrapper
self
.
build_model
(
args
)
def
build_model
(
self
,
args
):
""" build graph model"""
self
.
batch_nodes
=
L
.
data
(
name
=
"batch_nodes"
,
shape
=
[
-
1
],
dtype
=
"int64"
)
self
.
labels
=
L
.
data
(
name
=
"labels"
,
shape
=
[
-
1
],
dtype
=
"int64"
)
self
.
batch_nodes
=
L
.
reshape
(
self
.
batch_nodes
,
[
-
1
,
1
])
self
.
labels
=
L
.
reshape
(
self
.
labels
,
[
-
1
,
1
])
self
.
batch_nodes
.
stop_gradients
=
True
self
.
labels
.
stop_gradients
=
True
feat
=
self
.
graph_wrapper
.
node_feat
[
'feat'
]
if
self
.
graph_wrapper
is
not
None
:
feat
=
self
.
neighbor_aggregator
(
feat
)
assert
feat
is
not
None
feat
=
L
.
gather
(
feat
,
self
.
batch_nodes
)
self
.
logits
=
L
.
fc
(
feat
,
size
=
40
,
act
=
None
,
name
=
"node_predictor_logits"
)
self
.
loss
()
def
mlp
(
self
,
feat
):
for
i
in
range
(
3
):
feat
=
L
.
fc
(
node
,
size
=
self
.
hidden_size
,
name
=
"simple_mlp_{}"
.
format
(
i
))
feat
=
L
.
batch_norm
(
feat
)
feat
=
L
.
relu
(
feat
)
feat
=
L
.
dropout
(
feat
,
dropout_prob
=
0.5
)
return
feat
def
loss
(
self
):
self
.
loss
=
L
.
softmax_with_cross_entropy
(
self
.
logits
,
self
.
labels
)
self
.
loss
=
L
.
reduce_mean
(
self
.
loss
)
self
.
metrics
=
{
"loss"
:
self
.
loss
,
}
def
neighbor_aggregator
(
self
,
feature
):
"""neighbor aggregation"""
raise
NotImplementedError
(
"Please implement this method when you using graph wrapper for GNNs."
)
class
MLPModel
(
BaseGraph
):
def
__init__
(
self
,
args
,
gw
):
super
(
MLPModel
,
self
).
__init__
(
args
,
gw
)
def
neighbor_aggregator
(
self
,
feature
):
for
i
in
range
(
3
):
feature
=
L
.
fc
(
feature
,
size
=
self
.
hidden_size
,
name
=
"simple_mlp_{}"
.
format
(
i
))
#feature = L.batch_norm(feature)
feature
=
L
.
relu
(
feature
)
feature
=
L
.
dropout
(
feature
,
dropout_prob
=
self
.
drop_rate
)
return
feature
class
SAGEModel
(
BaseGraph
):
def
__init__
(
self
,
args
,
gw
):
super
(
SAGEModel
,
self
).
__init__
(
args
,
gw
)
def
neighbor_aggregator
(
self
,
feature
):
sage
=
GraphSageModel
(
40
,
3
,
256
)
feature
=
sage
.
forward
(
self
.
graph_wrapper
,
feature
,
self
.
drop_rate
)
return
feature
class
GAANModel
(
BaseGraph
):
def
__init__
(
self
,
args
,
gw
):
super
(
GAANModel
,
self
).
__init__
(
args
,
gw
)
def
neighbor_aggregator
(
self
,
feature
):
gaan
=
GaANModel
(
40
,
3
,
hidden_size_a
=
48
,
hidden_size_v
=
64
,
hidden_size_m
=
128
,
hidden_size_o
=
256
)
feature
=
gaan
.
forward
(
self
.
graph_wrapper
,
feature
,
self
.
drop_rate
)
return
feature
class
GINModel
(
BaseGraph
):
def
__init__
(
self
,
args
,
gw
):
super
(
GINModel
,
self
).
__init__
(
args
,
gw
)
def
neighbor_aggregator
(
self
,
feature
):
gin
=
GinModel
(
40
,
2
,
256
)
feature
=
gin
.
forward
(
self
.
graph_wrapper
,
feature
,
self
.
drop_rate
)
return
feature
class
GATModel
(
BaseGraph
):
def
__init__
(
self
,
args
,
gw
):
super
(
GATModel
,
self
).
__init__
(
args
,
gw
)
def
neighbor_aggregator
(
self
,
feature
):
feature
=
gat
(
self
.
graph_wrapper
,
feature
,
hidden_size
=
self
.
hidden_size
,
activation
=
'relu'
,
name
=
"GAT_1"
)
feature
=
gat
(
self
.
graph_wrapper
,
feature
,
hidden_size
=
self
.
hidden_size
,
activation
=
'relu'
,
name
=
"GAT_2"
)
return
feature
class
GCNModel
(
BaseGraph
):
def
__init__
(
self
,
args
,
gw
):
super
(
GCNModel
,
self
).
__init__
(
args
,
gw
)
def
neighbor_aggregator
(
self
,
feature
):
feature
=
gcn
(
self
.
graph_wrapper
,
feature
,
hidden_size
=
self
.
hidden_size
,
activation
=
'relu'
,
name
=
"GCN_1"
,
)
feature
=
fluid
.
layers
.
dropout
(
feature
,
dropout_prob
=
self
.
drop_rate
)
feature
=
gcn
(
self
.
graph_wrapper
,
feature
,
hidden_size
=
self
.
hidden_size
,
activation
=
'relu'
,
name
=
"GCN_2"
)
feature
=
fluid
.
layers
.
dropout
(
feature
,
dropout_prob
=
self
.
drop_rate
)
return
feature
class
GinModel
(
object
):
def
__init__
(
self
,
num_class
,
num_layers
,
hidden_size
,
act
=
'relu'
,
name
=
"GINModel"
):
self
.
num_class
=
num_class
self
.
num_layers
=
num_layers
self
.
hidden_size
=
hidden_size
self
.
act
=
act
self
.
name
=
name
def
forward
(
self
,
gw
,
feature
):
for
i
in
range
(
self
.
num_layers
):
feature
=
gin
(
gw
,
feature
,
self
.
hidden_size
,
self
.
act
,
self
.
name
+
'_'
+
str
(
i
))
feature
=
fluid
.
layers
.
layer_norm
(
feature
,
begin_norm_axis
=
1
,
param_attr
=
fluid
.
ParamAttr
(
name
=
"norm_scale_%s"
%
(
i
),
initializer
=
fluid
.
initializer
.
Constant
(
1.0
)),
bias_attr
=
fluid
.
ParamAttr
(
name
=
"norm_bias_%s"
%
(
i
),
initializer
=
fluid
.
initializer
.
Constant
(
0.0
)),
)
feature
=
fluid
.
layers
.
relu
(
feature
)
return
feature
class
GaANModel
(
object
):
def
__init__
(
self
,
num_class
,
num_layers
,
hidden_size_a
=
24
,
hidden_size_v
=
32
,
hidden_size_m
=
64
,
hidden_size_o
=
128
,
heads
=
8
,
act
=
'relu'
,
name
=
"GaAN"
):
self
.
num_class
=
num_class
self
.
num_layers
=
num_layers
self
.
hidden_size_a
=
hidden_size_a
self
.
hidden_size_v
=
hidden_size_v
self
.
hidden_size_m
=
hidden_size_m
self
.
hidden_size_o
=
hidden_size_o
self
.
act
=
act
self
.
name
=
name
self
.
heads
=
heads
def
GaANConv
(
self
,
gw
,
feature
,
name
):
feat_key
=
fluid
.
layers
.
fc
(
feature
,
self
.
hidden_size_a
*
self
.
heads
,
bias_attr
=
False
,
param_attr
=
fluid
.
ParamAttr
(
name
=
name
+
'_project_key'
))
# N * (D2 * M)
feat_value
=
fluid
.
layers
.
fc
(
feature
,
self
.
hidden_size_v
*
self
.
heads
,
bias_attr
=
False
,
param_attr
=
fluid
.
ParamAttr
(
name
=
name
+
'_project_value'
))
# N * (D1 * M)
feat_query
=
fluid
.
layers
.
fc
(
feature
,
self
.
hidden_size_a
*
self
.
heads
,
bias_attr
=
False
,
param_attr
=
fluid
.
ParamAttr
(
name
=
name
+
'_project_query'
))
# N * Dm
feat_gate
=
fluid
.
layers
.
fc
(
feature
,
self
.
hidden_size_m
,
bias_attr
=
False
,
param_attr
=
fluid
.
ParamAttr
(
name
=
name
+
'_project_gate'
))
# send
message
=
gw
.
send
(
self
.
send_func
,
nfeat_list
=
[(
'node_feat'
,
feature
),
(
'feat_key'
,
feat_key
),
(
'feat_value'
,
feat_value
),
(
'feat_query'
,
feat_query
),
(
'feat_gate'
,
feat_gate
)],
efeat_list
=
None
,
)
# recv
output
=
gw
.
recv
(
message
,
self
.
recv_func
)
output
=
fluid
.
layers
.
fc
(
output
,
self
.
hidden_size_o
,
bias_attr
=
False
,
param_attr
=
fluid
.
ParamAttr
(
name
=
name
+
'_project_output'
))
output
=
fluid
.
layers
.
leaky_relu
(
output
,
alpha
=
0.1
)
output
=
fluid
.
layers
.
dropout
(
output
,
dropout_prob
=
0.1
)
return
output
def
forward
(
self
,
gw
,
feature
,
drop_rate
):
for
i
in
range
(
self
.
num_layers
):
feature
=
self
.
GaANConv
(
gw
,
feature
,
self
.
name
+
'_'
+
str
(
i
))
feature
=
fluid
.
layers
.
dropout
(
feature
,
dropout_prob
=
drop_rate
)
return
feature
def
send_func
(
self
,
src_feat
,
dst_feat
,
edge_feat
):
# E * (M * D1)
feat_query
,
feat_key
=
dst_feat
[
'feat_query'
],
src_feat
[
'feat_key'
]
# E * M * D1
old
=
feat_query
feat_query
=
fluid
.
layers
.
reshape
(
feat_query
,
[
-
1
,
self
.
heads
,
self
.
hidden_size_a
])
feat_key
=
fluid
.
layers
.
reshape
(
feat_key
,
[
-
1
,
self
.
heads
,
self
.
hidden_size_a
])
# E * M
alpha
=
fluid
.
layers
.
reduce_sum
(
feat_key
*
feat_query
,
dim
=-
1
)
return
{
'dst_node_feat'
:
dst_feat
[
'node_feat'
],
'src_node_feat'
:
src_feat
[
'node_feat'
],
'feat_value'
:
src_feat
[
'feat_value'
],
'alpha'
:
alpha
,
'feat_gate'
:
src_feat
[
'feat_gate'
]
}
def
recv_func
(
self
,
message
):
dst_feat
=
message
[
'dst_node_feat'
]
src_feat
=
message
[
'src_node_feat'
]
x
=
fluid
.
layers
.
sequence_pool
(
dst_feat
,
'average'
)
z
=
fluid
.
layers
.
sequence_pool
(
src_feat
,
'average'
)
feat_gate
=
message
[
'feat_gate'
]
g_max
=
fluid
.
layers
.
sequence_pool
(
feat_gate
,
'max'
)
g
=
fluid
.
layers
.
concat
([
x
,
g_max
,
z
],
axis
=
1
)
g
=
fluid
.
layers
.
fc
(
g
,
self
.
heads
,
bias_attr
=
False
,
act
=
"sigmoid"
)
# softmax
alpha
=
message
[
'alpha'
]
alpha
=
paddle_helper
.
sequence_softmax
(
alpha
)
# E * M
feat_value
=
message
[
'feat_value'
]
# E * (M * D2)
old
=
feat_value
feat_value
=
fluid
.
layers
.
reshape
(
feat_value
,
[
-
1
,
self
.
heads
,
self
.
hidden_size_v
])
# E * M * D2
feat_value
=
fluid
.
layers
.
elementwise_mul
(
feat_value
,
alpha
,
axis
=
0
)
feat_value
=
fluid
.
layers
.
reshape
(
feat_value
,
[
-
1
,
self
.
heads
*
self
.
hidden_size_v
])
# E * (M * D2)
feat_value
=
fluid
.
layers
.
lod_reset
(
feat_value
,
old
)
feat_value
=
fluid
.
layers
.
sequence_pool
(
feat_value
,
'sum'
)
# N * (M * D2)
feat_value
=
fluid
.
layers
.
reshape
(
feat_value
,
[
-
1
,
self
.
heads
,
self
.
hidden_size_v
])
# N * M * D2
output
=
fluid
.
layers
.
elementwise_mul
(
feat_value
,
g
,
axis
=
0
)
output
=
fluid
.
layers
.
reshape
(
output
,
[
-
1
,
self
.
heads
*
self
.
hidden_size_v
])
# N * (M * D2)
output
=
fluid
.
layers
.
concat
([
x
,
output
],
axis
=
1
)
return
output
class
GraphSageModel
(
object
):
def
__init__
(
self
,
num_class
,
num_layers
,
hidden_size
,
act
=
'relu'
,
name
=
"GraphSage"
):
self
.
num_class
=
num_class
self
.
num_layers
=
num_layers
self
.
hidden_size
=
hidden_size
self
.
act
=
act
self
.
name
=
name
def
GraphSageConv
(
self
,
gw
,
feature
,
name
):
message
=
gw
.
send
(
self
.
send_func
,
nfeat_list
=
[(
'node_feat'
,
feature
)],
efeat_list
=
None
,
)
neighbor_feat
=
gw
.
recv
(
message
,
self
.
recv_func
)
neighbor_feat
=
fluid
.
layers
.
fc
(
neighbor_feat
,
self
.
hidden_size
,
act
=
self
.
act
,
name
=
name
+
'_n'
)
self_feature
=
fluid
.
layers
.
fc
(
feature
,
self
.
hidden_size
,
act
=
self
.
act
,
name
=
name
+
'_s'
)
output
=
self_feature
+
neighbor_feat
output
=
fluid
.
layers
.
l2_normalize
(
output
,
axis
=
1
)
return
output
def
SageConv
(
self
,
gw
,
feature
,
name
,
hidden_size
,
act
):
message
=
gw
.
send
(
self
.
send_func
,
nfeat_list
=
[(
'node_feat'
,
feature
)],
efeat_list
=
None
,
)
neighbor_feat
=
gw
.
recv
(
message
,
self
.
recv_func
)
neighbor_feat
=
fluid
.
layers
.
fc
(
neighbor_feat
,
hidden_size
,
act
=
None
,
name
=
name
+
'_n'
)
self_feature
=
fluid
.
layers
.
fc
(
feature
,
hidden_size
,
act
=
None
,
name
=
name
+
'_s'
)
output
=
self_feature
+
neighbor_feat
# output = fluid.layers.concat([self_feature, neighbor_feat], axis=1)
output
=
fluid
.
layers
.
l2_normalize
(
output
,
axis
=
1
)
if
act
is
not
None
:
ouput
=
L
.
relu
(
output
)
return
output
def
bn_drop
(
self
,
feat
,
drop_rate
):
#feat = L.batch_norm(feat)
feat
=
L
.
dropout
(
feat
,
dropout_prob
=
drop_rate
)
return
feat
def
forward
(
self
,
gw
,
feature
,
drop_rate
):
for
i
in
range
(
self
.
num_layers
):
final
=
(
i
==
(
self
.
num_layers
-
1
))
feature
=
self
.
SageConv
(
gw
,
feature
,
self
.
name
+
'_'
+
str
(
i
),
self
.
hidden_size
,
None
if
final
else
self
.
act
)
if
not
final
:
feature
=
self
.
bn_drop
(
feature
,
drop_rate
)
return
feature
def
send_func
(
self
,
src_feat
,
dst_feat
,
edge_feat
):
return
src_feat
[
"node_feat"
]
def
recv_func
(
self
,
feat
):
return
fluid
.
layers
.
sequence_pool
(
feat
,
pool_type
=
"average"
)
ogb_examples/nodeproppred/ogbn-arxiv/monitor/__init__.py
0 → 100644
浏览文件 @
d82d02cc
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""init"""
ogb_examples/nodeproppred/ogbn-arxiv/monitor/train_monitor.py
0 → 100644
浏览文件 @
d82d02cc
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""train and evaluate"""
import
tqdm
import
json
import
numpy
as
np
import
sys
import
os
import
paddle.fluid
as
F
from
tensorboardX
import
SummaryWriter
from
ogb.nodeproppred
import
Evaluator
from
ogb.nodeproppred
import
NodePropPredDataset
def
multi_device
(
reader
,
dev_count
):
"""multi device"""
if
dev_count
==
1
:
for
batch
in
reader
:
yield
batch
else
:
batches
=
[]
for
batch
in
reader
:
batches
.
append
(
batch
)
if
len
(
batches
)
==
dev_count
:
yield
batches
batches
=
[]
class
OgbEvaluator
(
object
):
def
__init__
(
self
):
d_name
=
"ogbn-arxiv"
dataset
=
NodePropPredDataset
(
name
=
d_name
)
graph
,
label
=
dataset
[
0
]
self
.
num_nodes
=
graph
[
"num_nodes"
]
self
.
ogb_evaluator
=
Evaluator
(
name
=
"ogbn-arxiv"
)
def
eval
(
self
,
scores
,
labels
,
phase
):
pred
=
(
np
.
argmax
(
scores
,
axis
=
1
)).
reshape
([
-
1
,
1
])
ret
=
{}
ret
[
'%s_acc'
%
(
phase
)]
=
self
.
ogb_evaluator
.
eval
({
'y_true'
:
labels
,
'y_pred'
:
pred
,
})[
'acc'
]
return
ret
def
evaluate
(
model
,
valid_exe
,
valid_ds
,
valid_prog
,
dev_count
,
evaluator
,
phase
,
full_batch
):
"""evaluate """
cc
=
0
scores
=
[]
labels
=
[]
if
full_batch
:
valid_iter
=
_full_batch_wapper
(
valid_ds
)
else
:
valid_iter
=
valid_ds
.
generator
for
feed_dict
in
tqdm
.
tqdm
(
multi_device
(
valid_iter
(),
dev_count
),
desc
=
'evaluating'
):
if
dev_count
>
1
:
output
=
valid_exe
.
run
(
feed
=
feed_dict
,
fetch_list
=
[
model
.
logits
,
model
.
labels
])
else
:
output
=
valid_exe
.
run
(
valid_prog
,
feed
=
feed_dict
,
fetch_list
=
[
model
.
logits
,
model
.
labels
])
scores
.
append
(
output
[
0
])
labels
.
append
(
output
[
1
])
scores
=
np
.
vstack
(
scores
)
labels
=
np
.
vstack
(
labels
)
ret
=
evaluator
.
eval
(
scores
,
labels
,
phase
)
return
ret
def
_create_if_not_exist
(
path
):
basedir
=
os
.
path
.
dirname
(
path
)
if
not
os
.
path
.
exists
(
basedir
):
os
.
makedirs
(
basedir
)
def
_full_batch_wapper
(
ds
):
feed_dict
=
{}
feed_dict
[
"batch_nodes"
]
=
np
.
array
(
ds
.
nodes_idx
,
dtype
=
"int64"
)
feed_dict
[
"labels"
]
=
np
.
array
(
ds
.
labels
,
dtype
=
"int64"
)
def
r
():
yield
feed_dict
return
r
def
train_and_evaluate
(
exe
,
train_exe
,
valid_exe
,
train_ds
,
valid_ds
,
test_ds
,
train_prog
,
valid_prog
,
full_batch
,
model
,
metric
,
epoch
=
20
,
dev_count
=
1
,
train_log_step
=
5
,
eval_step
=
10000
,
evaluator
=
None
,
output_path
=
None
):
"""train and evaluate"""
global_step
=
0
log_path
=
os
.
path
.
join
(
output_path
,
"log"
)
_create_if_not_exist
(
log_path
)
writer
=
SummaryWriter
(
log_path
)
best_model
=
0
if
full_batch
:
train_iter
=
_full_batch_wapper
(
train_ds
)
else
:
train_iter
=
train_ds
.
generator
for
e
in
range
(
epoch
):
ret_sum_loss
=
0
per_step
=
0
scores
=
[]
labels
=
[]
for
feed_dict
in
tqdm
.
tqdm
(
multi_device
(
train_iter
(),
dev_count
),
desc
=
'Epoch %s'
%
e
):
if
dev_count
>
1
:
ret
=
train_exe
.
run
(
feed
=
feed_dict
,
fetch_list
=
metric
.
vars
)
ret
=
[[
np
.
mean
(
v
)]
for
v
in
ret
]
else
:
ret
=
train_exe
.
run
(
train_prog
,
feed
=
feed_dict
,
fetch_list
=
[
model
.
loss
,
model
.
logits
,
model
.
labels
]
#fetch_list=metric.vars
)
scores
.
append
(
ret
[
1
])
labels
.
append
(
ret
[
2
])
ret
=
[
ret
[
0
]]
ret
=
metric
.
parse
(
ret
)
if
global_step
%
train_log_step
==
0
:
for
key
,
value
in
ret
.
items
():
writer
.
add_scalar
(
'train_'
+
key
,
value
,
global_step
=
global_step
)
ret_sum_loss
+=
ret
[
'loss'
]
per_step
+=
1
global_step
+=
1
if
global_step
%
eval_step
==
0
:
eval_ret
=
evaluate
(
model
,
exe
,
valid_ds
,
valid_prog
,
1
,
evaluator
,
"valid"
,
full_batch
)
test_eval_ret
=
evaluate
(
model
,
exe
,
test_ds
,
valid_prog
,
1
,
evaluator
,
"test"
,
full_batch
)
eval_ret
.
update
(
test_eval_ret
)
sys
.
stderr
.
write
(
json
.
dumps
(
eval_ret
,
indent
=
4
)
+
"
\n
"
)
for
key
,
value
in
eval_ret
.
items
():
writer
.
add_scalar
(
key
,
value
,
global_step
=
global_step
)
if
eval_ret
[
"valid_acc"
]
>
best_model
:
F
.
io
.
save_persistables
(
exe
,
os
.
path
.
join
(
output_path
,
"checkpoint"
),
train_prog
)
eval_ret
[
"epoch"
]
=
e
#eval_ret["step"] = global_step
with
open
(
os
.
path
.
join
(
output_path
,
"best.txt"
),
"w"
)
as
f
:
f
.
write
(
json
.
dumps
(
eval_ret
,
indent
=
2
)
+
'
\n
'
)
best_model
=
eval_ret
[
"valid_acc"
]
scores
=
np
.
vstack
(
scores
)
labels
=
np
.
vstack
(
labels
)
ret
=
evaluator
.
eval
(
scores
,
labels
,
"train"
)
sys
.
stderr
.
write
(
json
.
dumps
(
ret
,
indent
=
4
)
+
"
\n
"
)
#print(json.dumps(ret, indent=4) + "\n")
# Epoch End
sys
.
stderr
.
write
(
"epoch:{}, average loss {}
\n
"
.
format
(
e
,
ret_sum_loss
/
per_step
))
eval_ret
=
evaluate
(
model
,
exe
,
valid_ds
,
valid_prog
,
1
,
evaluator
,
"valid"
,
full_batch
)
test_eval_ret
=
evaluate
(
model
,
exe
,
test_ds
,
valid_prog
,
1
,
evaluator
,
"test"
,
full_batch
)
eval_ret
.
update
(
test_eval_ret
)
sys
.
stderr
.
write
(
json
.
dumps
(
eval_ret
,
indent
=
4
)
+
"
\n
"
)
for
key
,
value
in
eval_ret
.
items
():
writer
.
add_scalar
(
key
,
value
,
global_step
=
global_step
)
if
eval_ret
[
"valid_acc"
]
>
best_model
:
F
.
io
.
save_persistables
(
exe
,
os
.
path
.
join
(
output_path
,
"checkpoint"
),
train_prog
)
#eval_ret["step"] = global_step
eval_ret
[
"epoch"
]
=
e
with
open
(
os
.
path
.
join
(
output_path
,
"best.txt"
),
"w"
)
as
f
:
f
.
write
(
json
.
dumps
(
eval_ret
,
indent
=
2
)
+
'
\n
'
)
best_model
=
eval_ret
[
"valid_acc"
]
writer
.
close
()
ogb_examples/nodeproppred/ogbn-arxiv/run.sh
0 → 100644
浏览文件 @
d82d02cc
device
=
0
model
=
'gaan'
lr
=
0.001
drop
=
0.5
CUDA_VISIBLE_DEVICES
=
${
device
}
\
python
-u
train.py
\
--use_cuda
1
\
--num_workers
4
\
--output_path
./output/model
\
--batch_size
1024
\
--test_batch_size
512
\
--epoch
100
\
--learning_rate
${
lr
}
\
--full_batch
0
\
--model
${
model
}
\
--drop_rate
${
drop
}
\
--samples
8 8 8
\
--test_samples
20 20 20
\
--hidden_size
256
ogb_examples/nodeproppred/ogbn-arxiv/train.py
0 → 100644
浏览文件 @
d82d02cc
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""listwise model
"""
import
torch
import
os
import
re
import
time
import
logging
from
random
import
random
from
functools
import
reduce
,
partial
# For downloading ogb
import
ssl
ssl
.
_create_default_https_context
=
ssl
.
_create_unverified_context
# SSL
import
numpy
as
np
import
multiprocessing
import
pgl
import
paddle
import
paddle.fluid
as
F
import
paddle.fluid.layers
as
L
from
args
import
parser
from
utils.args
import
print_arguments
,
check_cuda
from
utils.init
import
init_checkpoint
,
init_pretraining_params
from
utils.to_undirected
import
to_undirected
from
model
import
BaseGraph
,
MLPModel
,
SAGEModel
,
GAANModel
,
GATModel
,
GCNModel
,
GINModel
from
dataloader.ogbn_arxiv_dataloader
import
ArxivDataGenerator
from
monitor.train_monitor
import
train_and_evaluate
,
OgbEvaluator
from
pgl.contrib.ogb.nodeproppred.dataset_pgl
import
PglNodePropPredDataset
log
=
logging
.
getLogger
(
__name__
)
class
Metric
(
object
):
"""Metric"""
def
__init__
(
self
,
**
args
):
self
.
args
=
args
@
property
def
vars
(
self
):
""" fetch metric vars"""
values
=
[
self
.
args
[
k
]
for
k
in
self
.
args
.
keys
()]
return
values
def
parse
(
self
,
fetch_list
):
"""parse"""
tup
=
list
(
zip
(
self
.
args
.
keys
(),
[
float
(
v
[
0
])
for
v
in
fetch_list
]))
return
dict
(
tup
)
if
__name__
==
'__main__'
:
args
=
parser
.
parse_args
()
print_arguments
(
args
)
evaluator
=
OgbEvaluator
()
train_prog
=
F
.
Program
()
startup_prog
=
F
.
Program
()
args
.
num_nodes
=
evaluator
.
num_nodes
if
args
.
use_cuda
:
dev_list
=
F
.
cuda_places
()
place
=
dev_list
[
0
]
dev_count
=
len
(
dev_list
)
else
:
place
=
F
.
CPUPlace
()
dev_count
=
int
(
os
.
environ
.
get
(
'CPU_NUM'
,
multiprocessing
.
cpu_count
()))
assert
dev_count
==
1
,
"The program not support multi devices now!"
dataset
=
PglNodePropPredDataset
(
name
=
"ogbn-arxiv"
)
graph
,
label
=
dataset
[
0
]
graph
=
to_undirected
(
graph
)
if
args
.
model
is
None
:
Model
=
BaseGraph
elif
args
.
model
.
upper
()
==
"MLP"
:
Model
=
MLPModel
elif
args
.
model
.
upper
()
==
"SAGE"
:
Model
=
SAGEModel
elif
args
.
model
.
upper
()
==
"GAT"
:
Model
=
GATModel
elif
args
.
model
.
upper
()
==
"GCN"
:
Model
=
GCNModel
elif
args
.
model
.
upper
()
==
"GAAN"
:
Model
=
GAANModel
elif
args
.
model
.
upper
()
==
"GIN"
:
Model
=
GINModel
else
:
raise
ValueError
(
"Not support {} model!"
.
format
(
args
.
model
))
with
F
.
program_guard
(
train_prog
,
startup_prog
):
with
F
.
unique_name
.
guard
():
if
args
.
full_batch
:
gw
=
pgl
.
graph_wrapper
.
StaticGraphWrapper
(
name
=
"graph"
,
graph
=
graph
,
place
=
place
)
else
:
gw
=
pgl
.
graph_wrapper
.
GraphWrapper
(
name
=
"graph"
,
node_feat
=
graph
.
node_feat_info
(),
edge_feat
=
graph
.
edge_feat_info
())
log
.
info
(
gw
.
node_feat
.
keys
())
graph_model
=
Model
(
args
,
gw
)
test_prog
=
train_prog
.
clone
(
for_test
=
True
)
opt
=
F
.
optimizer
.
Adam
(
learning_rate
=
args
.
learning_rate
)
opt
.
minimize
(
graph_model
.
loss
)
train_ds
=
ArxivDataGenerator
(
phase
=
"train"
,
graph_wrapper
=
graph_model
.
graph_wrapper
,
num_workers
=
args
.
num_workers
,
batch_size
=
args
.
batch_size
,
samples
=
args
.
samples
)
valid_ds
=
ArxivDataGenerator
(
phase
=
"valid"
,
graph_wrapper
=
graph_model
.
graph_wrapper
,
num_workers
=
args
.
num_workers
,
batch_size
=
args
.
test_batch_size
,
samples
=
args
.
test_samples
)
test_ds
=
ArxivDataGenerator
(
phase
=
"test"
,
graph_wrapper
=
graph_model
.
graph_wrapper
,
num_workers
=
args
.
num_workers
,
batch_size
=
args
.
test_batch_size
,
samples
=
args
.
test_samples
)
exe
=
F
.
Executor
(
place
)
exe
.
run
(
startup_prog
)
if
args
.
full_batch
:
gw
.
initialize
(
place
)
if
args
.
init_pretraining_params
is
not
None
:
init_pretraining_params
(
exe
,
args
.
init_pretraining_params
,
main_program
=
startup_prog
)
metric
=
Metric
(
**
graph_model
.
metrics
)
nccl2_num_trainers
=
1
nccl2_trainer_id
=
0
if
dev_count
>
1
:
exec_strategy
=
F
.
ExecutionStrategy
()
exec_strategy
.
num_threads
=
dev_count
train_exe
=
F
.
ParallelExecutor
(
use_cuda
=
args
.
use_cuda
,
loss_name
=
graph_model
.
loss
.
name
,
exec_strategy
=
exec_strategy
,
main_program
=
train_prog
,
num_trainers
=
nccl2_num_trainers
,
trainer_id
=
nccl2_trainer_id
)
test_exe
=
exe
else
:
train_exe
,
test_exe
=
exe
,
exe
train_and_evaluate
(
exe
=
exe
,
train_exe
=
train_exe
,
valid_exe
=
test_exe
,
train_ds
=
train_ds
,
valid_ds
=
valid_ds
,
test_ds
=
test_ds
,
train_prog
=
train_prog
,
valid_prog
=
test_prog
,
full_batch
=
args
.
full_batch
,
train_log_step
=
5
,
output_path
=
args
.
output_path
,
dev_count
=
dev_count
,
model
=
graph_model
,
epoch
=
args
.
epoch
,
eval_step
=
1000000
,
evaluator
=
evaluator
,
metric
=
metric
)
ogb_examples/nodeproppred/ogbn-arxiv/utils/__init__.py
0 → 100644
浏览文件 @
d82d02cc
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""utils"""
ogb_examples/nodeproppred/ogbn-arxiv/utils/args.py
0 → 100644
浏览文件 @
d82d02cc
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Arguments for configuration."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
from
__future__
import
unicode_literals
from
__future__
import
absolute_import
import
six
import
os
import
sys
import
argparse
import
logging
import
paddle.fluid
as
fluid
log
=
logging
.
getLogger
(
__name__
)
def
prepare_logger
(
logger
,
debug
=
False
,
save_to_file
=
None
):
"""doc"""
formatter
=
logging
.
Formatter
(
fmt
=
'[%(levelname)s] %(asctime)s [%(filename)12s:%(lineno)5d]:
\t
%(message)s'
)
#console_hdl = logging.StreamHandler()
#console_hdl.setFormatter(formatter)
#logger.addHandler(console_hdl)
if
save_to_file
is
not
None
and
not
os
.
path
.
exists
(
save_to_file
):
file_hdl
=
logging
.
FileHandler
(
save_to_file
)
file_hdl
.
setFormatter
(
formatter
)
logger
.
addHandler
(
file_hdl
)
logger
.
setLevel
(
logging
.
DEBUG
)
logger
.
propagate
=
False
def
str2bool
(
v
):
"""doc"""
# because argparse does not support to parse "true, False" as python
# boolean directly
return
v
.
lower
()
in
(
"true"
,
"t"
,
"1"
)
class
ArgumentGroup
(
object
):
"""doc"""
def
__init__
(
self
,
parser
,
title
,
des
):
self
.
_group
=
parser
.
add_argument_group
(
title
=
title
,
description
=
des
)
def
add_arg
(
self
,
name
,
type
,
default
,
help
,
positional_arg
=
False
,
**
kwargs
):
"""doc"""
prefix
=
""
if
positional_arg
else
"--"
type
=
str2bool
if
type
==
bool
else
type
self
.
_group
.
add_argument
(
prefix
+
name
,
default
=
default
,
type
=
type
,
help
=
help
+
' Default: %(default)s.'
,
**
kwargs
)
def
print_arguments
(
args
):
"""doc"""
log
.
info
(
'----------- Configuration Arguments -----------'
)
for
arg
,
value
in
sorted
(
six
.
iteritems
(
vars
(
args
))):
log
.
info
(
'%s: %s'
%
(
arg
,
value
))
log
.
info
(
'------------------------------------------------'
)
def
check_cuda
(
use_cuda
,
err
=
\
"
\n
You can not set use_cuda=True in the model because you are using paddlepaddle-cpu.
\n
\
Please: 1. Install paddlepaddle-gpu to run your models on GPU or 2. Set use_cuda=False to run models on CPU.
\n
"
):
"""doc"""
try
:
if
use_cuda
==
True
and
fluid
.
is_compiled_with_cuda
()
==
False
:
log
.
error
(
err
)
sys
.
exit
(
1
)
except
Exception
as
e
:
pass
ogb_examples/nodeproppred/ogbn-arxiv/utils/init.py
0 → 100644
浏览文件 @
d82d02cc
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""paddle init"""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
from
__future__
import
unicode_literals
from
__future__
import
absolute_import
import
os
import
six
import
ast
import
copy
import
logging
import
numpy
as
np
import
paddle.fluid
as
fluid
log
=
logging
.
getLogger
(
__name__
)
def
cast_fp32_to_fp16
(
exe
,
main_program
):
"""doc"""
log
.
info
(
"Cast parameters to float16 data format."
)
for
param
in
main_program
.
global_block
().
all_parameters
():
if
not
param
.
name
.
endswith
(
".master"
):
param_t
=
fluid
.
global_scope
().
find_var
(
param
.
name
).
get_tensor
()
data
=
np
.
array
(
param_t
)
if
param
.
name
.
startswith
(
"encoder_layer"
)
\
and
"layer_norm"
not
in
param
.
name
:
param_t
.
set
(
np
.
float16
(
data
).
view
(
np
.
uint16
),
exe
.
place
)
#load fp32
master_param_var
=
fluid
.
global_scope
().
find_var
(
param
.
name
+
".master"
)
if
master_param_var
is
not
None
:
master_param_var
.
get_tensor
().
set
(
data
,
exe
.
place
)
def
init_checkpoint
(
exe
,
init_checkpoint_path
,
main_program
,
use_fp16
=
False
):
"""init"""
assert
os
.
path
.
exists
(
init_checkpoint_path
),
"[%s] cann't be found."
%
init_checkpoint_path
def
existed_persitables
(
var
):
"""existed"""
if
not
fluid
.
io
.
is_persistable
(
var
):
return
False
return
os
.
path
.
exists
(
os
.
path
.
join
(
init_checkpoint_path
,
var
.
name
))
fluid
.
io
.
load_vars
(
exe
,
init_checkpoint_path
,
main_program
=
main_program
,
predicate
=
existed_persitables
)
log
.
info
(
"Load model from {}"
.
format
(
init_checkpoint_path
))
if
use_fp16
:
cast_fp32_to_fp16
(
exe
,
main_program
)
def
init_pretraining_params
(
exe
,
pretraining_params_path
,
main_program
,
use_fp16
=
False
):
"""init"""
assert
os
.
path
.
exists
(
pretraining_params_path
),
"[%s] cann't be found."
%
pretraining_params_path
def
existed_params
(
var
):
"""doc"""
if
not
isinstance
(
var
,
fluid
.
framework
.
Parameter
):
return
False
return
os
.
path
.
exists
(
os
.
path
.
join
(
pretraining_params_path
,
var
.
name
))
fluid
.
io
.
load_vars
(
exe
,
pretraining_params_path
,
main_program
=
main_program
,
predicate
=
existed_params
)
log
.
info
(
"Load pretraining parameters from {}."
.
format
(
pretraining_params_path
))
if
use_fp16
:
cast_fp32_to_fp16
(
exe
,
main_program
)
ogb_examples/nodeproppred/ogbn-arxiv/utils/to_undirected.py
0 → 100644
浏览文件 @
d82d02cc
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Arguments for configuration."""
from
__future__
import
absolute_import
from
__future__
import
unicode_literals
import
paddle.fluid
as
fluid
import
pgl
import
numpy
as
np
def
to_undirected
(
graph
):
inv_edges
=
np
.
zeros
(
graph
.
edges
.
shape
)
inv_edges
[:,
0
]
=
graph
.
edges
[:,
1
]
inv_edges
[:,
1
]
=
graph
.
edges
[:,
0
]
edges
=
np
.
vstack
((
graph
.
edges
,
inv_edges
))
g
=
pgl
.
graph
.
Graph
(
num_nodes
=
graph
.
num_nodes
,
edges
=
edges
)
for
k
,
v
in
graph
.
_edge_feat
.
items
():
g
.
_edge_feat
[
k
]
=
np
.
vstack
((
v
,
v
))
for
k
,
v
in
graph
.
_node_feat
.
items
():
g
.
_node_feat
[
k
]
=
v
return
g
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录