提交 765f307b 编写于 作者: U u010280923

add ppo model

上级 da480d58
......@@ -222,7 +222,7 @@ if __name__ == "__main__":
from src.trainer import train_callback, generate_init_weight
args.vocab_size = 20000
args.vocab_size = 50277
from src.model import RWKV
model = RWKV(args)
......@@ -263,7 +263,7 @@ if __name__ == "__main__":
trainer.strategy.config["zero_optimization"]["allgather_bucket_size"] = args.ds_bucket_mb * 1000 * 1000
trainer.strategy.config["zero_optimization"]["reduce_bucket_size"] = args.ds_bucket_mb * 1000 * 1000
seq = torch.randint(0, 20000, (1, 100))
seq = torch.randint(0, 50277, (1, 100))
model(seq)
import ipdb
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册