diff --git a/README.md b/README.md index 533731674909bb755c83fbc41fc74c662d2c775a..5456dcc58359a9138b0fa6b65bcf5c63688095ce 100644 --- a/README.md +++ b/README.md @@ -83,7 +83,7 @@ python train_ppo.py --load_sft_model "./out_sft/rwkv-190.pth" --load_rm_model " --ctx_len 1024 --epoch_steps 200 --epoch_count 1000 --epoch_begin 0 --epoch_save 2 \ --micro_bsz 2 --n_layer 24 --n_embd 2048 --pre_ffn 0 --head_qk 0 \ --lr_init 1e-5 --lr_final 1e-5 --warmup_steps 0 --beta1 0.9 --beta2 0.999 --adam_eps 1e-8 \ ---accelerator gpu --devices 1 --precision bf16 --strategy deepspeed_stage_2_offload --grad_cp 0 \ +--accelerator gpu --devices 1 --precision bf16 --strategy deepspeed_stage_2_offload --grad_cp 1 \ --my_qa_mask 1 ``` diff --git a/src/rlhf/ppo.py b/src/rlhf/ppo.py index 4d26e31512b786f3b077b47afb3d85b67bc693cf..0e1dc65646e94da194a32350a20cafd03c8433b2 100644 --- a/src/rlhf/ppo.py +++ b/src/rlhf/ppo.py @@ -71,8 +71,7 @@ class ActorCritic(nn.Module): state, max_seq_len, eos_token = None, - return_values = False, - **kwargs + return_values = False ): # 产生一条 response,相当于采取了一次 action actions = self.actor.generate( @@ -324,10 +323,8 @@ class RLHF(pl.LightningModule): def generate( self, max_seq_len, - *args, prompt, - num_samples = 4, # sample 4 per prompt and select the one with highest reward - **kwargs + num_samples = 4 # sample 4 per prompt and select the one with highest reward ): ''' 未参与训练,仅推理时使用 ''' @@ -345,10 +342,8 @@ class RLHF(pl.LightningModule): _ ) = self.actor_critic.generate( prompt, - *args, max_seq_len = max_seq_len, - return_values = False, - **kwargs + return_values = False ) rewards = self.reward_model(