提交 00e7b66f 编写于 作者: U u010280923

opt ppo model

上级 ade381c4
......@@ -503,10 +503,9 @@ class RWKV(pl.LightningModule):
pad_value = 0.,
eos_token = None,
return_seq_without_prompt = True,
use_tqdm = False,
**kwargs
use_tqdm = False
):
'''
''' 生成 response,用于 ppo 模型的训练
'''
prompt, leading_dims = pack([prompt], '* n')
......
......@@ -453,7 +453,6 @@ class RLHF(pl.LightningModule):
return {'actor_loss': actor_loss.item(), 'critic_loss': critic_loss.item()}
@torch.no_grad()
def make_experience(self, prompts, eos_token=None, temperature=1):
''' 通过与 environment 交互产生训练数据
'''
......
......@@ -296,6 +296,19 @@ if __name__ == "__main__":
callbacks=[rlhf_train_callback(args)],
)
if trainer.global_rank == 0:
for n in rlhf_model.state_dict():
shape = rlhf_model.state_dict()[n].shape
shape = [i for i in shape if i != 1]
if len(shape) > 1:
print(f"{str(shape[0]).ljust(5)} {str(shape[1]).ljust(5)} {n}")
else:
print(f"{str(shape[0]).ljust(5)} {n}")
if "deepspeed" in args.strategy:
trainer.strategy.config["zero_optimization"]["allgather_bucket_size"] = args.ds_bucket_mb * 1000 * 1000
trainer.strategy.config["zero_optimization"]["reduce_bucket_size"] = args.ds_bucket_mb * 1000 * 1000
time_cnt = 0
for eps in tqdm(range(args.num_episodes), desc = 'episodes'):
for timestep in range(args.max_timesteps):
......@@ -307,19 +320,6 @@ if __name__ == "__main__":
# learn from the stored memories
if time_cnt % args.update_timesteps == 0:
if trainer.global_rank == 0:
for n in rlhf_model.state_dict():
shape = rlhf_model.state_dict()[n].shape
shape = [i for i in shape if i != 1]
if len(shape) > 1:
print(f"{str(shape[0]).ljust(5)} {str(shape[1]).ljust(5)} {n}")
else:
print(f"{str(shape[0]).ljust(5)} {n}")
if "deepspeed" in args.strategy:
trainer.strategy.config["zero_optimization"]["allgather_bucket_size"] = args.ds_bucket_mb * 1000 * 1000
trainer.strategy.config["zero_optimization"]["reduce_bucket_size"] = args.ds_bucket_mb * 1000 * 1000
train_data = PPODataset(memory)
data_loader = DataLoader(train_data, shuffle=False, pin_memory=True, batch_size=args.micro_bsz, num_workers=1, persistent_workers=False, drop_last=True)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册