From 7928c1174b20c71cc9d96ac2cf0d3e81e760745e Mon Sep 17 00:00:00 2001 From: u010280923 Date: Mon, 20 Mar 2023 14:57:03 +0800 Subject: [PATCH] opt ppo model --- train_ppo.py | 2 +- train_rm.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/train_ppo.py b/train_ppo.py index 8ede4d2..169e212 100644 --- a/train_ppo.py +++ b/train_ppo.py @@ -299,8 +299,8 @@ if __name__ == "__main__": if trainer.global_rank == 0: for n in rlhf_model.state_dict(): shape = rlhf_model.state_dict()[n].shape - shape = [i for i in shape if i != 1] if len(shape) > 1: + shape = [i for i in shape if i != 1] print(f"{str(shape[0]).ljust(5)} {str(shape[1]).ljust(5)} {n}") else: print(f"{str(shape[0]).ljust(5)} {n}") diff --git a/train_rm.py b/train_rm.py index 59e3cf2..9109381 100644 --- a/train_rm.py +++ b/train_rm.py @@ -267,8 +267,8 @@ if __name__ == "__main__": if trainer.global_rank == 0: for n in rm_model.state_dict(): shape = rm_model.state_dict()[n].shape - shape = [i for i in shape if i != 1] if len(shape) > 1: + shape = [i for i in shape if i != 1] print(f"{str(shape[0]).ljust(5)} {str(shape[1]).ljust(5)} {n}") else: print(f"{str(shape[0]).ljust(5)} {n}") -- GitLab