提交 821acdb3 编写于 作者: T tangwei12

update op to trianer and pserver

上级 f688652f
......@@ -14,6 +14,7 @@
from __future__ import print_function
import os
import math
import distributed_splitter as splitter
......@@ -26,6 +27,10 @@ LOOKUP_TABLE_TYPE = "lookup_table"
LOOKUP_TABLE_GRAD_TYPE = "lookup_table_grad"
RPC_CLIENT_VAR_NAME = "RPC_CLIENT_VAR"
# for checkpoint
SUCCESS = "_SUCCESS"
SERIAL_VAR_NAME = "SERIAL_NUMBER"
class VarBlock:
def __init__(self, varname, offset, size):
......@@ -153,7 +158,8 @@ class DistributeTranspiler:
pservers="127.0.0.1:6174",
trainers=1,
split_method=splitter.round_robin,
sync_mode=True):
sync_mode=True,
checkpoint_dir=None):
"""
Transpile the program to distributed data-parallelism programs.
The main_program will be transformed to use a remote parameter server
......@@ -315,22 +321,22 @@ class DistributeTranspiler:
"sync_mode": self.sync_mode
})
serial_var = program.global_block().create_var(
name="SERIAL_NUMBER",
if checkpoint_dir and self.is_chief:
program.global_block().create_var(
name=SERIAL_VAR_NAME,
persistable=True,
type=core.VarDesc.VarType.RAW)
save_vars = []
for var in self.origin_program.list_vars():
if self.is_persistable(var):
if self._is_persistable(var):
save_vars.append(var.name)
program.global_block().append_op(
type="checkpoint_save",
inputs={"X": save_vars},
outputs={"Serial": serial_var},
attrs={"overwrite": False,
"dir": "/workspace/ckpt/"})
attrs={"overwrite": True,
"dir": checkpoint_dir})
# step4: Concat the parameters splits together after recv.
for varname, splited_var in param_var_mapping.iteritems():
......@@ -512,13 +518,6 @@ class DistributeTranspiler:
pserver_program.sync_with_cpp()
return pserver_program
def is_persistable(self, var):
if var.desc.type() == core.VarDesc.VarType.FEED_MINIBATCH or \
var.desc.type() == core.VarDesc.VarType.FETCH_LIST or \
var.desc.type() == core.VarDesc.VarType.RAW :
return False
return var.persistable
def get_train_startup_program(self, checkpoint_load_dir=None):
"""
Get train startup program.
......@@ -532,13 +531,16 @@ class DistributeTranspiler:
load_vars = []
for var in startup_prog.list_vars():
if self.is_persistable(var):
if self._is_persistable(var):
load_vars.append(var.name)
serial_number = self._get_lastest_checkpoint_dir(checkpoint_load_dir)
startup_prog.global_block().append_op(
type="checkpoint_load",
outputs={"Out": load_vars},
attrs={"dir": checkpoint_load_dir})
inputs={"X": load_vars},
attrs={"dir": checkpoint_load_dir,
"Serial": serial_number})
return startup_prog
def get_startup_program(self,
......@@ -603,12 +605,55 @@ class DistributeTranspiler:
if not checkpoint_load_dir:
return s_prog
serial_number = self._get_lastest_checkpoint_dir(checkpoint_load_dir)
s_prog.global_block().append_op(
type="checkpoint_load",
inputs={"X": load_vars},
attrs={"dir": checkpoint_load_dir})
attrs={"dir": checkpoint_load_dir,
"Serial": serial_number})
return s_prog
def _is_persistable(self, var):
"""only save LodTensor variable"""
if var.desc.type() == core.VarDesc.VarType.FEED_MINIBATCH or \
var.desc.type() == core.VarDesc.VarType.FETCH_LIST or \
var.desc.type() == core.VarDesc.VarType.RAW :
return False
return var.persistable
def _get_lastest_checkpoint_dir(self, checkpoint_dir):
"""
get the biggest number in checkpoint_dir, which has _SUCCESS
"""
if not checkpoint_dir.strip():
return ""
def has_success(checkpoint_dir, cur_dir):
"""
is _SUCCESS in this dir
"""
if not os.path.isdir(cur_dir):
return -1
try:
int(cur_dir)
except ValueError:
return -1
success_path = os.path.join(checkpoint_dir, cur_dir, SUCCESS)
if os.path.isfile(success_path):
return int(cur_dir)
current_dir = 0
dirs = os.listdir(checkpoint_dir)
for cur_dir in dirs:
success_num = has_success(checkpoint_dir, cur_dir)
if success_num > current_dir:
current_dir = success_num
return str(current_dir)
# transpiler function for dis lookup_table
def _replace_lookup_table_op_with_prefetch(self, program, rpc_client_var,
eplist):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册