From ed2129cc50b794f76574065430577e0303a6703d Mon Sep 17 00:00:00 2001 From: tangwei12 Date: Mon, 21 May 2018 16:57:40 +0800 Subject: [PATCH] revert distribute_transpiler.py --- .../fluid/transpiler/distribute_transpiler.py | 126 +----------------- 1 file changed, 1 insertion(+), 125 deletions(-) diff --git a/python/paddle/fluid/transpiler/distribute_transpiler.py b/python/paddle/fluid/transpiler/distribute_transpiler.py index 1d51ed457..42ff0a9eb 100644 --- a/python/paddle/fluid/transpiler/distribute_transpiler.py +++ b/python/paddle/fluid/transpiler/distribute_transpiler.py @@ -14,7 +14,6 @@ from __future__ import print_function -import os import math import distributed_splitter as splitter @@ -27,10 +26,6 @@ LOOKUP_TABLE_TYPE = "lookup_table" LOOKUP_TABLE_GRAD_TYPE = "lookup_table_grad" RPC_CLIENT_VAR_NAME = "RPC_CLIENT_VAR" -# for checkpoint -SUCCESS = "_SUCCESS" -SERIAL_VAR_NAME = "SERIAL_NUMBER" - class VarBlock: def __init__(self, varname, offset, size): @@ -161,8 +156,7 @@ class DistributeTranspiler: pservers="127.0.0.1:6174", trainers=1, split_method=splitter.round_robin, - sync_mode=True, - checkpoint_dir=None): + sync_mode=True): """ Transpile the program to distributed data-parallelism programs. The main_program will be transformed to use a remote parameter server @@ -216,12 +210,6 @@ class DistributeTranspiler: self.pserver_endpoints = pserver_endpoints self.optimize_ops, params_grads = self._get_optimize_pass() - # is_chief (no.0 triner) for checkpoint - # the no.0 trainer will save all variables and its own reader offset to checkpoint - # other trianers will save its own reader offset to checkpoint - self._is_chief = trainer_id == 0 - self.checkpoint_dir = checkpoint_dir - # process lookup_table_op # 1. check all lookup_table_op is distributed # 2. check all lookup_table_op share the same table. @@ -327,24 +315,6 @@ class DistributeTranspiler: "epmap": eplist, "sync_mode": self.sync_mode }) - - if self.checkpoint_dir and self._is_chief: - program.global_block().create_var( - name=SERIAL_VAR_NAME, - persistable=True, - type=core.VarDesc.VarType.RAW) - - save_vars = [] - for var in self.origin_program.list_vars(): - if self._is_persistable(var): - save_vars.append(var.name) - - program.global_block().append_op( - type="checkpoint_save", - inputs={"X": save_vars}, - attrs={"overwrite": True, - "dir": self.checkpoint_dir}) - # step4: Concat the parameters splits together after recv. for varname, splited_var in param_var_mapping.iteritems(): if len(splited_var) <= 1: @@ -525,37 +495,6 @@ class DistributeTranspiler: pserver_program.sync_with_cpp() return pserver_program - def get_train_startup_program(self): - """ - Get train startup program. - If self.checkpoint_dir is None, rerurn default startup program. - IF self.checkpoint_dir is Exist, add checkpoint_load op and load Var. - """ - startup_prog = default_startup_program() - - if not self.checkpoint_dir: - return startup_prog - - load_vars = [] - for var in startup_prog.list_vars(): - if self._is_persistable(var): - load_vars.append(var.name) - - serial_number = self._get_lastest_checkpoint_dir(self.checkpoint_dir) - - startup_prog.global_block().create_var( - name=SERIAL_VAR_NAME, - persistable=True, - type=core.VarDesc.VarType.RAW) - - startup_prog.global_block().append_op( - type="checkpoint_load", - inputs={"X": load_vars}, - outputs={"Argv": []}, - attrs={"dir": self.checkpoint_dir, - "Serial": serial_number}) - return startup_prog - def get_startup_program(self, endpoint, pserver_program): """ Get startup program for current parameter server. @@ -581,7 +520,6 @@ class DistributeTranspiler: created_var_map[var.name] = tmpvar # 2. rename op outputs - load_vars = [] for op in orig_s_prog.global_block().ops: new_inputs = dict() new_outputs = dict() @@ -609,70 +547,8 @@ class DistributeTranspiler: inputs=new_inputs, outputs=new_outputs, attrs=op.attrs) - for var in new_outputs.values(): - load_vars.append(var.name) - # add checkpoint op - if not self.checkpoint_dir: - return s_prog - - serial_number = self._get_lastest_checkpoint_dir(self.checkpoint_dir) - - s_prog.global_block().create_var( - name=SERIAL_VAR_NAME, - persistable=True, - type=core.VarDesc.VarType.RAW) - - s_prog.global_block().append_op( - type="checkpoint_load", - inputs={"X": load_vars}, - outputs={"Argv": []}, - attrs={"dir": self.checkpoint_dir, - "Serial": serial_number}) - return s_prog - def _is_persistable(self, var): - """only save LodTensor variable""" - if var.desc.type() == core.VarDesc.VarType.FEED_MINIBATCH or \ - var.desc.type() == core.VarDesc.VarType.FETCH_LIST or \ - var.desc.type() == core.VarDesc.VarType.RAW : - return False - return var.persistable - - def _get_lastest_checkpoint_dir(self, checkpoint_dir): - """ - get the biggest number in checkpoint_dir, which has _SUCCESS - """ - if not checkpoint_dir.strip(): - return "" - - def has_success(checkpoint_dir, cur_dir): - """ - is _SUCCESS in this dir - """ - if not os.path.isdir(os.path.join(checkpoint_dir, cur_dir)): - return -1 - - try: - int(cur_dir) - except ValueError: - return -1 - - success_path = os.path.join(checkpoint_dir, cur_dir, SUCCESS) - if os.path.isfile(success_path): - return int(cur_dir) - - if not os.path.isdir(checkpoint_dir): - return "-1" - - current_dir = 0 - dirs = os.listdir(checkpoint_dir) - for cur_dir in dirs: - success_num = has_success(checkpoint_dir, cur_dir) - if success_num > current_dir: - current_dir = success_num - return str(current_dir) - # transpiler function for dis lookup_table def _replace_lookup_table_op_with_prefetch(self, program, rpc_client_var, eplist): -- GitLab