Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PGL
提交
e5ec593b
P
PGL
项目概览
PaddlePaddle
/
PGL
通知
76
Star
4
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
11
列表
看板
标记
里程碑
合并请求
1
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PGL
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
11
Issue
11
列表
看板
标记
里程碑
合并请求
1
合并请求
1
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
e5ec593b
编写于
9月 08, 2020
作者:
S
sys1874
提交者:
GitHub
9月 08, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Update main_arxiv.py
上级
3b74c110
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
13 addition
and
22 deletion
+13
-22
ogb_examples/nodeproppred/unimp/main_arxiv.py
ogb_examples/nodeproppred/unimp/main_arxiv.py
+13
-22
未找到文件。
ogb_examples/nodeproppred/unimp/main_arxiv.py
浏览文件 @
e5ec593b
...
...
@@ -20,7 +20,7 @@ evaluator = Evaluator(name='ogbn-arxiv')
def
get_config
():
parser
=
argparse
.
ArgumentParser
()
##
基本模型参数
##
model_base_arg
model_group
=
parser
.
add_argument_group
(
'model_base_arg'
)
model_group
.
add_argument
(
'--num_layers'
,
default
=
3
,
type
=
int
)
model_group
.
add_argument
(
'--hidden_size'
,
default
=
128
,
type
=
int
)
...
...
@@ -28,7 +28,7 @@ def get_config():
model_group
.
add_argument
(
'--dropout'
,
default
=
0.3
,
type
=
float
)
model_group
.
add_argument
(
'--attn_dropout'
,
default
=
0
,
type
=
float
)
##
label embedding模型参数
##
embed_arg
embed_group
=
parser
.
add_argument_group
(
'embed_arg'
)
embed_group
.
add_argument
(
'--use_label_e'
,
action
=
'store_true'
)
embed_group
.
add_argument
(
'--label_rate'
,
default
=
0.625
,
type
=
float
)
...
...
@@ -42,10 +42,6 @@ def get_config():
train_group
.
add_argument
(
'--log_file'
,
default
=
'result_arxiv.txt'
,
type
=
str
)
return
parser
.
parse_args
()
# def optimizer_func(lr=0.01):
# return F.optimizer.AdamOptimizer(learning_rate=lr, regularization=F.regularizer.L2Decay(
# regularization_coeff=0.001))
def
optimizer_func
(
lr
=
0.01
):
return
F
.
optimizer
.
AdamOptimizer
(
learning_rate
=
lr
,
regularization
=
F
.
regularizer
.
L2Decay
(
regularization_coeff
=
0.0005
))
...
...
@@ -81,17 +77,16 @@ def eval_test(parser, program, model, test_exe, graph, y_true, split_idx):
def
train_loop
(
parser
,
start_program
,
main_program
,
test_program
,
model
,
graph
,
label
,
split_idx
,
exe
,
run_id
,
wf
=
None
):
#
启动上文构建的训练器
#
start_program
exe
.
run
(
start_program
)
max_acc
=
0
#
最佳
test_acc
max_step
=
0
#
最佳test_acc 对应step
max_val_acc
=
0
#
最佳
val_acc
max_cor_acc
=
0
#
最佳val_acc对应test
_acc
max_cor_step
=
0
#
最佳val_acc对应step
#
训练循环
max_acc
=
0
#
best
test_acc
max_step
=
0
#
step for best_test_acc
max_val_acc
=
0
#
best
val_acc
max_cor_acc
=
0
#
test_acc for best_val
_acc
max_cor_step
=
0
#
step for test_acc
#
training loop
for
epoch_id
in
tqdm
(
range
(
parser
.
epochs
)):
#运行训练器
for
epoch_id
in
tqdm
(
range
(
parser
.
epochs
)):
if
parser
.
use_label_e
:
feed_dict
=
model
.
gw
.
to_feed
(
graph
)
...
...
@@ -115,7 +110,7 @@ def train_loop(parser, start_program, main_program, test_program,
# print(loss[1][0])
loss
=
loss
[
0
]
#
测试结果
#
test result
result
=
eval_test
(
parser
,
test_program
,
model
,
exe
,
graph
,
label
,
split_idx
)
train_acc
,
valid_acc
,
test_acc
=
result
...
...
@@ -191,11 +186,7 @@ if __name__ == '__main__':
test_prog
=
train_prog
.
clone
(
for_test
=
True
)
model
.
train_program
()
# ave_loss = train_program(pred_output)#训练程序
# lr, global_step= linear_warmup_decay(parser.lr, parser.epochs*0.1, parser.epochs)
# adam_optimizer = optimizer_func(lr)#训练优化函数
adam_optimizer
=
optimizer_func
(
parser
.
lr
)
#训练优化函数
adam_optimizer
=
optimizer_func
(
parser
.
lr
)
#adam_optimizer
adam_optimizer
.
minimize
(
model
.
avg_cost
)
exe
=
F
.
Executor
(
place
)
...
...
@@ -206,4 +197,4 @@ if __name__ == '__main__':
total_test_acc
+=
train_loop
(
parser
,
startup_prog
,
train_prog
,
test_prog
,
model
,
graph
,
label
,
split_idx
,
exe
,
run_i
,
wf
)
wf
.
write
(
f
'average:
{
100
*
(
total_test_acc
/
parser
.
runs
):.
2
f
}
%'
)
wf
.
close
()
\ No newline at end of file
wf
.
close
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录