Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PGL
提交
cd30e61c
P
PGL
项目概览
PaddlePaddle
/
PGL
通知
76
Star
4
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
11
列表
看板
标记
里程碑
合并请求
1
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PGL
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
11
Issue
11
列表
看板
标记
里程碑
合并请求
1
合并请求
1
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
cd30e61c
编写于
4月 29, 2020
作者:
Y
yelrose
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add linkprediction for ogb
上级
6c4a0850
变更
14
隐藏空白更改
内联
并排
Showing
14 changed file
with
1229 addition
and
272 deletion
+1229
-272
ogb_examples/linkproppred/main_pgl.py
ogb_examples/linkproppred/main_pgl.py
+0
-272
ogb_examples/linkproppred/ogbl-ppa/args.py
ogb_examples/linkproppred/ogbl-ppa/args.py
+44
-0
ogb_examples/linkproppred/ogbl-ppa/dataloader/__init__.py
ogb_examples/linkproppred/ogbl-ppa/dataloader/__init__.py
+13
-0
ogb_examples/linkproppred/ogbl-ppa/dataloader/base_dataloader.py
...mples/linkproppred/ogbl-ppa/dataloader/base_dataloader.py
+148
-0
ogb_examples/linkproppred/ogbl-ppa/dataloader/ogbl_ppa_dataloader.py
...s/linkproppred/ogbl-ppa/dataloader/ogbl_ppa_dataloader.py
+118
-0
ogb_examples/linkproppred/ogbl-ppa/model.py
ogb_examples/linkproppred/ogbl-ppa/model.py
+110
-0
ogb_examples/linkproppred/ogbl-ppa/monitor/__init__.py
ogb_examples/linkproppred/ogbl-ppa/monitor/__init__.py
+14
-0
ogb_examples/linkproppred/ogbl-ppa/monitor/train_monitor.py
ogb_examples/linkproppred/ogbl-ppa/monitor/train_monitor.py
+185
-0
ogb_examples/linkproppred/ogbl-ppa/train.py
ogb_examples/linkproppred/ogbl-ppa/train.py
+157
-0
ogb_examples/linkproppred/ogbl-ppa/utils/__init__.py
ogb_examples/linkproppred/ogbl-ppa/utils/__init__.py
+14
-0
ogb_examples/linkproppred/ogbl-ppa/utils/args.py
ogb_examples/linkproppred/ogbl-ppa/utils/args.py
+97
-0
ogb_examples/linkproppred/ogbl-ppa/utils/cards.py
ogb_examples/linkproppred/ogbl-ppa/utils/cards.py
+31
-0
ogb_examples/linkproppred/ogbl-ppa/utils/fp16.py
ogb_examples/linkproppred/ogbl-ppa/utils/fp16.py
+201
-0
ogb_examples/linkproppred/ogbl-ppa/utils/init.py
ogb_examples/linkproppred/ogbl-ppa/utils/init.py
+97
-0
未找到文件。
ogb_examples/linkproppred/main_pgl.py
已删除
100644 → 0
浏览文件 @
6c4a0850
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""test ogb
"""
import
argparse
import
time
import
logging
import
numpy
as
np
import
paddle.fluid
as
fluid
import
pgl
from
pgl.contrib.ogb.linkproppred.dataset_pgl
import
PglLinkPropPredDataset
from
pgl.utils
import
paddle_helper
from
ogb.linkproppred
import
Evaluator
def
send_func
(
src_feat
,
dst_feat
,
edge_feat
):
"""send_func"""
return
src_feat
[
"h"
]
def
recv_func
(
feat
):
"""recv_func"""
return
fluid
.
layers
.
sequence_pool
(
feat
,
pool_type
=
"sum"
)
class
GNNModel
(
object
):
"""GNNModel"""
def
__init__
(
self
,
name
,
num_nodes
,
emb_dim
,
num_layers
):
self
.
num_nodes
=
num_nodes
self
.
emb_dim
=
emb_dim
self
.
num_layers
=
num_layers
self
.
name
=
name
self
.
src_nodes
=
fluid
.
layers
.
data
(
name
=
'src_nodes'
,
shape
=
[
None
],
dtype
=
'int64'
,
)
self
.
dst_nodes
=
fluid
.
layers
.
data
(
name
=
'dst_nodes'
,
shape
=
[
None
],
dtype
=
'int64'
,
)
self
.
edge_label
=
fluid
.
layers
.
data
(
name
=
'edge_label'
,
shape
=
[
None
,
1
],
dtype
=
'float32'
,
)
def
forward
(
self
,
graph
):
"""forward"""
h
=
fluid
.
layers
.
create_parameter
(
shape
=
[
self
.
num_nodes
,
self
.
emb_dim
],
dtype
=
"float32"
,
name
=
self
.
name
+
"_embedding"
)
for
layer
in
range
(
self
.
num_layers
):
msg
=
graph
.
send
(
send_func
,
nfeat_list
=
[(
"h"
,
h
)],
)
h
=
graph
.
recv
(
msg
,
recv_func
)
h
=
fluid
.
layers
.
fc
(
h
,
size
=
self
.
emb_dim
,
bias_attr
=
False
,
param_attr
=
fluid
.
ParamAttr
(
name
=
self
.
name
+
'_%s'
%
layer
))
h
=
h
*
graph
.
node_feat
[
"norm"
]
bias
=
fluid
.
layers
.
create_parameter
(
shape
=
[
self
.
emb_dim
],
dtype
=
'float32'
,
is_bias
=
True
,
name
=
self
.
name
+
'_bias_%s'
%
layer
)
h
=
fluid
.
layers
.
elementwise_add
(
h
,
bias
,
act
=
"relu"
)
src
=
fluid
.
layers
.
gather
(
h
,
self
.
src_nodes
,
overwrite
=
False
)
dst
=
fluid
.
layers
.
gather
(
h
,
self
.
dst_nodes
,
overwrite
=
False
)
edge_embed
=
src
*
dst
pred
=
fluid
.
layers
.
fc
(
input
=
edge_embed
,
size
=
1
,
name
=
self
.
name
+
"_pred_output"
)
prob
=
fluid
.
layers
.
sigmoid
(
pred
)
loss
=
fluid
.
layers
.
sigmoid_cross_entropy_with_logits
(
pred
,
self
.
edge_label
)
loss
=
fluid
.
layers
.
reduce_mean
(
loss
)
return
pred
,
prob
,
loss
def
main
():
"""main
"""
# Training settings
parser
=
argparse
.
ArgumentParser
(
description
=
'Graph Dataset'
)
parser
.
add_argument
(
'--epochs'
,
type
=
int
,
default
=
4
,
help
=
'number of epochs to train (default: 100)'
)
parser
.
add_argument
(
'--dataset'
,
type
=
str
,
default
=
"ogbl-ppa"
,
help
=
'dataset name (default: protein protein associations)'
)
parser
.
add_argument
(
'--use_cuda'
,
action
=
'store_true'
)
parser
.
add_argument
(
'--batch_size'
,
type
=
int
,
default
=
5120
)
parser
.
add_argument
(
'--embed_dim'
,
type
=
int
,
default
=
64
)
parser
.
add_argument
(
'--num_layers'
,
type
=
int
,
default
=
2
)
parser
.
add_argument
(
'--lr'
,
type
=
float
,
default
=
0.001
)
args
=
parser
.
parse_args
()
print
(
args
)
place
=
fluid
.
CUDAPlace
(
0
)
if
args
.
use_cuda
else
fluid
.
CPUPlace
()
### automatic dataloading and splitting
print
(
"loadding dataset"
)
dataset
=
PglLinkPropPredDataset
(
name
=
args
.
dataset
)
splitted_edge
=
dataset
.
get_edge_split
()
print
(
splitted_edge
[
'train_edge'
].
shape
)
print
(
splitted_edge
[
'train_edge_label'
].
shape
)
print
(
"building evaluator"
)
### automatic evaluator. takes dataset name as input
evaluator
=
Evaluator
(
args
.
dataset
)
graph_data
=
dataset
[
0
]
print
(
"num_nodes: %d"
%
graph_data
.
num_nodes
)
train_program
=
fluid
.
Program
()
startup_program
=
fluid
.
Program
()
# degree normalize
indegree
=
graph_data
.
indegree
()
norm
=
np
.
zeros_like
(
indegree
,
dtype
=
"float32"
)
norm
[
indegree
>
0
]
=
np
.
power
(
indegree
[
indegree
>
0
],
-
0.5
)
graph_data
.
node_feat
[
"norm"
]
=
np
.
expand_dims
(
norm
,
-
1
).
astype
(
"float32"
)
# graph_data.node_feat["index"] = np.array([i for i in range(graph_data.num_nodes)], dtype=np.int64).reshape(-1,1)
with
fluid
.
program_guard
(
train_program
,
startup_program
):
model
=
GNNModel
(
name
=
"gnn"
,
num_nodes
=
graph_data
.
num_nodes
,
emb_dim
=
args
.
embed_dim
,
num_layers
=
args
.
num_layers
)
gw
=
pgl
.
graph_wrapper
.
GraphWrapper
(
"graph"
,
place
,
node_feat
=
graph_data
.
node_feat_info
(),
edge_feat
=
graph_data
.
edge_feat_info
())
pred
,
prob
,
loss
=
model
.
forward
(
gw
)
val_program
=
train_program
.
clone
(
for_test
=
True
)
with
fluid
.
program_guard
(
train_program
,
startup_program
):
global_steps
=
int
(
splitted_edge
[
'train_edge'
].
shape
[
0
]
/
args
.
batch_size
*
2
)
learning_rate
=
fluid
.
layers
.
polynomial_decay
(
args
.
lr
,
global_steps
,
0.00005
)
adam
=
fluid
.
optimizer
.
Adam
(
learning_rate
=
learning_rate
,
regularization
=
fluid
.
regularizer
.
L2DecayRegularizer
(
regularization_coeff
=
0.0005
))
adam
.
minimize
(
loss
)
exe
=
fluid
.
Executor
(
place
)
exe
.
run
(
startup_program
)
feed
=
gw
.
to_feed
(
graph_data
)
print
(
"evaluate result before training: "
)
result
=
test
(
exe
,
val_program
,
prob
,
evaluator
,
feed
,
splitted_edge
)
print
(
result
)
print
(
"training"
)
cc
=
0
for
epoch
in
range
(
1
,
args
.
epochs
+
1
):
for
batch_data
,
batch_label
in
data_generator
(
graph_data
,
splitted_edge
[
"train_edge"
],
splitted_edge
[
"train_edge_label"
],
batch_size
=
args
.
batch_size
):
feed
[
'src_nodes'
]
=
batch_data
[:,
0
].
reshape
(
-
1
,
1
)
feed
[
'dst_nodes'
]
=
batch_data
[:,
1
].
reshape
(
-
1
,
1
)
feed
[
'edge_label'
]
=
batch_label
.
astype
(
"float32"
)
res_loss
,
y_pred
,
b_lr
=
exe
.
run
(
train_program
,
feed
=
feed
,
fetch_list
=
[
loss
,
prob
,
learning_rate
])
if
cc
%
1
==
0
:
print
(
"epoch %d | step %d | lr %s | Loss %s"
%
(
epoch
,
cc
,
b_lr
[
0
],
res_loss
[
0
]))
cc
+=
1
if
cc
%
20
==
0
:
print
(
"Evaluating..."
)
result
=
test
(
exe
,
val_program
,
prob
,
evaluator
,
feed
,
splitted_edge
)
print
(
"epoch %d | step %d"
%
(
epoch
,
cc
))
print
(
result
)
def
test
(
exe
,
val_program
,
prob
,
evaluator
,
feed
,
splitted_edge
):
"""Evaluation"""
result
=
{}
feed
[
'src_nodes'
]
=
splitted_edge
[
"valid_edge"
][:,
0
].
reshape
(
-
1
,
1
)
feed
[
'dst_nodes'
]
=
splitted_edge
[
"valid_edge"
][:,
1
].
reshape
(
-
1
,
1
)
feed
[
'edge_label'
]
=
splitted_edge
[
"valid_edge_label"
].
astype
(
"float32"
).
reshape
(
-
1
,
1
)
y_pred
=
exe
.
run
(
val_program
,
feed
=
feed
,
fetch_list
=
[
prob
])[
0
]
input_dict
=
{
"y_true"
:
splitted_edge
[
"valid_edge_label"
],
"y_pred"
:
y_pred
.
reshape
(
-
1
,
),
}
result
[
"valid"
]
=
evaluator
.
eval
(
input_dict
)
feed
[
'src_nodes'
]
=
splitted_edge
[
"test_edge"
][:,
0
].
reshape
(
-
1
,
1
)
feed
[
'dst_nodes'
]
=
splitted_edge
[
"test_edge"
][:,
1
].
reshape
(
-
1
,
1
)
feed
[
'edge_label'
]
=
splitted_edge
[
"test_edge_label"
].
astype
(
"float32"
).
reshape
(
-
1
,
1
)
y_pred
=
exe
.
run
(
val_program
,
feed
=
feed
,
fetch_list
=
[
prob
])[
0
]
input_dict
=
{
"y_true"
:
splitted_edge
[
"test_edge_label"
],
"y_pred"
:
y_pred
.
reshape
(
-
1
,
),
}
result
[
"test"
]
=
evaluator
.
eval
(
input_dict
)
return
result
def
data_generator
(
graph
,
data
,
label_data
,
batch_size
,
shuffle
=
True
):
"""Data Generator"""
perm
=
np
.
arange
(
0
,
len
(
data
))
if
shuffle
:
np
.
random
.
shuffle
(
perm
)
offset
=
0
while
offset
<
len
(
perm
):
batch_index
=
perm
[
offset
:(
offset
+
batch_size
)]
offset
+=
batch_size
pos_data
=
data
[
batch_index
]
pos_label
=
label_data
[
batch_index
]
neg_src_node
=
pos_data
[:,
0
]
neg_dst_node
=
np
.
random
.
choice
(
pos_data
.
reshape
(
-
1
,
),
size
=
len
(
neg_src_node
))
neg_data
=
np
.
hstack
(
[
neg_src_node
.
reshape
(
-
1
,
1
),
neg_dst_node
.
reshape
(
-
1
,
1
)])
exists
=
graph
.
has_edges_between
(
neg_src_node
,
neg_dst_node
)
neg_data
=
neg_data
[
np
.
invert
(
exists
)]
neg_label
=
np
.
zeros
(
shape
=
len
(
neg_data
),
dtype
=
np
.
int64
)
batch_data
=
np
.
vstack
([
pos_data
,
neg_data
])
label
=
np
.
vstack
([
pos_label
.
reshape
(
-
1
,
1
),
neg_label
.
reshape
(
-
1
,
1
)])
yield
batch_data
,
label
if
__name__
==
"__main__"
:
main
()
ogb_examples/linkproppred/ogbl-ppa/args.py
0 → 100644
浏览文件 @
cd30e61c
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""finetune args"""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
from
__future__
import
unicode_literals
from
__future__
import
absolute_import
import
os
import
time
import
argparse
from
utils.args
import
ArgumentGroup
# yapf: disable
parser
=
argparse
.
ArgumentParser
(
__doc__
)
model_g
=
ArgumentGroup
(
parser
,
"model"
,
"model configuration and paths."
)
model_g
.
add_arg
(
"init_checkpoint"
,
str
,
None
,
"Init checkpoint to resume training from."
)
model_g
.
add_arg
(
"init_pretraining_params"
,
str
,
None
,
"Init pre-training params which preforms fine-tuning from. If the "
"arg 'init_checkpoint' has been set, this argument wouldn't be valid."
)
train_g
=
ArgumentGroup
(
parser
,
"training"
,
"training options."
)
train_g
.
add_arg
(
"epoch"
,
int
,
3
,
"Number of epoches for fine-tuning."
)
train_g
.
add_arg
(
"learning_rate"
,
float
,
5e-5
,
"Learning rate used to train with warmup."
)
run_type_g
=
ArgumentGroup
(
parser
,
"run_type"
,
"running type options."
)
run_type_g
.
add_arg
(
"use_cuda"
,
bool
,
True
,
"If set, use GPU for training."
)
run_type_g
.
add_arg
(
"num_workers"
,
int
,
1
,
"use multiprocess to generate graph"
)
run_type_g
.
add_arg
(
"output_path"
,
str
,
None
,
"path to save model"
)
run_type_g
.
add_arg
(
"hidden_size"
,
int
,
128
,
"model hidden-size"
)
run_type_g
.
add_arg
(
"batch_size"
,
int
,
128
,
"batch_size"
)
ogb_examples/linkproppred/ogbl-ppa/dataloader/__init__.py
0 → 100644
浏览文件 @
cd30e61c
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
ogb_examples/linkproppred/ogbl-ppa/dataloader/base_dataloader.py
0 → 100644
浏览文件 @
cd30e61c
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Base DataLoader
"""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
from
__future__
import
unicode_literals
from
__future__
import
absolute_import
import
os
import
sys
import
six
from
io
import
open
from
collections
import
namedtuple
import
numpy
as
np
import
tqdm
import
paddle
from
pgl.utils
import
mp_reader
import
collections
import
time
import
pgl
if
six
.
PY3
:
import
io
sys
.
stdout
=
io
.
TextIOWrapper
(
sys
.
stdout
.
buffer
,
encoding
=
'utf-8'
)
sys
.
stderr
=
io
.
TextIOWrapper
(
sys
.
stderr
.
buffer
,
encoding
=
'utf-8'
)
def
batch_iter
(
data
,
perm
,
batch_size
,
fid
,
num_workers
):
"""node_batch_iter
"""
size
=
len
(
data
)
start
=
0
cc
=
0
while
start
<
size
:
index
=
perm
[
start
:
start
+
batch_size
]
start
+=
batch_size
cc
+=
1
if
cc
%
num_workers
!=
fid
:
continue
yield
data
[
index
]
def
scan_batch_iter
(
data
,
batch_size
,
fid
,
num_workers
):
"""node_batch_iter
"""
batch
=
[]
cc
=
0
for
line_example
in
data
.
scan
():
cc
+=
1
if
cc
%
num_workers
!=
fid
:
continue
batch
.
append
(
line_example
)
if
len
(
batch
)
==
batch_size
:
yield
batch
batch
=
[]
if
len
(
batch
)
>
0
:
yield
batch
class
BaseDataGenerator
(
object
):
"""Base Data Geneartor"""
def
__init__
(
self
,
buf_size
,
batch_size
,
num_workers
,
shuffle
=
True
):
self
.
num_workers
=
num_workers
self
.
batch_size
=
batch_size
self
.
line_examples
=
[]
self
.
buf_size
=
buf_size
self
.
shuffle
=
shuffle
def
batch_fn
(
self
,
batch_examples
):
""" batch_fn batch producer"""
raise
NotImplementedError
(
"No defined Batch Fn"
)
def
batch_iter
(
self
,
fid
,
perm
):
""" batch iterator"""
if
self
.
shuffle
:
for
batch
in
batch_iter
(
self
,
perm
,
self
.
batch_size
,
fid
,
self
.
num_workers
):
yield
batch
else
:
for
batch
in
scan_batch_iter
(
self
,
self
.
batch_size
,
fid
,
self
.
num_workers
):
yield
batch
def
__len__
(
self
):
return
len
(
self
.
line_examples
)
def
__getitem__
(
self
,
idx
):
if
isinstance
(
idx
,
collections
.
Iterable
):
return
[
self
[
bidx
]
for
bidx
in
idx
]
else
:
return
self
.
line_examples
[
idx
]
def
generator
(
self
):
"""batch dict generator"""
def
worker
(
filter_id
,
perm
):
""" multiprocess worker"""
def
func_run
():
""" func_run """
pid
=
os
.
getpid
()
np
.
random
.
seed
(
pid
+
int
(
time
.
time
()))
for
batch_examples
in
self
.
batch_iter
(
filter_id
,
perm
):
batch_dict
=
self
.
batch_fn
(
batch_examples
)
yield
batch_dict
return
func_run
# consume a seed
np
.
random
.
rand
()
if
self
.
shuffle
:
perm
=
np
.
arange
(
0
,
len
(
self
))
np
.
random
.
shuffle
(
perm
)
else
:
perm
=
None
if
self
.
num_workers
==
1
:
r
=
paddle
.
reader
.
buffered
(
worker
(
0
,
perm
),
self
.
buf_size
)
else
:
worker_pool
=
[
worker
(
wid
,
perm
)
for
wid
in
range
(
self
.
num_workers
)
]
worker
=
mp_reader
.
multiprocess_reader
(
worker_pool
,
use_pipe
=
True
,
queue_size
=
1000
)
r
=
paddle
.
reader
.
buffered
(
worker
,
self
.
buf_size
)
for
batch
in
r
():
yield
batch
def
scan
(
self
):
for
line_example
in
self
.
line_examples
:
yield
line_example
ogb_examples/linkproppred/ogbl-ppa/dataloader/ogbl_ppa_dataloader.py
0 → 100644
浏览文件 @
cd30e61c
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
from
__future__
import
unicode_literals
from
__future__
import
absolute_import
from
dataloader.base_dataloader
import
BaseDataGenerator
import
ssl
ssl
.
_create_default_https_context
=
ssl
.
_create_unverified_context
from
ogb.linkproppred
import
LinkPropPredDataset
from
ogb.linkproppred
import
Evaluator
import
tqdm
from
collections
import
namedtuple
import
pgl
import
numpy
as
np
class
PPADataGenerator
(
BaseDataGenerator
):
def
__init__
(
self
,
graph_wrapper
=
None
,
buf_size
=
1000
,
batch_size
=
128
,
num_workers
=
1
,
shuffle
=
True
,
phase
=
"train"
):
super
(
PPADataGenerator
,
self
).
__init__
(
buf_size
=
buf_size
,
num_workers
=
num_workers
,
batch_size
=
batch_size
,
shuffle
=
shuffle
)
self
.
d_name
=
"ogbl-ppa"
self
.
graph_wrapper
=
graph_wrapper
dataset
=
LinkPropPredDataset
(
name
=
self
.
d_name
)
splitted_edge
=
dataset
.
get_edge_split
()
self
.
phase
=
phase
graph
=
dataset
[
0
]
edges
=
graph
[
"edge_index"
].
T
#self.graph = pgl.graph.Graph(num_nodes=graph["num_nodes"],
# edges=edges,
# node_feat={"nfeat": graph["node_feat"],
# "node_id": np.arange(0, graph["num_nodes"], dtype="int64").reshape(-1, 1) })
#self.graph.indegree()
self
.
num_nodes
=
graph
[
"num_nodes"
]
if
self
.
phase
==
'train'
:
edges
=
splitted_edge
[
"train"
][
"edge"
]
labels
=
np
.
ones
(
len
(
edges
))
elif
self
.
phase
==
"valid"
:
# Compute the embedding for all the nodes
pos_edges
=
splitted_edge
[
"valid"
][
"edge"
]
neg_edges
=
splitted_edge
[
"valid"
][
"edge_neg"
]
pos_labels
=
np
.
ones
(
len
(
pos_edges
))
neg_labels
=
np
.
zeros
(
len
(
neg_edges
))
edges
=
np
.
vstack
([
pos_edges
,
neg_edges
])
labels
=
pos_labels
.
tolist
()
+
neg_labels
.
tolist
()
elif
self
.
phase
==
"test"
:
# Compute the embedding for all the nodes
pos_edges
=
splitted_edge
[
"test"
][
"edge"
]
neg_edges
=
splitted_edge
[
"test"
][
"edge_neg"
]
pos_labels
=
np
.
ones
(
len
(
pos_edges
))
neg_labels
=
np
.
zeros
(
len
(
neg_edges
))
edges
=
np
.
vstack
([
pos_edges
,
neg_edges
])
labels
=
pos_labels
.
tolist
()
+
neg_labels
.
tolist
()
self
.
line_examples
=
[]
Example
=
namedtuple
(
'Example'
,
[
'src'
,
"dst"
,
"label"
])
for
edge
,
label
in
zip
(
edges
,
labels
):
self
.
line_examples
.
append
(
Example
(
src
=
edge
[
0
],
dst
=
edge
[
1
],
label
=
label
))
print
(
"Phase"
,
self
.
phase
)
print
(
"Len Examples"
,
len
(
self
.
line_examples
))
def
batch_fn
(
self
,
batch_ex
):
batch_src
=
[]
batch_dst
=
[]
join_graph
=
[]
cc
=
0
batch_node_id
=
[]
batch_labels
=
[]
for
ex
in
batch_ex
:
batch_src
.
append
(
ex
.
src
)
batch_dst
.
append
(
ex
.
dst
)
batch_labels
.
append
(
ex
.
label
)
if
self
.
phase
==
"train"
:
for
num
in
range
(
1
):
rand_src
=
np
.
random
.
randint
(
low
=
0
,
high
=
self
.
num_nodes
,
size
=
len
(
batch_ex
))
rand_dst
=
np
.
random
.
randint
(
low
=
0
,
high
=
self
.
num_nodes
,
size
=
len
(
batch_ex
))
batch_src
=
batch_src
+
rand_src
.
tolist
()
batch_dst
=
batch_dst
+
rand_dst
.
tolist
()
batch_labels
=
batch_labels
+
np
.
zeros_like
(
rand_src
,
dtype
=
"int64"
).
tolist
()
feed_dict
=
{}
feed_dict
[
"batch_src"
]
=
np
.
array
(
batch_src
,
dtype
=
"int64"
)
feed_dict
[
"batch_dst"
]
=
np
.
array
(
batch_dst
,
dtype
=
"int64"
)
feed_dict
[
"labels"
]
=
np
.
array
(
batch_labels
,
dtype
=
"int64"
)
return
feed_dict
ogb_examples/linkproppred/ogbl-ppa/model.py
0 → 100644
浏览文件 @
cd30e61c
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""lbs_model"""
import
os
import
re
import
time
from
random
import
random
from
functools
import
reduce
,
partial
import
numpy
as
np
import
multiprocessing
import
paddle
import
paddle.fluid
as
F
import
paddle.fluid.layers
as
L
from
pgl.graph_wrapper
import
GraphWrapper
from
pgl.layers.conv
import
gcn
,
gat
class
BaseGraph
(
object
):
"""Base Graph Model"""
def
__init__
(
self
,
args
):
node_feature
=
[(
'nfeat'
,
[
None
,
58
],
"float32"
),
(
'node_id'
,
[
None
,
1
],
"int64"
)]
self
.
hidden_size
=
args
.
hidden_size
self
.
num_nodes
=
args
.
num_nodes
self
.
graph_wrapper
=
None
# GraphWrapper(
#name="graph", place=F.CPUPlace(), node_feat=node_feature)
self
.
build_model
(
args
)
def
build_model
(
self
,
args
):
""" build graph model"""
self
.
batch_src
=
L
.
data
(
name
=
"batch_src"
,
shape
=
[
-
1
],
dtype
=
"int64"
)
self
.
batch_src
=
L
.
reshape
(
self
.
batch_src
,
[
-
1
,
1
])
self
.
batch_dst
=
L
.
data
(
name
=
"batch_dst"
,
shape
=
[
-
1
],
dtype
=
"int64"
)
self
.
batch_dst
=
L
.
reshape
(
self
.
batch_dst
,
[
-
1
,
1
])
self
.
labels
=
L
.
data
(
name
=
"labels"
,
shape
=
[
-
1
],
dtype
=
"int64"
)
self
.
labels
=
L
.
reshape
(
self
.
labels
,
[
-
1
,
1
])
self
.
labels
.
stop_gradients
=
True
self
.
src_repr
=
L
.
embedding
(
self
.
batch_src
,
size
=
(
self
.
num_nodes
,
self
.
hidden_size
),
param_attr
=
F
.
ParamAttr
(
name
=
"node_embeddings"
,
initializer
=
F
.
initializer
.
NormalInitializer
(
loc
=
0.0
,
scale
=
1.0
)))
self
.
dst_repr
=
L
.
embedding
(
self
.
batch_dst
,
size
=
(
self
.
num_nodes
,
self
.
hidden_size
),
param_attr
=
F
.
ParamAttr
(
name
=
"node_embeddings"
,
initializer
=
F
.
initializer
.
NormalInitializer
(
loc
=
0.0
,
scale
=
1.0
)))
self
.
link_predictor
(
self
.
src_repr
,
self
.
dst_repr
)
self
.
bce_loss
()
def
link_predictor
(
self
,
x
,
y
):
""" siamese network"""
feat
=
x
*
y
feat
=
L
.
fc
(
feat
,
size
=
self
.
hidden_size
,
name
=
"link_predictor_1"
)
feat
=
L
.
relu
(
feat
)
feat
=
L
.
fc
(
feat
,
size
=
self
.
hidden_size
,
name
=
"link_predictor_2"
)
feat
=
L
.
relu
(
feat
)
self
.
logits
=
L
.
fc
(
feat
,
size
=
1
,
act
=
"sigmoid"
,
name
=
"link_predictor_logits"
)
def
bce_loss
(
self
):
"""listwise model"""
mask
=
L
.
cast
(
self
.
labels
>
0.5
,
dtype
=
"float32"
)
mask
.
stop_gradients
=
True
self
.
loss
=
L
.
log_loss
(
self
.
logits
,
mask
,
epsilon
=
1e-15
)
self
.
loss
=
L
.
reduce_mean
(
self
.
loss
)
*
2
proba
=
L
.
sigmoid
(
self
.
logits
)
proba
=
L
.
concat
([
proba
*
-
1
+
1
,
proba
],
axis
=
1
)
auc_out
,
batch_auc_out
,
_
=
\
L
.
auc
(
input
=
proba
,
label
=
self
.
labels
,
curve
=
'ROC'
,
slide_steps
=
1
)
self
.
metrics
=
{
"loss"
:
self
.
loss
,
"top1"
:
batch_auc_out
,
"max"
:
L
.
reduce_max
(
self
.
logits
),
"min"
:
L
.
reduce_min
(
self
.
logits
)
}
def
neighbor_aggregator
(
self
,
node_repr
):
"""neighbor aggregation"""
return
node_repr
ogb_examples/linkproppred/ogbl-ppa/monitor/__init__.py
0 → 100644
浏览文件 @
cd30e61c
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""init"""
ogb_examples/linkproppred/ogbl-ppa/monitor/train_monitor.py
0 → 100644
浏览文件 @
cd30e61c
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""train and evaluate"""
import
tqdm
import
json
import
numpy
as
np
import
sys
import
os
import
paddle.fluid
as
F
from
tensorboardX
import
SummaryWriter
from
ogb.linkproppred
import
Evaluator
from
ogb.linkproppred
import
LinkPropPredDataset
def
multi_device
(
reader
,
dev_count
):
"""multi device"""
if
dev_count
==
1
:
for
batch
in
reader
:
yield
batch
else
:
batches
=
[]
for
batch
in
reader
:
batches
.
append
(
batch
)
if
len
(
batches
)
==
dev_count
:
yield
batches
batches
=
[]
class
OgbEvaluator
(
object
):
def
__init__
(
self
):
d_name
=
"ogbl-ppa"
dataset
=
LinkPropPredDataset
(
name
=
d_name
)
splitted_edge
=
dataset
.
get_edge_split
()
graph
=
dataset
[
0
]
self
.
num_nodes
=
graph
[
"num_nodes"
]
self
.
ogb_evaluator
=
Evaluator
(
name
=
"ogbl-ppa"
)
def
eval
(
self
,
scores
,
labels
,
phase
):
labels
=
np
.
reshape
(
labels
,
[
-
1
])
ret
=
{}
pos
=
scores
[
labels
>
0.5
].
squeeze
(
-
1
)
neg
=
scores
[
labels
<
0.5
].
squeeze
(
-
1
)
for
K
in
[
10
,
50
,
100
]:
self
.
ogb_evaluator
.
K
=
K
ret
[
'%s_hits@%s'
%
(
phase
,
K
)]
=
self
.
ogb_evaluator
.
eval
({
'y_pred_pos'
:
pos
,
'y_pred_neg'
:
neg
,
})[
f
'hits@
{
K
}
'
]
return
ret
def
evaluate
(
model
,
valid_exe
,
valid_ds
,
valid_prog
,
dev_count
,
evaluator
,
phase
):
"""evaluate """
cc
=
0
scores
=
[]
labels
=
[]
for
feed_dict
in
tqdm
.
tqdm
(
multi_device
(
valid_ds
.
generator
(),
dev_count
),
desc
=
'evaluating'
):
if
dev_count
>
1
:
output
=
valid_exe
.
run
(
feed
=
feed_dict
,
fetch_list
=
[
model
.
logits
,
model
.
labels
])
else
:
output
=
valid_exe
.
run
(
valid_prog
,
feed
=
feed_dict
,
fetch_list
=
[
model
.
logits
,
model
.
labels
])
scores
.
append
(
output
[
0
])
labels
.
append
(
output
[
1
])
scores
=
np
.
vstack
(
scores
)
labels
=
np
.
vstack
(
labels
)
ret
=
evaluator
.
eval
(
scores
,
labels
,
phase
)
return
ret
def
_create_if_not_exist
(
path
):
basedir
=
os
.
path
.
dirname
(
path
)
if
not
os
.
path
.
exists
(
basedir
):
os
.
makedirs
(
basedir
)
def
train_and_evaluate
(
exe
,
train_exe
,
valid_exe
,
train_ds
,
valid_ds
,
test_ds
,
train_prog
,
valid_prog
,
model
,
metric
,
epoch
=
20
,
dev_count
=
1
,
train_log_step
=
5
,
eval_step
=
10000
,
evaluator
=
None
,
output_path
=
None
):
"""train and evaluate"""
global_step
=
0
log_path
=
os
.
path
.
join
(
output_path
,
"log"
)
_create_if_not_exist
(
log_path
)
writer
=
SummaryWriter
(
log_path
)
best_model
=
0
for
e
in
range
(
epoch
):
for
feed_dict
in
tqdm
.
tqdm
(
multi_device
(
train_ds
.
generator
(),
dev_count
),
desc
=
'Epoch %s'
%
e
):
if
dev_count
>
1
:
ret
=
train_exe
.
run
(
feed
=
feed_dict
,
fetch_list
=
metric
.
vars
)
ret
=
[[
np
.
mean
(
v
)]
for
v
in
ret
]
else
:
ret
=
train_exe
.
run
(
train_prog
,
feed
=
feed_dict
,
fetch_list
=
metric
.
vars
)
ret
=
metric
.
parse
(
ret
)
if
global_step
%
train_log_step
==
0
:
sys
.
stderr
.
write
(
json
.
dumps
(
ret
)
+
'
\n
'
)
for
key
,
value
in
ret
.
items
():
writer
.
add_scalar
(
'train_'
+
key
,
value
,
global_step
=
global_step
)
global_step
+=
1
if
global_step
%
eval_step
==
0
:
eval_ret
=
evaluate
(
model
,
exe
,
valid_ds
,
valid_prog
,
1
,
evaluator
,
"valid"
)
test_eval_ret
=
evaluate
(
model
,
exe
,
test_ds
,
valid_prog
,
1
,
evaluator
,
"test"
)
eval_ret
.
update
(
test_eval_ret
)
sys
.
stderr
.
write
(
json
.
dumps
(
eval_ret
,
indent
=
4
)
+
"
\n
"
)
for
key
,
value
in
eval_ret
.
items
():
writer
.
add_scalar
(
key
,
value
,
global_step
=
global_step
)
if
eval_ret
[
"valid_hits@100"
]
>
best_model
:
F
.
io
.
save_persistables
(
exe
,
os
.
path
.
join
(
output_path
,
"checkpoint"
),
train_prog
)
eval_ret
[
"step"
]
=
global_step
with
open
(
os
.
path
.
join
(
output_path
,
"best.txt"
),
"w"
)
as
f
:
f
.
write
(
json
.
dumps
(
eval_ret
,
indent
=
2
)
+
'
\n
'
)
best_model
=
eval_ret
[
"valid_hits@100"
]
# Epoch End
eval_ret
=
evaluate
(
model
,
exe
,
valid_ds
,
valid_prog
,
1
,
evaluator
,
"valid"
)
test_eval_ret
=
evaluate
(
model
,
exe
,
test_ds
,
valid_prog
,
1
,
evaluator
,
"test"
)
eval_ret
.
update
(
test_eval_ret
)
sys
.
stderr
.
write
(
json
.
dumps
(
eval_ret
,
indent
=
4
)
+
"
\n
"
)
for
key
,
value
in
eval_ret
.
items
():
writer
.
add_scalar
(
key
,
value
,
global_step
=
global_step
)
if
eval_ret
[
"valid_hits@100"
]
>
best_model
:
F
.
io
.
save_persistables
(
exe
,
os
.
path
.
join
(
output_path
,
"checkpoint"
),
train_prog
)
eval_ret
[
"step"
]
=
global_step
with
open
(
os
.
path
.
join
(
output_path
,
"best.txt"
),
"w"
)
as
f
:
f
.
write
(
json
.
dumps
(
eval_ret
,
indent
=
2
)
+
'
\n
'
)
best_model
=
eval_ret
[
"valid_hits@100"
]
writer
.
close
()
ogb_examples/linkproppred/ogbl-ppa/train.py
0 → 100644
浏览文件 @
cd30e61c
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""listwise model
"""
import
torch
import
os
import
re
import
time
import
logging
from
random
import
random
from
functools
import
reduce
,
partial
# For downloading ogb
import
ssl
ssl
.
_create_default_https_context
=
ssl
.
_create_unverified_context
# SSL
import
numpy
as
np
import
multiprocessing
import
paddle
import
paddle.fluid
as
F
import
paddle.fluid.layers
as
L
import
pgl
from
utils.args
import
print_arguments
,
check_cuda
from
utils.init
import
init_checkpoint
,
init_pretraining_params
from
args
import
parser
from
model
import
BaseGraph
,
GCNGraph
from
dataloader.ogbl_ppa_dataloader
import
PPADataGenerator
from
monitor.train_monitor
import
train_and_evaluate
,
OgbEvaluator
log
=
logging
.
getLogger
(
__name__
)
class
Metric
(
object
):
"""Metric"""
def
__init__
(
self
,
**
args
):
self
.
args
=
args
@
property
def
vars
(
self
):
""" fetch metric vars"""
values
=
[
self
.
args
[
k
]
for
k
in
self
.
args
.
keys
()]
return
values
def
parse
(
self
,
fetch_list
):
"""parse"""
tup
=
list
(
zip
(
self
.
args
.
keys
(),
[
float
(
v
[
0
])
for
v
in
fetch_list
]))
return
dict
(
tup
)
if
__name__
==
'__main__'
:
args
=
parser
.
parse_args
()
print_arguments
(
args
)
evaluator
=
OgbEvaluator
()
train_prog
=
F
.
Program
()
startup_prog
=
F
.
Program
()
args
.
num_nodes
=
evaluator
.
num_nodes
if
args
.
use_cuda
:
dev_list
=
F
.
cuda_places
()
place
=
dev_list
[
0
]
dev_count
=
len
(
dev_list
)
else
:
place
=
F
.
CPUPlace
()
dev_count
=
int
(
os
.
environ
.
get
(
'CPU_NUM'
,
multiprocessing
.
cpu_count
()))
with
F
.
program_guard
(
train_prog
,
startup_prog
):
with
F
.
unique_name
.
guard
():
graph_model
=
BaseGraph
(
args
)
test_prog
=
train_prog
.
clone
(
for_test
=
True
)
opt
=
F
.
optimizer
.
Adam
(
learning_rate
=
args
.
learning_rate
)
opt
.
minimize
(
graph_model
.
loss
)
#test_prog = F.Program()
#with F.program_guard(test_prog, startup_prog):
# with F.unique_name.guard():
# _graph_model = BaseGraph(args)
train_ds
=
PPADataGenerator
(
phase
=
"train"
,
graph_wrapper
=
graph_model
.
graph_wrapper
,
num_workers
=
args
.
num_workers
,
batch_size
=
args
.
batch_size
)
valid_ds
=
PPADataGenerator
(
phase
=
"valid"
,
graph_wrapper
=
graph_model
.
graph_wrapper
,
num_workers
=
args
.
num_workers
,
batch_size
=
args
.
batch_size
)
test_ds
=
PPADataGenerator
(
phase
=
"test"
,
graph_wrapper
=
graph_model
.
graph_wrapper
,
num_workers
=
args
.
num_workers
,
batch_size
=
args
.
batch_size
)
exe
=
F
.
Executor
(
place
)
exe
.
run
(
startup_prog
)
if
args
.
init_pretraining_params
is
not
None
:
init_pretraining_params
(
exe
,
args
.
init_pretraining_params
,
main_program
=
startup_prog
)
metric
=
Metric
(
**
graph_model
.
metrics
)
nccl2_num_trainers
=
1
nccl2_trainer_id
=
0
if
dev_count
>
1
:
exec_strategy
=
F
.
ExecutionStrategy
()
exec_strategy
.
num_threads
=
dev_count
train_exe
=
F
.
ParallelExecutor
(
use_cuda
=
args
.
use_cuda
,
loss_name
=
graph_model
.
loss
.
name
,
exec_strategy
=
exec_strategy
,
main_program
=
train_prog
,
num_trainers
=
nccl2_num_trainers
,
trainer_id
=
nccl2_trainer_id
)
test_exe
=
exe
else
:
train_exe
,
test_exe
=
exe
,
exe
train_and_evaluate
(
exe
=
exe
,
train_exe
=
train_exe
,
valid_exe
=
test_exe
,
train_ds
=
train_ds
,
valid_ds
=
valid_ds
,
test_ds
=
test_ds
,
train_prog
=
train_prog
,
valid_prog
=
test_prog
,
train_log_step
=
5
,
output_path
=
args
.
output_path
,
dev_count
=
dev_count
,
model
=
graph_model
,
epoch
=
args
.
epoch
,
eval_step
=
1000000
,
evaluator
=
evaluator
,
metric
=
metric
)
ogb_examples/linkproppred/ogbl-ppa/utils/__init__.py
0 → 100644
浏览文件 @
cd30e61c
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""utils"""
ogb_examples/linkproppred/ogbl-ppa/utils/args.py
0 → 100644
浏览文件 @
cd30e61c
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Arguments for configuration."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
from
__future__
import
unicode_literals
from
__future__
import
absolute_import
import
six
import
os
import
sys
import
argparse
import
logging
import
paddle.fluid
as
fluid
log
=
logging
.
getLogger
(
__name__
)
def
prepare_logger
(
logger
,
debug
=
False
,
save_to_file
=
None
):
"""doc"""
formatter
=
logging
.
Formatter
(
fmt
=
'[%(levelname)s] %(asctime)s [%(filename)12s:%(lineno)5d]:
\t
%(message)s'
)
#console_hdl = logging.StreamHandler()
#console_hdl.setFormatter(formatter)
#logger.addHandler(console_hdl)
if
save_to_file
is
not
None
and
not
os
.
path
.
exists
(
save_to_file
):
file_hdl
=
logging
.
FileHandler
(
save_to_file
)
file_hdl
.
setFormatter
(
formatter
)
logger
.
addHandler
(
file_hdl
)
logger
.
setLevel
(
logging
.
DEBUG
)
logger
.
propagate
=
False
def
str2bool
(
v
):
"""doc"""
# because argparse does not support to parse "true, False" as python
# boolean directly
return
v
.
lower
()
in
(
"true"
,
"t"
,
"1"
)
class
ArgumentGroup
(
object
):
"""doc"""
def
__init__
(
self
,
parser
,
title
,
des
):
self
.
_group
=
parser
.
add_argument_group
(
title
=
title
,
description
=
des
)
def
add_arg
(
self
,
name
,
type
,
default
,
help
,
positional_arg
=
False
,
**
kwargs
):
"""doc"""
prefix
=
""
if
positional_arg
else
"--"
type
=
str2bool
if
type
==
bool
else
type
self
.
_group
.
add_argument
(
prefix
+
name
,
default
=
default
,
type
=
type
,
help
=
help
+
' Default: %(default)s.'
,
**
kwargs
)
def
print_arguments
(
args
):
"""doc"""
log
.
info
(
'----------- Configuration Arguments -----------'
)
for
arg
,
value
in
sorted
(
six
.
iteritems
(
vars
(
args
))):
log
.
info
(
'%s: %s'
%
(
arg
,
value
))
log
.
info
(
'------------------------------------------------'
)
def
check_cuda
(
use_cuda
,
err
=
\
"
\n
You can not set use_cuda=True in the model because you are using paddlepaddle-cpu.
\n
\
Please: 1. Install paddlepaddle-gpu to run your models on GPU or 2. Set use_cuda=False to run models on CPU.
\n
"
):
"""doc"""
try
:
if
use_cuda
==
True
and
fluid
.
is_compiled_with_cuda
()
==
False
:
log
.
error
(
err
)
sys
.
exit
(
1
)
except
Exception
as
e
:
pass
ogb_examples/linkproppred/ogbl-ppa/utils/cards.py
0 → 100644
浏览文件 @
cd30e61c
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""cards"""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
from
__future__
import
unicode_literals
from
__future__
import
absolute_import
import
os
def
get_cards
():
"""
get gpu cards number
"""
num
=
0
cards
=
os
.
environ
.
get
(
'CUDA_VISIBLE_DEVICES'
,
''
)
if
cards
!=
''
:
num
=
len
(
cards
.
split
(
","
))
return
num
ogb_examples/linkproppred/ogbl-ppa/utils/fp16.py
0 → 100644
浏览文件 @
cd30e61c
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
print_function
import
paddle
import
paddle.fluid
as
fluid
def
append_cast_op
(
i
,
o
,
prog
):
"""
Append a cast op in a given Program to cast input `i` to data type `o.dtype`.
Args:
i (Variable): The input Variable.
o (Variable): The output Variable.
prog (Program): The Program to append cast op.
"""
prog
.
global_block
().
append_op
(
type
=
"cast"
,
inputs
=
{
"X"
:
i
},
outputs
=
{
"Out"
:
o
},
attrs
=
{
"in_dtype"
:
i
.
dtype
,
"out_dtype"
:
o
.
dtype
})
def
copy_to_master_param
(
p
,
block
):
v
=
block
.
vars
.
get
(
p
.
name
,
None
)
if
v
is
None
:
raise
ValueError
(
"no param name %s found!"
%
p
.
name
)
new_p
=
fluid
.
framework
.
Parameter
(
block
=
block
,
shape
=
v
.
shape
,
dtype
=
fluid
.
core
.
VarDesc
.
VarType
.
FP32
,
type
=
v
.
type
,
lod_level
=
v
.
lod_level
,
stop_gradient
=
p
.
stop_gradient
,
trainable
=
p
.
trainable
,
optimize_attr
=
p
.
optimize_attr
,
regularizer
=
p
.
regularizer
,
gradient_clip_attr
=
p
.
gradient_clip_attr
,
error_clip
=
p
.
error_clip
,
name
=
v
.
name
+
".master"
)
return
new_p
def
apply_dynamic_loss_scaling
(
loss_scaling
,
master_params_grads
,
incr_every_n_steps
,
decr_every_n_nan_or_inf
,
incr_ratio
,
decr_ratio
):
_incr_every_n_steps
=
fluid
.
layers
.
fill_constant
(
shape
=
[
1
],
dtype
=
'int32'
,
value
=
incr_every_n_steps
)
_decr_every_n_nan_or_inf
=
fluid
.
layers
.
fill_constant
(
shape
=
[
1
],
dtype
=
'int32'
,
value
=
decr_every_n_nan_or_inf
)
_num_good_steps
=
fluid
.
layers
.
create_global_var
(
name
=
fluid
.
unique_name
.
generate
(
"num_good_steps"
),
shape
=
[
1
],
value
=
0
,
dtype
=
'int32'
,
persistable
=
True
)
_num_bad_steps
=
fluid
.
layers
.
create_global_var
(
name
=
fluid
.
unique_name
.
generate
(
"num_bad_steps"
),
shape
=
[
1
],
value
=
0
,
dtype
=
'int32'
,
persistable
=
True
)
grads
=
[
fluid
.
layers
.
reduce_sum
(
g
)
for
[
_
,
g
]
in
master_params_grads
]
all_grads
=
fluid
.
layers
.
concat
(
grads
)
all_grads_sum
=
fluid
.
layers
.
reduce_sum
(
all_grads
)
is_overall_finite
=
fluid
.
layers
.
isfinite
(
all_grads_sum
)
update_loss_scaling
(
is_overall_finite
,
loss_scaling
,
_num_good_steps
,
_num_bad_steps
,
_incr_every_n_steps
,
_decr_every_n_nan_or_inf
,
incr_ratio
,
decr_ratio
)
# apply_gradient append all ops in global block, thus we shouldn't
# apply gradient in the switch branch.
with
fluid
.
layers
.
Switch
()
as
switch
:
with
switch
.
case
(
is_overall_finite
):
pass
with
switch
.
default
():
for
_
,
g
in
master_params_grads
:
fluid
.
layers
.
assign
(
fluid
.
layers
.
zeros_like
(
g
),
g
)
def
create_master_params_grads
(
params_grads
,
main_prog
,
startup_prog
,
loss_scaling
):
master_params_grads
=
[]
for
p
,
g
in
params_grads
:
with
main_prog
.
_optimized_guard
([
p
,
g
]):
# create master parameters
master_param
=
copy_to_master_param
(
p
,
main_prog
.
global_block
())
startup_master_param
=
startup_prog
.
global_block
().
_clone_variable
(
master_param
)
startup_p
=
startup_prog
.
global_block
().
var
(
p
.
name
)
append_cast_op
(
startup_p
,
startup_master_param
,
startup_prog
)
# cast fp16 gradients to fp32 before apply gradients
if
g
.
name
.
find
(
"layer_norm"
)
>
-
1
:
scaled_g
=
g
/
loss_scaling
master_params_grads
.
append
([
p
,
scaled_g
])
continue
master_grad
=
fluid
.
layers
.
cast
(
g
,
"float32"
)
master_grad
=
master_grad
/
loss_scaling
master_params_grads
.
append
([
master_param
,
master_grad
])
return
master_params_grads
def
master_param_to_train_param
(
master_params_grads
,
params_grads
,
main_prog
):
for
idx
,
m_p_g
in
enumerate
(
master_params_grads
):
train_p
,
_
=
params_grads
[
idx
]
if
train_p
.
name
.
find
(
"layer_norm"
)
>
-
1
:
continue
with
main_prog
.
_optimized_guard
([
m_p_g
[
0
],
m_p_g
[
1
]]):
append_cast_op
(
m_p_g
[
0
],
train_p
,
main_prog
)
def
update_loss_scaling
(
is_overall_finite
,
prev_loss_scaling
,
num_good_steps
,
num_bad_steps
,
incr_every_n_steps
,
decr_every_n_nan_or_inf
,
incr_ratio
,
decr_ratio
):
"""
Update loss scaling according to overall gradients. If all gradients is
finite after incr_every_n_steps, loss scaling will increase by incr_ratio.
Otherwisw, loss scaling will decrease by decr_ratio after
decr_every_n_nan_or_inf steps and each step some gradients are infinite.
Args:
is_overall_finite (Variable): A boolean variable indicates whether
all gradients are finite.
prev_loss_scaling (Variable): Previous loss scaling.
num_good_steps (Variable): A variable accumulates good steps in which
all gradients are finite.
num_bad_steps (Variable): A variable accumulates bad steps in which
some gradients are infinite.
incr_every_n_steps (Variable): A variable represents increasing loss
scaling every n consecutive steps with
finite gradients.
decr_every_n_nan_or_inf (Variable): A variable represents decreasing
loss scaling every n accumulated
steps with nan or inf gradients.
incr_ratio(float): The multiplier to use when increasing the loss
scaling.
decr_ratio(float): The less-than-one-multiplier to use when decreasing
loss scaling.
"""
zero_steps
=
fluid
.
layers
.
fill_constant
(
shape
=
[
1
],
dtype
=
'int32'
,
value
=
0
)
with
fluid
.
layers
.
Switch
()
as
switch
:
with
switch
.
case
(
is_overall_finite
):
should_incr_loss_scaling
=
fluid
.
layers
.
less_than
(
incr_every_n_steps
,
num_good_steps
+
1
)
with
fluid
.
layers
.
Switch
()
as
switch1
:
with
switch1
.
case
(
should_incr_loss_scaling
):
new_loss_scaling
=
prev_loss_scaling
*
incr_ratio
loss_scaling_is_finite
=
fluid
.
layers
.
isfinite
(
new_loss_scaling
)
with
fluid
.
layers
.
Switch
()
as
switch2
:
with
switch2
.
case
(
loss_scaling_is_finite
):
fluid
.
layers
.
assign
(
new_loss_scaling
,
prev_loss_scaling
)
with
switch2
.
default
():
pass
fluid
.
layers
.
assign
(
zero_steps
,
num_good_steps
)
fluid
.
layers
.
assign
(
zero_steps
,
num_bad_steps
)
with
switch1
.
default
():
fluid
.
layers
.
increment
(
num_good_steps
)
fluid
.
layers
.
assign
(
zero_steps
,
num_bad_steps
)
with
switch
.
default
():
should_decr_loss_scaling
=
fluid
.
layers
.
less_than
(
decr_every_n_nan_or_inf
,
num_bad_steps
+
1
)
with
fluid
.
layers
.
Switch
()
as
switch3
:
with
switch3
.
case
(
should_decr_loss_scaling
):
new_loss_scaling
=
prev_loss_scaling
*
decr_ratio
static_loss_scaling
=
\
fluid
.
layers
.
fill_constant
(
shape
=
[
1
],
dtype
=
'float32'
,
value
=
1.0
)
less_than_one
=
fluid
.
layers
.
less_than
(
new_loss_scaling
,
static_loss_scaling
)
with
fluid
.
layers
.
Switch
()
as
switch4
:
with
switch4
.
case
(
less_than_one
):
fluid
.
layers
.
assign
(
static_loss_scaling
,
prev_loss_scaling
)
with
switch4
.
default
():
fluid
.
layers
.
assign
(
new_loss_scaling
,
prev_loss_scaling
)
fluid
.
layers
.
assign
(
zero_steps
,
num_good_steps
)
fluid
.
layers
.
assign
(
zero_steps
,
num_bad_steps
)
with
switch3
.
default
():
fluid
.
layers
.
assign
(
zero_steps
,
num_good_steps
)
fluid
.
layers
.
increment
(
num_bad_steps
)
ogb_examples/linkproppred/ogbl-ppa/utils/init.py
0 → 100644
浏览文件 @
cd30e61c
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""paddle init"""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
from
__future__
import
unicode_literals
from
__future__
import
absolute_import
import
os
import
six
import
ast
import
copy
import
logging
import
numpy
as
np
import
paddle.fluid
as
fluid
log
=
logging
.
getLogger
(
__name__
)
def
cast_fp32_to_fp16
(
exe
,
main_program
):
"""doc"""
log
.
info
(
"Cast parameters to float16 data format."
)
for
param
in
main_program
.
global_block
().
all_parameters
():
if
not
param
.
name
.
endswith
(
".master"
):
param_t
=
fluid
.
global_scope
().
find_var
(
param
.
name
).
get_tensor
()
data
=
np
.
array
(
param_t
)
if
param
.
name
.
startswith
(
"encoder_layer"
)
\
and
"layer_norm"
not
in
param
.
name
:
param_t
.
set
(
np
.
float16
(
data
).
view
(
np
.
uint16
),
exe
.
place
)
#load fp32
master_param_var
=
fluid
.
global_scope
().
find_var
(
param
.
name
+
".master"
)
if
master_param_var
is
not
None
:
master_param_var
.
get_tensor
().
set
(
data
,
exe
.
place
)
def
init_checkpoint
(
exe
,
init_checkpoint_path
,
main_program
,
use_fp16
=
False
):
"""init"""
assert
os
.
path
.
exists
(
init_checkpoint_path
),
"[%s] cann't be found."
%
init_checkpoint_path
def
existed_persitables
(
var
):
"""existed"""
if
not
fluid
.
io
.
is_persistable
(
var
):
return
False
return
os
.
path
.
exists
(
os
.
path
.
join
(
init_checkpoint_path
,
var
.
name
))
fluid
.
io
.
load_vars
(
exe
,
init_checkpoint_path
,
main_program
=
main_program
,
predicate
=
existed_persitables
)
log
.
info
(
"Load model from {}"
.
format
(
init_checkpoint_path
))
if
use_fp16
:
cast_fp32_to_fp16
(
exe
,
main_program
)
def
init_pretraining_params
(
exe
,
pretraining_params_path
,
main_program
,
use_fp16
=
False
):
"""init"""
assert
os
.
path
.
exists
(
pretraining_params_path
),
"[%s] cann't be found."
%
pretraining_params_path
def
existed_params
(
var
):
"""doc"""
if
not
isinstance
(
var
,
fluid
.
framework
.
Parameter
):
return
False
return
os
.
path
.
exists
(
os
.
path
.
join
(
pretraining_params_path
,
var
.
name
))
fluid
.
io
.
load_vars
(
exe
,
pretraining_params_path
,
main_program
=
main_program
,
predicate
=
existed_params
)
log
.
info
(
"Load pretraining parameters from {}."
.
format
(
pretraining_params_path
))
if
use_fp16
:
cast_fp32_to_fp16
(
exe
,
main_program
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录