diff --git a/model_zoo/gat/README.md b/model_zoo/gat/README.md index 7c30e08851734fdd62fc740bf9fc96063d8bb0a9..0c46aebbaf4a05e9452f75d6c78a8672aee41cc9 100644 --- a/model_zoo/gat/README.md +++ b/model_zoo/gat/README.md @@ -72,9 +72,9 @@ sh run_process_data.sh [SRC_PATH] [DATASET_NAME] >> Launch ``` #Generate dataset in mindrecord format for cora -sh run_process_data.sh cora +./run_process_data.sh ./data cora #Generate dataset in mindrecord format for citeseer -sh run_process_data.sh citeseer +./run_process_data.sh ./data citeseer ``` # Features diff --git a/model_zoo/gat/train.py b/model_zoo/gat/train.py index af1808b995a7123cc2f4fcff381ad204203e6a62..acfbb05b78af07ad6ab9d7c93add4af5267db75c 100644 --- a/model_zoo/gat/train.py +++ b/model_zoo/gat/train.py @@ -96,6 +96,8 @@ def train(): if eval_acc >= val_acc_max and eval_loss < val_loss_min: val_acc_model = eval_acc val_loss_model = eval_loss + if os.path.exists("ckpts/gat.ckpt"): + os.remove("ckpts/gat.ckpt") _exec_save_checkpoint(train_net.network, "ckpts/gat.ckpt") val_acc_max = np.max((val_acc_max, eval_acc)) val_loss_min = np.min((val_loss_min, eval_loss))