Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PGL
提交
ecbe2e32
P
PGL
项目概览
PaddlePaddle
/
PGL
通知
76
Star
4
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
11
列表
看板
标记
里程碑
合并请求
1
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PGL
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
11
Issue
11
列表
看板
标记
里程碑
合并请求
1
合并请求
1
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
ecbe2e32
编写于
9月 10, 2020
作者:
S
sys1874
提交者:
GitHub
9月 10, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Update main_protein.py
上级
3c8705e7
变更
1
显示空白变更内容
内联
并排
Showing
1 changed file
with
13 addition
and
12 deletion
+13
-12
ogb_examples/nodeproppred/unimp/main_protein.py
ogb_examples/nodeproppred/unimp/main_protein.py
+13
-12
未找到文件。
ogb_examples/nodeproppred/unimp/main_protein.py
浏览文件 @
ecbe2e32
...
@@ -23,7 +23,7 @@ evaluator = Evaluator(name='ogbn-proteins')
...
@@ -23,7 +23,7 @@ evaluator = Evaluator(name='ogbn-proteins')
def
get_config
():
def
get_config
():
parser
=
argparse
.
ArgumentParser
()
parser
=
argparse
.
ArgumentParser
()
##
基本模型参数
##
model_arg
model_group
=
parser
.
add_argument_group
(
'model_base_arg'
)
model_group
=
parser
.
add_argument_group
(
'model_base_arg'
)
model_group
.
add_argument
(
'--num_layers'
,
default
=
7
,
type
=
int
)
model_group
.
add_argument
(
'--num_layers'
,
default
=
7
,
type
=
int
)
model_group
.
add_argument
(
'--hidden_size'
,
default
=
64
,
type
=
int
)
model_group
.
add_argument
(
'--hidden_size'
,
default
=
64
,
type
=
int
)
...
@@ -31,7 +31,7 @@ def get_config():
...
@@ -31,7 +31,7 @@ def get_config():
model_group
.
add_argument
(
'--dropout'
,
default
=
0.1
,
type
=
float
)
model_group
.
add_argument
(
'--dropout'
,
default
=
0.1
,
type
=
float
)
model_group
.
add_argument
(
'--attn_dropout'
,
default
=
0
,
type
=
float
)
model_group
.
add_argument
(
'--attn_dropout'
,
default
=
0
,
type
=
float
)
## label
embedding模型参数
## label
_embed_arg
embed_group
=
parser
.
add_argument_group
(
'embed_arg'
)
embed_group
=
parser
.
add_argument_group
(
'embed_arg'
)
embed_group
.
add_argument
(
'--use_label_e'
,
action
=
'store_true'
)
embed_group
.
add_argument
(
'--use_label_e'
,
action
=
'store_true'
)
embed_group
.
add_argument
(
'--label_rate'
,
default
=
0.5
,
type
=
float
)
embed_group
.
add_argument
(
'--label_rate'
,
default
=
0.5
,
type
=
float
)
...
@@ -90,15 +90,16 @@ def eval_test(parser, program, model, test_exe, graph, y_true, split_idx):
...
@@ -90,15 +90,16 @@ def eval_test(parser, program, model, test_exe, graph, y_true, split_idx):
def
train_loop
(
parser
,
start_program
,
main_program
,
test_program
,
def
train_loop
(
parser
,
start_program
,
main_program
,
test_program
,
model
,
graph
,
label
,
split_idx
,
exe
,
run_id
,
wf
=
None
):
model
,
graph
,
label
,
split_idx
,
exe
,
run_id
,
wf
=
None
):
#
启动上文构建的训练器
#
build up training program
exe
.
run
(
start_program
)
exe
.
run
(
start_program
)
max_acc
=
0
# 最佳test_acc
max_acc
=
0
# best test_acc
max_step
=
0
# 最佳test_acc 对应step
max_step
=
0
# step for best test_acc
max_val_acc
=
0
# 最佳val_acc
max_val_acc
=
0
# best val_acc
max_cor_acc
=
0
# 最佳val_acc对应test_acc
max_cor_acc
=
0
# test_acc for best val_acc
max_cor_step
=
0
# 最佳val_acc对应step
max_cor_step
=
0
# step for best val_acc
#训练循环
#training loop
graph
.
node_feat
[
"label"
]
=
label
graph
.
node_feat
[
"label"
]
=
label
graph
.
node_feat
[
"nid"
]
=
np
.
arange
(
0
,
graph
.
num_nodes
)
graph
.
node_feat
[
"nid"
]
=
np
.
arange
(
0
,
graph
.
num_nodes
)
...
@@ -112,7 +113,7 @@ def train_loop(parser, start_program, main_program, test_program,
...
@@ -112,7 +113,7 @@ def train_loop(parser, start_program, main_program, test_program,
for
epoch_id
in
tqdm
(
range
(
parser
.
epochs
)):
for
epoch_id
in
tqdm
(
range
(
parser
.
epochs
)):
for
subgraph
in
random_partition
(
num_clusters
=
9
,
graph
=
graph
,
shuffle
=
True
):
for
subgraph
in
random_partition
(
num_clusters
=
9
,
graph
=
graph
,
shuffle
=
True
):
#
运行训练器
#
start training
if
parser
.
use_label_e
:
if
parser
.
use_label_e
:
feed_dict
=
model
.
gw
.
to_feed
(
subgraph
)
feed_dict
=
model
.
gw
.
to_feed
(
subgraph
)
sub_idx
=
set
(
subgraph
.
node_feat
[
"nid"
])
sub_idx
=
set
(
subgraph
.
node_feat
[
"nid"
])
...
@@ -139,7 +140,7 @@ def train_loop(parser, start_program, main_program, test_program,
...
@@ -139,7 +140,7 @@ def train_loop(parser, start_program, main_program, test_program,
fetch_list
=
[
model
.
avg_cost
])
fetch_list
=
[
model
.
avg_cost
])
loss
=
loss
[
0
]
loss
=
loss
[
0
]
#
测试结果
#
eval result
if
(
epoch_id
+
1
)
>
parser
.
epochs
*
0.9
:
if
(
epoch_id
+
1
)
>
parser
.
epochs
*
0.9
:
result
=
eval_test
(
parser
,
test_program
,
model
,
exe
,
graph
,
label
,
split_idx
)
result
=
eval_test
(
parser
,
test_program
,
model
,
exe
,
graph
,
label
,
split_idx
)
train_acc
,
valid_acc
,
test_acc
=
result
train_acc
,
valid_acc
,
test_acc
=
result
...
@@ -221,7 +222,7 @@ if __name__ == '__main__':
...
@@ -221,7 +222,7 @@ if __name__ == '__main__':
model
.
train_program
()
model
.
train_program
()
adam_optimizer
=
optimizer_func
(
parser
.
lr
)
#
训练优化函数
adam_optimizer
=
optimizer_func
(
parser
.
lr
)
#
optimizer
adam_optimizer
.
minimize
(
model
.
avg_cost
)
adam_optimizer
.
minimize
(
model
.
avg_cost
)
exe
=
F
.
Executor
(
place
)
exe
=
F
.
Executor
(
place
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录