From 18ecd433f558b16f069e6fb3a2028fda34716eb9 Mon Sep 17 00:00:00 2001 From: WeiXin Date: Mon, 18 Jan 2021 10:50:57 +0800 Subject: [PATCH] Avoid bug on 'MAC python3.5/6'. (#30485) * Avoid bug on 'MAC python3.5/6'. * Choose the saving method according to the OS. * smaller length of '_unpack_saved_dict' for MAC OS. * add version information of Python. * Edit comment. --- python/paddle/fluid/io.py | 17 +++++++-- .../tests/unittests/test_paddle_save_load.py | 7 ++-- .../tests/unittests/test_static_save_load.py | 36 ++++++++++++++---- python/paddle/framework/io.py | 14 ++++++- .../static_mode_white_list.cpython-35.pyc | Bin 0 -> 19792 bytes 5 files changed, 58 insertions(+), 16 deletions(-) create mode 100644 tools/__pycache__/static_mode_white_list.cpython-35.pyc diff --git a/python/paddle/fluid/io.py b/python/paddle/fluid/io.py index d5963675a8..36088aa803 100644 --- a/python/paddle/fluid/io.py +++ b/python/paddle/fluid/io.py @@ -22,6 +22,7 @@ import logging import pickle import contextlib from functools import reduce +import sys import numpy as np import math @@ -1715,7 +1716,7 @@ def _unpack_saved_dict(saved_obj): unpack_infor = {} for key, value in saved_obj.items(): if isinstance(value, np.ndarray): - MAX_NUMBER_OF_ELEMENT = 2**22 + MAX_NUMBER_OF_ELEMENT = int((2**30 - 1) / value.dtype.itemsize) num_element = np.prod(value.shape) if num_element > MAX_NUMBER_OF_ELEMENT: unpack_infor[key] = {} @@ -1809,8 +1810,18 @@ def save(program, model_path): parameter_list = list(filter(is_parameter, program.list_vars())) param_dict = {p.name: get_tensor(p) for p in parameter_list} param_dict = _unpack_saved_dict(param_dict) - with open(model_path + ".pdparams", 'wb') as f: - pickle.dump(param_dict, f, protocol=2) + + # When value of dict is lager than 4GB ,there is a Bug on 'MAC python3.5/6' + if sys.platform == 'darwin' and sys.version_info.major == 3 and ( + sys.version_info.minor == 5 or sys.version_info.minor == 6): + pickle_bytes = pickle.dumps(param_dict, protocol=2) + with open(model_path + ".pdparams", 'wb') as f: + max_bytes = 2**30 + for i in range(0, len(pickle_bytes), max_bytes): + f.write(pickle_bytes[i:i + max_bytes]) + else: + with open(model_path + ".pdparams", 'wb') as f: + pickle.dump(param_dict, f, protocol=2) optimizer_var_list = list( filter(is_belong_to_optimizer, program.list_vars())) diff --git a/python/paddle/fluid/tests/unittests/test_paddle_save_load.py b/python/paddle/fluid/tests/unittests/test_paddle_save_load.py index 3d5c8dfb48..3a8531db6f 100644 --- a/python/paddle/fluid/tests/unittests/test_paddle_save_load.py +++ b/python/paddle/fluid/tests/unittests/test_paddle_save_load.py @@ -16,6 +16,7 @@ from __future__ import print_function import unittest import numpy as np +import os import paddle import paddle.nn as nn import paddle.optimizer as opt @@ -90,13 +91,13 @@ class TestSaveLoadLargeParameters(unittest.TestCase): layer = LayerWithLargeParameters() save_dict = layer.state_dict() - path = "test_paddle_save_load_large_param_save/layer" + ".pdparams" + path = os.path.join("test_paddle_save_load_large_param_save", + "layer.pdparams") paddle.save(layer.state_dict(), path) dict_load = paddle.load(path) # compare results before and after saving for key, value in save_dict.items(): - self.assertTrue( - np.sum(np.abs(dict_load[key] - value.numpy())) < 1e-15) + self.assertTrue(np.array_equal(dict_load[key], value.numpy())) class TestSaveLoad(unittest.TestCase): diff --git a/python/paddle/fluid/tests/unittests/test_static_save_load.py b/python/paddle/fluid/tests/unittests/test_static_save_load.py index 0f4fca6d7f..257d6e0489 100644 --- a/python/paddle/fluid/tests/unittests/test_static_save_load.py +++ b/python/paddle/fluid/tests/unittests/test_static_save_load.py @@ -1324,7 +1324,7 @@ class TestStaticSaveLoadLargeParameters(unittest.TestCase): name="static_save_load_large_x", shape=[None, 10], dtype='float32') - z = paddle.static.nn.fc(x, LARGE_PARAM) + z = paddle.static.nn.fc(x, LARGE_PARAM, bias_attr=False) place = paddle.CPUPlace() exe = paddle.static.Executor(place) exe.run(paddle.static.default_startup_program()) @@ -1334,16 +1334,36 @@ class TestStaticSaveLoadLargeParameters(unittest.TestCase): result_z = exe.run(program=prog, feed={"static_save_load_large_x": inputs}, fetch_list=[z.name]) - path = "test_static_save_load_large_param/static_save" + base_map = {} + for var in prog.list_vars(): + if isinstance(var, framework.Parameter) or var.persistable: + t = np.array(fluid.global_scope().find_var(var.name) + .get_tensor()) + # make sure all the paramerter or optimizer var have been update + self.assertTrue(np.sum(np.abs(t)) != 0) + base_map[var.name] = t + + path = os.path.join("test_static_save_load_large_param", + "static_save") paddle.fluid.save(prog, path) + # set var to zero + for var in prog.list_vars(): + if isinstance(var, framework.Parameter) or var.persistable: + ten = fluid.global_scope().find_var(var.name).get_tensor() + ten.set(np.zeros_like(np.array(ten)), place) + + new_t = np.array(fluid.global_scope().find_var(var.name) + .get_tensor()) + self.assertTrue(np.sum(np.abs(new_t)) == 0) paddle.fluid.load(prog, path) - result_load = exe.run(program=prog, - feed={"static_save_load_large_x": inputs}, - fetch_list=[z.name]) - # compare results before and after saving - self.assertTrue( - np.sum(np.abs(result_z[0] - result_load[0])) < 1e-15) + + for var in prog.list_vars(): + if isinstance(var, framework.Parameter) or var.persistable: + new_t = np.array(fluid.global_scope().find_var(var.name) + .get_tensor()) + base_t = base_map[var.name] + self.assertTrue(np.array_equal(new_t, base_t)) class TestProgramStateOldSaveSingleModel(unittest.TestCase): diff --git a/python/paddle/framework/io.py b/python/paddle/framework/io.py index 66f843dc05..2dfad8dc10 100644 --- a/python/paddle/framework/io.py +++ b/python/paddle/framework/io.py @@ -19,6 +19,7 @@ import collections import pickle import six import warnings +import sys import paddle @@ -262,8 +263,17 @@ def save(obj, path): saved_obj = _build_saved_state_dict(obj) saved_obj = _unpack_saved_dict(saved_obj) - with open(path, 'wb') as f: - pickle.dump(saved_obj, f, protocol=2) + # When value of dict is lager than 4GB ,there is a Bug on 'MAC python3.5/6' + if sys.platform == 'darwin' and sys.version_info.major == 3 and ( + sys.version_info.minor == 5 or sys.version_info.minor == 6): + pickle_bytes = pickle.dumps(saved_obj, protocol=2) + with open(path, 'wb') as f: + max_bytes = 2**30 + for i in range(0, len(pickle_bytes), max_bytes): + f.write(pickle_bytes[i:i + max_bytes]) + else: + with open(path, 'wb') as f: + pickle.dump(saved_obj, f, protocol=2) def load(path, **configs): diff --git a/tools/__pycache__/static_mode_white_list.cpython-35.pyc b/tools/__pycache__/static_mode_white_list.cpython-35.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7dae6374903af069c516bb436c2fcea7f761dcd1 GIT binary patch literal 19792 zcmeHPb(}m$m48)O2n2$KV8I~)lJNWeLU1QQAOuOUMVj9E_RKarGd<~^+5K&BcXvC^ zIqu-x{ovdMxyvbb?tYwb-&gOo%N_;K7222p%eUnBd`p zM+hD%c$DBG!J`F_5j<9KvEXrn#|thIJV9`&;E94K37#x?ir_NA<$|XQo+fy@;2DBv z3bq8#5#M%EqIOKwSw0PUN3lq;EjSe z3EnJti{Pz-w+Y@Zc!%Jff_DktEqIUMy@K}%-Y@uo;Ddq>3BE?~VZql5zE1GZwh`(@F~Hk1)mXoR`A<`-w|9T_+7#834UMj2ZBEo{E^^ug3k-SAo!x-OM*WZ z{E6UC1%D>^bHQH-{!;K)g1;90jo@zue<%2R!9NK8QSeWKe-`|U;9mv*Cit@8-v$36 z_)o!q3BE%44+sG<;93FK4!BOhbpx&!a7w`S15ORNLBI_IP7AnEz>Nb=54cIdO#^Ng zaPxp$1e_6Y%YZWjZWVCrfZGJzHsE#vX9e6o;0^(I47gLkodfO?aMysl1>8O0?0|Cu z?h$a$fO`d;8*pC0`2qJ1I2Ld`;6%U$0rv^GFyOuc_Y1gxzyksv81SHg2M0VP;GqEz z3wU_IBLW^7@Th=`0v;Xkn1IIyTpaMYfX4@167Yn8O9P%5@T7nz2RtRZ0-hU?1w1bx57-VE1t6dZ7za!Ob^^+P=LhTtOao>C^MEQ~ z5pYF79nb_U16Bch0s8?b0}cYN40u7n3jVVe- zyf)x<0k02uL%q7g4fwHu9}oD6fS(Nbsen%ed@|st1AZppX9IpN z;O7H=A>bDSektIW1AZmoR|9@6;MW6wBj7g!ek8Q*KMMF{6niFj+o+alf`@s5ahM!YNH-4XAJcyGk}BHkbIfrt-Ad?-#n61Rj7B3^d(hZziP zov|#MWtNY}<-Ev7JNbNGOtWgSb)y6yRr684)bTCTxR@5RV!qrjnt0tS4yT$Q-X38(-wSKdlXQO(oGdD?VrWLZfXo~dVn>(o5T`jU@zCA4z z@w5b8>c-Cc6>%6UMiYr#de8h{Tdu-j*&<4=GP1vHUrheIOb#Zc$&ynTL z=C(Gm3FWEu2u9g%lhxeeMwbpgIP}JfW?J&--DSI>qb+jIE~^a1V!5Mex3Sw@vC8WL z+4OjZvRUSOeNLX9CiBr~Rp+Avg&0Tk(X!mjmt{52=JRY)=V(!pqnIvpMK$u~$aP); zH?-Z5&vbOV;b>XTO7c4!FHRiiKQ6qWUp>oDX8EFY+2VEz;Cf@Gj`wwF1mbYPqfyTG zCYf^O7&d2gxzDsm00B*zu0zIGDCZO1^45Mrl`CGWR-zSi5apozI%erIH#1aIh)G?h z`#Nws9Zf5W{xm-*>a9B&=4?S3E=Icuh_ z$9ws7l~%d#Q)lLhV$)`sNZ?0XjdZDt2?q^{$y+T{x;SxryTD#W644@1%=G?cJ}s|I z^5lfw&PTiZc@4kgcHWeutfs8Z^I75Ax1BFXJK4Od)8}uIHeQ2ed&k@D_#N_~ah7fu z6v2jegzHg|;pt*V&fy%-Yq#AlM=So9XQicB=sj*_pIe9TB|~L*-!7*)k@IQRsi4^@ zPM8f-p*9igK)%22W2EZlNN=)~gsiX{4Xo>Sxyb9KvUHXR!owLWwoNomhqtSf*=Sm( zv->#|!mB5_I?Jv@kkO+~0_m*hj=ssFPGnR9m{p|2j*|YDl!mLWpLQALP2sxEKhsHV z`G`3DTw;VncB*O7>?X_N8}6*;yIDbL(~p>rj}nYS6UcOp?Sqc(tEnNE+Y>9in~ z#*4X9jOv|F8{N_PIjUwPT0wmvEfe!id=y?zDXCmTt2Nv58&z|3E9Xd-f(P85R-@ha zP!;cXa!DS{^9j}5XiN{XcZ^mwX}(!n2Q8_>4Ug5Dg&v>2$?#1b72n)dznrDXkV{4NtT1lw;^K}-#-*2#6>kgco+%Ic(%5jncGfg8p^`=ZZ>&&2a zrSwjw)i#N`m)9*x>D`kwx;b_}lbd=Q)7FuWTr+y^GHz%x3L7U61}aJ99&TdjY+sw; zJ}!3&%%O8)Im&3KCi8Z@nl-C|=^O3oz?A^`GIxsWpVKNi5SpB+T46>huKI(?-J;;l zYHH-xZIdDm=BO0?tep4r(ekI{*k%ZbD0GHg^_CGw6={L^+#PTWz7vdY2rVV>q}>zHhNwZrwG}qfi@1<(w8N z1fXDS7)A448)yl?Pz`TfH8v0?^(vcFHB9_BCR3cDvq$dDFgb;7dUNC1#&U|=s;+r) zCF_wN*x#Y)*I6?P1+TOg44a+}lNpqIgWU#M&Uji;k*{vHBR0QvcvFw#j+kJOck~Ko zK1WZik{2~@4iDa1>BApcR4JN1#n?Q{TWYH^Bmy&TTbCoV^u=sJP1j^ik78o07Wg1z zc2#@)&;yRU9u$_rImJ(sM;hpAFig>QGgS=1D|$?K&e@0>U3hrkA$e3?PXrUlYcpS7w1y_RgdF)z)))rJM;Q(KH-n zBv^$EV|bptBU44g$LIQW)bZrb*gprOl5S&^5L|oE~YD9pM{B3 zD$)}h&K*tWxf+?8F{>o?H3;6<{Gx6{l-V}3yDL&OqO@=|Y3JOKQH9#n)rDl$XhB1r z5Eof|$qg0OB6d3|ux2$NMKIhc9i6MX+b9ChknW%b5thAL95YM4GjX&S;k73HaCceX z!lo&>+{?>q=GxmGymZ6tFY_Xroqduocc@>>x)>OmNxot}pY!xIvQ>&+%^7-`L9@ta zml?=*VF{DzUpLz(Qv(bm{UMB$Zdq{DmoWN})-`iNyJG zhKhBypU{k7og3BE&xJ3QUYdNycsFZ>+|_z8l~Uez69&g!DrshIQV(DmiwI`0T`l~7DPw3?+%@Y)f^ON>bk=iF{PCQwAHp}#0oRMm9yhrosU41xpKykixI7+ z*1l4;#PxUp1~Z7!*a>>sYnzS@u+e8B`dD&p)yty^Xc;7 z@a8Eq*PM%jv3V2o&gFbm+sP~hDOr;|nYq+cV%DHZcT~7sj^<tEoSR6p zmX^Eom2B_$=p)-jcg-5z6JdKAe?~X%CYfHFH@WMoI3>tV7B2LkfsV_ZH%XUP8=5>`pa7&|;W!$85-M zO}D|3U155xyRES_S>>#mRh8B5ZP7HRzw5S5wWDsG+1~+d#?LbPC%QS2!PNHn*o9d? zZ)RmM&*YKRwd+q=$(Uz->xPOD{brqRxXlOY)s z9??8Y8fnt#k!_0byE#@Mm3KKaU<$`<;2%^e^MzLI^}2opg{4yXw3ttfiYCmdeU!8I zn7T!7kb1rLRn?+#hHAN?WuCEEAuiY+Miasc5p#i-YSo}0<rfSjD5Fim?=z| zmJrLi?d5hS97icUwCm0CdcrJ`zMU@ICCj9Z(Sdp=##77%{CI>h%hIjQKl#``-7Txj zlPz8Co3nz(tgKcZ?0EgAE;P6+PL8H4meW&CqWOdGY&9OUwk?SHMphy*%m_Funw_?H zVytIhqhM%WGp0M4oO;6cW0Z4d0UcdtrMX|TLbP3UleBO7LuW>kQ+ORQreV>L{PO z8r#5IKx|N{+SYC-*#V`oZKc^iMy*PMYqVmSFZ1J>*tURSN=}${lUi?* z&LEpv*{Dq%Nq+0SX+hD|rVmZbgmz;xBqgX!9Wz#d_T9^92bbHDv`msUTl08wt>deg zOleKqe)SMbV=1;W$fLGS0WEo(VA%JMLSdzMHTT%!@O;uv-mKA9ulB}OQ;(;>a;mj_ zlgLGRl1yxN7;sGsm-2zKmZZU z?4?A~tBWoZY(gI8^&zTP=d|5H3Wf7qnwM7R4 zPcnOA{i^Sbbm%r?wRUGz-8M4xK2~i@hiDnuH9B;6sRGPzE!s{CuQb~(=P6zD zh^wE|=we#zCApwQcdOgaFSM(lJf}wpgQ*q0-QtMvc@=Z2`ezY;zsNI$E*tl}r$=D1 zgCSjOX*`b@R>*pi@QKW-@at9>cNOm`i<21zCcVM4TUCewEa| zvp$OA{PA|4LQH4r){@V4>akHktGU;q+fdQ{L~Ok?_0?jI7A@>ai3>O>o`$=-avWz1}&A z@?_3T-{Y($;+C8DR5rubqK+&2gDxm@@;HypcMS4sA`c_=MmxDYmEX*=S?0PaJoTHc z0TtN-=c}(9D30G=g<9X;wDG1dT`*6y&XRUI2I{LzO;ehL<2i}*A%TeDJX$0*^i!tF zYgW)>@I_)pi|d5;iYy<~wP*0n*dy2-oYcx?eO1E0F?>|eP?(vpnNHe(eJ?(+vCAG( zFuTYSEElt#nZ(uTtRPdJWE$?0IKM2kg~kHVl1Uq7xn*w~ZQDJUtYuW`B~$y7!&GkA zP0}iujon(X?~%A{H`mT{y15VeoG{Ir(N5M`15xKgGDqogOvjp3#a4UMbBq1 z)XIcTL7TM)KkMdM&$iI`@G z_lt6}qs=ePx7t^BCQtk1y*|KW`Ow!IggYC;K{c(~hZ`nF3fVi+X}(#r{-thzD1&h) zJB7Z}s0ApxOqOyHAd5oJhy|d6>1xO5?~(W5d2X6)Rbm`XNiin9X7#- zR+|YNps`LsA3x{2FHfrnbJd7_?6wq;CC79a$EFdsQn{FR&r{#_E=MzDwSPCmA7Bg#W_ z^r9Bso2@+}htp$OQQIrzUlk|I`3`_dUA4d5@w=u5cWX|KkJ|0aSbOGcWa|@}J3G8H zAGjKuv^Bym%3`&3@AVwAjd|IYXi-xM#luNmasU6>igc(!`%{;iDM+Q%=4$hL>$z)= z?JVe;K*gNpj#OyvRG^=$^+trd*KJ>^kIR}BqGnli-oKq;3&|(iP{g0RYvZr=p)In9 zu!q4ZMBV&o>z;WWOE9}*(*EyuZ?ZLEGs~q%NgZkJ=Cl*q0KJe z()Q~Ht?cBaHhvn>Z(un-2p4MejvMNU$#L}&XaG?LdI7_evGnHD=E=QUlUSo;H?cKh@k$Ap7b zbVl9$rxwmVMC$ajE4f<`pao-hwd%wZEBHvmT;|{{d*78YALO5zZJ zUw6cI`o?%;Nw=NOZG6h8sORNezg;ig1ozD75&CtkO!Z*9@ZkMZKV^*B5c$DU6XMPJ zjjJ5qM2q_UM-FeI^BbFVn$WN7M1BL+%-Kf#(l4{)Y{osC^DbtZgc;kYW=qQmkj%IEsH8_%a=(pRPNxGIyo3#}CHB_&G!rC#s z+PU$ywYiNW+O2$De+)Z0?TEc7?k9|>b^qA7pE5SL-mQn5!ng2^@4I(pGS@Hg{E)B> zWD}TPV>a%&IlZyEO@7*+6!ago$kxZnybON2)pOw7&j{?c=|#q(qS12U(y-ZX zF(9uAqB77spI(C^V6!x7;hBkp8DCw-*NrBxiCh1}WN-3+Pu+Cs|1M9|6>j61WqSd( zONQ(cLEuzC=HuC4Z8^(9PO?cTz%Oge{Bu7=SkJ$BJY;atV@rxN;LFMqb` z=(c}nIuo852#2LQUIN?IT;~(Sa=2Ol+#^AaL7v$U(T!$4VGZ5EHtw#Dw-+&SwxWGP zGyl-`fQBIdNp=F6|2pEeY`Vy$=Y*@yyzKIaUw-i;vnM?9Q5R*GUv$~!7hn3A?2?Nw zyZob4Uz`6oJn0`_e#H4TpMjmPEz{=wCy~I@;@ngU+Ab$PK7YxgVspONkF*4pv3jIs vME1STTO3?<%7X~xyRC