提交 db2a6b65 编写于 作者: U u010280923

opt ppo model

上级 54b452e9
...@@ -20,6 +20,7 @@ from src.rlhf.utils import exists ...@@ -20,6 +20,7 @@ from src.rlhf.utils import exists
from src.rlhf.utils import gumbel_sample from src.rlhf.utils import gumbel_sample
from src.rlhf.utils import top_k from src.rlhf.utils import top_k
from src.rlhf.utils import identity from src.rlhf.utils import identity
from src.rlhf.utils import eval_decorator
# from deepspeed.runtime.fp16.onebit.zoadam import ZeroOneAdam # from deepspeed.runtime.fp16.onebit.zoadam import ZeroOneAdam
...@@ -493,6 +494,7 @@ class RWKV(pl.LightningModule): ...@@ -493,6 +494,7 @@ class RWKV(pl.LightningModule):
return logits return logits
@torch.no_grad() @torch.no_grad()
@eval_decorator
def generate( def generate(
self, self,
seq_len, seq_len,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册