From 54b452e904ebb16a3c6ef9ce6a5cd99498ca7a27 Mon Sep 17 00:00:00 2001 From: u010280923 Date: Mon, 20 Mar 2023 14:59:57 +0800 Subject: [PATCH] opt ppo model --- train_ppo.py | 2 +- train_rm.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/train_ppo.py b/train_ppo.py index 169e212..028cc03 100644 --- a/train_ppo.py +++ b/train_ppo.py @@ -299,8 +299,8 @@ if __name__ == "__main__": if trainer.global_rank == 0: for n in rlhf_model.state_dict(): shape = rlhf_model.state_dict()[n].shape + shape = [i for i in shape] if len(shape) > 1: - shape = [i for i in shape if i != 1] print(f"{str(shape[0]).ljust(5)} {str(shape[1]).ljust(5)} {n}") else: print(f"{str(shape[0]).ljust(5)} {n}") diff --git a/train_rm.py b/train_rm.py index 9109381..59e3cf2 100644 --- a/train_rm.py +++ b/train_rm.py @@ -267,8 +267,8 @@ if __name__ == "__main__": if trainer.global_rank == 0: for n in rm_model.state_dict(): shape = rm_model.state_dict()[n].shape + shape = [i for i in shape if i != 1] if len(shape) > 1: - shape = [i for i in shape if i != 1] print(f"{str(shape[0]).ljust(5)} {str(shape[1]).ljust(5)} {n}") else: print(f"{str(shape[0]).ljust(5)} {n}") -- GitLab