提交 765f307b 编写于 作者: U u010280923

add ppo model

上级 da480d58
...@@ -222,7 +222,7 @@ if __name__ == "__main__": ...@@ -222,7 +222,7 @@ if __name__ == "__main__":
from src.trainer import train_callback, generate_init_weight from src.trainer import train_callback, generate_init_weight
args.vocab_size = 20000 args.vocab_size = 50277
from src.model import RWKV from src.model import RWKV
model = RWKV(args) model = RWKV(args)
...@@ -263,7 +263,7 @@ if __name__ == "__main__": ...@@ -263,7 +263,7 @@ if __name__ == "__main__":
trainer.strategy.config["zero_optimization"]["allgather_bucket_size"] = args.ds_bucket_mb * 1000 * 1000 trainer.strategy.config["zero_optimization"]["allgather_bucket_size"] = args.ds_bucket_mb * 1000 * 1000
trainer.strategy.config["zero_optimization"]["reduce_bucket_size"] = args.ds_bucket_mb * 1000 * 1000 trainer.strategy.config["zero_optimization"]["reduce_bucket_size"] = args.ds_bucket_mb * 1000 * 1000
seq = torch.randint(0, 20000, (1, 100)) seq = torch.randint(0, 50277, (1, 100))
model(seq) model(seq)
import ipdb import ipdb
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册