diff --git a/README.md b/README.md index f37df80572e171be8509cc7881e320e85b76b737..63361e9c44767ac55c8bd3a635d61b2665b1e426 100644 --- a/README.md +++ b/README.md @@ -64,7 +64,7 @@ python train_sft.py --load_model "rwkv-190.pth" --wandb "" --proj_dir "out_sft" ### Reward Model ``` -python train_rm.py --load_model "rwkv-190.pth" --wandb "" --proj_dir "out_sft" \ +python train_rm.py --load_model "rwkv-190.pth" --wandb "" --proj_dir "out_rm" \ --data_file "data/rm_mock_data.csv" --data_type "utf-8" --vocab_size 50277 \ --ctx_len 2048 --epoch_steps 200 --epoch_count 1000 --epoch_begin 0 --epoch_save 2 \ --micro_bsz 2 --n_layer 24 --n_embd 2048 --pre_ffn 0 --head_qk 0 \ diff --git a/src/trainer.py b/src/trainer.py index 41cacbf01d21a86d164ed4249d7745269d8f1e4c..b43becd86c3f1e46212eb05cc11062953bf88da1 100644 --- a/src/trainer.py +++ b/src/trainer.py @@ -151,6 +151,142 @@ class train_callback(pl.Callback): trainer.my_loss_count = 0 +class rm_train_callback(pl.Callback): + def __init__(self, args): + super().__init__() + self.args = args + + def on_train_batch_start(self, trainer, pl_module, batch, batch_idx): + args = self.args + # if args.cuda_cleanup > 0: + # torch.cuda.empty_cache() + real_step = trainer.global_step + args.epoch_begin * args.epoch_steps + + # LR schedule + w_step = args.warmup_steps + if args.lr_final == args.lr_init or args.epoch_count == 0: + lr = args.lr_init + else: + decay_step = real_step - args.my_pile_edecay * args.epoch_steps + decay_total = (args.epoch_count - args.my_pile_edecay) * args.epoch_steps + progress = (decay_step - w_step + 1) / (decay_total - w_step) + progress = min(1, max(0, progress)) + + if args.lr_final == 0 or args.lr_init == 0: # linear decay + lr = args.lr_init + (args.lr_final - args.lr_init) * progress + else: # exp decay + lr = args.lr_init * math.exp(math.log(args.lr_final / args.lr_init) * pow(progress, 1)) + + if trainer.global_step < w_step: + lr = lr * (0.2 + 0.8 * trainer.global_step / w_step) + # if trainer.is_global_zero: + # print(trainer.global_step, decay_step, decay_total, w_step, progress, lr) + + for param_group in trainer.optimizers[0].param_groups: + if args.layerwise_lr > 0: + param_group["lr"] = lr * param_group["my_lr_scale"] + # print(param_group["lr"], param_group["my_lr_scale"]) + else: + param_group["lr"] = lr + + trainer.my_lr = lr + # rank_zero_info(f"{real_step} {lr}") + + if trainer.global_step == 0: + if trainer.is_global_zero: # logging + trainer.my_loss_sum = 0 + trainer.my_loss_count = 0 + trainer.my_log = open(args.proj_dir + "/train_log.txt", "a") + trainer.my_log.write(f"NEW RUN {args.my_timestamp}\n{vars(self.args)}\n") + try: + print(f"\n{trainer.strategy.config}\n") + trainer.my_log.write(f"{trainer.strategy.config}\n") + except: + pass + trainer.my_log.flush() + if len(args.wandb) > 0: + print("Login to wandb...") + import wandb + wandb.init( + project=args.wandb, + name=args.run_name + " " + args.my_timestamp, + config=args, + save_code=False, + ) + trainer.my_wandb = wandb + + def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): + args = self.args + if trainer.is_global_zero: # logging + t_now = time.time_ns() + token_per_step = args.ctx_len * args.real_bsz + real_step = trainer.global_step + args.epoch_begin * args.epoch_steps + kt_s = 0 + try: + t_cost = (t_now - trainer.my_time_ns) / 1e9 + kt_s = token_per_step / t_cost / 1000 + self.log("REAL it/s", 1.0 / t_cost, prog_bar=True, on_step=True) + self.log("Kt/s", kt_s, prog_bar=True, on_step=True) + except: + pass + trainer.my_time_ns = t_now + trainer.my_loss = trainer.my_loss_all.float().mean().item() + trainer.my_loss_sum += trainer.my_loss + trainer.my_loss_count += 1 + trainer.my_epoch_loss = trainer.my_loss_sum / trainer.my_loss_count + self.log("lr", trainer.my_lr, prog_bar=True, on_step=True) + self.log("loss", trainer.my_epoch_loss, prog_bar=True, on_step=True) + # self.log("s", real_step, prog_bar=True, on_step=True) + + if len(args.wandb) > 0: + lll = {"loss": trainer.my_loss, "lr": trainer.my_lr, "Gtokens": real_step * token_per_step / 1e9} + if kt_s > 0: + lll["kt/s"] = kt_s + trainer.my_wandb.log(lll, step=int(real_step)) + if args.magic_prime > 0: + if int(real_step) == int(args.magic_prime * (1 + args.my_qa_mask) // args.real_bsz) - 1: + to_save_dict = pl_module.state_dict() + my_save( + to_save_dict, + f"{args.proj_dir}/rwkv-final.pth", + ) + + + def on_train_epoch_start(self, trainer, pl_module): + args = self.args + dataset = trainer.train_dataloader.dataset.datasets + assert "RMDataset" in str(dataset) + dataset.global_rank = trainer.global_rank + dataset.real_epoch = int(args.epoch_begin + trainer.current_epoch) + dataset.world_size = trainer.world_size + # print(f'########## world_size {dataset.world_size} global_rank {dataset.global_rank} real_epoch {dataset.real_epoch} ##########') + + def on_train_epoch_end(self, trainer, pl_module): + args = self.args + if trainer.is_global_zero: # logging & save state_dict + if (args.epoch_save > 0 and trainer.current_epoch % args.epoch_save == 0) or trainer.current_epoch == args.epoch_count - 1: + if args.data_type == 'wds_img': + raw_dict = pl_module.state_dict() + to_save_dict = {} + for k in raw_dict: + if k.startswith('encoder.') or k.startswith('decoder.'): + to_save_dict[k] = raw_dict[k] + else: + to_save_dict = pl_module.state_dict() + try: + my_save( + to_save_dict, + f"{args.proj_dir}/rwkv-{args.epoch_begin + trainer.current_epoch}.pth", + ) + except Exception as e: + print('Error\n\n', e, '\n\n') + trainer.my_log.write(f"{args.epoch_begin + trainer.current_epoch} {trainer.my_epoch_loss:.6f} {math.exp(trainer.my_epoch_loss):.4f} {trainer.my_lr:.8f} {datetime.datetime.now()} {trainer.current_epoch}\n") + trainer.my_log.flush() + + trainer.my_loss_sum = 0 + trainer.my_loss_count = 0 + + @rank_zero_only def generate_init_weight(model, init_weight_name): mm = model.generate_init_weight() diff --git a/train_rm.py b/train_rm.py index bff30bd5910a6b1c348825a8f43feaaa1d170cb9..bf9961b026cc3bc68d448f647120b7708496879f 100644 --- a/train_rm.py +++ b/train_rm.py @@ -224,7 +224,7 @@ if __name__ == "__main__": import torch from tqdm import tqdm - from src.trainer import train_callback + from src.trainer import rm_train_callback from src.rlhf.reward import RewardModel from src.dataset import RMDataset @@ -239,7 +239,7 @@ if __name__ == "__main__": # 训练 trainer = Trainer.from_argparse_args( args, - callbacks=[train_callback(args)], + callbacks=[rm_train_callback(args)], ) if trainer.global_rank == 0: