diff --git a/README.md b/README.md index 5456dcc58359a9138b0fa6b65bcf5c63688095ce..533731674909bb755c83fbc41fc74c662d2c775a 100644 --- a/README.md +++ b/README.md @@ -83,7 +83,7 @@ python train_ppo.py --load_sft_model "./out_sft/rwkv-190.pth" --load_rm_model " --ctx_len 1024 --epoch_steps 200 --epoch_count 1000 --epoch_begin 0 --epoch_save 2 \ --micro_bsz 2 --n_layer 24 --n_embd 2048 --pre_ffn 0 --head_qk 0 \ --lr_init 1e-5 --lr_final 1e-5 --warmup_steps 0 --beta1 0.9 --beta2 0.999 --adam_eps 1e-8 \ ---accelerator gpu --devices 1 --precision bf16 --strategy deepspeed_stage_2_offload --grad_cp 1 \ +--accelerator gpu --devices 1 --precision bf16 --strategy deepspeed_stage_2_offload --grad_cp 0 \ --my_qa_mask 1 ``` diff --git a/src/model.py b/src/model.py index ee5169d9f4530b0f6ffd9ca13244616ed8a85384..ca8559b89d32d586620d1d4392ca3fb80c736fcf 100644 --- a/src/model.py +++ b/src/model.py @@ -20,7 +20,6 @@ from einops import unpack from src.rlhf.utils import exists from src.rlhf.utils import gumbel_sample from src.rlhf.utils import top_k -from src.rlhf.utils import identity from src.rlhf.utils import eval_decorator # from deepspeed.runtime.fp16.onebit.zoadam import ZeroOneAdam @@ -510,8 +509,7 @@ class RWKV(pl.LightningModule): filter_thres = 0.9, pad_value = 0., eos_token = None, - return_seq_without_prompt = True, - use_tqdm = False + return_seq_without_prompt = True ): ''' 生成 response,用于 ppo 模型的训练 ''' @@ -520,10 +518,9 @@ class RWKV(pl.LightningModule): n, out = prompt.shape[-1], prompt.clone() - wrapper_fn = identity if not use_tqdm else tqdm sample_num_times = max(1, seq_len - prompt.shape[-1]) - for _ in wrapper_fn(range(sample_num_times)): + for _ in tqdm(range(sample_num_times), desc="gen responses"): logits, embeds = self.forward(out, ppo_train=True) logits, embeds = logits[:, -1], embeds[:, -1] diff --git a/src/rlhf/ppo.py b/src/rlhf/ppo.py index 09e343d801a5ab759d4755c779b8b7eb9f2e46f4..4d26e31512b786f3b077b47afb3d85b67bc693cf 100644 --- a/src/rlhf/ppo.py +++ b/src/rlhf/ppo.py @@ -78,9 +78,7 @@ class ActorCritic(nn.Module): actions = self.actor.generate( max_seq_len, prompt = state, - eos_token = eos_token, - use_tqdm = True, - **kwargs + eos_token = eos_token ) # 将 prompt (state) 和 response (action) 进行拼接