未验证 提交 62cb9073 编写于 作者: S sys1874 提交者: GitHub

Update main_arxiv.py

上级 1e962841
...@@ -20,7 +20,7 @@ evaluator = Evaluator(name='ogbn-arxiv') ...@@ -20,7 +20,7 @@ evaluator = Evaluator(name='ogbn-arxiv')
def get_config(): def get_config():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
## 基本模型参数 ## model_arg
model_group=parser.add_argument_group('model_base_arg') model_group=parser.add_argument_group('model_base_arg')
model_group.add_argument('--num_layers', default=3, type=int) model_group.add_argument('--num_layers', default=3, type=int)
model_group.add_argument('--hidden_size', default=128, type=int) model_group.add_argument('--hidden_size', default=128, type=int)
...@@ -28,7 +28,7 @@ def get_config(): ...@@ -28,7 +28,7 @@ def get_config():
model_group.add_argument('--dropout', default=0.3, type=float) model_group.add_argument('--dropout', default=0.3, type=float)
model_group.add_argument('--attn_dropout', default=0, type=float) model_group.add_argument('--attn_dropout', default=0, type=float)
## label embedding模型参数 ## label_embed_arg
embed_group=parser.add_argument_group('embed_arg') embed_group=parser.add_argument_group('embed_arg')
embed_group.add_argument('--use_label_e', action='store_true') embed_group.add_argument('--use_label_e', action='store_true')
embed_group.add_argument('--label_rate', default=0.625, type=float) embed_group.add_argument('--label_rate', default=0.625, type=float)
...@@ -81,17 +81,17 @@ def eval_test(parser, program, model, test_exe, graph, y_true, split_idx): ...@@ -81,17 +81,17 @@ def eval_test(parser, program, model, test_exe, graph, y_true, split_idx):
def train_loop(parser, start_program, main_program, test_program, def train_loop(parser, start_program, main_program, test_program,
model, graph, label, split_idx, exe, run_id, wf=None): model, graph, label, split_idx, exe, run_id, wf=None):
#启动上文构建的训练器 #build up training program
exe.run(start_program) exe.run(start_program)
max_acc=0 # 最佳test_acc max_acc=0 # best test_acc
max_step=0 # 最佳test_acc 对应step max_step=0 # step for best test_acc
max_val_acc=0 # 最佳val_acc max_val_acc=0 # best val_acc
max_cor_acc=0 # 最佳val_acc对应test_acc max_cor_acc=0 # test_acc for best val_acc
max_cor_step=0 # 最佳val_acc对应step max_cor_step=0 # step for best val_acc
#训练循环 #training loop
for epoch_id in tqdm(range(parser.epochs)): for epoch_id in tqdm(range(parser.epochs)):
#运行训练器 #start training
if parser.use_label_e: if parser.use_label_e:
feed_dict=model.gw.to_feed(graph) feed_dict=model.gw.to_feed(graph)
...@@ -115,7 +115,7 @@ def train_loop(parser, start_program, main_program, test_program, ...@@ -115,7 +115,7 @@ def train_loop(parser, start_program, main_program, test_program,
# print(loss[1][0]) # print(loss[1][0])
loss = loss[0] loss = loss[0]
#测试结果 #eval result
result = eval_test(parser, test_program, model, exe, graph, label, split_idx) result = eval_test(parser, test_program, model, exe, graph, label, split_idx)
train_acc, valid_acc, test_acc = result train_acc, valid_acc, test_acc = result
...@@ -191,11 +191,7 @@ if __name__ == '__main__': ...@@ -191,11 +191,7 @@ if __name__ == '__main__':
test_prog=train_prog.clone(for_test=True) test_prog=train_prog.clone(for_test=True)
model.train_program() model.train_program()
adam_optimizer = optimizer_func(parser.lr)#optimizer
# ave_loss = train_program(pred_output)#训练程序
# lr, global_step= linear_warmup_decay(parser.lr, parser.epochs*0.1, parser.epochs)
# adam_optimizer = optimizer_func(lr)#训练优化函数
adam_optimizer = optimizer_func(parser.lr)#训练优化函数
adam_optimizer.minimize(model.avg_cost) adam_optimizer.minimize(model.avg_cost)
exe = F.Executor(place) exe = F.Executor(place)
...@@ -206,4 +202,4 @@ if __name__ == '__main__': ...@@ -206,4 +202,4 @@ if __name__ == '__main__':
total_test_acc+=train_loop(parser, startup_prog, train_prog, test_prog, model, total_test_acc+=train_loop(parser, startup_prog, train_prog, test_prog, model,
graph, label, split_idx, exe, run_i, wf) graph, label, split_idx, exe, run_i, wf)
wf.write(f'average: {100 * (total_test_acc/parser.runs):.2f}%') wf.write(f'average: {100 * (total_test_acc/parser.runs):.2f}%')
wf.close() wf.close()
\ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册