diff --git a/python/paddle/fluid/transpiler/distribute_transpiler.py b/python/paddle/fluid/transpiler/distribute_transpiler.py index 6366ba8a58558c10706d7b257ab0d55ae3cf13da..104e2405322e96fa50091139d34826cee1cfae7f 100644 --- a/python/paddle/fluid/transpiler/distribute_transpiler.py +++ b/python/paddle/fluid/transpiler/distribute_transpiler.py @@ -315,10 +315,21 @@ class DistributeTranspiler: "sync_mode": self.sync_mode }) + serial_var = program.global_block().create_var( + name="SERIAL_NUMBER", + persistable=True, + type=core.VarDesc.VarType.RAW) + + save_vars = [] + for var in self.origin_program.list_vars(): + if self.is_persistable(var): + save_vars.append(var.name) + program.global_block().append_op( type="checkpoint_save", - inputs={"X": send_outputs}, - attrs={"overwrite": True, + inputs={"X": save_vars}, + outputs={"Serial": serial_var}, + attrs={"overwrite": False, "dir": "/workspace/ckpt/"}) # step4: Concat the parameters splits together after recv. @@ -501,6 +512,27 @@ class DistributeTranspiler: pserver_program.sync_with_cpp() return pserver_program + def is_persistable(self, var): + if var.desc.type() == core.VarDesc.VarType.FEED_MINIBATCH or \ + var.desc.type() == core.VarDesc.VarType.FETCH_LIST or \ + var.desc.type() == core.VarDesc.VarType.RAW : + return False + return var.persistable + + def get_train_startup_program(self, checkpoint_load_dir=None): + startup_prog = default_startup_program() + + if not checkpoint_load_dir: + return startup_prog + + for var in startup_prog.list_vars(): + if self.is_persistable(var): + print("var: %s" % var.name) + + startup_prog.global_block().append_op( + type="checkpoint_load", attrs={"dir": checkpoint_load_dir}) + return startup_prog + def get_startup_program(self, endpoint, pserver_program): """ Get startup program for current parameter server.