提交 ed2129cc 编写于 作者: T tangwei12

revert distribute_transpiler.py

上级 01975ec1
...@@ -14,7 +14,6 @@ ...@@ -14,7 +14,6 @@
from __future__ import print_function from __future__ import print_function
import os
import math import math
import distributed_splitter as splitter import distributed_splitter as splitter
...@@ -27,10 +26,6 @@ LOOKUP_TABLE_TYPE = "lookup_table" ...@@ -27,10 +26,6 @@ LOOKUP_TABLE_TYPE = "lookup_table"
LOOKUP_TABLE_GRAD_TYPE = "lookup_table_grad" LOOKUP_TABLE_GRAD_TYPE = "lookup_table_grad"
RPC_CLIENT_VAR_NAME = "RPC_CLIENT_VAR" RPC_CLIENT_VAR_NAME = "RPC_CLIENT_VAR"
# for checkpoint
SUCCESS = "_SUCCESS"
SERIAL_VAR_NAME = "SERIAL_NUMBER"
class VarBlock: class VarBlock:
def __init__(self, varname, offset, size): def __init__(self, varname, offset, size):
...@@ -161,8 +156,7 @@ class DistributeTranspiler: ...@@ -161,8 +156,7 @@ class DistributeTranspiler:
pservers="127.0.0.1:6174", pservers="127.0.0.1:6174",
trainers=1, trainers=1,
split_method=splitter.round_robin, split_method=splitter.round_robin,
sync_mode=True, sync_mode=True):
checkpoint_dir=None):
""" """
Transpile the program to distributed data-parallelism programs. Transpile the program to distributed data-parallelism programs.
The main_program will be transformed to use a remote parameter server The main_program will be transformed to use a remote parameter server
...@@ -216,12 +210,6 @@ class DistributeTranspiler: ...@@ -216,12 +210,6 @@ class DistributeTranspiler:
self.pserver_endpoints = pserver_endpoints self.pserver_endpoints = pserver_endpoints
self.optimize_ops, params_grads = self._get_optimize_pass() self.optimize_ops, params_grads = self._get_optimize_pass()
# is_chief (no.0 triner) for checkpoint
# the no.0 trainer will save all variables and its own reader offset to checkpoint
# other trianers will save its own reader offset to checkpoint
self._is_chief = trainer_id == 0
self.checkpoint_dir = checkpoint_dir
# process lookup_table_op # process lookup_table_op
# 1. check all lookup_table_op is distributed # 1. check all lookup_table_op is distributed
# 2. check all lookup_table_op share the same table. # 2. check all lookup_table_op share the same table.
...@@ -327,24 +315,6 @@ class DistributeTranspiler: ...@@ -327,24 +315,6 @@ class DistributeTranspiler:
"epmap": eplist, "epmap": eplist,
"sync_mode": self.sync_mode "sync_mode": self.sync_mode
}) })
if self.checkpoint_dir and self._is_chief:
program.global_block().create_var(
name=SERIAL_VAR_NAME,
persistable=True,
type=core.VarDesc.VarType.RAW)
save_vars = []
for var in self.origin_program.list_vars():
if self._is_persistable(var):
save_vars.append(var.name)
program.global_block().append_op(
type="checkpoint_save",
inputs={"X": save_vars},
attrs={"overwrite": True,
"dir": self.checkpoint_dir})
# step4: Concat the parameters splits together after recv. # step4: Concat the parameters splits together after recv.
for varname, splited_var in param_var_mapping.iteritems(): for varname, splited_var in param_var_mapping.iteritems():
if len(splited_var) <= 1: if len(splited_var) <= 1:
...@@ -525,37 +495,6 @@ class DistributeTranspiler: ...@@ -525,37 +495,6 @@ class DistributeTranspiler:
pserver_program.sync_with_cpp() pserver_program.sync_with_cpp()
return pserver_program return pserver_program
def get_train_startup_program(self):
"""
Get train startup program.
If self.checkpoint_dir is None, rerurn default startup program.
IF self.checkpoint_dir is Exist, add checkpoint_load op and load Var.
"""
startup_prog = default_startup_program()
if not self.checkpoint_dir:
return startup_prog
load_vars = []
for var in startup_prog.list_vars():
if self._is_persistable(var):
load_vars.append(var.name)
serial_number = self._get_lastest_checkpoint_dir(self.checkpoint_dir)
startup_prog.global_block().create_var(
name=SERIAL_VAR_NAME,
persistable=True,
type=core.VarDesc.VarType.RAW)
startup_prog.global_block().append_op(
type="checkpoint_load",
inputs={"X": load_vars},
outputs={"Argv": []},
attrs={"dir": self.checkpoint_dir,
"Serial": serial_number})
return startup_prog
def get_startup_program(self, endpoint, pserver_program): def get_startup_program(self, endpoint, pserver_program):
""" """
Get startup program for current parameter server. Get startup program for current parameter server.
...@@ -581,7 +520,6 @@ class DistributeTranspiler: ...@@ -581,7 +520,6 @@ class DistributeTranspiler:
created_var_map[var.name] = tmpvar created_var_map[var.name] = tmpvar
# 2. rename op outputs # 2. rename op outputs
load_vars = []
for op in orig_s_prog.global_block().ops: for op in orig_s_prog.global_block().ops:
new_inputs = dict() new_inputs = dict()
new_outputs = dict() new_outputs = dict()
...@@ -609,70 +547,8 @@ class DistributeTranspiler: ...@@ -609,70 +547,8 @@ class DistributeTranspiler:
inputs=new_inputs, inputs=new_inputs,
outputs=new_outputs, outputs=new_outputs,
attrs=op.attrs) attrs=op.attrs)
for var in new_outputs.values():
load_vars.append(var.name)
# add checkpoint op
if not self.checkpoint_dir:
return s_prog
serial_number = self._get_lastest_checkpoint_dir(self.checkpoint_dir)
s_prog.global_block().create_var(
name=SERIAL_VAR_NAME,
persistable=True,
type=core.VarDesc.VarType.RAW)
s_prog.global_block().append_op(
type="checkpoint_load",
inputs={"X": load_vars},
outputs={"Argv": []},
attrs={"dir": self.checkpoint_dir,
"Serial": serial_number})
return s_prog return s_prog
def _is_persistable(self, var):
"""only save LodTensor variable"""
if var.desc.type() == core.VarDesc.VarType.FEED_MINIBATCH or \
var.desc.type() == core.VarDesc.VarType.FETCH_LIST or \
var.desc.type() == core.VarDesc.VarType.RAW :
return False
return var.persistable
def _get_lastest_checkpoint_dir(self, checkpoint_dir):
"""
get the biggest number in checkpoint_dir, which has _SUCCESS
"""
if not checkpoint_dir.strip():
return ""
def has_success(checkpoint_dir, cur_dir):
"""
is _SUCCESS in this dir
"""
if not os.path.isdir(os.path.join(checkpoint_dir, cur_dir)):
return -1
try:
int(cur_dir)
except ValueError:
return -1
success_path = os.path.join(checkpoint_dir, cur_dir, SUCCESS)
if os.path.isfile(success_path):
return int(cur_dir)
if not os.path.isdir(checkpoint_dir):
return "-1"
current_dir = 0
dirs = os.listdir(checkpoint_dir)
for cur_dir in dirs:
success_num = has_success(checkpoint_dir, cur_dir)
if success_num > current_dir:
current_dir = success_num
return str(current_dir)
# transpiler function for dis lookup_table # transpiler function for dis lookup_table
def _replace_lookup_table_op_with_prefetch(self, program, rpc_client_var, def _replace_lookup_table_op_with_prefetch(self, program, rpc_client_var,
eplist): eplist):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册