From 62cb90733a6bcac4eebab24db6e8d0841c9a2cfc Mon Sep 17 00:00:00 2001 From: sys1874 <578417645@qq.com> Date: Thu, 10 Sep 2020 15:30:38 +0800 Subject: [PATCH] Update main_arxiv.py --- ogb_examples/nodeproppred/unimp/main_arxiv.py | 30 ++++++++----------- 1 file changed, 13 insertions(+), 17 deletions(-) diff --git a/ogb_examples/nodeproppred/unimp/main_arxiv.py b/ogb_examples/nodeproppred/unimp/main_arxiv.py index b82742d..4a13fd6 100644 --- a/ogb_examples/nodeproppred/unimp/main_arxiv.py +++ b/ogb_examples/nodeproppred/unimp/main_arxiv.py @@ -20,7 +20,7 @@ evaluator = Evaluator(name='ogbn-arxiv') def get_config(): parser = argparse.ArgumentParser() - ## 基本模型参数 + ## model_arg model_group=parser.add_argument_group('model_base_arg') model_group.add_argument('--num_layers', default=3, type=int) model_group.add_argument('--hidden_size', default=128, type=int) @@ -28,7 +28,7 @@ def get_config(): model_group.add_argument('--dropout', default=0.3, type=float) model_group.add_argument('--attn_dropout', default=0, type=float) - ## label embedding模型参数 + ## label_embed_arg embed_group=parser.add_argument_group('embed_arg') embed_group.add_argument('--use_label_e', action='store_true') embed_group.add_argument('--label_rate', default=0.625, type=float) @@ -81,17 +81,17 @@ def eval_test(parser, program, model, test_exe, graph, y_true, split_idx): def train_loop(parser, start_program, main_program, test_program, model, graph, label, split_idx, exe, run_id, wf=None): - #启动上文构建的训练器 + #build up training program exe.run(start_program) - max_acc=0 # 最佳test_acc - max_step=0 # 最佳test_acc 对应step - max_val_acc=0 # 最佳val_acc - max_cor_acc=0 # 最佳val_acc对应test_acc - max_cor_step=0 # 最佳val_acc对应step - #训练循环 + max_acc=0 # best test_acc + max_step=0 # step for best test_acc + max_val_acc=0 # best val_acc + max_cor_acc=0 # test_acc for best val_acc + max_cor_step=0 # step for best val_acc + #training loop for epoch_id in tqdm(range(parser.epochs)): - #运行训练器 + #start training if parser.use_label_e: feed_dict=model.gw.to_feed(graph) @@ -115,7 +115,7 @@ def train_loop(parser, start_program, main_program, test_program, # print(loss[1][0]) loss = loss[0] - #测试结果 + #eval result result = eval_test(parser, test_program, model, exe, graph, label, split_idx) train_acc, valid_acc, test_acc = result @@ -191,11 +191,7 @@ if __name__ == '__main__': test_prog=train_prog.clone(for_test=True) model.train_program() - -# ave_loss = train_program(pred_output)#训练程序 -# lr, global_step= linear_warmup_decay(parser.lr, parser.epochs*0.1, parser.epochs) -# adam_optimizer = optimizer_func(lr)#训练优化函数 - adam_optimizer = optimizer_func(parser.lr)#训练优化函数 + adam_optimizer = optimizer_func(parser.lr)#optimizer adam_optimizer.minimize(model.avg_cost) exe = F.Executor(place) @@ -206,4 +202,4 @@ if __name__ == '__main__': total_test_acc+=train_loop(parser, startup_prog, train_prog, test_prog, model, graph, label, split_idx, exe, run_i, wf) wf.write(f'average: {100 * (total_test_acc/parser.runs):.2f}%') - wf.close() \ No newline at end of file + wf.close() -- GitLab