From 821acdb3bffdf0594d4bf94a4cddc47c2c681ca6 Mon Sep 17 00:00:00 2001 From: tangwei12 Date: Fri, 18 May 2018 11:18:22 +0800 Subject: [PATCH] update op to trianer and pserver --- .../fluid/transpiler/distribute_transpiler.py | 99 ++++++++++++++----- 1 file changed, 72 insertions(+), 27 deletions(-) diff --git a/python/paddle/fluid/transpiler/distribute_transpiler.py b/python/paddle/fluid/transpiler/distribute_transpiler.py index 84cfc6e0117..4e157187711 100644 --- a/python/paddle/fluid/transpiler/distribute_transpiler.py +++ b/python/paddle/fluid/transpiler/distribute_transpiler.py @@ -14,6 +14,7 @@ from __future__ import print_function +import os import math import distributed_splitter as splitter @@ -26,6 +27,10 @@ LOOKUP_TABLE_TYPE = "lookup_table" LOOKUP_TABLE_GRAD_TYPE = "lookup_table_grad" RPC_CLIENT_VAR_NAME = "RPC_CLIENT_VAR" +# for checkpoint +SUCCESS = "_SUCCESS" +SERIAL_VAR_NAME = "SERIAL_NUMBER" + class VarBlock: def __init__(self, varname, offset, size): @@ -153,7 +158,8 @@ class DistributeTranspiler: pservers="127.0.0.1:6174", trainers=1, split_method=splitter.round_robin, - sync_mode=True): + sync_mode=True, + checkpoint_dir=None): """ Transpile the program to distributed data-parallelism programs. The main_program will be transformed to use a remote parameter server @@ -315,22 +321,22 @@ class DistributeTranspiler: "sync_mode": self.sync_mode }) - serial_var = program.global_block().create_var( - name="SERIAL_NUMBER", - persistable=True, - type=core.VarDesc.VarType.RAW) + if checkpoint_dir and self.is_chief: + program.global_block().create_var( + name=SERIAL_VAR_NAME, + persistable=True, + type=core.VarDesc.VarType.RAW) - save_vars = [] - for var in self.origin_program.list_vars(): - if self.is_persistable(var): - save_vars.append(var.name) + save_vars = [] + for var in self.origin_program.list_vars(): + if self._is_persistable(var): + save_vars.append(var.name) - program.global_block().append_op( - type="checkpoint_save", - inputs={"X": save_vars}, - outputs={"Serial": serial_var}, - attrs={"overwrite": False, - "dir": "/workspace/ckpt/"}) + program.global_block().append_op( + type="checkpoint_save", + inputs={"X": save_vars}, + attrs={"overwrite": True, + "dir": checkpoint_dir}) # step4: Concat the parameters splits together after recv. for varname, splited_var in param_var_mapping.iteritems(): @@ -512,13 +518,6 @@ class DistributeTranspiler: pserver_program.sync_with_cpp() return pserver_program - def is_persistable(self, var): - if var.desc.type() == core.VarDesc.VarType.FEED_MINIBATCH or \ - var.desc.type() == core.VarDesc.VarType.FETCH_LIST or \ - var.desc.type() == core.VarDesc.VarType.RAW : - return False - return var.persistable - def get_train_startup_program(self, checkpoint_load_dir=None): """ Get train startup program. @@ -532,13 +531,16 @@ class DistributeTranspiler: load_vars = [] for var in startup_prog.list_vars(): - if self.is_persistable(var): + if self._is_persistable(var): load_vars.append(var.name) + serial_number = self._get_lastest_checkpoint_dir(checkpoint_load_dir) + startup_prog.global_block().append_op( type="checkpoint_load", - outputs={"Out": load_vars}, - attrs={"dir": checkpoint_load_dir}) + inputs={"X": load_vars}, + attrs={"dir": checkpoint_load_dir, + "Serial": serial_number}) return startup_prog def get_startup_program(self, @@ -599,16 +601,59 @@ class DistributeTranspiler: attrs=op.attrs) for var in new_outputs.values(): load_vars.append(var.name) - # add checkpoint op + # add checkpoint op if not checkpoint_load_dir: return s_prog + serial_number = self._get_lastest_checkpoint_dir(checkpoint_load_dir) + s_prog.global_block().append_op( type="checkpoint_load", inputs={"X": load_vars}, - attrs={"dir": checkpoint_load_dir}) + attrs={"dir": checkpoint_load_dir, + "Serial": serial_number}) + return s_prog + def _is_persistable(self, var): + """only save LodTensor variable""" + if var.desc.type() == core.VarDesc.VarType.FEED_MINIBATCH or \ + var.desc.type() == core.VarDesc.VarType.FETCH_LIST or \ + var.desc.type() == core.VarDesc.VarType.RAW : + return False + return var.persistable + + def _get_lastest_checkpoint_dir(self, checkpoint_dir): + """ + get the biggest number in checkpoint_dir, which has _SUCCESS + """ + if not checkpoint_dir.strip(): + return "" + + def has_success(checkpoint_dir, cur_dir): + """ + is _SUCCESS in this dir + """ + if not os.path.isdir(cur_dir): + return -1 + + try: + int(cur_dir) + except ValueError: + return -1 + + success_path = os.path.join(checkpoint_dir, cur_dir, SUCCESS) + if os.path.isfile(success_path): + return int(cur_dir) + + current_dir = 0 + dirs = os.listdir(checkpoint_dir) + for cur_dir in dirs: + success_num = has_success(checkpoint_dir, cur_dir) + if success_num > current_dir: + current_dir = success_num + return str(current_dir) + # transpiler function for dis lookup_table def _replace_lookup_table_op_with_prefetch(self, program, rpc_client_var, eplist): -- GitLab