diff --git a/README.md b/README.md index 511f1fdd220243b13c0251737a92602dbfdf9631..94f50daddf79d995c98dfff3498efeff585c78f8 100644 --- a/README.md +++ b/README.md @@ -77,7 +77,7 @@ python train_rm.py --load_model "./out_sft/rwkv-190.pth" --wandb "" --proj_dir ### PPO Model (Reinforcement learning from Human Feedback) ``` -python train_ppo.py --load_model "./out_sft/rwkv-190.pth" --load_rm_model "./out_rm/rm-2.pth" --wandb "" \ +python train_ppo.py --load_sft_model "./out_sft/rwkv-190.pth" --load_rm_model "./out_rm/rm-2.pth" --wandb "" \ --proj_dir "out_rlhf" \ --data_file "data/rm_mock_data.csv" --data_type "utf-8" --vocab_size 50277 \ --ctx_len 1024 --epoch_steps 200 --epoch_count 1000 --epoch_begin 0 --epoch_save 2 \ diff --git a/train_ppo.py b/train_ppo.py index de0d8a83b6d588660976fa60f9c4256940e9408f..951f96422f0d2804708ecf59a056cd6250a6259e 100644 --- a/train_ppo.py +++ b/train_ppo.py @@ -57,7 +57,7 @@ if __name__ == "__main__": parser = ArgumentParser() - parser.add_argument("--load_model", default="", type=str) # full path, with .pth + parser.add_argument("--load_sft_model", default="", type=str) # full path, with .pth parser.add_argument("--load_rm_model", default="", type=str) # full path, with .pth parser.add_argument("--wandb", default="", type=str) # wandb project name. if "" then don't use wandb parser.add_argument("--proj_dir", default="out", type=str)