diff --git a/src/model.py b/src/model.py index 8a61e4b0c2c180789f9e10e021e283898c0abc9f..b829466a50d0fec86c028b5e6badd3363bd4958e 100644 --- a/src/model.py +++ b/src/model.py @@ -503,10 +503,9 @@ class RWKV(pl.LightningModule): pad_value = 0., eos_token = None, return_seq_without_prompt = True, - use_tqdm = False, - **kwargs + use_tqdm = False ): - ''' + ''' 生成 response,用于 ppo 模型的训练 ''' prompt, leading_dims = pack([prompt], '* n') diff --git a/src/rlhf/ppo.py b/src/rlhf/ppo.py index 50f9d76555e91b5dc993e40bc8f998478f5c7c14..a244bb3357c191424bbbd24d36fe65f771fc4d06 100644 --- a/src/rlhf/ppo.py +++ b/src/rlhf/ppo.py @@ -453,7 +453,6 @@ class RLHF(pl.LightningModule): return {'actor_loss': actor_loss.item(), 'critic_loss': critic_loss.item()} - @torch.no_grad() def make_experience(self, prompts, eos_token=None, temperature=1): ''' 通过与 environment 交互产生训练数据 ''' diff --git a/train_ppo.py b/train_ppo.py index 98bc3d55a889523823b6bd679f2919949e99de6b..8ede4d247a26f4a5d4041616e5b421527b7c4bfa 100644 --- a/train_ppo.py +++ b/train_ppo.py @@ -296,6 +296,19 @@ if __name__ == "__main__": callbacks=[rlhf_train_callback(args)], ) + if trainer.global_rank == 0: + for n in rlhf_model.state_dict(): + shape = rlhf_model.state_dict()[n].shape + shape = [i for i in shape if i != 1] + if len(shape) > 1: + print(f"{str(shape[0]).ljust(5)} {str(shape[1]).ljust(5)} {n}") + else: + print(f"{str(shape[0]).ljust(5)} {n}") + + if "deepspeed" in args.strategy: + trainer.strategy.config["zero_optimization"]["allgather_bucket_size"] = args.ds_bucket_mb * 1000 * 1000 + trainer.strategy.config["zero_optimization"]["reduce_bucket_size"] = args.ds_bucket_mb * 1000 * 1000 + time_cnt = 0 for eps in tqdm(range(args.num_episodes), desc = 'episodes'): for timestep in range(args.max_timesteps): @@ -307,19 +320,6 @@ if __name__ == "__main__": # learn from the stored memories if time_cnt % args.update_timesteps == 0: - if trainer.global_rank == 0: - for n in rlhf_model.state_dict(): - shape = rlhf_model.state_dict()[n].shape - shape = [i for i in shape if i != 1] - if len(shape) > 1: - print(f"{str(shape[0]).ljust(5)} {str(shape[1]).ljust(5)} {n}") - else: - print(f"{str(shape[0]).ljust(5)} {n}") - - if "deepspeed" in args.strategy: - trainer.strategy.config["zero_optimization"]["allgather_bucket_size"] = args.ds_bucket_mb * 1000 * 1000 - trainer.strategy.config["zero_optimization"]["reduce_bucket_size"] = args.ds_bucket_mb * 1000 * 1000 - train_data = PPODataset(memory) data_loader = DataLoader(train_data, shuffle=False, pin_memory=True, batch_size=args.micro_bsz, num_workers=1, persistent_workers=False, drop_last=True)