From c0fb03a0dc7bbfc09f11315feb1873d1f8b0ab81 Mon Sep 17 00:00:00 2001 From: WeiXin Date: Mon, 18 Jan 2021 21:00:29 +0800 Subject: [PATCH] Supplement PR29988(https://github.com/PaddlePaddle/Paddle/pull/29988) (#30507) --- python/paddle/fluid/io.py | 2 ++ .../tests/unittests/test_static_save_load.py | 19 ++++++++++++++++++ tools/static_mode_white_list.pyc | Bin 0 -> 21803 bytes 3 files changed, 21 insertions(+) create mode 100644 tools/static_mode_white_list.pyc diff --git a/python/paddle/fluid/io.py b/python/paddle/fluid/io.py index 36088aa803..313855b6c5 100644 --- a/python/paddle/fluid/io.py +++ b/python/paddle/fluid/io.py @@ -2180,6 +2180,7 @@ def load_program_state(model_path, var_list=None): with open(parameter_file_name, 'rb') as f: para_dict = pickle.load(f) if six.PY2 else pickle.load( f, encoding='latin1') + para_dict = _pack_loaded_dict(para_dict) opt_file_name = model_prefix + ".pdopt" if os.path.exists(opt_file_name): @@ -2231,6 +2232,7 @@ def set_program_state(program, state_dict): static.set_program_state(prog, program_state) """ + state_dict = _pack_loaded_dict(state_dict) parameter_list = list(filter(is_persistable, program.list_vars())) used_para_list = {} diff --git a/python/paddle/fluid/tests/unittests/test_static_save_load.py b/python/paddle/fluid/tests/unittests/test_static_save_load.py index 257d6e0489..68d0e07e0c 100644 --- a/python/paddle/fluid/tests/unittests/test_static_save_load.py +++ b/python/paddle/fluid/tests/unittests/test_static_save_load.py @@ -1365,6 +1365,25 @@ class TestStaticSaveLoadLargeParameters(unittest.TestCase): base_t = base_map[var.name] self.assertTrue(np.array_equal(new_t, base_t)) + # set var to zero + for var in prog.list_vars(): + if isinstance(var, framework.Parameter) or var.persistable: + ten = fluid.global_scope().find_var(var.name).get_tensor() + ten.set(np.zeros_like(np.array(ten)), place) + + new_t = np.array(fluid.global_scope().find_var(var.name) + .get_tensor()) + self.assertTrue(np.sum(np.abs(new_t)) == 0) + + program_state = fluid.load_program_state(path) + fluid.set_program_state(prog, program_state) + for var in prog.list_vars(): + if isinstance(var, framework.Parameter) or var.persistable: + new_t = np.array(fluid.global_scope().find_var(var.name) + .get_tensor()) + base_t = base_map[var.name] + self.assertTrue(np.array_equal(new_t, base_t)) + class TestProgramStateOldSaveSingleModel(unittest.TestCase): def test_ptb_rnn_cpu_float32(self): diff --git a/tools/static_mode_white_list.pyc b/tools/static_mode_white_list.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e9012c233595b6844f54e625972360f5aeeb0d3b GIT binary patch literal 21803 zcmeHPb+{~7k*{-t27=4X9}(-xW3>l!3_jA6x>K~W5G=XHx=AWaC5;e z1h*92N^onzZ3MR!+)i+N!5sv56x>O0XTeq3-$#k1*ZfD zf-40t7Q968Qo+jvFBiN*@Jhj}1g{pnM(|p}>jbYCyg~3r!J7nc7Q999R>9i@Zx_5n z@J_+I1n(BSNAO<3`vmV7d_eF)!G{DN7JNkTQNdRTJ|_4|!B+{sTJSZ3uN8cq;Ohk+ z7kq=@8wKAa_-4Vk2tFbBq~KG6PYXUH_^jYt1>Yw4oZ$01wSJAQNfQ1eq8Vqf}a%pl;EcYKO^{A!OsbPUhoToUljb3 z;FkrzBKTFouL*u#@Ed~P6#SOpw*|i=_+7#834UMj2ZBEo{E^^~1%D#=Q^B7J{#@`E zg1;2}mEf-hew&>UxNP@{Ey&E z2|&PzJmxU@u@Eun1TNQ~|4iD+20( zCZG*i2kZx&3^)~V5O8I{ivwN~@X~;n1-v}q6#=gdcvZlw16~vG+JM&uyguL!0dEX= zQ^1=8-V*TEfVTy_J>VSy?+kcXz`Fz96Y$=E_XWH^-~$044ERvMhXXzm@X>&;2>4jQ zR|b4lz*h%+O~BU%d|klT2Yfu>8v?#D;F|)zIpA9YJ`wQAfKLT{I^Z(_pAGodfNu-< zT)^i8zCGYO0=_fgy8^yD;ClkTH{c5a-xu)x0Y4D%g8@Gj@WTN=67Zt|KNj%g0Y4G& zlL0>!@Y4Z56Y#SEKNs-x0lyINivhnB@XGS@W%mv67Z)1e-`lP0e=zjmjQnj@Yex<6Y#eIe;4rg0sj#2j{#o{ z_@{t>4)~XVe+~GzfPWA8kAVLS_^*Kf4)~vdF9k?Mh=>tqL|h}{ni1EExOT*KBCZ>I z{eNb}^`h$kvuFz<&f?#r+JfuTB1GIM_7>sB5jWvanuds*Mch2%77@3MxK+fhBW@FM z+lbpm+&nh__l!6@;$9K=j<`?6eIw3^I5*_O5f6)ac*G+j9vShdh(|{}CgP%q$3{FZ z;_(p|M?4|oi4m7XJSpPRh$lxpCE}?OPm8!L;_`^6M?53qnGw&5cy`2dBAy%Zyol#V zydWZrcws~yu@fWJ4wyf)%>5wDMUL&O^+-W2iXh_^(%HR5d% zZ;yCK#5*J274hzf_e8um;(ZbCkN7~u2O~Zd@!^P%M0_;PJ|3SVG!XH!vp+^>puGjX zfwpMcET2xxWsyyG^X0ObXVt2`!GM}n%SqlU=#~Shm=}v;*`6$$BFll(;+pl2;%+s} zo0ju=c4bjlO*Sw0isKu|IS-K6dW3SBP3oy4Z#>MLSIFw3DbhJ^=FnPu|q3k7Wwe_+`8Tkw~y3FBu&AjB5#?&x~4_@UQ(^eUZRlBQ$Z)^ZFQd?wEiXZ?BTsZf;CEb3cB?v@70aT|TmFrQ zyq>L5=gV2q`d;-cTTt+N()mNxv!<$(*!jYBogbttW+!Xip9`D1$eV^|-3FDjrGiZK zI7&hL`Fx$G8)vGNE)#9%O`FK%M_x~K4QD$3Zb*Dyd!-yp+#T(D`xS{v4Mpx%8kPCH zyfQJUKg3Qx**lrna3bvFO*zSGvhXrr6fXEXc{|z7mQ|e&bn{`84RE%9e6%#T<;ykB z6gvf(yP<;PLMX^J>DosAF`D-GM`b7?PX{p_m9ykBu!+-qKCFrCJoOzXsV4_G{oI3QxQ6=GtV0f z5p;oy)TguM2((@_>rHh%*-yVs?2xy)zmxu&SIJ0HDbCfL zP%2#;jqrO#!K>HQ3hnI%86Kf2gFe7cjbdj#Qw}zbNV1Z7>7KdkVHQnEzFsBK=?ohu z-qrRxjLdZZ&IXT;Tuf;pk}s-UmhB(Y2Q$%zqgPE`mC3WpnfZ|(H!K#Af z{CoQ`Xyxd|@(PP$QPoKoCD*944pX{xBZC9_usJR1m2?kwn!&Cc5023DD)}$QvM8Df zRdcm0TzjKv^U3b`_dZRls|OVNJDbsWgu7=b^I29bb_%q2vTM+(tjnD>1>lg9Gux(~ zXE%tB4Q074W(Tg<4NHHs8J(%xROF~mE?3pMZf0<9#)?0cu6|urv+Fd~Y#|P)KUcfi zVm)umjL(|<1QSLaEY(YBO{)-)FM8LZX(^cM+$O6tT5_q5#enFU$2ssIq3Lk9$=t8j zZF~2DW`*iv>V?k~)&MT^ps~BJp z!hMfiz4wTZ#?fOw?L3C8r^T(4%V4iRzRh4+j@)onr8wp~#u;?9{9RPi(o@j)-DE>E zQ^jIM4WrSw9?pb7?ZrVx=eu^F%3;S{$cip7_ZK?Fm(}1xV-OV=Q=7#Yv!-v*dnMU~ znxyb(%f|H~O+js8O4E>2g6Tgkdz07R$0#stQy1;J9`GDWR@*OX4Nz3cnsv7zUGHyp z*P)%;YqqXOaP0O9-9Ynii-SZ_d$nm!5|6YjS1_U_oT+DG&KzztO(BlGXu7rA8u}Oc zigEkFs&H|h78;yrf~5Bbat+ViQRwjN<)Wx(B(2@7gRA%9=hFO4zI6WB#8B^OLLkYo z>B5xL;T8Pin@Juev++3xt8YzLxYVVR-7z0~%7}$mNIr|{kshcinzo#+t96s5$aRBz zhd~>ZoEOwt#`jbGNv=?@w@jJY&6{24iEc_w$A7`MtlfX!LS!el*IP5< zWS*pGG;}r@>!z(+(qMFFJ(j{6_tvM33Pv>J*@}}@y+>}MPowtG#c|psuh3tkt&?Uk zW}UA2q=UHoFolv@<_=8POQvwDCH+zLU5<|o_92ve!Ka;n;`H;!&bPZhQ&kW-57mb? zE5>*u5-66(?Gr%>Y1xo8wJzps(uc{cl$6vvWNFx)O|rNKGF7x)+R`&tT-dC%F2+%d z9meWcBo9sZb$J}ny`H7G4rvPKv~{C-Fg2^3YW_M`R`4)*{10X^(5(f`;_YzA1Ir|HQkyr zY_2v&7%+&NOr4}7)O65*F$2gXuY~9myB2@ zsYG++-8OEcRRDJ$+H;RMA&JE_GKKnN$zs89z>KNweBK@$xkbvqXgEw>+P;{1AmwsW zcc(NztYo(QRMtxpX}XO_f~kz@47i+fk7dOehSb|xPxWUrui9qx5;`VTdKOGB&a<*| z3vAh8T6kxpWzrIioRfcVNuY9o1IiUcWa`h_TFijG%hL0 z!D1jioCk1axygWQA*;!LbC6iAn#^583fPJNcKvPS5bB)V1&#Dh<>T9p2sUyH{H$OI9 z+SwslkE~f#Rad^U+4G#Xd%R1^J=y|hCr01}55k~}lgt8+ushRZ7i2!fo@z9N@jhR} zNHk|cYpMA=&0DZ8v&&5FoZRJ8yX?E?^zD`Vy!=K6-cnjbHd?oHQYBT!bDMEY0FCzh zNFjw(V>DCZ`tCXN8Dyo!rg2j>kgv32rNr<E~9f#`!^0PQ$|+lPItk9FI~6WMYv~*oeN>sOsBCSbAWKKp66SS$HcrN(+J*Yd+cz#`;8lJ*V zu_&l+OQxmF2kdixb)mj+acVMOvxJo5Lk;14mo7EfnTHXSv0qjqy|)p3Q8c?l(a~6l zMZkiAP)#rRRB~w296wmOWC%8<;G*o@4O5cTv2+EfjCVI&3E#HT=7BC-Z11qB=*G|j z`bNBFs==7^(-YKUmXw@Jp2t8we=^P6?Xm`|DGhE$r({GTMbcAC=~bt1g4UHrZFH@> z%=D=VLkgX2%18@?l-#<<S(0HHl27(7gCVk4Pyo%o z@k{r+FRrp4tO}%dCC?@Xnv6kge0$1t5)UIR)6H3baubdLww041xN9JBF-{pIqf7vwt}BOYp4%mslY0lwtbl>|m?;6lTtLQz;z_rs+M>ZN_#jM+XqfH3 z`gUW75+a43LtBM=r|O2X?#Cw{l3CE|PRC|;oTEmuTCrZb2X=T!(h$6HqrDv3CZpQo z_~7In9#7w~ca^%C9L?H!s&}{sK0Ix)|FvBt6I~QKF!{*Kq z)m985m`iFVqYM<`rx@xoJL=WL@uUvfRhImn>!Hl2x3wH^I@X9AtQe@H{LK!M?pK&| z$WWZ*g%>;JGDX|&wXO`U-q*a?Ph3WG$mJUb;3#oaF1WOSP5%!jRVDc7S z6-q(N@KjSebVao!Y3x~(yZVsRVT*=xlALF!4LyoUL7kw=#|c4b@6hfRbH0Th%;&H; zp85+EYVm23CGM#Kv#KFekvhuXwlAzJ(DVlFc5*hH)GVV> z^qwPUFPgMLm2lQQ2J@sAxulm!E>k(xmYnpPPJ8c7nwQec3WJ=oO
    uoSy@*$VX) zwLq!AxmE$gPL}RO& zi;g2U=uCW4^}ZV^x1+h}R9Q41vM<5=Wd`0f+LALFfppm*9?MjBvs;%GOs%B%%FAil z*;DF!t%q@v2}W;ewY7NYZob#>SZqUFXsh5xlBUbs_^X%#9!5_q+Bx%Tor-;%qglpL zs)bXw0;fiMGk8wFbeGOxz9WSJe(4RkPO+%v^ob@N=Co0v(PqHGo%aP+ObB$vTdfa{ zl!q>0BH5&2eE{k2mY^KxFS|B%(e3s(Y%1ZXF5Gx;cHhjYVap`W#?A`e&RzL6wvk=bX8Nh6UhLYJjwlk@ zxz|hjHHr2k(6s8F!>GxchbEUtRVHHxqUbC*n{~hYOl(Ef+<45#7JF4rB!}6swZ?KC z@7c}PE7p?w0*Y!`^<6l}s87YL0Z6q0DHHL{l$THTmK6v%cRxp~0E< z3|qO5F>0eA(P;aS*~deuLlpuO)6vi_!nBgzHs;V8rn5&QnNp(Dvm}u{$1*}!6SlDo z`JyZt-y_F$=@*T>*;$`;2eY=9(R!QXQYh=v$WV#>qMD&6+3Ybp8P%wV;KnIUKVm^n zruM`tTO~)}VV(tt_$*Nyp@VgKoa6(JBAd zd+0jpL0|Tybh=Teg3dPOlZpb3`>cnL&~@EtZpFhoE36hvro~d#U@S_GuJw*YnL;+y zv}A`Ltq>hdH?YH3Of1g!I8zQbQETyT77g6T21jC&bvg8_s4_+6xH#~-qyUHQxsv8+#cn4P|R@MBU*Czaf-+2vKw8D5OmCjbToycW20JwzVa9@pW{;mW`0O;%q$!!&tTvu$?kz z=WbV?@-e1-ZH>}; z38H;nNDiya5Yo;?*50f_Xb&o(MK~B2W>ia&EvO7`>cB(!DzqRu=dh5`3+kUp6%C1A@ z)qT+6IHP^-pL;mgm}4-_jxzl+ptg=R*wAj|BPm9sqxKigrPK$h2(PM7jDw(9jHmrCi}(Jb}C^~?HE z!=&q))E~{FS-^g z`8w=VuVx-7rM$fdaiorg_FezA#tFVy2kUBxR*o>W*kO75^znKZTBSf^6c!%thf0jl zM@L37Hf?V@j3bAm&5~p%OHCs;w6y58JGX9QmF|1q)qKsS6tDd1`p$u;bzLk~nHo$; z9h0e4rL>kdRcp#8T`OvOC+D(?Uz#_}H@P=r%l5qF@8Lx6ztwh+yhB-kJqwH{JxS0p z68U>k5?O4f7 zf=*^vcgOnafxrAPd-nf|n05C5IY`$@#(|mu>GS+U_TYR|%HuvBS7{l)O0QD`#GvlF z?2JotQd)fd)BrS^*pD;K+c({1r+6Zx=Ew>$lcp#eKcv?bXjawANt?@zT=VT^`tQ`! zcw447Y5iE^^0CaQith2|ETnICIG0yTC7#Vi``-S3lI>-xJdHM$)Pqy9H(~7Jy^2Xd zi<=kT{FM;)rhPJIJjtlrg$X#XEfNl-JvM|@|H z?s)0hL0L^5mCG)F*{XV; z_9-;ySrOyBr|@;=#d$5W&;)lS4$V7eOcrX!nf2gvR|l!n>_WoXM*1)%Z*$2tMa~FU U|8q43uBO1%6u6oK|6dCHFWA)*QUCw| literal 0 HcmV?d00001 -- GitLab