From a2aabac7a824616fed7537148bc6ac383c04dec3 Mon Sep 17 00:00:00 2001 From: u010280923 Date: Fri, 17 Mar 2023 18:13:00 +0800 Subject: [PATCH] bug fixed --- README.md | 2 +- train_ppo.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 94f50da..511f1fd 100644 --- a/README.md +++ b/README.md @@ -77,7 +77,7 @@ python train_rm.py --load_model "./out_sft/rwkv-190.pth" --wandb "" --proj_dir ### PPO Model (Reinforcement learning from Human Feedback) ``` -python train_ppo.py --load_sft_model "./out_sft/rwkv-190.pth" --load_rm_model "./out_rm/rm-2.pth" --wandb "" \ +python train_ppo.py --load_model "./out_sft/rwkv-190.pth" --load_rm_model "./out_rm/rm-2.pth" --wandb "" \ --proj_dir "out_rlhf" \ --data_file "data/rm_mock_data.csv" --data_type "utf-8" --vocab_size 50277 \ --ctx_len 1024 --epoch_steps 200 --epoch_count 1000 --epoch_begin 0 --epoch_save 2 \ diff --git a/train_ppo.py b/train_ppo.py index 951f964..de0d8a8 100644 --- a/train_ppo.py +++ b/train_ppo.py @@ -57,7 +57,7 @@ if __name__ == "__main__": parser = ArgumentParser() - parser.add_argument("--load_sft_model", default="", type=str) # full path, with .pth + parser.add_argument("--load_model", default="", type=str) # full path, with .pth parser.add_argument("--load_rm_model", default="", type=str) # full path, with .pth parser.add_argument("--wandb", default="", type=str) # wandb project name. if "" then don't use wandb parser.add_argument("--proj_dir", default="out", type=str) -- GitLab