Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PGL
提交
98049ba1
P
PGL
项目概览
PaddlePaddle
/
PGL
通知
76
Star
4
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
11
列表
看板
标记
里程碑
合并请求
1
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PGL
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
11
Issue
11
列表
看板
标记
里程碑
合并请求
1
合并请求
1
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
98049ba1
编写于
5月 27, 2020
作者:
W
wangwenjin
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
update
上级
e7968881
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
73 addition
and
164 deletion
+73
-164
examples/GaAN/conv.py
examples/GaAN/conv.py
+4
-14
examples/GaAN/model.py
examples/GaAN/model.py
+13
-0
examples/GaAN/preprocess.py
examples/GaAN/preprocess.py
+28
-31
examples/GaAN/train.py
examples/GaAN/train.py
+15
-41
examples/GaAN/train_tool.py
examples/GaAN/train_tool.py
+13
-78
未找到文件。
examples/GaAN/conv.py
浏览文件 @
98049ba1
...
@@ -264,8 +264,8 @@ def gaan(gw, feature, hidden_size_a, hidden_size_v, hidden_size_m, hidden_size_o
...
@@ -264,8 +264,8 @@ def gaan(gw, feature, hidden_size_a, hidden_size_v, hidden_size_m, hidden_size_o
"""Implementation of GaAN"""
"""Implementation of GaAN"""
def
send_func
(
src_feat
,
dst_feat
,
edge_feat
):
def
send_func
(
src_feat
,
dst_feat
,
edge_feat
):
#
计算每条边上的注意力分数
#
attention score of each edge
# E * (M * D1)
, 每个 dst 点都查询它的全部邻边的 src 点
# E * (M * D1)
feat_query
,
feat_key
=
dst_feat
[
'feat_query'
],
src_feat
[
'feat_key'
]
feat_query
,
feat_key
=
dst_feat
[
'feat_query'
],
src_feat
[
'feat_key'
]
# E * M * D1
# E * M * D1
old
=
feat_query
old
=
feat_query
...
@@ -281,16 +281,11 @@ def gaan(gw, feature, hidden_size_a, hidden_size_v, hidden_size_m, hidden_size_o
...
@@ -281,16 +281,11 @@ def gaan(gw, feature, hidden_size_a, hidden_size_v, hidden_size_m, hidden_size_o
'feat_gate'
:
src_feat
[
'feat_gate'
]}
'feat_gate'
:
src_feat
[
'feat_gate'
]}
def
recv_func
(
message
):
def
recv_func
(
message
):
# 每条边的终点的特征
dst_feat
=
message
[
'dst_node_feat'
]
dst_feat
=
message
[
'dst_node_feat'
]
# 每条边的出发点的特征
src_feat
=
message
[
'src_node_feat'
]
src_feat
=
message
[
'src_node_feat'
]
# 每个中心点自己的特征
x
=
fluid
.
layers
.
sequence_pool
(
dst_feat
,
'average'
)
x
=
fluid
.
layers
.
sequence_pool
(
dst_feat
,
'average'
)
# 每个中心点的邻居的特征的平均值
z
=
fluid
.
layers
.
sequence_pool
(
src_feat
,
'average'
)
z
=
fluid
.
layers
.
sequence_pool
(
src_feat
,
'average'
)
# 计算 gate
feat_gate
=
message
[
'feat_gate'
]
feat_gate
=
message
[
'feat_gate'
]
g_max
=
fluid
.
layers
.
sequence_pool
(
feat_gate
,
'max'
)
g_max
=
fluid
.
layers
.
sequence_pool
(
feat_gate
,
'max'
)
g
=
fluid
.
layers
.
concat
([
x
,
g_max
,
z
],
axis
=
1
)
g
=
fluid
.
layers
.
concat
([
x
,
g_max
,
z
],
axis
=
1
)
...
@@ -318,10 +313,6 @@ def gaan(gw, feature, hidden_size_a, hidden_size_v, hidden_size_m, hidden_size_o
...
@@ -318,10 +313,6 @@ def gaan(gw, feature, hidden_size_a, hidden_size_v, hidden_size_m, hidden_size_o
return
output
return
output
# feature N * D
# 计算每个点自己需要发送出去的内容
# 投影后的特征向量
# N * (D1 * M)
# N * (D1 * M)
feat_key
=
fluid
.
layers
.
fc
(
feature
,
hidden_size_a
*
heads
,
bias_attr
=
False
,
feat_key
=
fluid
.
layers
.
fc
(
feature
,
hidden_size_a
*
heads
,
bias_attr
=
False
,
param_attr
=
fluid
.
ParamAttr
(
name
=
name
+
'_project_key'
))
param_attr
=
fluid
.
ParamAttr
(
name
=
name
+
'_project_key'
))
...
@@ -335,8 +326,7 @@ def gaan(gw, feature, hidden_size_a, hidden_size_v, hidden_size_m, hidden_size_o
...
@@ -335,8 +326,7 @@ def gaan(gw, feature, hidden_size_a, hidden_size_v, hidden_size_m, hidden_size_o
feat_gate
=
fluid
.
layers
.
fc
(
feature
,
hidden_size_m
,
bias_attr
=
False
,
feat_gate
=
fluid
.
layers
.
fc
(
feature
,
hidden_size_m
,
bias_attr
=
False
,
param_attr
=
fluid
.
ParamAttr
(
name
=
name
+
'_project_gate'
))
param_attr
=
fluid
.
ParamAttr
(
name
=
name
+
'_project_gate'
))
# send 阶段
# send stage
message
=
gw
.
send
(
message
=
gw
.
send
(
send_func
,
send_func
,
nfeat_list
=
[(
'node_feat'
,
feature
),
(
'feat_key'
,
feat_key
),
(
'feat_value'
,
feat_value
),
nfeat_list
=
[(
'node_feat'
,
feature
),
(
'feat_key'
,
feat_key
),
(
'feat_value'
,
feat_value
),
...
@@ -344,7 +334,7 @@ def gaan(gw, feature, hidden_size_a, hidden_size_v, hidden_size_m, hidden_size_o
...
@@ -344,7 +334,7 @@ def gaan(gw, feature, hidden_size_a, hidden_size_v, hidden_size_m, hidden_size_o
efeat_list
=
None
,
efeat_list
=
None
,
)
)
#
聚合邻居特征
#
recv stage
output
=
gw
.
recv
(
message
,
recv_func
)
output
=
gw
.
recv
(
message
,
recv_func
)
output
=
fluid
.
layers
.
fc
(
output
,
hidden_size_o
,
bias_attr
=
False
,
output
=
fluid
.
layers
.
fc
(
output
,
hidden_size_o
,
bias_attr
=
False
,
param_attr
=
fluid
.
ParamAttr
(
name
=
name
+
'_project_output'
))
param_attr
=
fluid
.
ParamAttr
(
name
=
name
+
'_project_output'
))
...
...
examples/GaAN/model.py
浏览文件 @
98049ba1
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
paddle
import
fluid
from
paddle
import
fluid
from
pgl.utils
import
paddle_helper
from
pgl.utils
import
paddle_helper
...
...
examples/GaAN/preprocess.py
浏览文件 @
98049ba1
"""
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved
将 ogb_proteins 的数据处理为 PGL 的 graph 数据,并返回 graph, label, train/valid/test 等信息
#
"""
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
ssl
import
ssl
ssl
.
_create_default_https_context
=
ssl
.
_create_unverified_context
ssl
.
_create_default_https_context
=
ssl
.
_create_unverified_context
from
ogb.nodeproppred
import
NodePropPredDataset
,
Evaluator
from
ogb.nodeproppred
import
NodePropPredDataset
,
Evaluator
...
@@ -17,7 +27,7 @@ def get_graph_data(d_name="ogbn-proteins", mini_data=False):
...
@@ -17,7 +27,7 @@ def get_graph_data(d_name="ogbn-proteins", mini_data=False):
d_name: name of dataset
d_name: name of dataset
mini_data: if mini_data==True, only use a small dataset (for test)
mini_data: if mini_data==True, only use a small dataset (for test)
"""
"""
#
导入 ogb 数据
#
import ogb data
dataset
=
NodePropPredDataset
(
name
=
d_name
)
dataset
=
NodePropPredDataset
(
name
=
d_name
)
num_tasks
=
dataset
.
num_tasks
# obtaining the number of prediction tasks in a dataset
num_tasks
=
dataset
.
num_tasks
# obtaining the number of prediction tasks in a dataset
...
@@ -25,10 +35,10 @@ def get_graph_data(d_name="ogbn-proteins", mini_data=False):
...
@@ -25,10 +35,10 @@ def get_graph_data(d_name="ogbn-proteins", mini_data=False):
train_idx
,
valid_idx
,
test_idx
=
split_idx
[
"train"
],
split_idx
[
"valid"
],
split_idx
[
"test"
]
train_idx
,
valid_idx
,
test_idx
=
split_idx
[
"train"
],
split_idx
[
"valid"
],
split_idx
[
"test"
]
graph
,
label
=
dataset
[
0
]
graph
,
label
=
dataset
[
0
]
#
调整维度,符合 PGL 的 Graph 要求
#
reshape
graph
[
"edge_index"
]
=
graph
[
"edge_index"
].
T
graph
[
"edge_index"
]
=
graph
[
"edge_index"
].
T
#
使用小规模数据,500个节点
#
mini dataset
if
mini_data
:
if
mini_data
:
graph
[
'num_nodes'
]
=
500
graph
[
'num_nodes'
]
=
500
mask
=
(
graph
[
'edge_index'
][:,
0
]
<
500
)
*
(
graph
[
'edge_index'
][:,
1
]
<
500
)
mask
=
(
graph
[
'edge_index'
][:,
0
]
<
500
)
*
(
graph
[
'edge_index'
][:,
1
]
<
500
)
...
@@ -39,19 +49,9 @@ def get_graph_data(d_name="ogbn-proteins", mini_data=False):
...
@@ -39,19 +49,9 @@ def get_graph_data(d_name="ogbn-proteins", mini_data=False):
valid_idx
=
np
.
arange
(
400
,
450
)
valid_idx
=
np
.
arange
(
400
,
450
)
test_idx
=
np
.
arange
(
450
,
500
)
test_idx
=
np
.
arange
(
450
,
500
)
# 输出 dataset 的信息
print
(
graph
.
keys
())
print
(
"节点个数 "
,
graph
[
"num_nodes"
])
print
(
"节点最小编号"
,
graph
[
'edge_index'
][
0
].
min
())
print
(
"边个数 "
,
graph
[
"edge_index"
].
shape
[
1
])
print
(
"边索引 shape "
,
graph
[
"edge_index"
].
shape
)
print
(
"边特征 shape "
,
graph
[
"edge_feat"
].
shape
)
print
(
"节点特征是 "
,
graph
[
"node_feat"
])
print
(
"species shape"
,
graph
[
'species'
].
shape
)
print
(
"label shape "
,
label
.
shape
)
# 读取/计算 node feature
# read/compute node feature
# 确定读取文件的路径
if
mini_data
:
if
mini_data
:
node_feat_path
=
'./dataset/ogbn_proteins_node_feat_small.npy'
node_feat_path
=
'./dataset/ogbn_proteins_node_feat_small.npy'
else
:
else
:
...
@@ -59,14 +59,11 @@ def get_graph_data(d_name="ogbn-proteins", mini_data=False):
...
@@ -59,14 +59,11 @@ def get_graph_data(d_name="ogbn-proteins", mini_data=False):
new_node_feat
=
None
new_node_feat
=
None
if
os
.
path
.
exists
(
node_feat_path
):
if
os
.
path
.
exists
(
node_feat_path
):
# 如果文件存在,直接读取
print
(
"Begin: read node feature"
.
center
(
50
,
'='
))
print
(
"读取 node feature 开始"
.
center
(
50
,
'='
))
new_node_feat
=
np
.
load
(
node_feat_path
)
new_node_feat
=
np
.
load
(
node_feat_path
)
print
(
"
读取 node feature 成功
"
.
center
(
50
,
'='
))
print
(
"
End: read node feature
"
.
center
(
50
,
'='
))
else
:
else
:
# 如果文件不存在,则计算
print
(
"Begin: compute node feature"
.
center
(
50
,
'='
))
# 每个节点 i 的特征为其邻边特征的均值
print
(
"计算 node feature 开始"
.
center
(
50
,
'='
))
start
=
time
.
perf_counter
()
start
=
time
.
perf_counter
()
for
i
in
range
(
graph
[
'num_nodes'
]):
for
i
in
range
(
graph
[
'num_nodes'
]):
if
i
%
100
==
0
:
if
i
%
100
==
0
:
...
@@ -74,8 +71,8 @@ def get_graph_data(d_name="ogbn-proteins", mini_data=False):
...
@@ -74,8 +71,8 @@ def get_graph_data(d_name="ogbn-proteins", mini_data=False):
print
(
"{}/{}({}%), times: {:.2f}s"
.
format
(
print
(
"{}/{}({}%), times: {:.2f}s"
.
format
(
i
,
graph
[
'num_nodes'
],
i
/
graph
[
'num_nodes'
]
*
100
,
dur
i
,
graph
[
'num_nodes'
],
i
/
graph
[
'num_nodes'
]
*
100
,
dur
))
))
mask
=
(
graph
[
'edge_index'
][:,
0
]
==
i
)
# 选择 i 的所有邻边
mask
=
(
graph
[
'edge_index'
][:,
0
]
==
i
)
# 计算均值
current_node_feat
=
np
.
mean
(
np
.
compress
(
mask
,
graph
[
'edge_feat'
],
axis
=
0
),
current_node_feat
=
np
.
mean
(
np
.
compress
(
mask
,
graph
[
'edge_feat'
],
axis
=
0
),
axis
=
0
,
keepdims
=
True
)
axis
=
0
,
keepdims
=
True
)
if
i
==
0
:
if
i
==
0
:
...
@@ -84,23 +81,23 @@ def get_graph_data(d_name="ogbn-proteins", mini_data=False):
...
@@ -84,23 +81,23 @@ def get_graph_data(d_name="ogbn-proteins", mini_data=False):
new_node_feat
.
append
(
current_node_feat
)
new_node_feat
.
append
(
current_node_feat
)
new_node_feat
=
np
.
concatenate
(
new_node_feat
,
axis
=
0
)
new_node_feat
=
np
.
concatenate
(
new_node_feat
,
axis
=
0
)
print
(
"
计算 node feature 结束
"
.
center
(
50
,
'='
))
print
(
"
End: compute node feature
"
.
center
(
50
,
'='
))
print
(
"
存储 node feature 中,在
"
+
node_feat_path
.
center
(
50
,
'='
))
print
(
"
Saving node feature in
"
+
node_feat_path
.
center
(
50
,
'='
))
np
.
save
(
node_feat_path
,
new_node_feat
)
np
.
save
(
node_feat_path
,
new_node_feat
)
print
(
"
存储 node feature 结束
"
.
center
(
50
,
'='
))
print
(
"
Saving finish
"
.
center
(
50
,
'='
))
print
(
new_node_feat
)
print
(
new_node_feat
)
#
构造 Graph 对象
#
create graph
g
=
pgl
.
graph
.
Graph
(
g
=
pgl
.
graph
.
Graph
(
num_nodes
=
graph
[
"num_nodes"
],
num_nodes
=
graph
[
"num_nodes"
],
edges
=
graph
[
"edge_index"
],
edges
=
graph
[
"edge_index"
],
node_feat
=
{
'node_feat'
:
new_node_feat
},
node_feat
=
{
'node_feat'
:
new_node_feat
},
edge_feat
=
None
edge_feat
=
None
)
)
print
(
"
创建 Graph 对象成功
"
)
print
(
"
Create graph
"
)
print
(
g
)
print
(
g
)
return
g
,
label
,
train_idx
,
valid_idx
,
test_idx
,
Evaluator
(
d_name
)
return
g
,
label
,
train_idx
,
valid_idx
,
test_idx
,
Evaluator
(
d_name
)
\ No newline at end of file
examples/GaAN/train.py
浏览文件 @
98049ba1
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
preprocess
import
get_graph_data
from
preprocess
import
get_graph_data
import
pgl
import
pgl
import
argparse
import
argparse
import
numpy
as
np
import
numpy
as
np
import
time
import
time
from
paddle
import
fluid
from
paddle
import
fluid
from
visualdl
import
LogWriter
import
reader
import
reader
from
train_tool
import
train_epoch
,
valid_epoch
from
train_tool
import
train_epoch
,
valid_epoch
...
@@ -49,10 +61,8 @@ if __name__ == "__main__":
...
@@ -49,10 +61,8 @@ if __name__ == "__main__":
help
=
"the hidden size of each layer in GaAN"
)
help
=
"the hidden size of each layer in GaAN"
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
# d_name = "ogbn-proteins"
print
(
"
超参数配置
"
.
center
(
50
,
"="
))
print
(
"
Parameters Setting
"
.
center
(
50
,
"="
))
print
(
"lr = {}, rc = {}, epochs = {}, batch_size = {}"
.
format
(
args
.
lr
,
args
.
rc
,
args
.
epochs
,
print
(
"lr = {}, rc = {}, epochs = {}, batch_size = {}"
.
format
(
args
.
lr
,
args
.
rc
,
args
.
epochs
,
args
.
batch_size
))
args
.
batch_size
))
print
(
"Experiment ID: {}"
.
format
(
args
.
exp_id
).
center
(
50
,
"="
))
print
(
"Experiment ID: {}"
.
format
(
args
.
exp_id
).
center
(
50
,
"="
))
...
@@ -63,20 +73,6 @@ if __name__ == "__main__":
...
@@ -63,20 +73,6 @@ if __name__ == "__main__":
g
,
label
,
train_idx
,
valid_idx
,
test_idx
,
evaluator
=
get_graph_data
(
d_name
=
d_name
,
g
,
label
,
train_idx
,
valid_idx
,
test_idx
,
evaluator
=
get_graph_data
(
d_name
=
d_name
,
mini_data
=
eval
(
args
.
mini_data
))
mini_data
=
eval
(
args
.
mini_data
))
# create log writer
log_writer
=
LogWriter
(
args
.
log_path
+
'/'
+
str
(
args
.
exp_id
),
sync_cycle
=
10
)
with
log_writer
.
mode
(
"train"
)
as
logger
:
log_train_loss_epoch
=
logger
.
scalar
(
"loss"
)
log_train_rocauc_epoch
=
logger
.
scalar
(
"rocauc"
)
with
log_writer
.
mode
(
"valid"
)
as
logger
:
log_valid_loss_epoch
=
logger
.
scalar
(
"loss"
)
log_valid_rocauc_epoch
=
logger
.
scalar
(
"rocauc"
)
log_text
=
log_writer
.
text
(
"text"
)
log_time
=
log_writer
.
scalar
(
"time"
)
log_test_loss
=
log_writer
.
scalar
(
"test_loss"
)
log_test_rocauc
=
log_writer
.
scalar
(
"test_rocauc"
)
if
args
.
model
==
"GaAN"
:
if
args
.
model
==
"GaAN"
:
graph_model
=
GaANModel
(
112
,
3
,
args
.
hidden_size_a
,
args
.
hidden_size_v
,
args
.
hidden_size_m
,
graph_model
=
GaANModel
(
112
,
3
,
args
.
hidden_size_a
,
args
.
hidden_size_v
,
args
.
hidden_size_m
,
args
.
hidden_size_o
,
args
.
heads
)
args
.
hidden_size_o
,
args
.
heads
)
...
@@ -162,17 +158,13 @@ if __name__ == "__main__":
...
@@ -162,17 +158,13 @@ if __name__ == "__main__":
start
=
time
.
time
()
start
=
time
.
time
()
print
(
"Training Begin"
.
center
(
50
,
"="
))
print
(
"Training Begin"
.
center
(
50
,
"="
))
log_text
.
add_record
(
0
,
"Training Begin"
.
center
(
50
,
"="
))
best_valid
=
-
1.0
best_valid
=
-
1.0
for
epoch
in
range
(
args
.
epochs
):
for
epoch
in
range
(
args
.
epochs
):
start_e
=
time
.
time
()
start_e
=
time
.
time
()
# print("Train Epoch {}".format(epoch).center(50, "="))
train_loss
,
train_rocauc
=
train_epoch
(
train_loss
,
train_rocauc
=
train_epoch
(
train_iter
,
program
=
train_program
,
exe
=
exe
,
loss
=
loss
,
score
=
score
,
train_iter
,
program
=
train_program
,
exe
=
exe
,
loss
=
loss
,
score
=
score
,
evaluator
=
evaluator
,
epoch
=
epoch
evaluator
=
evaluator
,
epoch
=
epoch
)
)
print
(
"Valid Epoch {}"
.
format
(
epoch
).
center
(
50
,
"="
))
valid_loss
,
valid_rocauc
=
valid_epoch
(
valid_loss
,
valid_rocauc
=
valid_epoch
(
val_iter
,
program
=
val_program
,
exe
=
exe
,
loss
=
loss
,
score
=
score
,
val_iter
,
program
=
val_program
,
exe
=
exe
,
loss
=
loss
,
score
=
score
,
evaluator
=
evaluator
,
epoch
=
epoch
)
evaluator
=
evaluator
,
epoch
=
epoch
)
...
@@ -180,16 +172,7 @@ if __name__ == "__main__":
...
@@ -180,16 +172,7 @@ if __name__ == "__main__":
print
(
"Epoch {}: train_loss={:.4},val_loss={:.4}, train_rocauc={:.4}, val_rocauc={:.4}, s/epoch={:.3}"
.
format
(
print
(
"Epoch {}: train_loss={:.4},val_loss={:.4}, train_rocauc={:.4}, val_rocauc={:.4}, s/epoch={:.3}"
.
format
(
epoch
,
train_loss
,
valid_loss
,
train_rocauc
,
valid_rocauc
,
end_e
-
start_e
epoch
,
train_loss
,
valid_loss
,
train_rocauc
,
valid_rocauc
,
end_e
-
start_e
))
))
log_text
.
add_record
(
epoch
+
1
,
"Epoch {}: train_loss={:.4},val_loss={:.4}, train_rocauc={:.4}, val_rocauc={:.4}, s/epoch={:.3}"
.
format
(
epoch
,
train_loss
,
valid_loss
,
train_rocauc
,
valid_rocauc
,
end_e
-
start_e
))
log_train_loss_epoch
.
add_record
(
epoch
,
train_loss
)
log_valid_loss_epoch
.
add_record
(
epoch
,
valid_loss
)
log_train_rocauc_epoch
.
add_record
(
epoch
,
train_rocauc
)
log_valid_rocauc_epoch
.
add_record
(
epoch
,
valid_rocauc
)
log_time
.
add_record
(
epoch
,
end_e
-
start_e
)
if
valid_rocauc
>
best_valid
:
if
valid_rocauc
>
best_valid
:
print
(
"Update: new {}, old {}"
.
format
(
valid_rocauc
,
best_valid
))
print
(
"Update: new {}, old {}"
.
format
(
valid_rocauc
,
best_valid
))
best_valid
=
valid_rocauc
best_valid
=
valid_rocauc
...
@@ -198,23 +181,14 @@ if __name__ == "__main__":
...
@@ -198,23 +181,14 @@ if __name__ == "__main__":
print
(
"Test Stage"
.
center
(
50
,
"="
))
print
(
"Test Stage"
.
center
(
50
,
"="
))
log_text
.
add_record
(
args
.
epochs
+
1
,
"Test Stage"
.
center
(
50
,
"="
))
fluid
.
io
.
load_params
(
executor
=
exe
,
dirname
=
'./params/'
+
str
(
args
.
exp_id
),
main_program
=
val_program
)
fluid
.
io
.
load_params
(
executor
=
exe
,
dirname
=
'./params/'
+
str
(
args
.
exp_id
),
main_program
=
val_program
)
test_loss
,
test_rocauc
=
valid_epoch
(
test_loss
,
test_rocauc
=
valid_epoch
(
test_iter
,
program
=
val_program
,
exe
=
exe
,
loss
=
loss
,
score
=
score
,
test_iter
,
program
=
val_program
,
exe
=
exe
,
loss
=
loss
,
score
=
score
,
evaluator
=
evaluator
,
epoch
=
epoch
)
evaluator
=
evaluator
,
epoch
=
epoch
)
log_test_loss
.
add_record
(
0
,
test_loss
)
log_test_rocauc
.
add_record
(
0
,
test_rocauc
)
end
=
time
.
time
()
end
=
time
.
time
()
print
(
"test_loss={:.4},test_rocauc={:.4}, Total Time={:.3}"
.
format
(
print
(
"test_loss={:.4},test_rocauc={:.4}, Total Time={:.3}"
.
format
(
test_loss
,
test_rocauc
,
end
-
start
test_loss
,
test_rocauc
,
end
-
start
))
))
print
(
"End"
.
center
(
50
,
"="
))
print
(
"End"
.
center
(
50
,
"="
))
log_text
.
add_record
(
args
.
epochs
+
2
,
"test_loss={:.4},test_rocauc={:.4}, Total Time={:.3}"
.
format
(
test_loss
,
test_rocauc
,
end
-
start
))
log_text
.
add_record
(
args
.
epochs
+
3
,
"End"
.
center
(
50
,
"="
))
\ No newline at end of file
examples/GaAN/train_tool.py
浏览文件 @
98049ba1
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
time
import
time
from
pgl.utils.logger
import
log
from
pgl.utils.logger
import
log
...
@@ -15,50 +28,12 @@ def train_epoch(batch_iter, exe, program, loss, score, evaluator, epoch, log_per
...
@@ -15,50 +28,12 @@ def train_epoch(batch_iter, exe, program, loss, score, evaluator, epoch, log_per
total_sample
+=
num_samples
total_sample
+=
num_samples
input_dict
=
{
input_dict
=
{
"y_true"
:
batch_feed_dict
[
"node_label"
],
"y_true"
:
batch_feed_dict
[
"node_label"
],
# "y_pred": y_pred[batch_feed_dict["node_index"]]
"y_pred"
:
y_pred
"y_pred"
:
y_pred
}
}
result
+=
evaluator
.
eval
(
input_dict
)[
"rocauc"
]
result
+=
evaluator
.
eval
(
input_dict
)[
"rocauc"
]
# if batch % log_per_step == 0:
# print("Batch {}: Loss={}".format(batch, batch_loss))
# log.info("Batch %s %s-Loss %s %s-Acc %s" %
# (batch, prefix, batch_loss, prefix, batch_acc))
# print("Epoch {} Train: Loss={}, rocauc={}, Speed(per batch)={}".format(
# epoch, total_loss/total_sample, result/batch, (end-start)/batch))
return
total_loss
.
item
()
/
total_sample
,
result
/
batch
return
total_loss
.
item
()
/
total_sample
,
result
/
batch
def
inference
(
batch_iter
,
exe
,
program
,
loss
,
score
,
evaluator
,
epoch
,
log_per_step
=
1
):
batch
=
0
total_sample
=
0
total_loss
=
0
result
=
0
start
=
time
.
time
()
for
batch_feed_dict
in
batch_iter
():
batch
+=
1
y_pred
=
exe
.
run
(
program
,
fetch_list
=
[
score
],
feed
=
batch_feed_dict
)[
0
]
input_dict
=
{
"y_true"
:
batch_feed_dict
[
"node_label"
],
"y_pred"
:
y_pred
[
batch_feed_dict
[
"node_index"
]]
}
result
+=
evaluator
.
eval
(
input_dict
)[
"rocauc"
]
if
batch
%
log_per_step
==
0
:
print
(
batch
,
result
/
batch
)
num_samples
=
len
(
batch_feed_dict
[
"node_index"
])
# total_loss += batch_loss * num_samples
# total_acc += batch_acc * num_samples
total_sample
+=
num_samples
end
=
time
.
time
()
print
(
"Epoch {} Valid: Loss={}, Speed(per batch)={}"
.
format
(
epoch
,
total_loss
/
total_sample
,
(
end
-
start
)
/
batch
))
return
total_loss
/
total_sample
,
result
/
batch
def
valid_epoch
(
batch_iter
,
exe
,
program
,
loss
,
score
,
evaluator
,
epoch
,
log_per_step
=
1
):
def
valid_epoch
(
batch_iter
,
exe
,
program
,
loss
,
score
,
evaluator
,
epoch
,
log_per_step
=
1
):
batch
=
0
batch
=
0
total_sample
=
0
total_sample
=
0
...
@@ -69,53 +44,13 @@ def valid_epoch(batch_iter, exe, program, loss, score, evaluator, epoch, log_per
...
@@ -69,53 +44,13 @@ def valid_epoch(batch_iter, exe, program, loss, score, evaluator, epoch, log_per
batch_loss
,
y_pred
=
exe
.
run
(
program
,
fetch_list
=
[
loss
,
score
],
feed
=
batch_feed_dict
)
batch_loss
,
y_pred
=
exe
.
run
(
program
,
fetch_list
=
[
loss
,
score
],
feed
=
batch_feed_dict
)
input_dict
=
{
input_dict
=
{
"y_true"
:
batch_feed_dict
[
"node_label"
],
"y_true"
:
batch_feed_dict
[
"node_label"
],
# "y_pred": y_pred[batch_feed_dict["node_index"]]
"y_pred"
:
y_pred
"y_pred"
:
y_pred
}
}
# print(evaluator.eval(input_dict))
result
+=
evaluator
.
eval
(
input_dict
)[
"rocauc"
]
result
+=
evaluator
.
eval
(
input_dict
)[
"rocauc"
]
# if batch % log_per_step == 0:
# print(batch, result/batch)
num_samples
=
len
(
batch_feed_dict
[
"node_index"
])
num_samples
=
len
(
batch_feed_dict
[
"node_index"
])
total_loss
+=
batch_loss
*
num_samples
total_loss
+=
batch_loss
*
num_samples
# total_acc += batch_acc * num_samples
total_sample
+=
num_samples
total_sample
+=
num_samples
# print("Epoch {} Valid: Loss={}, Speed(per batch)={}".format(epoch, total_loss/total_sample, (end-start)/batch))
return
total_loss
.
item
()
/
total_sample
,
result
/
batch
return
total_loss
.
item
()
/
total_sample
,
result
/
batch
def
run_epoch
(
batch_iter
,
exe
,
program
,
prefix
,
model_loss
,
model_acc
,
epoch
,
log_per_step
=
100
):
"""
已废弃
"""
batch
=
0
total_loss
=
0.
total_acc
=
0.
total_sample
=
0
start
=
time
.
time
()
for
batch_feed_dict
in
batch_iter
():
batch
+=
1
batch_loss
,
batch_acc
=
exe
.
run
(
program
,
fetch_list
=
[
model_loss
,
model_acc
],
feed
=
batch_feed_dict
)
if
batch
%
log_per_step
==
0
:
log
.
info
(
"Batch %s %s-Loss %s %s-Acc %s"
%
(
batch
,
prefix
,
batch_loss
,
prefix
,
batch_acc
))
num_samples
=
len
(
batch_feed_dict
[
"node_index"
])
total_loss
+=
batch_loss
*
num_samples
total_acc
+=
batch_acc
*
num_samples
total_sample
+=
num_samples
end
=
time
.
time
()
log
.
info
(
"%s Epoch %s Loss %.5lf Acc %.5lf Speed(per batch) %.5lf sec"
%
(
prefix
,
epoch
,
total_loss
/
total_sample
,
total_acc
/
total_sample
,
(
end
-
start
)
/
batch
))
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录