提交 0cca2efe 编写于 作者: U u010280923

opt ppo model

上级 e620d171
......@@ -83,7 +83,7 @@ python train_ppo.py --load_sft_model "./out_sft/rwkv-190.pth" --load_rm_model "
--ctx_len 1024 --epoch_steps 200 --epoch_count 1000 --epoch_begin 0 --epoch_save 2 \
--micro_bsz 2 --n_layer 24 --n_embd 2048 --pre_ffn 0 --head_qk 0 \
--lr_init 1e-5 --lr_final 1e-5 --warmup_steps 0 --beta1 0.9 --beta2 0.999 --adam_eps 1e-8 \
--accelerator gpu --devices 1 --precision bf16 --strategy deepspeed_stage_2_offload --grad_cp 1 \
--accelerator gpu --devices 1 --precision bf16 --strategy deepspeed_stage_2_offload --grad_cp 0 \
--my_qa_mask 1
```
......
......@@ -20,7 +20,6 @@ from einops import unpack
from src.rlhf.utils import exists
from src.rlhf.utils import gumbel_sample
from src.rlhf.utils import top_k
from src.rlhf.utils import identity
from src.rlhf.utils import eval_decorator
# from deepspeed.runtime.fp16.onebit.zoadam import ZeroOneAdam
......@@ -510,8 +509,7 @@ class RWKV(pl.LightningModule):
filter_thres = 0.9,
pad_value = 0.,
eos_token = None,
return_seq_without_prompt = True,
use_tqdm = False
return_seq_without_prompt = True
):
''' 生成 response,用于 ppo 模型的训练
'''
......@@ -520,10 +518,9 @@ class RWKV(pl.LightningModule):
n, out = prompt.shape[-1], prompt.clone()
wrapper_fn = identity if not use_tqdm else tqdm
sample_num_times = max(1, seq_len - prompt.shape[-1])
for _ in wrapper_fn(range(sample_num_times)):
for _ in tqdm(range(sample_num_times), desc="gen responses"):
logits, embeds = self.forward(out, ppo_train=True)
logits, embeds = logits[:, -1], embeds[:, -1]
......
......@@ -78,9 +78,7 @@ class ActorCritic(nn.Module):
actions = self.actor.generate(
max_seq_len,
prompt = state,
eos_token = eos_token,
use_tqdm = True,
**kwargs
eos_token = eos_token
)
# 将 prompt (state) 和 response (action) 进行拼接
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册