From cf2bc5221ca157fb78968bb22147dbd6028bf5f0 Mon Sep 17 00:00:00 2001 From: u010280923 Date: Mon, 20 Mar 2023 17:42:07 +0800 Subject: [PATCH] opt ppo model --- README.md | 2 +- src/rlhf/ppo.py | 11 +++-------- 2 files changed, 4 insertions(+), 9 deletions(-) diff --git a/README.md b/README.md index 5337316..5456dcc 100644 --- a/README.md +++ b/README.md @@ -83,7 +83,7 @@ python train_ppo.py --load_sft_model "./out_sft/rwkv-190.pth" --load_rm_model " --ctx_len 1024 --epoch_steps 200 --epoch_count 1000 --epoch_begin 0 --epoch_save 2 \ --micro_bsz 2 --n_layer 24 --n_embd 2048 --pre_ffn 0 --head_qk 0 \ --lr_init 1e-5 --lr_final 1e-5 --warmup_steps 0 --beta1 0.9 --beta2 0.999 --adam_eps 1e-8 \ ---accelerator gpu --devices 1 --precision bf16 --strategy deepspeed_stage_2_offload --grad_cp 0 \ +--accelerator gpu --devices 1 --precision bf16 --strategy deepspeed_stage_2_offload --grad_cp 1 \ --my_qa_mask 1 ``` diff --git a/src/rlhf/ppo.py b/src/rlhf/ppo.py index 4d26e31..0e1dc65 100644 --- a/src/rlhf/ppo.py +++ b/src/rlhf/ppo.py @@ -71,8 +71,7 @@ class ActorCritic(nn.Module): state, max_seq_len, eos_token = None, - return_values = False, - **kwargs + return_values = False ): # 产生一条 response,相当于采取了一次 action actions = self.actor.generate( @@ -324,10 +323,8 @@ class RLHF(pl.LightningModule): def generate( self, max_seq_len, - *args, prompt, - num_samples = 4, # sample 4 per prompt and select the one with highest reward - **kwargs + num_samples = 4 # sample 4 per prompt and select the one with highest reward ): ''' 未参与训练,仅推理时使用 ''' @@ -345,10 +342,8 @@ class RLHF(pl.LightningModule): _ ) = self.actor_critic.generate( prompt, - *args, max_seq_len = max_seq_len, - return_values = False, - **kwargs + return_values = False ) rewards = self.reward_model( -- GitLab