提交 821acdb3 编写于 作者: T tangwei12

update op to trianer and pserver

上级 f688652f
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
from __future__ import print_function from __future__ import print_function
import os
import math import math
import distributed_splitter as splitter import distributed_splitter as splitter
...@@ -26,6 +27,10 @@ LOOKUP_TABLE_TYPE = "lookup_table" ...@@ -26,6 +27,10 @@ LOOKUP_TABLE_TYPE = "lookup_table"
LOOKUP_TABLE_GRAD_TYPE = "lookup_table_grad" LOOKUP_TABLE_GRAD_TYPE = "lookup_table_grad"
RPC_CLIENT_VAR_NAME = "RPC_CLIENT_VAR" RPC_CLIENT_VAR_NAME = "RPC_CLIENT_VAR"
# for checkpoint
SUCCESS = "_SUCCESS"
SERIAL_VAR_NAME = "SERIAL_NUMBER"
class VarBlock: class VarBlock:
def __init__(self, varname, offset, size): def __init__(self, varname, offset, size):
...@@ -153,7 +158,8 @@ class DistributeTranspiler: ...@@ -153,7 +158,8 @@ class DistributeTranspiler:
pservers="127.0.0.1:6174", pservers="127.0.0.1:6174",
trainers=1, trainers=1,
split_method=splitter.round_robin, split_method=splitter.round_robin,
sync_mode=True): sync_mode=True,
checkpoint_dir=None):
""" """
Transpile the program to distributed data-parallelism programs. Transpile the program to distributed data-parallelism programs.
The main_program will be transformed to use a remote parameter server The main_program will be transformed to use a remote parameter server
...@@ -315,22 +321,22 @@ class DistributeTranspiler: ...@@ -315,22 +321,22 @@ class DistributeTranspiler:
"sync_mode": self.sync_mode "sync_mode": self.sync_mode
}) })
serial_var = program.global_block().create_var( if checkpoint_dir and self.is_chief:
name="SERIAL_NUMBER", program.global_block().create_var(
persistable=True, name=SERIAL_VAR_NAME,
type=core.VarDesc.VarType.RAW) persistable=True,
type=core.VarDesc.VarType.RAW)
save_vars = [] save_vars = []
for var in self.origin_program.list_vars(): for var in self.origin_program.list_vars():
if self.is_persistable(var): if self._is_persistable(var):
save_vars.append(var.name) save_vars.append(var.name)
program.global_block().append_op( program.global_block().append_op(
type="checkpoint_save", type="checkpoint_save",
inputs={"X": save_vars}, inputs={"X": save_vars},
outputs={"Serial": serial_var}, attrs={"overwrite": True,
attrs={"overwrite": False, "dir": checkpoint_dir})
"dir": "/workspace/ckpt/"})
# step4: Concat the parameters splits together after recv. # step4: Concat the parameters splits together after recv.
for varname, splited_var in param_var_mapping.iteritems(): for varname, splited_var in param_var_mapping.iteritems():
...@@ -512,13 +518,6 @@ class DistributeTranspiler: ...@@ -512,13 +518,6 @@ class DistributeTranspiler:
pserver_program.sync_with_cpp() pserver_program.sync_with_cpp()
return pserver_program return pserver_program
def is_persistable(self, var):
if var.desc.type() == core.VarDesc.VarType.FEED_MINIBATCH or \
var.desc.type() == core.VarDesc.VarType.FETCH_LIST or \
var.desc.type() == core.VarDesc.VarType.RAW :
return False
return var.persistable
def get_train_startup_program(self, checkpoint_load_dir=None): def get_train_startup_program(self, checkpoint_load_dir=None):
""" """
Get train startup program. Get train startup program.
...@@ -532,13 +531,16 @@ class DistributeTranspiler: ...@@ -532,13 +531,16 @@ class DistributeTranspiler:
load_vars = [] load_vars = []
for var in startup_prog.list_vars(): for var in startup_prog.list_vars():
if self.is_persistable(var): if self._is_persistable(var):
load_vars.append(var.name) load_vars.append(var.name)
serial_number = self._get_lastest_checkpoint_dir(checkpoint_load_dir)
startup_prog.global_block().append_op( startup_prog.global_block().append_op(
type="checkpoint_load", type="checkpoint_load",
outputs={"Out": load_vars}, inputs={"X": load_vars},
attrs={"dir": checkpoint_load_dir}) attrs={"dir": checkpoint_load_dir,
"Serial": serial_number})
return startup_prog return startup_prog
def get_startup_program(self, def get_startup_program(self,
...@@ -599,16 +601,59 @@ class DistributeTranspiler: ...@@ -599,16 +601,59 @@ class DistributeTranspiler:
attrs=op.attrs) attrs=op.attrs)
for var in new_outputs.values(): for var in new_outputs.values():
load_vars.append(var.name) load_vars.append(var.name)
# add checkpoint op # add checkpoint op
if not checkpoint_load_dir: if not checkpoint_load_dir:
return s_prog return s_prog
serial_number = self._get_lastest_checkpoint_dir(checkpoint_load_dir)
s_prog.global_block().append_op( s_prog.global_block().append_op(
type="checkpoint_load", type="checkpoint_load",
inputs={"X": load_vars}, inputs={"X": load_vars},
attrs={"dir": checkpoint_load_dir}) attrs={"dir": checkpoint_load_dir,
"Serial": serial_number})
return s_prog return s_prog
def _is_persistable(self, var):
"""only save LodTensor variable"""
if var.desc.type() == core.VarDesc.VarType.FEED_MINIBATCH or \
var.desc.type() == core.VarDesc.VarType.FETCH_LIST or \
var.desc.type() == core.VarDesc.VarType.RAW :
return False
return var.persistable
def _get_lastest_checkpoint_dir(self, checkpoint_dir):
"""
get the biggest number in checkpoint_dir, which has _SUCCESS
"""
if not checkpoint_dir.strip():
return ""
def has_success(checkpoint_dir, cur_dir):
"""
is _SUCCESS in this dir
"""
if not os.path.isdir(cur_dir):
return -1
try:
int(cur_dir)
except ValueError:
return -1
success_path = os.path.join(checkpoint_dir, cur_dir, SUCCESS)
if os.path.isfile(success_path):
return int(cur_dir)
current_dir = 0
dirs = os.listdir(checkpoint_dir)
for cur_dir in dirs:
success_num = has_success(checkpoint_dir, cur_dir)
if success_num > current_dir:
current_dir = success_num
return str(current_dir)
# transpiler function for dis lookup_table # transpiler function for dis lookup_table
def _replace_lookup_table_op_with_prefetch(self, program, rpc_client_var, def _replace_lookup_table_op_with_prefetch(self, program, rpc_client_var,
eplist): eplist):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册