From eea5762e26a9a6ae2d9642830031028e5952af45 Mon Sep 17 00:00:00 2001 From: tangwei12 Date: Tue, 5 Jun 2018 17:04:17 +0800 Subject: [PATCH] add checkpoint unittest --- .../fluid/tests/unittests/test_checkpoint.py | 72 ++++++++++++++++++ tools/codestyle/docstring_checker.pyc | Bin 0 -> 12561 bytes 2 files changed, 72 insertions(+) create mode 100644 python/paddle/fluid/tests/unittests/test_checkpoint.py create mode 100644 tools/codestyle/docstring_checker.pyc diff --git a/python/paddle/fluid/tests/unittests/test_checkpoint.py b/python/paddle/fluid/tests/unittests/test_checkpoint.py new file mode 100644 index 00000000000..b8d82c59b4e --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_checkpoint.py @@ -0,0 +1,72 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle.fluid as fluid +import unittest + + +class TestCheckpoint(unittest.TestCase): + def setUp(self): + self.dirname = "/tmp/ckpt" + self.max_num_checkpoints = 3 + self.epoch_interval = 1 + self.step_interval = 1 + self.trainer_id = 0 + self.chief = self.trainer_id == 0 + self.place = fluid.CPUPlace() + self.epoch_id = 100 + self.step_id = 20 + + def test_checkpoint(self): + self.save_checkpoint() + serial = fluid.io.get_latest_checkpoint_serial(self.dirname) + self.assertTrue(serial >= 0) + trainer_args = ["epoch_id", "step_id"] + epoch_id, step_id = fluid.io.load_trainer_args( + self.dirname, serial, self.trainer_id, trainer_args) + self.assertEqual(self.step_id, step_id) + self.assertEqual(self.epoch_id, epoch_id) + + program = fluid.Program() + with fluid.program_guard(program): + exe = fluid.Executor(self.place) + fluid.io.load_checkpoint(exe, self.dirname, serial, program) + + fluid.io.clean_checkpoint(self.dirname, delete_dir=True) + + def save_checkpoint(self): + config = fluid.CheckpointConfig(self.dirname, self.max_num_checkpoints, + self.epoch_interval, self.step_interval) + + trainer_args = {} + trainer_args["epoch_id"] = self.epoch_id + trainer_args["step_id"] = self.step_id + + program = fluid.Program() + with fluid.program_guard(program): + program.global_block().create_var( + name="scale_0", + psersistable=True, + dtype="float32", + shape=[32, 32]) + + exe = fluid.Executor(self.place) + for i in xrange(10): + fluid.io.save_checkpoint( + exe, config.checkpoint_dir, self.trainer_id, self.chief, + trainer_args, program, config.max_num_checkpoints) + + +if __name__ == '__main__': + unittest.main() diff --git a/tools/codestyle/docstring_checker.pyc b/tools/codestyle/docstring_checker.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a27d3c9a8cccab8552d510578debb2df04eb53bb GIT binary patch literal 12561 zcmdT~%WoUU8J{I7N~9&rPx%oi-uOw^rX)Y&B&yDAw|-TX1?xj-N!7Warm=5bYj{N50^(C+V(2Hze+imT+b{yJug6!5>yc9NV zH>}%c|-bj7cn3o>=@Nnkurg-h%;-x9IB_yO%9{kseL9wF{k5 z(pgP9V*)px+w3Nfp1^NnrV?sO3cZ5b*n1kYj>Wa+;zyo8cGvtrf8pwU&3D_Q^{_Fv z=*MG=;h2uz%QP)JX z%ESef>E`Q%@rRw)@aX#}JkSJ40i_TbNhHf_XClMO0$BUR*40xSEj*&ytKSSK*#X5HqD5N+-*~tOdAmPz0z)XssT2wa8o|d$kuV zY>O+bR{f?QSF6XlDeg>J!&Zs|p5{7xrcf+)2pzON*==Nc8`<7QPEGpqPE~HZQXjap;KBS9)Z5cSyI%a3`ND`oLViZM`P=n*s;#+9nybV24qje+-}7YM7idHdS`*xPDk z%5E0DxGH-a#bGq>SXp^%(AsMit(}=V(MUQWoeEiA!10aneS(p|;WDJJQjZ~5hKS_V zJ_{I~p8K3w27Z@;adh{oIIAA@5&Muf4=okFt5%MyrkPAe@?!EFvkWPe^VE(QB;IM&)FRgwsrajM^2|QyuX|cAw1Z%+JAu%Bh_xau#(?mQCP&H^ot;DkHCbmRhq_mb zERy3*IJ*g;QtO;^!Xf$$_%o>A869#(2w<90a%*v?Tr;#NWuW(v7N|6rkTvXJGk!z-=hkssPg*YN0zxHRA? z^}4()-EKNt#?6Ai4<9Q_haScby(>8T1?8ieS00-E>Hu8#1<+p#2{K?(QJtM$9%Kes9_A2+RRQ4(N zoK*HJw=9(r`jcZs-VGGfp*Ge_5*(lj(^{h|-f`&>kPwSUoOl{1W_%KgkSKlA)@2n>*M3+H92UyyOJAv8rEKH!=kZi$k@k5kJ&9R@Fw6Yfd9!*1yhYKp(I zsh4G>rgey0-(+yQ(q(-DgYJiIZnhn`Ho{2u2Z>N$5x(l1! z>%5GcsP(*^xFWr?qzk_)czkMMF*82sI5E9%iWh0_GLlh(qs^B8;yKdQUlSKJi~aMt zjz|9hg|RuQ%MtlAUMS0qItcmt5U%5Cn&uV4a!Oy42IhPso+>pwd^_or_u;F-!Gi$+ zXElTQ#CKcYhA)}+OT3iZ(VwFD(1%E=F0`BC%^`AW~Tm51wQ0p`mCnuP&qm_3#v-d<`Ufnx6*&oh<9({{~(t+4=N+Wb%5^3MR&`%tKXq0ESt%M|3<1&3*uB}7XLkB<;7=U7LyZa0ic z`_6d{g>h$zC+7;Ac(yd+%$=Bn8#C@#Xm|Xj&gs6{C#2djCzPl)xcApzc1B*(x?5xe zav_mJXcz~BLJ19oh)hqZR?D22G7NBp6MPv;_j^2v3LO^ZIhYx;jzE1Lu?9@)fy#xa&T`}$3PZRc;h(_U zDhl}?67q$$2T&p>5Wyca(46`)Y!yh4jQSDY*rBE|groA3^?A>bK14Ng4sg8(xB&A8 zkSI(tWGspkAyn!@{YMlX0o@@|6BtiyyAtw2_N0Aj{0#k7IBFOzK3dmWdk%X~48e}k z*PY?HsBcqqkJn$P*+B#oxfk(h#&)2#=JqjQ25cH34QzeSS`n$-$-DAR5ta*}CXt3f zky4}M@~;_~Frp0*7gAeewuyT?x#%^$)mC)nuiS?rC(Z2_r5qAaamfE1$pnWUa4$Cs z)jd7f=ntENyIi^#O-Tmmc4H%Va#@Mw$>f5PE{3)SInH(L=}&kP{wpBHOqh;JOaB2y*AVJqyAE$jc%&1@uuqSkw!6IYjNFD$g>loswf_%>6OLgq>M*R) zOJ($?jIdr$O^yiu7tkx-7tUw$RBv;9SM+a2To|!MOy^8v>A!J#ru9@m<@yVF4nn4b zrpQq=qbSo!f%^L>Vlg{b$AthaSQ{s4df@(HdVsh{RvBiiR(3I83^NQy${jSAnb+3| zs343^ryll2N`Q#&e zZo;9w`8H!F_`DEFGO-3=q;s?UVh1+C4pmvb$+f2E-jpwe&tE+{{)gYL@Ilj2=U9FY{SOn#q;{1Uv6bQ6Z51?er5457O9WZD^0oWdh_kY ziR9ZJ99vJ=&PBkxj3uTb!pBG=m6+5y1B)q!i)@$Z8%feyqEyV)*O&ajjl5J6$P@8n!7>8nbH&Hg6BL@@;HxoP0{@ zq6YHT2tKGAMy7lpZES>!Pgn)9i1^(NecpkLQHA`r*%@WQl#Mgaf)<|n8ma2@%S}o0 zAfrw%|IY2ot-H5w&rH2jefLJCGIjGt_13h67r%Ss{=M7IbPw3M&xvSy>9xju7>HTf zLMAZs<2AXDhQ>Fdlp0}6zD#r8MBTZ{Vu}UBpUyQFbl=5?5QER58g(d~9G<+Z)upyY z9uDoy5Rd*2ihMS26`<0o+WDKs=Z!^t(NsjOU&`I6@iDD_YlxZt1Cry<7Yc(Tg^|ME z!r{VU`IJvC#<+X_s!=Y(SjFV0M>{EDsvqTm43wl0l3w|kd;uvzZElvIP}RdANWKP? y@0p!5e3LgJ_hn2!lV0JmwgrNPd?rm|;L1+yD)~TS&>BQOgn5v`Y@u)D&VK>kVxyJ- literal 0 HcmV?d00001 -- GitLab