未验证 提交 3c8705e7 编写于 作者: S sys1874 提交者: GitHub

Update main_product.py

上级 62cb9073
......@@ -22,14 +22,14 @@ evaluator = Evaluator(name='ogbn-products')
def get_config():
parser = argparse.ArgumentParser()
## 采样参数
## data_sampling_arg
data_group= parser.add_argument_group('data_arg')
data_group.add_argument('--batch_size', default=1500, type=int)
data_group.add_argument('--num_workers', default=12, type=int)
data_group.add_argument('--sizes', default=[10, 10, 10], type=int, nargs='+' )
data_group.add_argument('--buf_size', default=1000, type=int)
## 基本模型参数
## model_arg
model_group=parser.add_argument_group('model_base_arg')
model_group.add_argument('--num_layers', default=3, type=int)
model_group.add_argument('--hidden_size', default=128, type=int)
......@@ -37,7 +37,7 @@ def get_config():
model_group.add_argument('--dropout', default=0.3, type=float)
model_group.add_argument('--attn_dropout', default=0, type=float)
## label embedding模型参数
## label_embed_arg
embed_group=parser.add_argument_group('embed_arg')
embed_group.add_argument('--use_label_e', action='store_true')
embed_group.add_argument('--label_rate', default=0.625, type=float)
......@@ -113,7 +113,7 @@ def eval_test(parser, test_p_list, model, test_exe, dataset, split_idx):
def train_loop(parser, start_program, main_program, test_p_list,
model, feat_init, place, dataset, split_idx, exe, run_id, wf=None):
#启动上文构建的训练器
#build up training program
exe.run(start_program)
feat_init(place)
......@@ -122,10 +122,10 @@ def train_loop(parser, start_program, main_program, test_p_list,
max_val_acc=0 # 最佳val_acc
max_cor_acc=0 # 最佳val_acc对应test_acc
max_cor_step=0 # 最佳val_acc对应step
#训练循环
#training loop
for epoch_id in range(parser.epochs):
#运行训练器
#start training
if parser.use_label_e:
train_idx_temp=copy.deepcopy(split_idx['train'])
......@@ -158,8 +158,7 @@ def train_loop(parser, start_program, main_program, test_p_list,
print('acc: ', (acc_num/unlabel_idx.shape[0])*100)
#测试结果
# total=0.0
#eval result
if (epoch_id+1)>=50 and (epoch_id+1)%10==0:
result = eval_test(parser, test_p_list, model, exe, dataset, split_idx)
train_acc, valid_acc, test_acc = result
......@@ -242,17 +241,14 @@ if __name__ == '__main__':
# test_prog=train_prog.clone(for_test=True)
model.train_program()
# ave_loss = train_program(pred_output)#训练程序
# lr, global_step= linear_warmup_decay(0.01, 50, 500)
# adam_optimizer = optimizer_func(lr)#训练优化函数
adam_optimizer = optimizer_func(parser.lr)#训练优化函数
adam_optimizer = optimizer_func(parser.lr)#optimizer
adam_optimizer.minimize(model.avg_cost)
test_p_list=[]
with F.unique_name.guard():
## input层
## build up eval program
test_p=F.Program()
with F.program_guard(test_p, ):
gw_test=pgl.graph_wrapper.GraphWrapper(
......@@ -281,7 +277,7 @@ if __name__ == '__main__':
with F.program_guard(test_p, ):
gw_test=pgl.graph_wrapper.GraphWrapper(
name="product_"+str(0))
# feature_batch=model.get_batch_feature(label_feature, test=True) # 把图在CPU存起
# feature_batch=model.get_batch_feature(label_feature, test=True)
feature_batch = F.data( 'hidden_node_feat',
shape=[None, model.num_heads*model.hidden_size],
dtype='float32')
......@@ -322,4 +318,4 @@ if __name__ == '__main__':
total_test_acc+=train_loop(parser, startup_prog, train_prog, test_p_list, model, feat_init,
place, dataset, split_idx, exe, run_i, wf)
wf.write(f'average: {100 * (total_test_acc/parser.runs):.2f}%')
wf.close()
\ No newline at end of file
wf.close()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册