Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PGL
提交
62cb9073
P
PGL
项目概览
PaddlePaddle
/
PGL
通知
76
Star
4
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
11
列表
看板
标记
里程碑
合并请求
1
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PGL
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
11
Issue
11
列表
看板
标记
里程碑
合并请求
1
合并请求
1
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
62cb9073
编写于
9月 10, 2020
作者:
S
sys1874
提交者:
GitHub
9月 10, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Update main_arxiv.py
上级
1e962841
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
13 addition
and
17 deletion
+13
-17
ogb_examples/nodeproppred/unimp/main_arxiv.py
ogb_examples/nodeproppred/unimp/main_arxiv.py
+13
-17
未找到文件。
ogb_examples/nodeproppred/unimp/main_arxiv.py
浏览文件 @
62cb9073
...
@@ -20,7 +20,7 @@ evaluator = Evaluator(name='ogbn-arxiv')
...
@@ -20,7 +20,7 @@ evaluator = Evaluator(name='ogbn-arxiv')
def
get_config
():
def
get_config
():
parser
=
argparse
.
ArgumentParser
()
parser
=
argparse
.
ArgumentParser
()
##
基本模型参数
##
model_arg
model_group
=
parser
.
add_argument_group
(
'model_base_arg'
)
model_group
=
parser
.
add_argument_group
(
'model_base_arg'
)
model_group
.
add_argument
(
'--num_layers'
,
default
=
3
,
type
=
int
)
model_group
.
add_argument
(
'--num_layers'
,
default
=
3
,
type
=
int
)
model_group
.
add_argument
(
'--hidden_size'
,
default
=
128
,
type
=
int
)
model_group
.
add_argument
(
'--hidden_size'
,
default
=
128
,
type
=
int
)
...
@@ -28,7 +28,7 @@ def get_config():
...
@@ -28,7 +28,7 @@ def get_config():
model_group
.
add_argument
(
'--dropout'
,
default
=
0.3
,
type
=
float
)
model_group
.
add_argument
(
'--dropout'
,
default
=
0.3
,
type
=
float
)
model_group
.
add_argument
(
'--attn_dropout'
,
default
=
0
,
type
=
float
)
model_group
.
add_argument
(
'--attn_dropout'
,
default
=
0
,
type
=
float
)
## label
embedding模型参数
## label
_embed_arg
embed_group
=
parser
.
add_argument_group
(
'embed_arg'
)
embed_group
=
parser
.
add_argument_group
(
'embed_arg'
)
embed_group
.
add_argument
(
'--use_label_e'
,
action
=
'store_true'
)
embed_group
.
add_argument
(
'--use_label_e'
,
action
=
'store_true'
)
embed_group
.
add_argument
(
'--label_rate'
,
default
=
0.625
,
type
=
float
)
embed_group
.
add_argument
(
'--label_rate'
,
default
=
0.625
,
type
=
float
)
...
@@ -81,17 +81,17 @@ def eval_test(parser, program, model, test_exe, graph, y_true, split_idx):
...
@@ -81,17 +81,17 @@ def eval_test(parser, program, model, test_exe, graph, y_true, split_idx):
def
train_loop
(
parser
,
start_program
,
main_program
,
test_program
,
def
train_loop
(
parser
,
start_program
,
main_program
,
test_program
,
model
,
graph
,
label
,
split_idx
,
exe
,
run_id
,
wf
=
None
):
model
,
graph
,
label
,
split_idx
,
exe
,
run_id
,
wf
=
None
):
#
启动上文构建的训练器
#
build up training program
exe
.
run
(
start_program
)
exe
.
run
(
start_program
)
max_acc
=
0
#
最佳
test_acc
max_acc
=
0
#
best
test_acc
max_step
=
0
#
最佳test_acc 对应step
max_step
=
0
#
step for best test_acc
max_val_acc
=
0
#
最佳
val_acc
max_val_acc
=
0
#
best
val_acc
max_cor_acc
=
0
#
最佳val_acc对应test
_acc
max_cor_acc
=
0
#
test_acc for best val
_acc
max_cor_step
=
0
#
最佳val_acc对应step
max_cor_step
=
0
#
step for best val_acc
#
训练循环
#
training loop
for
epoch_id
in
tqdm
(
range
(
parser
.
epochs
)):
for
epoch_id
in
tqdm
(
range
(
parser
.
epochs
)):
#
运行训练器
#
start training
if
parser
.
use_label_e
:
if
parser
.
use_label_e
:
feed_dict
=
model
.
gw
.
to_feed
(
graph
)
feed_dict
=
model
.
gw
.
to_feed
(
graph
)
...
@@ -115,7 +115,7 @@ def train_loop(parser, start_program, main_program, test_program,
...
@@ -115,7 +115,7 @@ def train_loop(parser, start_program, main_program, test_program,
# print(loss[1][0])
# print(loss[1][0])
loss
=
loss
[
0
]
loss
=
loss
[
0
]
#
测试结果
#
eval result
result
=
eval_test
(
parser
,
test_program
,
model
,
exe
,
graph
,
label
,
split_idx
)
result
=
eval_test
(
parser
,
test_program
,
model
,
exe
,
graph
,
label
,
split_idx
)
train_acc
,
valid_acc
,
test_acc
=
result
train_acc
,
valid_acc
,
test_acc
=
result
...
@@ -191,11 +191,7 @@ if __name__ == '__main__':
...
@@ -191,11 +191,7 @@ if __name__ == '__main__':
test_prog
=
train_prog
.
clone
(
for_test
=
True
)
test_prog
=
train_prog
.
clone
(
for_test
=
True
)
model
.
train_program
()
model
.
train_program
()
adam_optimizer
=
optimizer_func
(
parser
.
lr
)
#optimizer
# ave_loss = train_program(pred_output)#训练程序
# lr, global_step= linear_warmup_decay(parser.lr, parser.epochs*0.1, parser.epochs)
# adam_optimizer = optimizer_func(lr)#训练优化函数
adam_optimizer
=
optimizer_func
(
parser
.
lr
)
#训练优化函数
adam_optimizer
.
minimize
(
model
.
avg_cost
)
adam_optimizer
.
minimize
(
model
.
avg_cost
)
exe
=
F
.
Executor
(
place
)
exe
=
F
.
Executor
(
place
)
...
@@ -206,4 +202,4 @@ if __name__ == '__main__':
...
@@ -206,4 +202,4 @@ if __name__ == '__main__':
total_test_acc
+=
train_loop
(
parser
,
startup_prog
,
train_prog
,
test_prog
,
model
,
total_test_acc
+=
train_loop
(
parser
,
startup_prog
,
train_prog
,
test_prog
,
model
,
graph
,
label
,
split_idx
,
exe
,
run_i
,
wf
)
graph
,
label
,
split_idx
,
exe
,
run_i
,
wf
)
wf
.
write
(
f
'average:
{
100
*
(
total_test_acc
/
parser
.
runs
):.
2
f
}
%'
)
wf
.
write
(
f
'average:
{
100
*
(
total_test_acc
/
parser
.
runs
):.
2
f
}
%'
)
wf
.
close
()
wf
.
close
()
\ No newline at end of file
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录