未验证 提交 ecbe2e32 编写于 作者: S sys1874 提交者: GitHub

Update main_protein.py

上级 3c8705e7
......@@ -23,7 +23,7 @@ evaluator = Evaluator(name='ogbn-proteins')
def get_config():
parser = argparse.ArgumentParser()
## 基本模型参数
## model_arg
model_group=parser.add_argument_group('model_base_arg')
model_group.add_argument('--num_layers', default=7, type=int)
model_group.add_argument('--hidden_size', default=64, type=int)
......@@ -31,7 +31,7 @@ def get_config():
model_group.add_argument('--dropout', default=0.1, type=float)
model_group.add_argument('--attn_dropout', default=0, type=float)
## label embedding模型参数
## label_embed_arg
embed_group=parser.add_argument_group('embed_arg')
embed_group.add_argument('--use_label_e', action='store_true')
embed_group.add_argument('--label_rate', default=0.5, type=float)
......@@ -90,15 +90,16 @@ def eval_test(parser, program, model, test_exe, graph, y_true, split_idx):
def train_loop(parser, start_program, main_program, test_program,
model, graph, label, split_idx, exe, run_id, wf=None):
#启动上文构建的训练器
#build up training program
exe.run(start_program)
max_acc=0 # 最佳test_acc
max_step=0 # 最佳test_acc 对应step
max_val_acc=0 # 最佳val_acc
max_cor_acc=0 # 最佳val_acc对应test_acc
max_cor_step=0 # 最佳val_acc对应step
#训练循环
max_acc=0 # best test_acc
max_step=0 # step for best test_acc
max_val_acc=0 # best val_acc
max_cor_acc=0 # test_acc for best val_acc
max_cor_step=0 # step for best val_acc
#training loop
graph.node_feat["label"] = label
graph.node_feat["nid"] = np.arange(0, graph.num_nodes)
......@@ -112,7 +113,7 @@ def train_loop(parser, start_program, main_program, test_program,
for epoch_id in tqdm(range(parser.epochs)):
for subgraph in random_partition(num_clusters=9, graph=graph, shuffle=True):
#运行训练器
#start training
if parser.use_label_e:
feed_dict = model.gw.to_feed(subgraph)
sub_idx = set(subgraph.node_feat["nid"])
......@@ -139,7 +140,7 @@ def train_loop(parser, start_program, main_program, test_program,
fetch_list=[model.avg_cost])
loss = loss[0]
#测试结果
#eval result
if (epoch_id+1) > parser.epochs*0.9:
result = eval_test(parser, test_program, model, exe, graph, label, split_idx)
train_acc, valid_acc, test_acc = result
......@@ -221,7 +222,7 @@ if __name__ == '__main__':
model.train_program()
adam_optimizer = optimizer_func(parser.lr)#训练优化函数
adam_optimizer = optimizer_func(parser.lr)#optimizer
adam_optimizer.minimize(model.avg_cost)
exe = F.Executor(place)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册