From 00e7b66fe08014e9016ee0582bf486343b9d4af2 Mon Sep 17 00:00:00 2001 From: u010280923 Date: Mon, 20 Mar 2023 14:36:45 +0800 Subject: [PATCH] opt ppo model --- src/model.py | 5 ++--- src/rlhf/ppo.py | 1 - train_ppo.py | 26 +++++++++++++------------- 3 files changed, 15 insertions(+), 17 deletions(-) diff --git a/src/model.py b/src/model.py index 8a61e4b..b829466 100644 --- a/src/model.py +++ b/src/model.py @@ -503,10 +503,9 @@ class RWKV(pl.LightningModule): pad_value = 0., eos_token = None, return_seq_without_prompt = True, - use_tqdm = False, - **kwargs + use_tqdm = False ): - ''' + ''' 生成 response,用于 ppo 模型的训练 ''' prompt, leading_dims = pack([prompt], '* n') diff --git a/src/rlhf/ppo.py b/src/rlhf/ppo.py index 50f9d76..a244bb3 100644 --- a/src/rlhf/ppo.py +++ b/src/rlhf/ppo.py @@ -453,7 +453,6 @@ class RLHF(pl.LightningModule): return {'actor_loss': actor_loss.item(), 'critic_loss': critic_loss.item()} - @torch.no_grad() def make_experience(self, prompts, eos_token=None, temperature=1): ''' 通过与 environment 交互产生训练数据 ''' diff --git a/train_ppo.py b/train_ppo.py index 98bc3d5..8ede4d2 100644 --- a/train_ppo.py +++ b/train_ppo.py @@ -296,6 +296,19 @@ if __name__ == "__main__": callbacks=[rlhf_train_callback(args)], ) + if trainer.global_rank == 0: + for n in rlhf_model.state_dict(): + shape = rlhf_model.state_dict()[n].shape + shape = [i for i in shape if i != 1] + if len(shape) > 1: + print(f"{str(shape[0]).ljust(5)} {str(shape[1]).ljust(5)} {n}") + else: + print(f"{str(shape[0]).ljust(5)} {n}") + + if "deepspeed" in args.strategy: + trainer.strategy.config["zero_optimization"]["allgather_bucket_size"] = args.ds_bucket_mb * 1000 * 1000 + trainer.strategy.config["zero_optimization"]["reduce_bucket_size"] = args.ds_bucket_mb * 1000 * 1000 + time_cnt = 0 for eps in tqdm(range(args.num_episodes), desc = 'episodes'): for timestep in range(args.max_timesteps): @@ -307,19 +320,6 @@ if __name__ == "__main__": # learn from the stored memories if time_cnt % args.update_timesteps == 0: - if trainer.global_rank == 0: - for n in rlhf_model.state_dict(): - shape = rlhf_model.state_dict()[n].shape - shape = [i for i in shape if i != 1] - if len(shape) > 1: - print(f"{str(shape[0]).ljust(5)} {str(shape[1]).ljust(5)} {n}") - else: - print(f"{str(shape[0]).ljust(5)} {n}") - - if "deepspeed" in args.strategy: - trainer.strategy.config["zero_optimization"]["allgather_bucket_size"] = args.ds_bucket_mb * 1000 * 1000 - trainer.strategy.config["zero_optimization"]["reduce_bucket_size"] = args.ds_bucket_mb * 1000 * 1000 - train_data = PPODataset(memory) data_loader = DataLoader(train_data, shuffle=False, pin_memory=True, batch_size=args.micro_bsz, num_workers=1, persistent_workers=False, drop_last=True) -- GitLab