提交 54b452e9 编写于 作者: U u010280923

opt ppo model

上级 7928c117
......@@ -299,8 +299,8 @@ if __name__ == "__main__":
if trainer.global_rank == 0:
for n in rlhf_model.state_dict():
shape = rlhf_model.state_dict()[n].shape
shape = [i for i in shape]
if len(shape) > 1:
shape = [i for i in shape if i != 1]
print(f"{str(shape[0]).ljust(5)} {str(shape[1]).ljust(5)} {n}")
else:
print(f"{str(shape[0]).ljust(5)} {n}")
......
......@@ -267,8 +267,8 @@ if __name__ == "__main__":
if trainer.global_rank == 0:
for n in rm_model.state_dict():
shape = rm_model.state_dict()[n].shape
if len(shape) > 1:
shape = [i for i in shape if i != 1]
if len(shape) > 1:
print(f"{str(shape[0]).ljust(5)} {str(shape[1]).ljust(5)} {n}")
else:
print(f"{str(shape[0]).ljust(5)} {n}")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册