diff --git a/README.md b/README.md index 381a7d7bfd77447f492ebaf8c131f8012076d6fc..a758549cf5fc7a5e59d9bd2e15c112776b0d1c8b 100644 --- a/README.md +++ b/README.md @@ -74,7 +74,18 @@ python train_rm.py --load_model "rwkv-190.pth" --wandb "" --proj_dir "out_rm" \ ``` -### 接入RLHF(Reinforcement Learning with Human Feedback) +### PPO Model (Reinforcement learning from Human Feedback) + +``` +python train_rm.py --load_sft_model "rwkv-190.pth" --load_rm_model "rm-6.pth" --wandb "" \ +--proj_dir "out_rlhf" \ +--data_file "data/rm_mock_data.csv" --data_type "utf-8" --vocab_size 50277 \ +--ctx_len 1024 --epoch_steps 200 --epoch_count 1000 --epoch_begin 0 --epoch_save 2 \ +--micro_bsz 2 --n_layer 24 --n_embd 2048 --pre_ffn 0 --head_qk 0 \ +--lr_init 1e-5 --lr_final 1e-5 --warmup_steps 0 --beta1 0.9 --beta2 0.999 --adam_eps 1e-8 \ +--accelerator gpu --devices 1 --precision bf16 --strategy deepspeed_stage_2_offload --grad_cp 1 \ +--my_qa_mask 1 +``` diff --git a/src/dataset.py b/src/dataset.py index 326a6e770fa6418e2ddb82004ceb6d7777662c80..585e47614ab88fd7914f01902371d991197628c1 100644 --- a/src/dataset.py +++ b/src/dataset.py @@ -323,4 +323,47 @@ class RMDataset(Dataset): m_p = torch.tensor(prompt_prefer_mask, dtype=torch.long) m_a = torch.tensor(prompt_alter_mask, dtype=torch.long) - return x_p, x_a, m_p, m_a \ No newline at end of file + return x_p, x_a, m_p, m_a + + +class PPODataset(Dataset): + def __init__(self, memory): + self.data = memory + + def __len__(self): + return len(self.data) + + def __getitem__(self, index): + # todo(luxin) 是否需要 padding ??? + + sequence, \ + prompt_mask, \ + mask, \ + action_prob, \ + action_log_prob, \ + reward, \ + value = self.data[index] + + return sequence, prompt_mask, mask, action_prob, action_log_prob, reward, value + + +def load_prompt_data_4_ppo(args): + prompt_token_ids = [] + + WORD_NAME = [ + "20B_tokenizer.json", + "20B_tokenizer.json", + ] # [vocab, vocab] for Pile model + + tokenizer = TOKENIZER(WORD_NAME) + + pf = pd.read_csv(args.data_file) + for index, row in pf.iterrows(): + prompt = row["prompt"] + prompt_token_ids.append(tokenizer.tokenizer.encode(prompt)) + + prompt_token_ids = torch.tensor(prompt_token_ids, dtype=torch.long) + + return prompt_token_ids + + diff --git a/src/model.py b/src/model.py index 9a785ce7512c512b819491854da9b71ef5b1f7e7..c8ad6b6fa71811b7d9430871a38b7ff3ff409231 100644 --- a/src/model.py +++ b/src/model.py @@ -12,6 +12,15 @@ from pytorch_lightning.strategies import DeepSpeedStrategy import deepspeed from deepspeed.ops.adam import DeepSpeedCPUAdam, FusedAdam +from tqdm import tqdm +from einops import pack +from einops import unpack + +from src.rlhf.utils import exists +from src.rlhf.utils import gumbel_sample +from src.rlhf.utils import top_k +from src.rlhf.utils import identity + # from deepspeed.runtime.fp16.onebit.zoadam import ZeroOneAdam try: @@ -429,7 +438,7 @@ class RWKV(pl.LightningModule): return cfg.get("offload_optimizer") or cfg.get("offload_param") return False - def forward(self, idx, extra_embed=None, rm_train=False): + def forward(self, idx, extra_embed=None, rm_train=False, ppo_train=False): args = self.args B, T = idx.size() assert T <= args.ctx_len, "Cannot forward, model ctx_len is exhausted." @@ -454,15 +463,15 @@ class RWKV(pl.LightningModule): else: x = block(x) - x = self.ln_out(x) + embeds = self.ln_out(x) # 用于 RM 模型的编码 if rm_train is True: - return x + return embeds if args.head_qk > 0: - q = self.head_q(x)[:, :T, :] - k = self.head_k(x)[:, :T, :] + q = self.head_q(embeds)[:, :T, :] + k = self.head_k(embeds)[:, :T, :] c = (q @ k.transpose(-2, -1)) * (1.0 / args.head_qk) c = c.masked_fill(self.copy_mask[:T, :T] == 0, 0) @@ -473,11 +482,66 @@ class RWKV(pl.LightningModule): elif os.environ["RWKV_FLOAT_MODE"] == "bf16": c = c @ F.one_hot(idx, num_classes=args.vocab_size).bfloat16() - x = self.head(x) + c + logits = self.head(embeds) + c else: - x = self.head(x) - - return x + logits = self.head(embeds) + + # 用于 PPO 模型 + if ppo_train is True: + return logits, embeds + + return logits + + @torch.no_grad() + def generate( + self, + seq_len, + prompt = None, + temperature = 1., + filter_logits_fn = top_k, + filter_thres = 0.9, + pad_value = 0., + eos_token = None, + return_seq_without_prompt = True, + use_tqdm = False, + **kwargs + ): + ''' + ''' + + prompt, leading_dims = pack([prompt], '* n') + + n, out = prompt.shape[-1], prompt.clone() + + wrapper_fn = identity if not use_tqdm else tqdm + sample_num_times = max(1, seq_len - prompt.shape[-1]) + + for _ in wrapper_fn(range(sample_num_times)): + logits, embeds = self.forward(out, return_logits_with_embedding = True, **kwargs) + logits, embeds = logits[:, -1], embeds[:, -1] + + if exists(filter_logits_fn): + logits = filter_logits_fn(logits, thres = filter_thres) + + sample = gumbel_sample(logits, temperature = temperature, dim = -1) + out, _ = pack([out, sample], 'b *') + + if exists(eos_token): + is_eos_tokens = (out == eos_token) + + if is_eos_tokens.any(dim = -1).all(): + # mask out everything after the eos tokens + shifted_is_eos_tokens = F.pad(is_eos_tokens, (1, -1)) + mask = shifted_is_eos_tokens.float().cumsum(dim = -1) >= 1 + out = out.masked_fill(mask, pad_value) + break + + out, = unpack(out, leading_dims, '* n') + + if not return_seq_without_prompt: + return out + + return out[..., n:] def training_step(self, batch, batch_idx): args = self.args diff --git a/src/rlhf/ppo.py b/src/rlhf/ppo.py index e420015cbe0460c67e9c4d85cc062d48de94bf6f..859d17e82dfbdb25334e193383ca509018dc0a4f 100644 --- a/src/rlhf/ppo.py +++ b/src/rlhf/ppo.py @@ -17,6 +17,8 @@ from torch.optim import Adam from torch.utils.data import Dataset, DataLoader from torch.nn.utils.rnn import pad_sequence +from pytorch_lightning.utilities import rank_zero_info + from einops import rearrange, repeat from einops.layers.torch import Rearrange @@ -27,7 +29,7 @@ from src.rlhf.utils import masked_mean, eval_decorator from accelerate import Accelerator -# actor critic - PaLM with lora +# actor critic - rwkv with lora PPOActionCriticReturn = namedtuple('PPOActionCriticReturn', [ 'actions', @@ -43,39 +45,16 @@ class ActorCritic(nn.Module): def __init__( self, rwkv: RWKV, - critic_palm: Optional[RWKV] = None, - pooled_values = False, - actor_lora = True, - critic_lora = True, - actor_lora_r = 8, - critic_lora_r = 8, - actor_lora_scope = 'actor', - critic_lora_scope = 'critic', - actor_dropout = 0., - critic_dropout = 0. + critic: Optional[RWKV] = None, + pooled_values = False ): super().__init__() - self.actor_palm = rwkv - - self.critic_palm = critic_palm - - if not exists(self.critic_palm): - self.critic_palm = copy.deepcopy(rwkv) - - self.actor_palm.set_dropout(actor_dropout) - self.critic_palm.set_dropout(critic_dropout) - - self.actor_lora = actor_lora - self.critic_lora = critic_lora + self.actor = rwkv - self.actor_lora_scope = actor_lora_scope if actor_lora else None - self.critic_lora_scope = critic_lora_scope if critic_lora else None + self.critic = critic - if self.actor_lora: - self.actor_palm.add_finetune_params(actor_lora_scope, lora_r = actor_lora_r) - - if self.critic_lora: - self.critic_palm.add_finetune_params(critic_lora_scope, lora_r = critic_lora_r) + if not exists(self.critic): + self.critic = copy.deepcopy(rwkv) self.pooled_values = pooled_values self.value_head = nn.Sequential( @@ -86,23 +65,6 @@ class ActorCritic(nn.Module): nn.init.zeros_(self.value_head[0].bias) nn.init.orthogonal_(self.value_head[0].weight, gain = math.sqrt(2)) - def actor_parameters(self): - if not self.actor_lora: - return self.actor_palm.parameters() - - return [ - *self.actor_palm.finetune_parameters(self.actor_lora_scope) - ] - - def critic_parameters(self): - if not self.actor_lora: - return [*self.critic_palm.parameters(), *self.value_head.parameters()] - - return [ - *self.critic_palm.finetune_parameters(self.critic_lora_scope), - *self.value_head.parameters() - ] - @torch.no_grad() @eval_decorator def generate( @@ -113,7 +75,8 @@ class ActorCritic(nn.Module): return_values = False, **kwargs ): - actions = self.actor_palm.generate( + # 产生一条 response,相当于采取了一次 action + actions = self.actor.generate( max_seq_len, prompt = state, eos_token = eos_token, @@ -122,21 +85,26 @@ class ActorCritic(nn.Module): **kwargs ) + # 将 prompt (state) 和 response (action) 进行拼接 sequence = torch.cat((state, actions), dim = -1) action_len = actions.shape[-1] state_len = state.shape[-1] + # 构建 prompt_mask (state_mask) 和 response_mask (action_mask) prompt_mask = torch.arange(sequence.shape[-1], device = state.device) < state_len prompt_mask = repeat(prompt_mask, 'n -> b n', b = sequence.shape[0]) action_mask = ~prompt_mask + # 考虑 eos token mask = None if exists(eos_token): mask = ((sequence == eos_token).cumsum(dim = -1) == 0) mask = F.pad(mask, (1, -1), value = True) # include eos token action_mask &= mask + # 将生成的 sequence 输入到 actor 中,得到 action_logits + # 将生成的 sequence 输入到 critic 中,得到 value action_logits, value = self.forward( sequence, mask = action_mask, @@ -158,7 +126,7 @@ class ActorCritic(nn.Module): mask = None, return_values = True ): - action_logits = self.actor_palm( + action_logits = self.actor( x, finetune_scope = self.actor_lora_scope ) @@ -166,7 +134,7 @@ class ActorCritic(nn.Module): if not return_values: return action_logits, None - critic_embeds = self.critic_palm( + critic_embeds = self.critic( x, return_only_embedding = True, finetune_scope = self.critic_lora_scope @@ -278,119 +246,63 @@ def clipped_value_loss(values, rewards, old_values, clip): value_loss_2 = (values.flatten() - rewards) ** 2 return torch.mean(torch.max(value_loss_1, value_loss_2)) -# rlhf trainer +# rlhf @beartype -class RLHFTrainer(nn.Module): +class RLHF(nn.Module): def __init__( self, - *, - prompts: Optional[List[str]] = None, - prompts_path: Optional[str] = None, - prompt_token_ids: Optional[torch.Tensor] = None, - tokenizer: Callable = None, - rwkv: RWKV, - reward_model: RewardModel, - actor_critic: Optional[ActorCritic] = None, - actor_lr = 1e-4, - critic_lr = 1e-4, - actor_wd = 0., - critic_wd = 0., - actor_adam_eps = 1e-7, - critic_adam_eps = 1e-7, - actor_lora = True, - critic_lora = True, - actor_lora_r = 8, - critic_lora_r = 8, - critic_pooled_values = True, - actor_dropout = 0., - critic_dropout = 0., - betas = (0.9, 0.999), - max_norm = None, - eps_clip = 0.2, - value_clip = 0.4, - beta_s = .01, - pad_value = 0., - minibatch_size = 16, - epochs = 1, - kl_div_loss_weight = 0.1, # between old action probs and new action probs - not sure what the right value is - accelerate_kwargs: dict = {}, - use_lion = False + args, + accelerate_kwargs: dict = {} ): super().__init__() - self.accelerate = Accelerator(**accelerate_kwargs) + self.args = args - # take care of prompts -> token ids + self.accelerate = Accelerator(**accelerate_kwargs) - assert (exists(prompts) + exists(prompts_path) + exists(prompt_token_ids)) == 1 + # 加载 RWKV 模型 + rwkv = RWKV(args) - if exists(prompts_path): - path = Path(prompts_path) - prompts = path.read_text().split('\n') + if len(args.load_sft_model) == 0: + rank_zero_info(f"SFT must load model, please input ") + exit(1) - if exists(prompts): - assert len(prompts) > 0, 'no prompts' - assert exists(tokenizer), 'tokenizer must be passed in if raw text prompts are given' - prompt_token_ids = tokenizer(prompts) + rank_zero_info(f"########## Loading {args.load_sft_model}... ##########") + try: + load_dict = torch.load(args.load_sft_model, map_location="cpu") + except: + rank_zero_info(f"Bad checkpoint {args.load_sft_model}") + exit(1) - self.pad_value = pad_value # token pad value - self.num_prompts = prompt_token_ids.shape[0] - self.register_buffer('prompt_token_ids', prompt_token_ids) + if args.load_partial == 1: + load_keys = load_dict.keys() + for k in rwkv.state_dict(): + if k not in load_keys: + load_dict[k] = rwkv.state_dict()[k] + rwkv.load_state_dict(load_dict) - # models self.rwkv = rwkv - if not exists(actor_critic): - actor_critic = ActorCritic( - rwkv = rwkv, - actor_lora = actor_lora, - critic_lora = critic_lora, - actor_lora_r = actor_lora_r, - critic_lora_r = critic_lora_r, - pooled_values = critic_pooled_values, - actor_dropout = actor_dropout, - critic_dropout = critic_dropout - ).to(rwkv.device) + # 使用 RWKV 初始化 actor_critic + actor_critic = ActorCritic( + rwkv = self.rwkv, + actor_lora = args.actor_lora, + critic_lora = args.critic_lora, + actor_lora_r = args.actor_lora_r, + critic_lora_r = args.critic_lora_r, + pooled_values = args.critic_pooled_values, + actor_dropout = args.actor_dropout, + critic_dropout = args.critic_dropout + ).to(self.rwkv.device) self.actor_critic = actor_critic + # 加载 reward_model,并将 reward_model 设置为 evaluation 模式 + reward_model = RewardModel(args) + reward_model.load(args.load_rm_model) self.reward_model = reward_model.eval() - # train hyperparameters - - self.epochs = epochs - self.minibatch_size = minibatch_size - self.max_norm = max_norm - - self.kl_div_loss_weight = kl_div_loss_weight - - # optimizers - - self.actor_optim = get_optimizer(actor_critic.actor_parameters(), lr = actor_lr, wd = actor_wd, betas = betas, eps = actor_adam_eps, use_lion = use_lion) - self.critic_optim = get_optimizer(actor_critic.critic_parameters(), lr = critic_lr, wd = critic_wd, betas = betas, eps = critic_adam_eps, use_lion = use_lion) - - # ppo hyperparams - - self.eps_clip = eps_clip - self.value_clip = value_clip - self.beta_s = beta_s - - # prepare with accelerator - - ( - self.actor_critic, - self.reward_model, - self.actor_optim, - self.critic_optim - ) = self.accelerate.prepare( - self.actor_critic, - self.reward_model, - self.actor_optim, - self.critic_optim - ) - - def print(self, msg): return self.accelerate.print(msg) @@ -414,6 +326,9 @@ class RLHFTrainer(nn.Module): num_samples = 4, # sample 4 per prompt and select the one with highest reward **kwargs ): + ''' 未参与训练,仅推理时使用 + ''' + assert prompt.ndim == 1, 'only one prompt allowed at a time for now' prompt = repeat(prompt, 'n -> b n', b = num_samples) @@ -451,209 +366,147 @@ class RLHFTrainer(nn.Module): return best_sequence - def learn( - self, - memories: Deque[Memory] - ): - # stack all data stored in the memories - - all_memories_stacked_and_padded = list(map(partial(pad_sequence_fixed, batch_first = True), zip(*memories))) - - # prepare dataloader for policy phase training - - dl = create_dataloader(all_memories_stacked_and_padded, self.minibatch_size, device = self.device) - - self.actor_critic.train() + def training_step(self, batch, batch_idx): + sequences, \ + prompt_masks, \ + masks, \ + old_action_probs, \ + old_log_probs, \ + rewards, \ + old_values = batch # PPO training + action_masks = ~prompt_masks & masks - for _ in range(self.epochs): - for ( - sequences, - prompt_masks, - masks, - old_action_probs, - old_log_probs, - rewards, - old_values - ) in dl: - action_masks = ~prompt_masks & masks - - action_logits, values = self.actor_critic( - sequences, - mask = action_masks - ) - - action_logits = shift(action_logits, shift = 1, dim = -2) # need to shift along sequence dimension by 1, since actions start from the last prompt (state) token - action_len = old_log_probs.shape[-1] - - action_probs = action_logits.softmax(dim = -1) - action_log_probs = log_prob(action_probs, sequences) - action_log_probs = action_log_probs[:, -action_len:] - - # calculate entropies, taking into account which part of the sequence is actually an action - - entropies = masked_entropy(action_probs, mask = action_masks) - - # calculate kl div between old action probs and new ones, taking into account which part of the sequence is action or not - - kl_div_loss = 0. - - if self.kl_div_loss_weight > 0: - kl_div_loss = masked_kl_div(action_probs, old_action_probs, mask = action_masks) * self.kl_div_loss_weight - - # handle non-pooled values - - normalize_kwargs = dict() + action_logits, values = self.actor_critic( + sequences, + mask = action_masks + ) - if old_values.ndim == 2: - old_values, values = map(lambda t: shift(t, shift = 1, dim = -2), (old_values, values)) + action_logits = shift(action_logits, shift = 1, dim = -2) # need to shift along sequence dimension by 1, since actions start from the last prompt (state) token + action_len = old_log_probs.shape[-1] - old_values = old_values[:, -action_len:] - values = values[:, -action_len:] - rewards = rearrange(rewards, 'b -> b 1') - normalize_kwargs = dict(dim = -1, mask = action_masks[:, -action_len:]) + action_probs = action_logits.softmax(dim = -1) + action_log_probs = log_prob(action_probs, sequences) + action_log_probs = action_log_probs[:, -action_len:] - if values.ndim < rewards.ndim: - values = rearrange(values, '... -> ... 1') + # calculate entropies, taking into account which part of the sequence is actually an action - # calculate clipped surrogate objective, classic PPO loss + entropies = masked_entropy(action_probs, mask = action_masks) - ratios = (action_log_probs - old_log_probs).exp() - advantages = masked_normalize(rewards - old_values, **normalize_kwargs) + # calculate kl div between old action probs and new ones, taking into account which part of the sequence is action or not - if advantages.ndim == 1: - advantages = rearrange(advantages, 'b -> b 1') + kl_div_loss = 0. - surr1 = ratios * advantages - surr2 = ratios.clamp(1 - self.eps_clip, 1 + self.eps_clip) * advantages - policy_loss = - torch.min(surr1, surr2) - self.beta_s * entropies + if self.args.kl_div_loss_weight > 0: + kl_div_loss = masked_kl_div(action_probs, old_action_probs, mask = action_masks) * self.args.kl_div_loss_weight - # combine losses + # handle non-pooled values - loss = policy_loss.mean() + kl_div_loss + normalize_kwargs = dict() - # update actor + if old_values.ndim == 2: + old_values, values = map(lambda t: shift(t, shift = 1, dim = -2), (old_values, values)) - self.accelerate.backward(loss) + old_values = old_values[:, -action_len:] + values = values[:, -action_len:] + rewards = rearrange(rewards, 'b -> b 1') + normalize_kwargs = dict(dim = -1, mask = action_masks[:, -action_len:]) - self.print(f'policy_loss: {loss.item():.3f}') + if values.ndim < rewards.ndim: + values = rearrange(values, '... -> ... 1') - if exists(self.max_norm): - self.accelerator.clip_grad_norm_(self.actor_critic.actor_parameters(), self.max_norm) + # calculate clipped surrogate objective, classic PPO loss - self.actor_optim.step() - self.actor_optim.zero_grad() + ratios = (action_log_probs - old_log_probs).exp() + advantages = masked_normalize(rewards - old_values, **normalize_kwargs) - # calculate value loss and update value network separate from policy network + if advantages.ndim == 1: + advantages = rearrange(advantages, 'b -> b 1') - value_loss = clipped_value_loss(values, rewards, old_values, self.value_clip) - value_loss = value_loss.mean() + surr1 = ratios * advantages + surr2 = ratios.clamp(1 - self.args.eps_clip, 1 + self.args.eps_clip) * advantages + policy_loss = - torch.min(surr1, surr2) - self.args.beta_s * entropies - self.print(f'critic_loss: {value_loss.item():.3f}') + # actor loss (也称为 policy loss, 是最终要使用模型的 loss) + actor_loss = policy_loss.mean() + kl_div_loss - self.accelerate.backward(value_loss) + # critic loss (也称为 value loss) + # update value network separate from policy network + critic_loss = clipped_value_loss(values, rewards, old_values, self.args.value_clip) + critic_loss = critic_loss.mean() - if exists(self.max_norm): - self.accelerator.clip_grad_norm_(self.actor_critic.critic_parameters(), self.max_norm) + return {'actor_loss': actor_loss.item(), 'critic_loss': critic_loss.item()} - self.critic_optim.step() - self.critic_optim.zero_grad() - def train( - self, - num_episodes = 50000, - max_timesteps = 500, - update_timesteps = 5000, - max_batch_size = 16, - max_seq_len = 2048, - eos_token = None, - temperature = 1. - ): + def make_experience(self, prompts, eos_token=None, temperature=1): + ''' 通过与 environment 交互产生训练数据 + ''' + device = self.device - time = 0 - memories = deque([]) - - for eps in tqdm(range(num_episodes), desc = 'episodes'): - for timestep in range(max_timesteps): - time += 1 - - # select a bunch of random states (prompts) - # and get the action (sampled sequence from palm as well as the action probs) - # also calculate the reward using reward model and store - - rand_prompt_index = randrange(0, self.num_prompts) - - state = self.prompt_token_ids[rand_prompt_index] - - # remove padding from state - - state_mask = state != self.pad_value - state = state[state_mask] - - # get predicted sequence - - ( - actions, - sequence, - mask, - prompt_mask, - action_logits, - value - ) = self.actor_critic.generate( - rearrange(state, 'n -> 1 n'), - max_seq_len = max_seq_len, - eos_token = eos_token, - temperature = temperature, - return_values = True - ) - action_logits = shift(action_logits, shift = 1, dim = -2) # need to shift along sequence dimension by 1, since actions start from the last prompt (state) token - - action_prob = action_logits.softmax(dim = -1) - - action_len = actions.shape[-1] - action_log_prob = log_prob(action_prob, sequence) - action_log_prob = action_log_prob[:, -action_len:] - - actions = rearrange(actions, '1 ... -> ...') - - # get reward as given by supervised trained reward model - - sequence = torch.cat((state, actions), dim = 0) - - prompt_length = len(state) - prompt_mask = torch.arange(sequence.shape[-1], device = device) < prompt_length + # select a bunch of random states (prompts) + # and get the action (sampled sequence from rwkv as well as the action probs) + # also calculate the reward using reward model and store + # 随机挑选一条 prompt + rand_prompt_index = randrange(0, self.num_prompts) + state = self.prompt_token_ids[rand_prompt_index] + + # remove padding from state + state_mask = state != self.args.pad_value + state = state[state_mask] + + # get predicted sequence + # 与 environment 进行交互,其中返回的: + # action 是 response, + # sequence 是 prompt + response, + ( + actions, + sequence, + mask, + prompt_mask, + action_logits, + value + ) = self.actor_critic.generate( + rearrange(state, 'n -> 1 n'), + max_seq_len = self.args.ctx_len, + eos_token = eos_token, + temperature = temperature, + return_values = True + ) + action_logits = shift(action_logits, shift = 1, dim = -2) # need to shift along sequence dimension by 1, since actions start from the last prompt (state) token - sequence = rearrange(sequence, 'n -> 1 n') - prompt_mask = rearrange(prompt_mask, 'n -> 1 n') - mask = rearrange(mask, 'n -> 1 n') if exists(mask) else torch.ones(sequence.shape, dtype = torch.bool, device = device) + action_prob = action_logits.softmax(dim = -1) - reward = self.reward_model( - sequence, - prompt_mask = prompt_mask, - mask = mask, - sample = True - ) + action_len = actions.shape[-1] + action_log_prob = log_prob(action_prob, sequence) + action_log_prob = action_log_prob[:, -action_len:] - detach_to_cpu_ = lambda t: rearrange(t.detach().cpu(), '1 ... -> ...') + actions = rearrange(actions, '1 ... -> ...') - # store memory for learning + # get reward as given by supervised trained reward model + sequence = torch.cat((state, actions), dim = 0) - memories.append(Memory(*map(detach_to_cpu_, ( - sequence, - prompt_mask, - mask, - action_prob, - action_log_prob, - reward, - value - )))) + prompt_length = len(state) + prompt_mask = torch.arange(sequence.shape[-1], device = device) < prompt_length - # learn from the stored memories + sequence = rearrange(sequence, 'n -> 1 n') + prompt_mask = rearrange(prompt_mask, 'n -> 1 n') + mask = rearrange(mask, 'n -> 1 n') if exists(mask) else torch.ones(sequence.shape, dtype = torch.bool, device = device) - if time % update_timesteps == 0: - self.learn(memories) - memories.clear() + reward = self.reward_model( + sequence, + prompt_mask = prompt_mask, + mask = mask, + sample = True + ) - print('rlhf training complete') + return ( + sequence, + prompt_mask, + mask, + action_prob, + action_log_prob, + reward, + value + ) diff --git a/src/rlhf/ppo_old.py b/src/rlhf/ppo_old.py new file mode 100644 index 0000000000000000000000000000000000000000..2a8fbda986140bdbfbf6d7b02f8d40da398256ab --- /dev/null +++ b/src/rlhf/ppo_old.py @@ -0,0 +1,623 @@ +import math +from pathlib import Path +import copy +from tqdm import tqdm +from functools import partial +from collections import deque, namedtuple +from random import randrange + +from beartype import beartype +from beartype.typing import List, Optional, Callable, Deque + +import torch +from torch import nn +import torch.nn.functional as F + +from torch.optim import Adam +from torch.utils.data import Dataset, DataLoader +from torch.nn.utils.rnn import pad_sequence + +from pytorch_lightning.utilities import rank_zero_info + +from einops import rearrange, repeat +from einops.layers.torch import Rearrange + +from src.model import RWKV +from src.rlhf.reward import RewardModel +from src.rlhf.optimizer import get_optimizer +from src.rlhf.utils import masked_mean, eval_decorator + +from accelerate import Accelerator + +# actor critic - PaLM with lora + +PPOActionCriticReturn = namedtuple('PPOActionCriticReturn', [ + 'actions', + 'sequence', + 'mask', + 'prompt_mask', + 'action_logits', + 'values' +]) + +@beartype +class ActorCritic(nn.Module): + def __init__( + self, + rwkv: RWKV, + critic_palm: Optional[RWKV] = None, + pooled_values = False, + actor_lora = True, + critic_lora = True, + actor_lora_r = 8, + critic_lora_r = 8, + actor_lora_scope = 'actor', + critic_lora_scope = 'critic', + actor_dropout = 0., + critic_dropout = 0. + ): + super().__init__() + self.actor_palm = rwkv + + self.critic_palm = critic_palm + + if not exists(self.critic_palm): + self.critic_palm = copy.deepcopy(rwkv) + + self.actor_palm.set_dropout(actor_dropout) + self.critic_palm.set_dropout(critic_dropout) + + self.actor_lora = actor_lora + self.critic_lora = critic_lora + + self.actor_lora_scope = actor_lora_scope if actor_lora else None + self.critic_lora_scope = critic_lora_scope if critic_lora else None + + if self.actor_lora: + self.actor_palm.add_finetune_params(actor_lora_scope, lora_r = actor_lora_r) + + if self.critic_lora: + self.critic_palm.add_finetune_params(critic_lora_scope, lora_r = critic_lora_r) + + self.pooled_values = pooled_values + self.value_head = nn.Sequential( + nn.Linear(rwkv.dim, 1), + Rearrange('... 1 -> ...') + ) + + nn.init.zeros_(self.value_head[0].bias) + nn.init.orthogonal_(self.value_head[0].weight, gain = math.sqrt(2)) + + def actor_parameters(self): + if not self.actor_lora: + return self.actor_palm.parameters() + + return [ + *self.actor_palm.finetune_parameters(self.actor_lora_scope) + ] + + def critic_parameters(self): + if not self.actor_lora: + return [*self.critic_palm.parameters(), *self.value_head.parameters()] + + return [ + *self.critic_palm.finetune_parameters(self.critic_lora_scope), + *self.value_head.parameters() + ] + + @torch.no_grad() + @eval_decorator + def generate( + self, + state, + max_seq_len, + eos_token = None, + return_values = False, + **kwargs + ): + actions = self.actor_palm.generate( + max_seq_len, + prompt = state, + eos_token = eos_token, + finetune_scope = self.actor_lora_scope, + use_tqdm = True, + **kwargs + ) + + sequence = torch.cat((state, actions), dim = -1) + action_len = actions.shape[-1] + state_len = state.shape[-1] + + prompt_mask = torch.arange(sequence.shape[-1], device = state.device) < state_len + prompt_mask = repeat(prompt_mask, 'n -> b n', b = sequence.shape[0]) + + action_mask = ~prompt_mask + + mask = None + if exists(eos_token): + mask = ((sequence == eos_token).cumsum(dim = -1) == 0) + mask = F.pad(mask, (1, -1), value = True) # include eos token + action_mask &= mask + + action_logits, value = self.forward( + sequence, + mask = action_mask, + return_values = return_values + ) + + return PPOActionCriticReturn( + actions, + sequence, + mask, + prompt_mask, + action_logits, + value + ) + + def forward( + self, + x, + mask = None, + return_values = True + ): + action_logits = self.actor_palm( + x, + finetune_scope = self.actor_lora_scope + ) + + if not return_values: + return action_logits, None + + critic_embeds = self.critic_palm( + x, + return_only_embedding = True, + finetune_scope = self.critic_lora_scope + ) + + if self.pooled_values: + critic_embeds = shift(critic_embeds, shift = 1, dim = -2) + critic_embeds = masked_mean(critic_embeds, mask, dim = 1) + + values = self.value_head(critic_embeds) + + return action_logits, values + +# data + +Memory = namedtuple('Memory', [ + 'sequence', + 'prompt_mask', + 'mask', + 'action_prob', + 'action_log_prob', + 'reward', + 'value' +]) + +@beartype +class ExperienceDataset(Dataset): + def __init__( + self, + data: List[torch.Tensor], + device = None + ): + super().__init__() + self.data = data + self.device = device + + def __len__(self): + return self.data[0].shape[0] + + def __getitem__(self, ind): + return tuple(map(lambda t: t[ind].to(self.device), self.data)) + +def create_dataloader(data, batch_size, shuffle = True, device = None, **kwargs): + ds = ExperienceDataset(data, device = device) + return DataLoader(ds, batch_size = batch_size, shuffle = shuffle, **kwargs) + +# helper functions + +def exists(val): + return val is not None + +def default(val, d): + return val if exists(val) else d + +def masked_normalize(t, eps = 1e-5, mask = None, dim = None): + dim = default(dim, tuple(range(t.ndim))) + kwargs = dict(dim = dim, keepdim = True) + + mean = masked_mean(t, mask = mask, **kwargs) + mean_centered = t - mean + var = masked_mean(mean_centered ** 2, mask = mask, **kwargs) + + return mean_centered * var.clamp(min = eps).rsqrt() + +def pad_sequence_fixed(sequences, *args, **kwargs): + first_el = sequences[0] + has_no_dimension = first_el.ndim == 0 + + # if no dimensions, add a single dimension + if has_no_dimension: + sequences = tuple(map(lambda t: t[None], sequences)) + + out = pad_sequence(sequences, *args, **kwargs) + + if has_no_dimension: + out = rearrange(out, '... 1 -> ...') + + return out + +def log(t, eps = 1e-20): + return torch.log(t.clamp(min = eps)) + +def log_prob(prob, indices): + assert prob.shape[:2] == indices.shape, f'preceding shapes of prob {prob.shape[:2]} and indices {indices.shape} must match' + return log(prob.gather(-1, indices[..., None])).squeeze(-1) + +def shift(t, value = 0, shift = 1, dim = -1): + zeros = (0, 0) * (-dim - 1) + return F.pad(t, (*zeros, shift, -shift), value = value) + +def masked_entropy(prob, dim = -1, mask = None): + entropies = (prob * log(prob)).sum(dim = -1) + return masked_mean(entropies, mask = mask).mean() + +def masked_kl_div(prob1, prob2, mask = None): + """ + need to account for variable sequence lengths, therefore not using the built-in functional version + """ + kl_divs = (prob1 * (log(prob2) - log(prob1))).sum(dim = -1) + + if not exists(mask): + return kl_divs.mean() + + return masked_mean(kl_divs, mask).mean() + +def clipped_value_loss(values, rewards, old_values, clip): + value_clipped = old_values + (values - old_values).clamp(-clip, clip) + value_loss_1 = (value_clipped.flatten() - rewards) ** 2 + value_loss_2 = (values.flatten() - rewards) ** 2 + return torch.mean(torch.max(value_loss_1, value_loss_2)) + +# rlhf trainer + +@beartype +class RLHFTrainer(nn.Module): + def __init__( + self, + args, + accelerate_kwargs: dict = {} + ): + super().__init__() + + self.args = args + + self.accelerate = Accelerator(**accelerate_kwargs) + + # 加载 RWKV 模型 + rwkv = RWKV(args) + + if len(args.load_sft_model) == 0: + rank_zero_info(f"SFT must load model, please input ") + exit(1) + + rank_zero_info(f"########## Loading {args.load_sft_model}... ##########") + try: + load_dict = torch.load(args.load_sft_model, map_location="cpu") + except: + rank_zero_info(f"Bad checkpoint {args.load_sft_model}") + exit(1) + + if args.load_partial == 1: + load_keys = load_dict.keys() + for k in rwkv.state_dict(): + if k not in load_keys: + load_dict[k] = rwkv.state_dict()[k] + rwkv.load_state_dict(load_dict) + + self.rwkv = rwkv + + # 使用 RWKV 初始化 actor_critic + actor_critic = ActorCritic( + rwkv = self.rwkv, + actor_lora = args.actor_lora, + critic_lora = args.critic_lora, + actor_lora_r = args.actor_lora_r, + critic_lora_r = args.critic_lora_r, + pooled_values = args.critic_pooled_values, + actor_dropout = args.actor_dropout, + critic_dropout = args.critic_dropout + ).to(self.rwkv.device) + + self.actor_critic = actor_critic + + # 加载 reward_model,并将 reward_model 设置为 evaluation 模式 + reward_model = RewardModel(args) + reward_model.load(args.load_rm_model) + self.reward_model = reward_model.eval() + + # optimizers + self.actor_optim = get_optimizer(actor_critic.actor_parameters(), lr = self.args.actor_lr, wd = self.args.actor_wd, betas = self.args.betas, eps = self.args.actor_adam_eps, use_lion = self.args.use_lion) + self.critic_optim = get_optimizer(actor_critic.critic_parameters(), lr = self.args.critic_lr, wd = self.args.critic_wd, betas = self.args.betas, eps = self.args.critic_adam_eps, use_lion = self.args.use_lion) + + # prepare with accelerator + + ( + self.actor_critic, + self.reward_model, + self.actor_optim, + self.critic_optim + ) = self.accelerate.prepare( + self.actor_critic, + self.reward_model, + self.actor_optim, + self.critic_optim + ) + + + def print(self, msg): + return self.accelerate.print(msg) + + def save(self, filepath = './checkpoint.pt'): + torch.save(self.actor_critic.state_dict(), filepath) + + def load(self, filepath = './checkpoint.pt'): + state_dict = torch.load(filepath) + self.actor_critic.load_state_dict(state_dict) + + @property + def device(self): + return self.accelerate.device + + @torch.no_grad() + def generate( + self, + max_seq_len, + *args, + prompt, + num_samples = 4, # sample 4 per prompt and select the one with highest reward + **kwargs + ): + assert prompt.ndim == 1, 'only one prompt allowed at a time for now' + prompt = repeat(prompt, 'n -> b n', b = num_samples) + + actor_critic = self.accelerate.unwrap_model(self.actor_critic) + reward_model = self.accelerate.unwrap_model(self.reward_model) + + actor_critic.eval() + + ( + actions, + sequences, + mask, + prompt_mask, + action_logits, + _ + ) = actor_critic.generate( + prompt, + *args, + max_seq_len = max_seq_len, + return_values = False, + **kwargs + ) + + rewards = reward_model( + sequences, + prompt_mask = prompt_mask, + mask = mask, + sample = True + ) + + best_sequence_index = rewards.topk(1, dim = -1).indices + + best_sequence = sequences[best_sequence_index] + best_sequence = rearrange(best_sequence, '1 ... -> ...') + + return best_sequence + + def learn( + self, + memories: Deque[Memory] + ): + # stack all data stored in the memories + + all_memories_stacked_and_padded = list(map(partial(pad_sequence_fixed, batch_first = True), zip(*memories))) + + # prepare dataloader for policy phase training + + dl = create_dataloader(all_memories_stacked_and_padded, self.minibatch_size, device = self.device) + + self.actor_critic.train() + + # PPO training + + for _ in range(self.epochs): + for ( + sequences, + prompt_masks, + masks, + old_action_probs, + old_log_probs, + rewards, + old_values + ) in dl: + action_masks = ~prompt_masks & masks + + action_logits, values = self.actor_critic( + sequences, + mask = action_masks + ) + + action_logits = shift(action_logits, shift = 1, dim = -2) # need to shift along sequence dimension by 1, since actions start from the last prompt (state) token + action_len = old_log_probs.shape[-1] + + action_probs = action_logits.softmax(dim = -1) + action_log_probs = log_prob(action_probs, sequences) + action_log_probs = action_log_probs[:, -action_len:] + + # calculate entropies, taking into account which part of the sequence is actually an action + + entropies = masked_entropy(action_probs, mask = action_masks) + + # calculate kl div between old action probs and new ones, taking into account which part of the sequence is action or not + + kl_div_loss = 0. + + if self.args.kl_div_loss_weight > 0: + kl_div_loss = masked_kl_div(action_probs, old_action_probs, mask = action_masks) * self.args.kl_div_loss_weight + + # handle non-pooled values + + normalize_kwargs = dict() + + if old_values.ndim == 2: + old_values, values = map(lambda t: shift(t, shift = 1, dim = -2), (old_values, values)) + + old_values = old_values[:, -action_len:] + values = values[:, -action_len:] + rewards = rearrange(rewards, 'b -> b 1') + normalize_kwargs = dict(dim = -1, mask = action_masks[:, -action_len:]) + + if values.ndim < rewards.ndim: + values = rearrange(values, '... -> ... 1') + + # calculate clipped surrogate objective, classic PPO loss + + ratios = (action_log_probs - old_log_probs).exp() + advantages = masked_normalize(rewards - old_values, **normalize_kwargs) + + if advantages.ndim == 1: + advantages = rearrange(advantages, 'b -> b 1') + + surr1 = ratios * advantages + surr2 = ratios.clamp(1 - self.args.eps_clip, 1 + self.args.eps_clip) * advantages + policy_loss = - torch.min(surr1, surr2) - self.args.beta_s * entropies + + # combine losses + + loss = policy_loss.mean() + kl_div_loss + + # update actor + + self.accelerate.backward(loss) + + self.print(f'policy_loss: {loss.item():.3f}') + + if exists(self.args.max_norm): + self.accelerator.clip_grad_norm_(self.actor_critic.actor_parameters(), self.args.max_norm) + + self.actor_optim.step() + self.actor_optim.zero_grad() + + # calculate value loss and update value network separate from policy network + + value_loss = clipped_value_loss(values, rewards, old_values, self.args.value_clip) + value_loss = value_loss.mean() + + self.print(f'critic_loss: {value_loss.item():.3f}') + + self.accelerate.backward(value_loss) + + if exists(self.args.max_norm): + self.accelerator.clip_grad_norm_(self.actor_critic.critic_parameters(), self.args.max_norm) + + self.critic_optim.step() + self.critic_optim.zero_grad() + + def train( + self, + num_episodes = 50000, + max_timesteps = 500, + update_timesteps = 5000, + max_batch_size = 16, + eos_token = None, + temperature = 1. + ): + device = self.device + + time = 0 + memories = deque([]) + + for eps in tqdm(range(num_episodes), desc = 'episodes'): + for timestep in range(max_timesteps): + time += 1 + + # select a bunch of random states (prompts) + # and get the action (sampled sequence from palm as well as the action probs) + # also calculate the reward using reward model and store + + rand_prompt_index = randrange(0, self.num_prompts) + + state = self.prompt_token_ids[rand_prompt_index] + + # remove padding from state + + state_mask = state != self.args.pad_value + state = state[state_mask] + + # get predicted sequence + + ( + actions, + sequence, + mask, + prompt_mask, + action_logits, + value + ) = self.actor_critic.generate( + rearrange(state, 'n -> 1 n'), + max_seq_len = self.args.ctx_len, + eos_token = eos_token, + temperature = temperature, + return_values = True + ) + action_logits = shift(action_logits, shift = 1, dim = -2) # need to shift along sequence dimension by 1, since actions start from the last prompt (state) token + + action_prob = action_logits.softmax(dim = -1) + + action_len = actions.shape[-1] + action_log_prob = log_prob(action_prob, sequence) + action_log_prob = action_log_prob[:, -action_len:] + + actions = rearrange(actions, '1 ... -> ...') + + # get reward as given by supervised trained reward model + + sequence = torch.cat((state, actions), dim = 0) + + prompt_length = len(state) + prompt_mask = torch.arange(sequence.shape[-1], device = device) < prompt_length + + sequence = rearrange(sequence, 'n -> 1 n') + prompt_mask = rearrange(prompt_mask, 'n -> 1 n') + mask = rearrange(mask, 'n -> 1 n') if exists(mask) else torch.ones(sequence.shape, dtype = torch.bool, device = device) + + reward = self.reward_model( + sequence, + prompt_mask = prompt_mask, + mask = mask, + sample = True + ) + + detach_to_cpu_ = lambda t: rearrange(t.detach().cpu(), '1 ... -> ...') + + # store memory for learning + + memories.append(Memory(*map(detach_to_cpu_, ( + sequence, + prompt_mask, + mask, + action_prob, + action_log_prob, + reward, + value + )))) + + # learn from the stored memories + + if time % update_timesteps == 0: + self.learn(memories) + memories.clear() + + print('rlhf training complete') diff --git a/src/rlhf/utils.py b/src/rlhf/utils.py index 8f632bdcafcbce3b01ceaf123486eb80f7f99b49..25e3305a4497a201b1748412e4f65f74404b8d85 100644 --- a/src/rlhf/utils.py +++ b/src/rlhf/utils.py @@ -8,6 +8,9 @@ from einops import rearrange def exists(val): return val is not None +def identity(t, *args, **kwargs): + return t + # decorators def eval_decorator(fn): diff --git a/src/trainer.py b/src/trainer.py index 257d2a5ae2648e4262fe7146df6adadabb1ecd09..3ff07afdc3d520b97196a4c397b5c999130b1465 100644 --- a/src/trainer.py +++ b/src/trainer.py @@ -287,6 +287,142 @@ class rm_train_callback(pl.Callback): trainer.my_loss_count = 0 +class rlhf_train_callback(pl.Callback): + def __init__(self, args): + super().__init__() + self.args = args + + def on_train_batch_start(self, trainer, pl_module, batch, batch_idx): + args = self.args + # if args.cuda_cleanup > 0: + # torch.cuda.empty_cache() + real_step = trainer.global_step + args.epoch_begin * args.epoch_steps + + # LR schedule + w_step = args.warmup_steps + if args.lr_final == args.lr_init or args.epoch_count == 0: + lr = args.lr_init + else: + decay_step = real_step - args.my_pile_edecay * args.epoch_steps + decay_total = (args.epoch_count - args.my_pile_edecay) * args.epoch_steps + progress = (decay_step - w_step + 1) / (decay_total - w_step) + progress = min(1, max(0, progress)) + + if args.lr_final == 0 or args.lr_init == 0: # linear decay + lr = args.lr_init + (args.lr_final - args.lr_init) * progress + else: # exp decay + lr = args.lr_init * math.exp(math.log(args.lr_final / args.lr_init) * pow(progress, 1)) + + if trainer.global_step < w_step: + lr = lr * (0.2 + 0.8 * trainer.global_step / w_step) + # if trainer.is_global_zero: + # print(trainer.global_step, decay_step, decay_total, w_step, progress, lr) + + for param_group in trainer.optimizers[0].param_groups: + if args.layerwise_lr > 0: + param_group["lr"] = lr * param_group["my_lr_scale"] + # print(param_group["lr"], param_group["my_lr_scale"]) + else: + param_group["lr"] = lr + + trainer.my_lr = lr + # rank_zero_info(f"{real_step} {lr}") + + if trainer.global_step == 0: + if trainer.is_global_zero: # logging + trainer.my_loss_sum = 0 + trainer.my_loss_count = 0 + trainer.my_log = open(args.proj_dir + "/train_log.txt", "a") + trainer.my_log.write(f"NEW RUN {args.my_timestamp}\n{vars(self.args)}\n") + try: + print(f"\n{trainer.strategy.config}\n") + trainer.my_log.write(f"{trainer.strategy.config}\n") + except: + pass + trainer.my_log.flush() + if len(args.wandb) > 0: + print("Login to wandb...") + import wandb + wandb.init( + project=args.wandb, + name=args.run_name + " " + args.my_timestamp, + config=args, + save_code=False, + ) + trainer.my_wandb = wandb + + def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): + args = self.args + if trainer.is_global_zero: # logging + t_now = time.time_ns() + token_per_step = args.ctx_len * args.real_bsz + real_step = trainer.global_step + args.epoch_begin * args.epoch_steps + kt_s = 0 + try: + t_cost = (t_now - trainer.my_time_ns) / 1e9 + kt_s = token_per_step / t_cost / 1000 + self.log("REAL it/s", 1.0 / t_cost, prog_bar=True, on_step=True) + self.log("Kt/s", kt_s, prog_bar=True, on_step=True) + except: + pass + trainer.my_time_ns = t_now + trainer.my_loss = trainer.my_loss_all.float().mean().item() + trainer.my_loss_sum += trainer.my_loss + trainer.my_loss_count += 1 + trainer.my_epoch_loss = trainer.my_loss_sum / trainer.my_loss_count + self.log("lr", trainer.my_lr, prog_bar=True, on_step=True) + self.log("loss", trainer.my_epoch_loss, prog_bar=True, on_step=True) + # self.log("s", real_step, prog_bar=True, on_step=True) + + if len(args.wandb) > 0: + lll = {"loss": trainer.my_loss, "lr": trainer.my_lr, "Gtokens": real_step * token_per_step / 1e9} + if kt_s > 0: + lll["kt/s"] = kt_s + trainer.my_wandb.log(lll, step=int(real_step)) + if args.magic_prime > 0: + if int(real_step) == int(args.magic_prime * (1 + args.my_qa_mask) // args.real_bsz) - 1: + to_save_dict = pl_module.state_dict() + my_save( + to_save_dict, + f"{args.proj_dir}/rlhf-final.pth", + ) + + + def on_train_epoch_start(self, trainer, pl_module): + args = self.args + dataset = trainer.train_dataloader.dataset.datasets + assert "RMDataset" in str(dataset) + dataset.global_rank = trainer.global_rank + dataset.real_epoch = int(args.epoch_begin + trainer.current_epoch) + dataset.world_size = trainer.world_size + # print(f'########## world_size {dataset.world_size} global_rank {dataset.global_rank} real_epoch {dataset.real_epoch} ##########') + + def on_train_epoch_end(self, trainer, pl_module): + args = self.args + if trainer.is_global_zero: # logging & save state_dict + if (args.epoch_save > 0 and trainer.current_epoch % args.epoch_save == 0) or trainer.current_epoch == args.epoch_count - 1: + if args.data_type == 'wds_img': + raw_dict = pl_module.state_dict() + to_save_dict = {} + for k in raw_dict: + if k.startswith('encoder.') or k.startswith('decoder.'): + to_save_dict[k] = raw_dict[k] + else: + to_save_dict = pl_module.state_dict() + try: + my_save( + to_save_dict, + f"{args.proj_dir}/rlhf-{args.epoch_begin + trainer.current_epoch}.pth", + ) + except Exception as e: + print('Error\n\n', e, '\n\n') + trainer.my_log.write(f"{args.epoch_begin + trainer.current_epoch} {trainer.my_epoch_loss:.6f} {math.exp(trainer.my_epoch_loss):.4f} {trainer.my_lr:.8f} {datetime.datetime.now()} {trainer.current_epoch}\n") + trainer.my_log.flush() + + trainer.my_loss_sum = 0 + trainer.my_loss_count = 0 + + @rank_zero_only def generate_init_weight(model, init_weight_name): mm = model.generate_init_weight() diff --git a/train_ppo.py b/train_ppo.py index 025d7d95fc5c1f29dfa89ba237d480b54ac6b964..5e9c8922f84ecbb240d229a25a22b072693ff40f 100644 --- a/train_ppo.py +++ b/train_ppo.py @@ -6,38 +6,293 @@ ''' # here put the import lib -import torch -from src.model import RWKV -from src.rlhf.reward import RewardModel -from src.rlhf.ppo import RLHFTrainer - -# load your pretrained RWKV -# todo(luxin) 加载 SFT 之后的预训练模型 -rwkv_model = RWKV() -# palm.load('./path/to/pretrained/palm.pt') - -# load your pretrained reward model -# todo(luxin) 加载训练好的 reward Model -reward_model = RewardModel( - rwkv_model, - num_binned_output = 5 -) -# reward_model.load('./path/to/pretrained/reward_model.pt') - -# ready your list of prompts for reinforcement learning -# todo(luxin) 读入 Prompts 数据集(此处的 Prompt 与 SFT、RM 阶段的 Prompt 要不一样) -prompts = torch.randint(0, 256, (50000, 512)) # 50k prompts - -# pass it all to the trainer and train -# 训练 PPO 模型 -trainer = RLHFTrainer( - palm = palm, - reward_model = reward_model, - prompt_token_ids = prompts -) -trainer.train(num_episodes = 100) - -# then, if it succeeded... -# generate say 10 samples and use the reward model to return the best one -answer = trainer.generate(2048, prompt = prompts[0], num_samples = 10) # (<= 2048,) -print(answer) \ No newline at end of file +######################################################################################################## +# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM +######################################################################################################## + +if __name__ == "__main__": + from argparse import ArgumentParser + from pytorch_lightning import Trainer + from pytorch_lightning.utilities import rank_zero_info, rank_zero_only + + rank_zero_info("########## work in progress ##########") + + ######################################################################################################## + # + # example: train a simple L12-D768 RWKV on dummy data + # + # python train.py --load_model "" --wandb "" --proj_dir "out" \ + # --data_file "" --data_type "dummy" --vocab_size 0 \ + # --ctx_len 128 --epoch_steps 1000 --epoch_count 20 --epoch_begin 0 --epoch_save 10 \ + # --micro_bsz 16 --n_layer 12 --n_embd 768 --pre_ffn 0 --head_qk 0 \ + # --lr_init 6e-4 --lr_final 1e-5 --warmup_steps 0 --beta1 0.9 --beta2 0.99 --adam_eps 1e-8 \ + # --accelerator gpu --devices 1 --precision bf16 --strategy ddp_find_unused_parameters_false --grad_cp 0 + + # example: train a simple L6-D512 RWKV from scratch on enwik8 + # + # python train.py --load_model "" --wandb "" --proj_dir "out" \ + # --data_file "../data/enwik8" --data_type "utf-8" --vocab_size 0 \ + # --ctx_len 512 --epoch_steps 5000 --epoch_count 500 --epoch_begin 0 --epoch_save 5 \ + # --micro_bsz 12 --n_layer 6 --n_embd 512 --pre_ffn 0 --head_qk 0 \ + # --lr_init 8e-4 --lr_final 1e-5 --warmup_steps 0 --beta1 0.9 --beta2 0.99 --adam_eps 1e-8 \ + # --accelerator gpu --devices 1 --precision bf16 --strategy ddp_find_unused_parameters_false --grad_cp 0 + + # example: fine-tune RWKV 1.5B using 8xA100 40G = 1.76it/s = 115k token/s, VRAM 37477M + # + # python train.py --load_model "/fsx/BlinkDL/CODE/FP16/out_1b2/all-8040.pth" --wandb "" --proj_dir "out" \ + # --data_file "../data/train.npy" --data_type "numpy" --vocab_size 50277 \ + # --ctx_len 1024 --epoch_steps 1000 --epoch_count 1000 --epoch_begin 0 --epoch_save 5 \ + # --micro_bsz 8 --n_layer 24 --n_embd 2048 --pre_ffn 0 --head_qk 0 \ + # --lr_init 1e-5 --lr_final 1e-5 --warmup_steps 0 --beta1 0.9 --beta2 0.999 --adam_eps 1e-8 \ + # --accelerator gpu --devices 8 --precision bf16 --strategy deepspeed_stage_2 --grad_cp 0 + + # example: fine-tune RWKV 1.5B using 1 GPU fp16 (VRAM 16G) NOTE: fp16 might overflow + # + # python train.py --load_model "/fsx/BlinkDL/CODE/FP16/out_1b2/all-8040.pth" --wandb "" --proj_dir "out" \ + # --data_file "../data/train.npy" --data_type "numpy" --vocab_size 50277 \ + # --ctx_len 1024 --epoch_steps 200 --epoch_count 1000 --epoch_begin 0 --epoch_save 1 \ + # --micro_bsz 11 --n_layer 24 --n_embd 2048 --pre_ffn 0 --head_qk 0 \ + # --lr_init 1e-5 --lr_final 1e-5 --warmup_steps 0 --beta1 0.9 --beta2 0.999 --adam_eps 1e-8 \ + # --accelerator gpu --devices 1 --precision fp16 --strategy deepspeed_stage_2_offload --grad_cp 1 + + parser = ArgumentParser() + + parser.add_argument("--load_sft_model", default="", type=str) # full path, with .pth + parser.add_argument("--load_rm_model", default="", type=str) # full path, with .pth + parser.add_argument("--wandb", default="", type=str) # wandb project name. if "" then don't use wandb + parser.add_argument("--proj_dir", default="out", type=str) + parser.add_argument("--random_seed", default="-1", type=int) + + parser.add_argument("--data_file", default="", type=str) + parser.add_argument("--data_type", default="utf-8", type=str) + parser.add_argument("--vocab_size", default=0, type=int) # vocab_size = 0 means auto (for char-level LM and .txt data) + + parser.add_argument("--ctx_len", default=1024, type=int) + parser.add_argument("--epoch_steps", default=1000, type=int) # a mini "epoch" has [epoch_steps] steps + parser.add_argument("--epoch_count", default=500, type=int) # train for this many "epochs". will continue afterwards with lr = lr_final + parser.add_argument("--epoch_begin", default=0, type=int) # if you load a model trained for x "epochs", set epoch_begin = x + parser.add_argument("--epoch_save", default=5, type=int) # save the model every [epoch_save] "epochs" + + parser.add_argument("--micro_bsz", default=12, type=int) # micro batch size (batch size per GPU) + parser.add_argument("--n_layer", default=6, type=int) + parser.add_argument("--n_embd", default=512, type=int) + parser.add_argument("--dim_att", default=0, type=int) + parser.add_argument("--dim_ffn", default=0, type=int) + parser.add_argument("--pre_ffn", default=0, type=int) # replace first att layer by ffn (sometimes better) + parser.add_argument("--head_qk", default=0, type=int) # my headQK trick + parser.add_argument("--tiny_att_dim", default=0, type=int) # tiny attention dim + parser.add_argument("--tiny_att_layer", default=-999, type=int) # tiny attention @ which layer + + parser.add_argument("--lr_init", default=6e-4, type=float) # 6e-4 for L12-D768, 4e-4 for L24-D1024, 3e-4 for L24-D2048 + parser.add_argument("--lr_final", default=1e-5, type=float) + parser.add_argument("--warmup_steps", default=0, type=int) # try 50 if you load a model + parser.add_argument("--beta1", default=0.9, type=float) + parser.add_argument("--beta2", default=0.99, type=float) # use 0.999 when your model is close to convergence + parser.add_argument("--adam_eps", default=1e-8, type=float) + + parser.add_argument("--grad_cp", default=0, type=int) # gradient checkpt: saves VRAM, but slower + parser.add_argument("--my_pile_stage", default=0, type=int) # my special pile mode + parser.add_argument("--my_pile_shift", default=-1, type=int) # my special pile mode - text shift + parser.add_argument("--my_pile_edecay", default=0, type=int) + parser.add_argument("--layerwise_lr", default=1, type=int) # layerwise lr for faster convergence (but slower it/s) + parser.add_argument("--ds_bucket_mb", default=200, type=int) # deepspeed bucket size in MB. 200 seems enough + # parser.add_argument("--cuda_cleanup", default=0, type=int) # extra cuda cleanup (sometimes helpful) + + parser.add_argument("--my_img_version", default=0, type=str) + parser.add_argument("--my_img_size", default=0, type=int) + parser.add_argument("--my_img_bit", default=0, type=int) + parser.add_argument("--my_img_clip", default='x', type=str) + parser.add_argument("--my_img_clip_scale", default=1, type=float) + parser.add_argument("--my_img_l1_scale", default=0, type=float) + parser.add_argument("--my_img_encoder", default='x', type=str) + # parser.add_argument("--my_img_noise_scale", default=0, type=float) + parser.add_argument("--my_sample_len", default=0, type=int) + parser.add_argument("--my_ffn_shift", default=1, type=int) + parser.add_argument("--my_att_shift", default=1, type=int) + parser.add_argument("--my_pos_emb", default=0, type=int) + parser.add_argument("--load_partial", default=0, type=int) + parser.add_argument("--magic_prime", default=0, type=int) + parser.add_argument("--my_qa_mask", default=0, type=int) + parser.add_argument("--my_testing", default='', type=str) + + # PPO model parameters + parser.add_argument("--critic_pooled_values", default=True, type=bool) + + parser.add_argument("--max_norm", default=None, type=float) + parser.add_argument("--kl_div_loss_weight", default=0.1, type=float) # between old action probs and new action probs - not sure what the right value is + + parser.add_argument("--eps_clip", default=0.2, type=float) + parser.add_argument("--value_clip", default=0.4, type=float) + parser.add_argument("--beta_s", default=0.01, type=float) + + parser.add_argument("--actor_lr", default=1e-4, type=float) + parser.add_argument("--critic_lr", default=1e-4, type=float) + parser.add_argument("--actor_wd", default=0., type=float) + parser.add_argument("--critic_wd", default=0., type=float) + parser.add_argument("--actor_adam_eps", default=1e-7, type=float) + parser.add_argument("--critic_adam_eps", default=1e-7, type=float) + parser.add_argument("--pad_value", default=1, type=float) # token pad value + parser.add_argument("--use_lion", default=False, type=bool) + + + parser.add_argument("--num_episodes", default=50000, type=int) + parser.add_argument("--max_timesteps", default=500, type=int) + parser.add_argument("--update_timesteps", default=5000, type=int) + + + parser = Trainer.add_argparse_args(parser) + args = parser.parse_args() + + ######################################################################################################## + + import os, warnings, math, datetime, sys, time + import numpy as np + import torch + from torch.utils.data import DataLoader + import deepspeed + import pytorch_lightning as pl + from pytorch_lightning import seed_everything + + if args.random_seed >= 0: + print(f"########## WARNING: GLOBAL SEED {args.random_seed} THIS WILL AFFECT MULTIGPU SAMPLING ##########\n" * 3) + seed_everything(args.random_seed) + + np.set_printoptions(precision=4, suppress=True, linewidth=200) + warnings.filterwarnings("ignore", ".*Consider increasing the value of the `num_workers` argument*") + warnings.filterwarnings("ignore", ".*The progress bar already tracks a metric with the*") + # os.environ["WDS_SHOW_SEED"] = "1" + os.environ["TOKENIZERS_PARALLELISM"] = "false" + + args.my_timestamp = datetime.datetime.today().strftime("%Y-%m-%d-%H-%M-%S") + args.enable_checkpointing = False + args.replace_sampler_ddp = False + args.logger = False + args.gradient_clip_val = 1.0 + args.num_sanity_val_steps = 0 + args.check_val_every_n_epoch = int(1e20) + args.log_every_n_steps = int(1e20) + args.max_epochs = -1 # continue forever + args.betas = (args.beta1, args.beta2) + args.real_bsz = int(args.num_nodes) * int(args.devices) * args.micro_bsz + os.environ["RWKV_T_MAX"] = str(args.ctx_len) + os.environ["RWKV_MY_TESTING"] = args.my_testing + if args.dim_att <= 0: + args.dim_att = args.n_embd + if args.dim_ffn <= 0: + args.dim_ffn = args.n_embd * 4 + + args.run_name = f"{args.vocab_size} ctx{args.ctx_len} L{args.n_layer} D{args.n_embd}" + if not os.path.exists(args.proj_dir): + os.makedirs(args.proj_dir) + + samples_per_epoch = args.epoch_steps * args.real_bsz + tokens_per_epoch = samples_per_epoch * args.ctx_len + rank_zero_info( + f""" +############################################################################ +# +# RWKV-4 {args.precision.upper()} on {args.num_nodes}x{args.devices} {args.accelerator.upper()}, bsz {args.num_nodes}x{args.devices}x{args.micro_bsz}={args.real_bsz}, {args.strategy} {'with grad_cp' if args.grad_cp > 0 else ''} +# +# Data = {args.data_file} ({args.data_type}), ProjDir = {args.proj_dir} +# +# Epoch = {args.epoch_begin} to {args.epoch_begin + args.epoch_count - 1} (will continue afterwards), save every {args.epoch_save} epoch +# +# Each "epoch" = {args.epoch_steps} steps, {samples_per_epoch} samples, {tokens_per_epoch} tokens +# +# Model = {args.n_layer} n_layer, {args.n_embd} n_embd, {args.ctx_len} ctx_len +# +# Adam = lr {args.lr_init} to {args.lr_final}, warmup {args.warmup_steps} steps, beta {args.betas}, eps {args.adam_eps} +# +# Found torch {torch.__version__}, recommend 1.12.1+cu116 or newer +# Found deepspeed {deepspeed.__version__}, recommend 0.7.0 (faster than newer versions) +# Found pytorch_lightning {pl.__version__}, recommend 1.7.4 or newer +# +############################################################################ +""" + ) + rank_zero_info(str(vars(args)) + "\n") + + assert args.data_type in ["utf-8", "utf-16le", "numpy", "binidx", "dummy", "wds_img", "uint16"] + + if args.lr_final == 0 or args.lr_init == 0: + rank_zero_info("\n\nNote: lr_final = 0 or lr_init = 0. Using linear LR schedule instead.\n\n") + + assert args.precision in ["fp32", "tf32", "fp16", "bf16"] + os.environ["RWKV_FLOAT_MODE"] = args.precision + if args.precision == "fp32": + rank_zero_info("\n\nNote: you are using fp32 (very slow). Try bf16 / tf32 for faster training.\n\n") + if args.precision == "fp16": + rank_zero_info("\n\nNote: you are using fp16 (might overflow). Try bf16 / tf32 for stable training.\n\n") + + os.environ["RWKV_JIT_ON"] = "1" + if "deepspeed_stage_3" in args.strategy: + os.environ["RWKV_JIT_ON"] = "0" + + torch.backends.cudnn.benchmark = True + torch.backends.cudnn.enabled = True + if args.precision == "fp32": + torch.backends.cudnn.allow_tf32 = False + torch.backends.cuda.matmul.allow_tf32 = False + else: + torch.backends.cudnn.allow_tf32 = True + torch.backends.cuda.matmul.allow_tf32 = True + + if "32" in args.precision: + args.precision = 32 + elif args.precision == "fp16": + args.precision = 16 + else: + args.precision = "bf16" + + ######################################################################################################## + from tqdm import tqdm + from collections import deque, namedtuple + from einops import rearrange + + from src.dataset import PPODataset, load_prompt_data_4_ppo + from src.rlhf.ppo import RLHF + from src.trainer import rlhf_train_callback + + # 用于 PPO 训练的数据,需要与 environment 交互获得 + memory = [] + + # 读入训练数据集 + prompts = load_prompt_data_4_ppo(args) + + # PPO 模型 + rlhf_model = RLHF(args) + + # 模型训练 + # trainer + trainer = Trainer.from_argparse_args( + args, + callbacks=[rlhf_train_callback(args)], + ) + + time_cnt = 0 + for eps in tqdm(range(args.num_episodes), desc = 'episodes'): + for timestep in range(args.max_timesteps): + time_cnt += 1 + + # 生成 ppo 模型的训练数据 + experience_data = rlhf_model.make_experience(prompts, eos_token=0) + memory.append(experience_data) + + # learn from the stored memories + if time_cnt % args.update_timesteps == 0: + if trainer.global_rank == 0: + for n in rlhf_model.state_dict(): + shape = rlhf_model.state_dict()[n].shape + shape = [i for i in shape if i != 1] + if len(shape) > 1: + print(f"{str(shape[0]).ljust(5)} {str(shape[1]).ljust(5)} {n}") + else: + print(f"{str(shape[0]).ljust(5)} {n}") + + train_data = PPODataset(memory) + data_loader = DataLoader(train_data, shuffle=False, pin_memory=True, batch_size=args.micro_bsz, num_workers=1, persistent_workers=False, drop_last=True) + + trainer.fit(rlhf_model, data_loader) + + print('rlhf training complete') + diff --git a/train_rm.py b/train_rm.py index bf9961b026cc3bc68d448f647120b7708496879f..9e980be703ae78763aca33e1bcb69a5ca3fd5bd3 100644 --- a/train_rm.py +++ b/train_rm.py @@ -224,8 +224,8 @@ if __name__ == "__main__": import torch from tqdm import tqdm + from src.trainer import rm_train_callback - from src.rlhf.reward import RewardModel from src.dataset import RMDataset