Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PGL
提交
3c8705e7
P
PGL
项目概览
PaddlePaddle
/
PGL
通知
76
Star
4
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
11
列表
看板
标记
里程碑
合并请求
1
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PGL
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
11
Issue
11
列表
看板
标记
里程碑
合并请求
1
合并请求
1
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
3c8705e7
编写于
9月 10, 2020
作者:
S
sys1874
提交者:
GitHub
9月 10, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Update main_product.py
上级
62cb9073
变更
1
显示空白变更内容
内联
并排
Showing
1 changed file
with
11 addition
and
15 deletion
+11
-15
ogb_examples/nodeproppred/unimp/main_product.py
ogb_examples/nodeproppred/unimp/main_product.py
+11
-15
未找到文件。
ogb_examples/nodeproppred/unimp/main_product.py
浏览文件 @
3c8705e7
...
...
@@ -22,14 +22,14 @@ evaluator = Evaluator(name='ogbn-products')
def
get_config
():
parser
=
argparse
.
ArgumentParser
()
##
采样参数
##
data_sampling_arg
data_group
=
parser
.
add_argument_group
(
'data_arg'
)
data_group
.
add_argument
(
'--batch_size'
,
default
=
1500
,
type
=
int
)
data_group
.
add_argument
(
'--num_workers'
,
default
=
12
,
type
=
int
)
data_group
.
add_argument
(
'--sizes'
,
default
=
[
10
,
10
,
10
],
type
=
int
,
nargs
=
'+'
)
data_group
.
add_argument
(
'--buf_size'
,
default
=
1000
,
type
=
int
)
##
基本模型参数
##
model_arg
model_group
=
parser
.
add_argument_group
(
'model_base_arg'
)
model_group
.
add_argument
(
'--num_layers'
,
default
=
3
,
type
=
int
)
model_group
.
add_argument
(
'--hidden_size'
,
default
=
128
,
type
=
int
)
...
...
@@ -37,7 +37,7 @@ def get_config():
model_group
.
add_argument
(
'--dropout'
,
default
=
0.3
,
type
=
float
)
model_group
.
add_argument
(
'--attn_dropout'
,
default
=
0
,
type
=
float
)
## label
embedding模型参数
## label
_embed_arg
embed_group
=
parser
.
add_argument_group
(
'embed_arg'
)
embed_group
.
add_argument
(
'--use_label_e'
,
action
=
'store_true'
)
embed_group
.
add_argument
(
'--label_rate'
,
default
=
0.625
,
type
=
float
)
...
...
@@ -113,7 +113,7 @@ def eval_test(parser, test_p_list, model, test_exe, dataset, split_idx):
def
train_loop
(
parser
,
start_program
,
main_program
,
test_p_list
,
model
,
feat_init
,
place
,
dataset
,
split_idx
,
exe
,
run_id
,
wf
=
None
):
#
启动上文构建的训练器
#
build up training program
exe
.
run
(
start_program
)
feat_init
(
place
)
...
...
@@ -122,10 +122,10 @@ def train_loop(parser, start_program, main_program, test_p_list,
max_val_acc
=
0
# 最佳val_acc
max_cor_acc
=
0
# 最佳val_acc对应test_acc
max_cor_step
=
0
# 最佳val_acc对应step
#
训练循环
#
training loop
for
epoch_id
in
range
(
parser
.
epochs
):
#
运行训练器
#
start training
if
parser
.
use_label_e
:
train_idx_temp
=
copy
.
deepcopy
(
split_idx
[
'train'
])
...
...
@@ -158,8 +158,7 @@ def train_loop(parser, start_program, main_program, test_p_list,
print
(
'acc: '
,
(
acc_num
/
unlabel_idx
.
shape
[
0
])
*
100
)
#测试结果
# total=0.0
#eval result
if
(
epoch_id
+
1
)
>=
50
and
(
epoch_id
+
1
)
%
10
==
0
:
result
=
eval_test
(
parser
,
test_p_list
,
model
,
exe
,
dataset
,
split_idx
)
train_acc
,
valid_acc
,
test_acc
=
result
...
...
@@ -242,17 +241,14 @@ if __name__ == '__main__':
# test_prog=train_prog.clone(for_test=True)
model
.
train_program
()
# ave_loss = train_program(pred_output)#训练程序
# lr, global_step= linear_warmup_decay(0.01, 50, 500)
# adam_optimizer = optimizer_func(lr)#训练优化函数
adam_optimizer
=
optimizer_func
(
parser
.
lr
)
#训练优化函数
adam_optimizer
=
optimizer_func
(
parser
.
lr
)
#optimizer
adam_optimizer
.
minimize
(
model
.
avg_cost
)
test_p_list
=
[]
with
F
.
unique_name
.
guard
():
##
input层
##
build up eval program
test_p
=
F
.
Program
()
with
F
.
program_guard
(
test_p
,
):
gw_test
=
pgl
.
graph_wrapper
.
GraphWrapper
(
...
...
@@ -281,7 +277,7 @@ if __name__ == '__main__':
with
F
.
program_guard
(
test_p
,
):
gw_test
=
pgl
.
graph_wrapper
.
GraphWrapper
(
name
=
"product_"
+
str
(
0
))
# feature_batch=model.get_batch_feature(label_feature, test=True)
# 把图在CPU存起
# feature_batch=model.get_batch_feature(label_feature, test=True)
feature_batch
=
F
.
data
(
'hidden_node_feat'
,
shape
=
[
None
,
model
.
num_heads
*
model
.
hidden_size
],
dtype
=
'float32'
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录