From db2a6b6526a54d7445a285e145cc786f4cc62727 Mon Sep 17 00:00:00 2001 From: u010280923 Date: Mon, 20 Mar 2023 15:15:03 +0800 Subject: [PATCH] opt ppo model --- src/model.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/model.py b/src/model.py index b829466..58aee0d 100644 --- a/src/model.py +++ b/src/model.py @@ -20,6 +20,7 @@ from src.rlhf.utils import exists from src.rlhf.utils import gumbel_sample from src.rlhf.utils import top_k from src.rlhf.utils import identity +from src.rlhf.utils import eval_decorator # from deepspeed.runtime.fp16.onebit.zoadam import ZeroOneAdam @@ -493,6 +494,7 @@ class RWKV(pl.LightningModule): return logits @torch.no_grad() + @eval_decorator def generate( self, seq_len, -- GitLab