diff --git a/train_ppo.py b/train_ppo.py index 169e21251e768c08e1db25ad7426fcdceeac741b..028cc03483e6c4dedc2d8af8f317c7b24afa1bc7 100644 --- a/train_ppo.py +++ b/train_ppo.py @@ -299,8 +299,8 @@ if __name__ == "__main__": if trainer.global_rank == 0: for n in rlhf_model.state_dict(): shape = rlhf_model.state_dict()[n].shape + shape = [i for i in shape] if len(shape) > 1: - shape = [i for i in shape if i != 1] print(f"{str(shape[0]).ljust(5)} {str(shape[1]).ljust(5)} {n}") else: print(f"{str(shape[0]).ljust(5)} {n}") diff --git a/train_rm.py b/train_rm.py index 91093812d4065e42852a410eb3ee0577ada22755..59e3cf220cd4630aa9049e6c52e28f28cb7c9d45 100644 --- a/train_rm.py +++ b/train_rm.py @@ -267,8 +267,8 @@ if __name__ == "__main__": if trainer.global_rank == 0: for n in rm_model.state_dict(): shape = rm_model.state_dict()[n].shape + shape = [i for i in shape if i != 1] if len(shape) > 1: - shape = [i for i in shape if i != 1] print(f"{str(shape[0]).ljust(5)} {str(shape[1]).ljust(5)} {n}") else: print(f"{str(shape[0]).ljust(5)} {n}")