提交 cf2bc522 编写于 作者: U u010280923

opt ppo model

上级 b7f231a9
...@@ -83,7 +83,7 @@ python train_ppo.py --load_sft_model "./out_sft/rwkv-190.pth" --load_rm_model " ...@@ -83,7 +83,7 @@ python train_ppo.py --load_sft_model "./out_sft/rwkv-190.pth" --load_rm_model "
--ctx_len 1024 --epoch_steps 200 --epoch_count 1000 --epoch_begin 0 --epoch_save 2 \ --ctx_len 1024 --epoch_steps 200 --epoch_count 1000 --epoch_begin 0 --epoch_save 2 \
--micro_bsz 2 --n_layer 24 --n_embd 2048 --pre_ffn 0 --head_qk 0 \ --micro_bsz 2 --n_layer 24 --n_embd 2048 --pre_ffn 0 --head_qk 0 \
--lr_init 1e-5 --lr_final 1e-5 --warmup_steps 0 --beta1 0.9 --beta2 0.999 --adam_eps 1e-8 \ --lr_init 1e-5 --lr_final 1e-5 --warmup_steps 0 --beta1 0.9 --beta2 0.999 --adam_eps 1e-8 \
--accelerator gpu --devices 1 --precision bf16 --strategy deepspeed_stage_2_offload --grad_cp 0 \ --accelerator gpu --devices 1 --precision bf16 --strategy deepspeed_stage_2_offload --grad_cp 1 \
--my_qa_mask 1 --my_qa_mask 1
``` ```
......
...@@ -71,8 +71,7 @@ class ActorCritic(nn.Module): ...@@ -71,8 +71,7 @@ class ActorCritic(nn.Module):
state, state,
max_seq_len, max_seq_len,
eos_token = None, eos_token = None,
return_values = False, return_values = False
**kwargs
): ):
# 产生一条 response,相当于采取了一次 action # 产生一条 response,相当于采取了一次 action
actions = self.actor.generate( actions = self.actor.generate(
...@@ -324,10 +323,8 @@ class RLHF(pl.LightningModule): ...@@ -324,10 +323,8 @@ class RLHF(pl.LightningModule):
def generate( def generate(
self, self,
max_seq_len, max_seq_len,
*args,
prompt, prompt,
num_samples = 4, # sample 4 per prompt and select the one with highest reward num_samples = 4 # sample 4 per prompt and select the one with highest reward
**kwargs
): ):
''' 未参与训练,仅推理时使用 ''' 未参与训练,仅推理时使用
''' '''
...@@ -345,10 +342,8 @@ class RLHF(pl.LightningModule): ...@@ -345,10 +342,8 @@ class RLHF(pl.LightningModule):
_ _
) = self.actor_critic.generate( ) = self.actor_critic.generate(
prompt, prompt,
*args,
max_seq_len = max_seq_len, max_seq_len = max_seq_len,
return_values = False, return_values = False
**kwargs
) )
rewards = self.reward_model( rewards = self.reward_model(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册