diff --git a/train_ppo.py b/train_ppo.py index 8ede4d247a26f4a5d4041616e5b421527b7c4bfa..169e21251e768c08e1db25ad7426fcdceeac741b 100644 --- a/train_ppo.py +++ b/train_ppo.py @@ -299,8 +299,8 @@ if __name__ == "__main__": if trainer.global_rank == 0: for n in rlhf_model.state_dict(): shape = rlhf_model.state_dict()[n].shape - shape = [i for i in shape if i != 1] if len(shape) > 1: + shape = [i for i in shape if i != 1] print(f"{str(shape[0]).ljust(5)} {str(shape[1]).ljust(5)} {n}") else: print(f"{str(shape[0]).ljust(5)} {n}") diff --git a/train_rm.py b/train_rm.py index 59e3cf220cd4630aa9049e6c52e28f28cb7c9d45..91093812d4065e42852a410eb3ee0577ada22755 100644 --- a/train_rm.py +++ b/train_rm.py @@ -267,8 +267,8 @@ if __name__ == "__main__": if trainer.global_rank == 0: for n in rm_model.state_dict(): shape = rm_model.state_dict()[n].shape - shape = [i for i in shape if i != 1] if len(shape) > 1: + shape = [i for i in shape if i != 1] print(f"{str(shape[0]).ljust(5)} {str(shape[1]).ljust(5)} {n}") else: print(f"{str(shape[0]).ljust(5)} {n}")