提交 ed2129cc 编写于 作者: T tangwei12

revert distribute_transpiler.py

上级 01975ec1
......@@ -14,7 +14,6 @@
from __future__ import print_function
import os
import math
import distributed_splitter as splitter
......@@ -27,10 +26,6 @@ LOOKUP_TABLE_TYPE = "lookup_table"
LOOKUP_TABLE_GRAD_TYPE = "lookup_table_grad"
RPC_CLIENT_VAR_NAME = "RPC_CLIENT_VAR"
# for checkpoint
SUCCESS = "_SUCCESS"
SERIAL_VAR_NAME = "SERIAL_NUMBER"
class VarBlock:
def __init__(self, varname, offset, size):
......@@ -161,8 +156,7 @@ class DistributeTranspiler:
pservers="127.0.0.1:6174",
trainers=1,
split_method=splitter.round_robin,
sync_mode=True,
checkpoint_dir=None):
sync_mode=True):
"""
Transpile the program to distributed data-parallelism programs.
The main_program will be transformed to use a remote parameter server
......@@ -216,12 +210,6 @@ class DistributeTranspiler:
self.pserver_endpoints = pserver_endpoints
self.optimize_ops, params_grads = self._get_optimize_pass()
# is_chief (no.0 triner) for checkpoint
# the no.0 trainer will save all variables and its own reader offset to checkpoint
# other trianers will save its own reader offset to checkpoint
self._is_chief = trainer_id == 0
self.checkpoint_dir = checkpoint_dir
# process lookup_table_op
# 1. check all lookup_table_op is distributed
# 2. check all lookup_table_op share the same table.
......@@ -327,24 +315,6 @@ class DistributeTranspiler:
"epmap": eplist,
"sync_mode": self.sync_mode
})
if self.checkpoint_dir and self._is_chief:
program.global_block().create_var(
name=SERIAL_VAR_NAME,
persistable=True,
type=core.VarDesc.VarType.RAW)
save_vars = []
for var in self.origin_program.list_vars():
if self._is_persistable(var):
save_vars.append(var.name)
program.global_block().append_op(
type="checkpoint_save",
inputs={"X": save_vars},
attrs={"overwrite": True,
"dir": self.checkpoint_dir})
# step4: Concat the parameters splits together after recv.
for varname, splited_var in param_var_mapping.iteritems():
if len(splited_var) <= 1:
......@@ -525,37 +495,6 @@ class DistributeTranspiler:
pserver_program.sync_with_cpp()
return pserver_program
def get_train_startup_program(self):
"""
Get train startup program.
If self.checkpoint_dir is None, rerurn default startup program.
IF self.checkpoint_dir is Exist, add checkpoint_load op and load Var.
"""
startup_prog = default_startup_program()
if not self.checkpoint_dir:
return startup_prog
load_vars = []
for var in startup_prog.list_vars():
if self._is_persistable(var):
load_vars.append(var.name)
serial_number = self._get_lastest_checkpoint_dir(self.checkpoint_dir)
startup_prog.global_block().create_var(
name=SERIAL_VAR_NAME,
persistable=True,
type=core.VarDesc.VarType.RAW)
startup_prog.global_block().append_op(
type="checkpoint_load",
inputs={"X": load_vars},
outputs={"Argv": []},
attrs={"dir": self.checkpoint_dir,
"Serial": serial_number})
return startup_prog
def get_startup_program(self, endpoint, pserver_program):
"""
Get startup program for current parameter server.
......@@ -581,7 +520,6 @@ class DistributeTranspiler:
created_var_map[var.name] = tmpvar
# 2. rename op outputs
load_vars = []
for op in orig_s_prog.global_block().ops:
new_inputs = dict()
new_outputs = dict()
......@@ -609,70 +547,8 @@ class DistributeTranspiler:
inputs=new_inputs,
outputs=new_outputs,
attrs=op.attrs)
for var in new_outputs.values():
load_vars.append(var.name)
# add checkpoint op
if not self.checkpoint_dir:
return s_prog
serial_number = self._get_lastest_checkpoint_dir(self.checkpoint_dir)
s_prog.global_block().create_var(
name=SERIAL_VAR_NAME,
persistable=True,
type=core.VarDesc.VarType.RAW)
s_prog.global_block().append_op(
type="checkpoint_load",
inputs={"X": load_vars},
outputs={"Argv": []},
attrs={"dir": self.checkpoint_dir,
"Serial": serial_number})
return s_prog
def _is_persistable(self, var):
"""only save LodTensor variable"""
if var.desc.type() == core.VarDesc.VarType.FEED_MINIBATCH or \
var.desc.type() == core.VarDesc.VarType.FETCH_LIST or \
var.desc.type() == core.VarDesc.VarType.RAW :
return False
return var.persistable
def _get_lastest_checkpoint_dir(self, checkpoint_dir):
"""
get the biggest number in checkpoint_dir, which has _SUCCESS
"""
if not checkpoint_dir.strip():
return ""
def has_success(checkpoint_dir, cur_dir):
"""
is _SUCCESS in this dir
"""
if not os.path.isdir(os.path.join(checkpoint_dir, cur_dir)):
return -1
try:
int(cur_dir)
except ValueError:
return -1
success_path = os.path.join(checkpoint_dir, cur_dir, SUCCESS)
if os.path.isfile(success_path):
return int(cur_dir)
if not os.path.isdir(checkpoint_dir):
return "-1"
current_dir = 0
dirs = os.listdir(checkpoint_dir)
for cur_dir in dirs:
success_num = has_success(checkpoint_dir, cur_dir)
if success_num > current_dir:
current_dir = success_num
return str(current_dir)
# transpiler function for dis lookup_table
def _replace_lookup_table_op_with_prefetch(self, program, rpc_client_var,
eplist):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册