From e5ec593bc7c9cf7ecc9177e6426fdec2213f8350 Mon Sep 17 00:00:00 2001 From: sys1874 <578417645@qq.com> Date: Tue, 8 Sep 2020 11:30:55 +0800 Subject: [PATCH] Update main_arxiv.py --- ogb_examples/nodeproppred/unimp/main_arxiv.py | 35 +++++++------------ 1 file changed, 13 insertions(+), 22 deletions(-) diff --git a/ogb_examples/nodeproppred/unimp/main_arxiv.py b/ogb_examples/nodeproppred/unimp/main_arxiv.py index b82742d..f1ce99b 100644 --- a/ogb_examples/nodeproppred/unimp/main_arxiv.py +++ b/ogb_examples/nodeproppred/unimp/main_arxiv.py @@ -20,7 +20,7 @@ evaluator = Evaluator(name='ogbn-arxiv') def get_config(): parser = argparse.ArgumentParser() - ## 基本模型参数 + ## model_base_arg model_group=parser.add_argument_group('model_base_arg') model_group.add_argument('--num_layers', default=3, type=int) model_group.add_argument('--hidden_size', default=128, type=int) @@ -28,7 +28,7 @@ def get_config(): model_group.add_argument('--dropout', default=0.3, type=float) model_group.add_argument('--attn_dropout', default=0, type=float) - ## label embedding模型参数 + ## embed_arg embed_group=parser.add_argument_group('embed_arg') embed_group.add_argument('--use_label_e', action='store_true') embed_group.add_argument('--label_rate', default=0.625, type=float) @@ -42,10 +42,6 @@ def get_config(): train_group.add_argument('--log_file', default='result_arxiv.txt', type=str) return parser.parse_args() -# def optimizer_func(lr=0.01): -# return F.optimizer.AdamOptimizer(learning_rate=lr, regularization=F.regularizer.L2Decay( -# regularization_coeff=0.001)) - def optimizer_func(lr=0.01): return F.optimizer.AdamOptimizer(learning_rate=lr, regularization=F.regularizer.L2Decay( regularization_coeff=0.0005)) @@ -81,17 +77,16 @@ def eval_test(parser, program, model, test_exe, graph, y_true, split_idx): def train_loop(parser, start_program, main_program, test_program, model, graph, label, split_idx, exe, run_id, wf=None): - #启动上文构建的训练器 + #start_program exe.run(start_program) - max_acc=0 # 最佳test_acc - max_step=0 # 最佳test_acc 对应step - max_val_acc=0 # 最佳val_acc - max_cor_acc=0 # 最佳val_acc对应test_acc - max_cor_step=0 # 最佳val_acc对应step - #训练循环 + max_acc=0 # best test_acc + max_step=0 # step for best_test_acc + max_val_acc=0 # best val_acc + max_cor_acc=0 # test_acc for best_val_acc + max_cor_step=0 # step for test_acc + #training loop - for epoch_id in tqdm(range(parser.epochs)): - #运行训练器 + for epoch_id in tqdm(range(parser.epochs)): if parser.use_label_e: feed_dict=model.gw.to_feed(graph) @@ -115,7 +110,7 @@ def train_loop(parser, start_program, main_program, test_program, # print(loss[1][0]) loss = loss[0] - #测试结果 + #test result result = eval_test(parser, test_program, model, exe, graph, label, split_idx) train_acc, valid_acc, test_acc = result @@ -191,11 +186,7 @@ if __name__ == '__main__': test_prog=train_prog.clone(for_test=True) model.train_program() - -# ave_loss = train_program(pred_output)#训练程序 -# lr, global_step= linear_warmup_decay(parser.lr, parser.epochs*0.1, parser.epochs) -# adam_optimizer = optimizer_func(lr)#训练优化函数 - adam_optimizer = optimizer_func(parser.lr)#训练优化函数 + adam_optimizer = optimizer_func(parser.lr)#adam_optimizer adam_optimizer.minimize(model.avg_cost) exe = F.Executor(place) @@ -206,4 +197,4 @@ if __name__ == '__main__': total_test_acc+=train_loop(parser, startup_prog, train_prog, test_prog, model, graph, label, split_idx, exe, run_i, wf) wf.write(f'average: {100 * (total_test_acc/parser.runs):.2f}%') - wf.close() \ No newline at end of file + wf.close() -- GitLab