From 4b0ff7c7eb3dabba46b2e0c61bf5426668dd9bba Mon Sep 17 00:00:00 2001 From: u010280923 Date: Thu, 9 Mar 2023 10:13:38 +0800 Subject: [PATCH] debug --- .gitignore | 1 + forward_demo.py | 58 +-- src/rlhf/__pycache__/__init__.cpython-38.pyc | Bin 0 -> 157 bytes src/rlhf/__pycache__/rwkv.cpython-38.pyc | Bin 0 -> 13514 bytes src/rlhf/rwkv.py | 457 +++++++++++++++++++ 5 files changed, 490 insertions(+), 26 deletions(-) create mode 100644 src/rlhf/__pycache__/__init__.cpython-38.pyc create mode 100644 src/rlhf/__pycache__/rwkv.cpython-38.pyc create mode 100644 src/rlhf/rwkv.py diff --git a/.gitignore b/.gitignore index 0366758..7397164 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,4 @@ data +model .DS_Store .idea \ No newline at end of file diff --git a/forward_demo.py b/forward_demo.py index da7df70..07a1240 100644 --- a/forward_demo.py +++ b/forward_demo.py @@ -28,12 +28,15 @@ np.set_printoptions(precision=4, suppress=True, linewidth=200) # set these before import RWKV os.environ['RWKV_JIT_ON'] = '1' os.environ["RWKV_CUDA_ON"] = '0' # if '1' then compile CUDA kernel for seq mode (much faster) -os.environ["RWKV_T_MAX"] = '1024' -from src.model import RWKV # pip install rwkv +# from rwkv.model import RWKV # pip install rwkv +from src.rlhf.rwkv import RWKV +# model = RWKV(model='./model/rwkv-190.pth', strategy='cpu fp32') +model = RWKV(model='./model/RWKV-4-Pile-169M-20220807-8023.pth', strategy='cpu fp32') + # model = RWKV(model='/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-169m/RWKV-4-Pile-169M-20220807-8023', strategy='cuda fp16') # model = RWKV(model='/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-169m/RWKV-4-Pile-169M-20220807-8023', strategy='cuda fp16i8') -model = RWKV(model='/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-169m/RWKV-4-Pile-169M-20220807-8023', strategy='cpu fp32') +# model = RWKV(model='/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-169m/RWKV-4-Pile-169M-20220807-8023', strategy='cpu fp32') # model = RWKV(model='/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-169m/RWKV-4-Pile-169M-20220807-8023', strategy='cpu fp32 *3 -> cuda fp16 *6+') # model = RWKV(model='/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-1b5/RWKV-4-Pile-1B5-20220903-8040', strategy='cpu fp32') # model = RWKV(model='/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-1b5/RWKV-4-Pile-1B5-20220903-8040', strategy='cuda fp16') @@ -45,34 +48,37 @@ model = RWKV(model='/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-169m/RWKV-4-Pile-169M-2022 out, state = model.forward([187, 510, 1563, 310, 247], None) print(out.detach().cpu().numpy()) # get logits -out, state = model.forward([187, 510], None) -out, state = model.forward([1563], state) # RNN has state (use deepcopy to clone states) -out, state = model.forward([310, 247], state) -print(out.detach().cpu().numpy()) # same result as above +# out, state = model.forward([187, 510], None) +# out, state = model.forward([1563], state) # RNN has state (use deepcopy to clone states) +# out, state = model.forward([310, 247], state) +# print(out.detach().cpu().numpy()) # same result as above -print('\n') +import ipdb +ipdb.set_trace() -from src.utils import PIPELINE, PIPELINE_ARGS -pipeline = PIPELINE(model, "20B_tokenizer.json") +# print('\n') -ctx = "\nIn a shocking finding, scientist discovered a herd of dragons living in a remote, previously unexplored valley, in Tibet. Even more surprising to the researchers was the fact that the dragons spoke perfect Chinese." -print(ctx, end='') +# from src.utils import PIPELINE, PIPELINE_ARGS +# pipeline = PIPELINE(model, "20B_tokenizer.json") -def my_print(s): - print(s, end='', flush=True) +# ctx = "\nIn a shocking finding, scientist discovered a herd of dragons living in a remote, previously unexplored valley, in Tibet. Even more surprising to the researchers was the fact that the dragons spoke perfect Chinese." +# print(ctx, end='') -# For alpha_frequency and alpha_presence, see "Frequency and presence penalties": -# https://platform.openai.com/docs/api-reference/parameter-details +# def my_print(s): +# print(s, end='', flush=True) -args = PIPELINE_ARGS(temperature = 1.0, top_p = 0.7, - alpha_frequency = 0.25, - alpha_presence = 0.25, - token_ban = [0], # ban the generation of some tokens - token_stop = []) # stop generation whenever you see any token here +# # For alpha_frequency and alpha_presence, see "Frequency and presence penalties": +# # https://platform.openai.com/docs/api-reference/parameter-details -######################################################################################################## -# 1. set os.environ["RWKV_CUDA_ON"] = '1' if possible, for faster preprocess of a long ctx. -# 2. Reuse the state (use deepcopy to clone it) when you are running the same ctx multiple times. -pipeline.generate(ctx, token_count=200, args=args, callback=my_print) +# args = PIPELINE_ARGS(temperature = 1.0, top_p = 0.7, +# alpha_frequency = 0.25, +# alpha_presence = 0.25, +# token_ban = [0], # ban the generation of some tokens +# token_stop = []) # stop generation whenever you see any token here + +# ######################################################################################################## +# # 1. set os.environ["RWKV_CUDA_ON"] = '1' if possible, for faster preprocess of a long ctx. +# # 2. Reuse the state (use deepcopy to clone it) when you are running the same ctx multiple times. +# pipeline.generate(ctx, token_count=200, args=args, callback=my_print) -print('\n') \ No newline at end of file +# print('\n') \ No newline at end of file diff --git a/src/rlhf/__pycache__/__init__.cpython-38.pyc b/src/rlhf/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fb73c28df764171bc0e2929f9002da59fb025ddf GIT binary patch literal 157 zcmWIL<>g`kf?c=SQf2|^#~=WHa^HWN5Qtd!ye+FU(0G9?PuK)l5 literal 0 HcmV?d00001 diff --git a/src/rlhf/__pycache__/rwkv.cpython-38.pyc b/src/rlhf/__pycache__/rwkv.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fecf661aa33afb20ddde83315f9d72ed35be96ff GIT binary patch literal 13514 zcmbVTTW}lKdEP573xEVk@B)ycAz79M%A_uq<(sTXk)_y)Mca~W#TaWaAkKmW2rj_0 z3sFQDPHHA~+leMjCJ&t^ngN?Mt$R(~v`y1C(>Cc#JI&;wPy3MROgqym)9Lu3Pim*_ z_n%#W07b`6z}bKQ{m=csopb(k_*x=iNcjE3$D-`ZS0(9R=%fD^#m9L(_G7XnF^S1F zDeu4Kyez(zydvIeUKMXGui>rK^tDhvEK5u^wQ8gqEf_s47GUveq7XBc>QHA`z@#~{ zBk#!hQGAVcQcT;CnZBy>udonNn08h|UE|e>{G@nK<R~BBVXE_M4rP;z9!`|_lS>+ zIoUJHO{?To>Q=FK-HTnl@u}AfpM3Fp;pIh7KHstf@z@%GD_xV`l^5TY zP?1t!;$hGPxPR{WYqrVl<0YF}$6w*~s#$W5zkcmvotcFjb-rphiY4=SX{G3h95yc< z=e3pcalWy7YoW2}>4kz-Z;Vqz3VPS2}j=RXBscxJ35m8EX zs8A}`ipzG1hS2|0@o^rHJq2J&q^)_yRGG|_&q-S{v{Z$ZG*5**O2P<`kA5*c2|Pjs z+d!D?Dk;g4T!krH$^=w;Rk1UU{I=}sRmD-O8ouYmx7yKnq_*m6;;p&5c06%X2ZLQL0-`Ww}{z+MaH%HJnYJqI@H}^q_W5Y(l|0-zaiMjg*|*o_yVtFM9HZCpSHL z)st^|O0noEOG}>8Xn6AGt_pMVE_@c?03Lf3KvFVtQjP?FGxDt5ia|gNK^u!X-HSkP zp%sOKXA}x+b=IsA9xoKuo5fm?61N?wP-)F{R_aW&(ATOkA<^^*8l0h~$kQ1b|6TK$!5R!9>sKe8XjYe%V7v)(H`9T7Q2s}*SFo8!1 z93k*1fujU^)-71$MXtWKSRhlR#IBVd2VTN6@Kb{?`&WT)z&EQ*gWu9j zo#`wjEFc72cuiPC7Q2B5l6REz zL+?nMlpltwrag`PNvYdE>LziVG>2&RkK@GBEiiQlj?Ym#>JEHfL35?521AlvwV;zp z6hdT%t08`hDHgmn@`RlT)0}7(o;WSRpN9k<7I*~Z!=Od)hPPyxk6P6`G1N@85gup8 zj%@$3t7=kJXU3C~A{BJh7o$`$H>gz3NpyyGr0w_@PY_-mrc~m`9R=+tosrHc^%`v@ z+nTGf;k1N#Dw(EsQmr`$bCpj;xl~yOK7urjjix1E=`I5KFnC)evUGx1DJ? zq!~7j`4N5Hl8K+~uyX({X527KPD)$E@yxcozU0q$u7dniJM!~XZ-lAKvNP*Ow&ax5 zjj9*kK+3^h%9Kc%7AY?vT^f;Rl^eWY&OKFnY!^`rMP&U+= z=~DgnnVOcmX^_H_93$k1>+umdp;UYTGJ5gYpN^|1ozAy`QBP z?Z0=9U_ux}4;~UdcpqGbF0l0H72>fQbs=kaxNTUeAVnHTNl=co^rX~IxCuADs-vBA zM59r-aW^rAG#Z60ZL8}aw<7KkOCt9Wr1Y;YB(MrP{~@TV=;Jr3jF0up0h%!Ew?xwz zO4R(i-Jy&W^s&DuXyYkrLtPj7M1AkOQD3esIVJ7wopq?!?)!et(2S^ShoA32>X3QOSja7A9<7Je87`%n_Q0lPrDDjzi+&Rh~VMlfp zP#^Wj2<)iu0jzT$y_6f9qMCDVOq7ob57L#yu;81l;I~5jnf?71J1+cM$iB)>_!$2*ak?Ew%a{$?6YfNHu008OtUU$zSUXL#gRqkj z{~HeHVH2)hjA8BV$ zH(4aSMTCDF^*!!W;SHa=qjXNQ`J1tEGBk3!rQaO?1V|z;Ad!WP< zAJEeIJ>~G7?nxFcb<6FxbYV{keD9X9{^{rH-OtO1pRsp+f|TUDZcdAQgJ()1Uqk*&@7k$|dr1hlFgfs@yL_e*R-<$pv*x^N?-3C9q zEn6Dwc^-P(!HOo%PGjaZY)9Zv^xZVFi%#R{)qPo+qF0&-_-$ zon#4EM2w&B^=Fb@>RLq08Q@DVWgoumraj1)-Lwb!vYRIOf}Y0xp87ouJX4IXbXTP8 zGCUxy_K23KKGk`eRx<1p6uStyB^7xdIMH*%5yRAis|me82JR z?#$7icBe6;f0BJ_N5T6gHZ9&)*rIs9jFq0!p2qwX>YQb-z@to~$FJ_ly}IVa7_K5E zHmAyv2Pf3#!MiaN+ zbB?grx3KYlfX{LMqg_7Vke-lUm99#kmKdTpxf_dlt=u?y<{s@2Iqlb5evrkF0%bXE zgjz${{=iF?)26dEoV|MK+QkbmT`IhE;f+gIvz~0Uw1o!3Cy|#BmZ@0F*_M`FSXjW; zyU|3nCcai|6l(R72%y|I%(bP34YRVm;&9qsx5k!`wrVe&Sg2Vif|S{ zkYCKx==**+$l}Qh1K}zP9O-zj;jw9Vyo?3U&CxdE7#(Q`n>q$J@+D=Rq89RfIh#{F z-FA4Tfmfqeaqz7yAP9DQ@6OjaTdUYkslL`YC+emRdI+1e)-XjP zRv3{l`h-MO(_*b~%_`KJ&O%FHF^g=WHBxk(1rQpb`HXHQ%VlffRi*!d`lz6YF%BdJLWUlH|ou75nSO-%OZ_L53==g z_UeTz3)xSYP8RVfGkfus*L=!ieIx6vRP5|VrBY%)sHtcY}vG`SVZ zieC^9d;FBAmH{G5L3oUlmV*YB(wh~_d5VvLaF2R+FGA&`*`tk$Wtl98#$TY!db!rL zSMCV|d_e@>p2=EuCtDO2KYtnwVrDIS^mH*>v{?4^l25P&Qz@g}Y&7cJF_Ce&=0oZ?}=F(8tMSfA|;U#qM&YUUz>WFiLaX+^$l=Zz~bKY!_^!s3N1m-53< zPoe$=Ads1t81OTm-T>7>z!@=D)4EmRb&H>+u|%NF_D027;c5B|TlK;+FEWlOjvwIi zl;x5aF4b!_gmTfvm14v6LRJ9^!#w%2X9)X3s7!$Smmy=j^)(~}(-^zDYHICZd1 z7n;F8L*p|Eh!CPf*vA0!5uceAaMsRi7@n1<(P03EN**y#ftAo5Cm(a_gju5JB}y{qUD=Bk3Z+`nwhIN{07ck!3F#;B*cMHOvVp%4{i&n!oEn$o z@{l|VIEpB62+?3gp2Pbo%Av#oBXU-rP*cdI;m;6|@=^ddE)4|_>X9Z8Kc@Uc_=}*X zxST>ADZC9}5!6QSBvL7jzGf9ey(LK#il!gnU>&uW0CX7JZA$t;E3mIS`j6A4c-INTMNg z!XB6oGZ2NvieGk=stP|l>wjNEYDWvw5TRu<*2Boj6xt^VOzmM>57PQfReT@ipY+)zLZ1+=)MLVR&In#2mNoxqolC*;Q5Fi}7o zLKt$$jU-VA0#mX(j8Pk91*6s$`33ZWpCLfgS}x2Rgs&4I>knoJyRy7ihj$9WLt=&C zY22(e-%a-O$-WK$ZxuEKT=MkQ!diuVsCX%>8=i6t565#G99_N5JE-21Ip~-bu+-Qy zv18sYW36SG2C;VtndwL=1;ZNxgaL@9_d<~A0Vb2?2qym?kPl|EV4l)JgdMvJ1E*77b|8fHnr25zvf)mJ;$vxx+%4$K6pkk(4?*j$$X=F?T4y zC&A?jn8TD%blsi8{xJg2TD$f5wNCVE9d~C?PckXB55TMrpyrVeYROHP$yW}@mJCj~&YZEI zY<-DxWb(pTj#So~Yy6j}2>DBYg+IG%fc!U!dW68E1dbB;CV}4~@Y?{Mev6Eeze{-7 zDTwU@A!6&me~0q@E&(zbEEzZW@6*>G5cn2>4uL-;aG1a&0G_ge^Iv+N%yFDV9mmwF=Z<>snPU)L`-kMF!csD9uytW$AhSWnh-YrwV$ew zLZ`YmKJo!J-ql&neNbmL_d%W2+S3n=5V_(OF-+{hDXoOt0o{U4G4({cFbRp;tFfNDs92_4nLyx=tbVKMVcOlK%I5 zGn&d=UFbd52p1qMf1-h~e4pkke_~(d|6|Iu>)6N)_$GlrA+W~|4*pDb!2gu0AS>X1 zM&LUH2DP96Inh2`{dt!1eI5W)+A5~9Tl~93`3nMnN#L&te4s(V%RUQs*r$=J=Zt$M zVZbE(8x+22%ymseta52x*k>T*uqo8k{n1sAo`KMa{YcK6N;}q$J6LZ(Wbt6VA;WFc z%0a{Nu+A{)xe1twy3bUY2-YOKQz5M`{58q$R5%P;7^^_`1$Sinv_>eAzfi%>s?sw1bcgn_d{Aw2slAFZ=n?xg zVUPT?RM|lSA3+KBxVkunJPPEr8r;m2cu5u#{8`GMZSkgU3>K~vjiNW4= zGY$A*;yT3)gde6AV^7USAY8g)L;eSA1Dt2#^jIuKhj)xOwf0y$)gEt8IEoN<44!Ws z-cb0Attiqb+iF|GFr^ON9!bGNCU-QP`%Pgb(DN+*RUM80En4Q^CqTDq_zwVb>Ak`~ zLP?Jj*k^X^(}kbz6|z*pu4)XYfuxqz=Sxs_Qm&Q+Tkhi%TiSZ$FPDGwt~&~Da@rNF3)I@c>T%PXzc7@S**A0}dsyg`zgzW5W(MRa>&SiSrF-xGHYKuo1!e zZ=n20JBeOM)luA%!EK`Uh!w{>CBDY7N{-=9Q4|~=y#tpFEuVA~=+)R=1s4SsVLLu0 z6;Q@{dBy`usz;d!DC6ki6z_?iKEi4)^EMv%UXofBq$!ZQI*Bb)BSXZr91efVr2{*b`Yci8?o z%ic-8se@15-UwaneF?qmeF+`weFYH*u!_CG>!Myh4o!` zgqnRwj8Nl8{SiXzGEBp{pv-aefq_>McNMf43r)jLaR2BOQbOb-gjX zQ^$7Q818xWre{lWV`!f=7(N3}*w2RYTm7dxS>Fj^JL)qgY_XfFDKt>41;~ z-T(O!_#tSR&SKUmzF-CHDAl>65Y3H=yMEY<-@rWRpT&J0Q(oVhTyBxyfK$AjQ~j{f z7CuBgSe(!K2MFXbmw)nw&eLz4%U!;wQjVJho*?ivfoA|(`T`y1;6(>GBGfC+?l{fJ zI07L5>(~4uVV@-MDFQD6Tt^0S&R*QATLp2*hG=}ts97hwr$7aX0cAXjeIkN>R#_d|OoVNC*?1Vm_@ILJQ>;OSV})oorUf>`_CChQjpv*Je&@$nig(K(~V;@A1gO& zH9G3U$zETFU#ALw4FG3r<}J}zMC!46q5!{$UV7@%(o$Z9D&R_{UcG#I zao>H$;QCSAa<*xKKpPWUAoy`%if;Sy^{jjde@Em)N?gq#0!=6tH01vwwvMt`!3hYu`1|kkA&ROGmJ_AA`Td;#Goe!)o))%o zz^pu~%%T+9RJ!50;7YlCIhW*f#Jpn!9wV?sfCS_jS2iyP|60K*qUsB9SDr$=-=_px z*!d!-fRvXAP9;hvgcY1j2yXs5<%<$H0YGa)n^u1aGJHmQITBK2L&sy_N#L2pGm9sN zM~y^9ncO41MCEG)NZmY*%+?QxM*O2>%i^>=7aE0=#bOB>#Gpc0S6etxS~9KD%36`H zdIp`-Y!n=vK>BA$A>8R-Yt}rifq$X!!mPsme~>6H;iL}~oCTIEbRRyCbtC7dh0O%$ zLghUbk#0||Rva&Wt%Ngp=Zd&IpO0MG^f51C3p$Ri#J%}{r0y%|Qpisi{~(f2`-5mY zm6}ce|53pIV!#6MPX+!dV1i0TicP0ZCxc$3C!!G=vsh;nbNeDag7w68FLD+JfeZcT m$a&hd-1d(x=zs<4qS4U@j%5(_i6|OWCXU0IRNPPv?f(Hy7cm')] + plan = [0] * len(s) + stream_i = -1 + stream_count = 0 + to_allocate = args.n_layer + 1 + allocated = 0 + free_slots = 0 + for i in range(len(s)): + si = s[i] + si1 = si[1] + if si1.startswith('fp32'): si[1] = [torch.float] + elif si1.startswith('fp16'): si[1] = [torch.float16] + elif si1.startswith('bf16'): si[1] = [torch.bfloat16] + if si1.endswith('i8'): si[1] += [torch.uint8] + else: si[1] += [si[1][0]] + if len(si) > 2: + ss = si[2] + assert ss.startswith('*') + if ss.endswith('+'): + plan[i] = int(ss[1:-1]) + stream_i = i + else: + plan[i] = int(ss[1:]) + allocated += plan[i] + if allocated >= to_allocate: + plan[i] += to_allocate - allocated + break + else: + free_slots += 1 + if stream_i < 0: + if free_slots > 0 and to_allocate > allocated: + for i in range(len(s)): + if plan[i] == 0: + plan[i] = (to_allocate - allocated) // free_slots + allocated += plan[i] + free_slots -= 1 + if to_allocate > allocated: + plan[len(s)-1] += to_allocate - allocated + else: + if to_allocate > allocated: + stream_count = to_allocate - allocated + plan[stream_i] += stream_count + print(f'Strategy: (total {args.n_layer}+1={args.n_layer+1} layers)') + for i in range(len(s)): + ss = s[i] + if i != stream_i: + print(f'* {ss[0]} {str(ss[1]).replace("torch.","")}, store {plan[i]} layers') + else: + print(f'* {ss[0]} {str(ss[1]).replace("torch.","")}, store {plan[i]-stream_count} layers, stream {stream_count} layers') + plan[i] += (0 if i == 0 else plan[i-1]) + self.strategy = [None] * (args.n_layer + 1) + strategy = self.strategy + for n in range(args.n_layer + 1): + for i in range(len(s)): + if n < plan[i]: + strategy[n] = types.SimpleNamespace() + strategy[n].device = s[i][0] + strategy[n].atype = s[i][1][0] + strategy[n].wtype = s[i][1][1] + strategy[n].stream = False + if i == stream_i and n >= (plan[i] - stream_count): + strategy[n].stream = True + break + print(f"{n}-{strategy[n].device}-{str(strategy[n].atype).replace('torch.','')}-{str(strategy[n].wtype).replace('torch.','')}{'-stream' if strategy[n].stream else ''}",end=' ') + print() + + # Load weights + print_need_newline = False + for x in keys: + w[x].requires_grad = False + layer_id = int(x.split('.')[1]) if ('blocks.' in x) else 0 + if ('ln_out.' in x) or ('head.' in x): + layer_id = args.n_layer + dd = strategy[layer_id] + DEVICE = dd.device + ATYPE = dd.atype + WTYPE = dd.wtype + + if self.RESCALE_LAYER > 0: + if 'att.output.weight' in x: + w[x] = w[x] / (2 ** int(layer_id // self.RESCALE_LAYER)) + if 'ffn.value.weight' in x: + w[x] = w[x] / (2 ** int(layer_id // self.RESCALE_LAYER)) + + if '.time_' in x: + w[x] = w[x].squeeze() + if 'key.weight' in x or 'value.weight' in x or 'receptance.weight' in x or 'output.weight' in x or 'head.weight' in x: + w[x] = w[x].t() + + if '.time_decay' in x: # need fp32 for this + w[x] = -torch.exp(w[x].float()) + elif '.time_first' in x: # need fp32 for this + w[x] = w[x].float() + else: + if (len(w[x].shape) == 2) and ('emb' not in x): + if WTYPE != torch.uint8: + w[x] = w[x].to(dtype=WTYPE) + else: + w[x] = w[x].float() + if w[x].shape[0] > w[x].shape[1]: + w[x+'_my'] = torch.amin(w[x], dim=1).unsqueeze(1) + w[x] = w[x] - w[x+'_my'] + w[x+'_mx'] = torch.amin(w[x], dim=0) + w[x] = w[x] - w[x+'_mx'] + w[x+'_ry'] = torch.amax(w[x], dim=1).unsqueeze(1) + w[x] = w[x] / w[x+'_ry'] + w[x+'_rx'] = torch.amax(w[x], dim=0) + w[x] = w[x] / w[x+'_rx'] + else: + w[x+'_mx'] = torch.amin(w[x], dim=0) + w[x] = w[x] - w[x+'_mx'] + w[x+'_my'] = torch.amin(w[x], dim=1).unsqueeze(1) + w[x] = w[x] - w[x+'_my'] + w[x+'_rx'] = torch.amax(w[x], dim=0) + w[x] = w[x] / w[x+'_rx'] + w[x+'_ry'] = torch.amax(w[x], dim=1).unsqueeze(1) + w[x] = w[x] / w[x+'_ry'] + w[x] = torch.round(w[x] * 255.0).to(dtype=torch.uint8) + w[x+'_mx'] = w[x+'_mx'].to(dtype=ATYPE) + w[x+'_rx'] = w[x+'_rx'].to(dtype=ATYPE) + w[x+'_my'] = w[x+'_my'].to(dtype=ATYPE) + w[x+'_ry'] = w[x+'_ry'].to(dtype=ATYPE) + else: + w[x] = w[x].to(dtype=ATYPE) + + if 'emb.' in x: + pass + elif (dd.stream) and (x.endswith('key.weight') or x.endswith('value.weight') or x.endswith('receptance.weight') or x.endswith('output.weight')): + try: + w[x] = w[x].pin_memory() # if you see "CUDA error: out of memory" here, that's out of CPU RAM, not VRAM. Get more RAM :) + except: + print('Note: You are running out of RAM. Get more CPU RAM. Now this will run much slower.') + elif DEVICE != 'cpu': + w[x] = w[x].to(device=DEVICE) + try: + w[x+'_mx'] = w[x+'_mx'].to(device=DEVICE) + w[x+'_rx'] = w[x+'_rx'].to(device=DEVICE) + w[x+'_my'] = w[x+'_my'].to(device=DEVICE) + w[x+'_ry'] = w[x+'_ry'].to(device=DEVICE) + except: + pass + + if 'ffn.value.weight' in x: + gc.collect() + if 'cuda' in args.strategy_string: + torch.cuda.empty_cache() + + shape = [i for i in w[x].shape if i != 1] + if len(shape) > 1: + shape = f" {str(shape[0]).rjust(5)} {str(shape[1]).rjust(5)}" + else: + shape = f" {str(shape[0]).rjust(5)} " + if layer_id == 0 or layer_id >= args.n_layer-1: + if print_need_newline: + print('\n', end = '') + print_need_newline = False + dt = str(w[x].dtype).replace('torch.', '') + dt = dt.replace('float32', 'f32').replace('bfloat16', 'bf16').replace('float16', 'f16').replace('uint8', 'i8') + print(x.ljust(32), dt.rjust(4), str(w[x].device).rjust(8), shape, ' (pinned)' if w[x].is_pinned() else '') + else: + print_need_newline = True + print('.', end = '', flush = True) + assert len(keys) == 4 + (4+9+5) * args.n_layer, 'Error: not a RWKV-4 model (4a and 4b models are not supported as of now)' + gc.collect() + if 'cuda' in args.strategy_string: + torch.cuda.empty_cache() + + def get_w(self, x, dtype): + w = self.w + if w[x].dtype != torch.uint8: + return w[x] + return self.uint8_to_type(w[x].to(dtype=dtype), w[x+'_mx'], w[x+'_my'], w[x+'_rx'], w[x+'_ry']) + + @MyFunction + def uint8_to_type(self, x, mx, my, rx, ry): + return (x * rx * ry / 255.0) + mx + my + + ######################################################################################################## + + @MyFunction + def ffn_one(self, x, sx, ln_w, ln_b, k_mix, r_mix, kw, vw, rw): + xx = F.layer_norm(x, (x.shape[-1],), weight=ln_w, bias=ln_b) + kx = xx * k_mix + sx * (1 - k_mix) + rx = xx * r_mix + sx * (1 - r_mix) + + r = torch.sigmoid(rx @ rw) + vx = torch.square(torch.relu(kx @ kw)) + out = r * (vx @ vw) + return x + out, xx + + @MyFunction + def ffn_seq(self, x, sx, ln_w, ln_b, k_mix, r_mix, kw, vw, rw): + xx = F.layer_norm(x, (x.shape[-1],), weight=ln_w, bias=ln_b) + sx = torch.cat((sx.unsqueeze(0), xx[:-1,:])) + kx = xx * k_mix + sx * (1 - k_mix) + rx = xx * r_mix + sx * (1 - r_mix) + + r = torch.sigmoid(rx @ rw) + vx = torch.square(torch.relu(kx @ kw)) + out = r * (vx @ vw) + return x + out, xx[-1,:] + + ######################################################################################################## + + @MyFunction + def att_one(self, x, sx, aa, bb, pp, ln_w, ln_b, k_mix, v_mix, r_mix, t_decay, t_first, kw, vw, rw, ow): + xx = F.layer_norm(x, (x.shape[-1],), weight=ln_w, bias=ln_b) + kx = xx * k_mix + sx * (1 - k_mix) + vx = xx * v_mix + sx * (1 - v_mix) + rx = xx * r_mix + sx * (1 - r_mix) + + r = torch.sigmoid(rx @ rw) + k = (kx @ kw).float() + v = (vx @ vw).float() + + ww = t_first + k + p = torch.maximum(pp, ww) + e1 = torch.exp(pp - p) + e2 = torch.exp(ww - p) + wkv = ((e1 * aa + e2 * v) / (e1 * bb + e2)).to(dtype=x.dtype) + ww = t_decay + pp + p = torch.maximum(ww, k) + e1 = torch.exp(ww - p) + e2 = torch.exp(k - p) + + out = (r * wkv) @ ow + return x + out, xx, e1 * aa + e2 * v, e1 * bb + e2, p + + @MyFunction + def att_seq(self, x, sx, aa, bb, pp, ln_w, ln_b, k_mix, v_mix, r_mix, t_decay, t_first, kw, vw, rw, ow): + xx = F.layer_norm(x, (x.shape[-1],), weight=ln_w, bias=ln_b) + sx = torch.cat((sx.unsqueeze(0), xx[:-1,:])) + kx = xx * k_mix + sx * (1 - k_mix) + vx = xx * v_mix + sx * (1 - v_mix) + rx = xx * r_mix + sx * (1 - r_mix) + + r = torch.sigmoid(rx @ rw) + k = (kx @ kw).float() + v = (vx @ vw).float() + + T = x.shape[0] + for t in range(T): + kk = k[t] + vv = v[t] + ww = t_first + kk + p = torch.maximum(pp, ww) + e1 = torch.exp(pp - p) + e2 = torch.exp(ww - p) + sx[t] = ((e1 * aa + e2 * vv) / (e1 * bb + e2)).to(dtype=x.dtype) + ww = t_decay + pp + p = torch.maximum(ww, kk) + e1 = torch.exp(ww - p) + e2 = torch.exp(kk - p) + aa = e1 * aa + e2 * vv + bb = e1 * bb + e2 + pp = p + out = (r * sx) @ ow + return x + out, xx[-1,:], aa, bb, pp + + @MyFunction + def cuda_att_pre(self, x, sx, ln_w, ln_b, k_mix, v_mix, r_mix, kw, vw, rw): + T, C = x.size() + xx = F.layer_norm(x, (C,), weight=ln_w, bias=ln_b) + sx = torch.cat((sx.unsqueeze(0), xx[:-1,:])) + kx = xx * k_mix + sx * (1 - k_mix) + vx = xx * v_mix + sx * (1 - v_mix) + rx = xx * r_mix + sx * (1 - r_mix) + r = torch.sigmoid(rx @ rw) + k = kx @ kw + v = vx @ vw + return xx[-1,:], r, k, v + @MyFunction + def cuda_att_seq_post(self, x, r, y, ow): + out = (r * y) @ ow + return x + out + def cuda_att_seq(self, x, sx, aa, bb, pp, ln_w, ln_b, k_mix, v_mix, r_mix, t_decay, t_first, kw, vw, rw, ow): + T, C = x.size() + xx, r, k, v = self.cuda_att_pre(x, sx, ln_w, ln_b, k_mix, v_mix, r_mix, kw, vw, rw) + y, aa, bb, pp = RUN_CUDA(T, C, t_decay, t_first, k, v, aa, bb, pp) + out = self.cuda_att_seq_post(x, r, y, ow) + return out, xx, aa, bb, pp + + ######################################################################################################## + + def forward(self, tokens, state, full_output=False): + with torch.no_grad(): + w = self.w + args = self.args + + if state == None: + state = [None] * args.n_layer * 5 + for i in range(args.n_layer): # state: 0=att_xx 1=att_aa 2=att_bb 3=att_pp 4=ffn_xx + dd = self.strategy[i] + dev = dd.device + atype = dd.atype + state[i*5+0] = torch.zeros(args.n_embd, dtype=atype, requires_grad=False, device=dev) + state[i*5+1] = torch.zeros(args.n_embd, dtype=torch.float, requires_grad=False, device=dev) + state[i*5+2] = torch.zeros(args.n_embd, dtype=torch.float, requires_grad=False, device=dev) + state[i*5+3] = torch.zeros(args.n_embd, dtype=torch.float, requires_grad=False, device=dev) - 1e30 + state[i*5+4] = torch.zeros(args.n_embd, dtype=atype, requires_grad=False, device=dev) + + seq_mode = len(tokens) > 1 + + # 输入:根据 idx 取每个 token 的 embedding + x = w['emb.weight'][tokens if seq_mode else tokens[0]] + + for i in range(args.n_layer): + bbb = f'blocks.{i}.' + att = f'blocks.{i}.att.' + ffn = f'blocks.{i}.ffn.' + dd = self.strategy[i] + dev = dd.device + atype = dd.atype + wtype = dd.wtype + if seq_mode: + if 'cuda' in str(dev) and os.environ["RWKV_CUDA_ON"] == '1': + ATT = self.cuda_att_seq + else: + ATT = self.att_seq + FFN = self.ffn_seq + else: + ATT = self.att_one + FFN = self.ffn_one + + x = x.to(dtype=atype, device=dev) + + kw = self.get_w(f'{att}key.weight', atype) + vw = self.get_w(f'{att}value.weight', atype) + rw = self.get_w(f'{att}receptance.weight', atype) + ow = self.get_w(f'{att}output.weight', atype) + if dd.stream: + kw = kw.to(device=dev, non_blocking=True) + vw = vw.to(device=dev, non_blocking=True) + rw = rw.to(device=dev, non_blocking=True) + ow = ow.to(device=dev, non_blocking=True) + x, state[i*5+0], state[i*5+1], state[i*5+2], state[i*5+3] = ATT( + x, sx=state[i*5+0], aa=state[i*5+1], bb=state[i*5+2], pp=state[i*5+3], + ln_w=w[f'{bbb}ln1.weight'], ln_b=w[f'{bbb}ln1.bias'], + k_mix=w[f'{att}time_mix_k'], v_mix=w[f'{att}time_mix_v'], r_mix=w[f'{att}time_mix_r'], + t_decay = w[f'{att}time_decay'], t_first = w[f'{att}time_first'], + kw=kw, vw=vw, rw=rw, ow=ow) + if wtype == torch.uint8 or dd.stream: + del kw, vw, rw, ow + + kw = self.get_w(f'{ffn}key.weight', atype) + vw = self.get_w(f'{ffn}value.weight', atype) + rw = self.get_w(f'{ffn}receptance.weight', atype) + if dd.stream: + kw = kw.to(device=dev, non_blocking=True) + vw = vw.to(device=dev, non_blocking=True) + rw = rw.to(device=dev, non_blocking=True) + x, state[i*5+4] = FFN( + x, sx=state[i*5+4], + ln_w=w[f'{bbb}ln2.weight'], ln_b=w[f'{bbb}ln2.bias'], + k_mix=w[f'{ffn}time_mix_k'], r_mix=w[f'{ffn}time_mix_r'], + kw=kw, vw=vw, rw=rw) + if wtype == torch.uint8 or dd.stream: + del kw, vw, rw + + if self.RESCALE_LAYER > 0: + if (i+1) % self.RESCALE_LAYER == 0: + x = x / 2 + + dd = self.strategy[args.n_layer] + x = x[-1,:] if (seq_mode and (not full_output)) else x + x = x.to(dtype=dd.atype, device=dd.device) + + x = F.layer_norm(x, (args.n_embd,), weight=w['ln_out.weight'], bias=w['ln_out.bias']) + + if w['head.weight'].dtype != torch.uint8: + x = x @ w['head.weight'] + else: + x = x @ self.get_w('head.weight', dd.atype) + + return x.float(), state -- GitLab