diff --git a/src/model.py b/src/model.py index b829466a50d0fec86c028b5e6badd3363bd4958e..58aee0de1ea99f1323f12235005c3855174e902a 100644 --- a/src/model.py +++ b/src/model.py @@ -20,6 +20,7 @@ from src.rlhf.utils import exists from src.rlhf.utils import gumbel_sample from src.rlhf.utils import top_k from src.rlhf.utils import identity +from src.rlhf.utils import eval_decorator # from deepspeed.runtime.fp16.onebit.zoadam import ZeroOneAdam @@ -493,6 +494,7 @@ class RWKV(pl.LightningModule): return logits @torch.no_grad() + @eval_decorator def generate( self, seq_len,