提交 6588d2e9 编写于 作者: Y yi.wu

complete dist transpiler doc

上级 4c3eb448
...@@ -12,14 +12,6 @@ ...@@ -12,14 +12,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
""" """
Transpile the program to distributed data-parallelism programs.
The main_program will be transformed to use a remote parameter server
to do parameter optimization. And the optimization graph will be put
into a parameter server program.
Use different methods to split trainable variables to different
parameter servers.
Steps to transpile trainer: Steps to transpile trainer:
1. split variable to multiple blocks, aligned by product(dim[1:]) (width). 1. split variable to multiple blocks, aligned by product(dim[1:]) (width).
2. rename splited grad variables to add trainer_id suffix ".trainer_%d". 2. rename splited grad variables to add trainer_id suffix ".trainer_%d".
...@@ -118,128 +110,40 @@ def slice_variable(var_list, slice_count, min_block_size=8192): ...@@ -118,128 +110,40 @@ def slice_variable(var_list, slice_count, min_block_size=8192):
class DistributeTranspiler: class DistributeTranspiler:
def _has_distributed_lookup_table(self): """
# process lookup_table_op **DistributeTranspiler**
# 1. check all lookup_table_op is distributed
# 2. check all lookup_table_op share the same table. Convert the fluid program to distributed data-parallelism programs.
distributed_lookup_table_ops = []
# support only one distributed_lookup_table now The main_program will be transformed to use a remote parameter server
self.table_name = None to do parameter optimization. And the optimization graph will be put
for op in self.origin_program.global_block().ops: into a parameter server program.
if op.type == LOOKUP_TABLE_TYPE:
if op.attrs['is_distributed'] is True: Examples:
if self.table_name is None: .. code-block:: python
self.table_name = op.input("W")[0]
if self.table_name != op.input("W")[0]: # Define your model before these codes.
raise RuntimeError("all distributed lookup_table_ops" port = os.getenv("PADDLE_PSERVER_PORT", "6174")
" should have only one table") pserver_ips = os.getenv("PADDLE_PSERVER_IPS", "")
distributed_lookup_table_ops.append(op) eplist = []
else: for ip in pserver_ips.split(","):
if self.table_name is not None: eplist.append(':'.join([ip, port]))
assert op.input("W")[0] != self.table_name pserver_endpoints = ",".join(eplist)
trainers = int(os.getenv("PADDLE_TRAINERS"))
return len(distributed_lookup_table_ops) > 0 current_endpoint = os.getenv("PADDLE_CURRENT_IP", "") + ":" + port
trainer_id = int(os.getenv("PADDLE_TRAINER_ID", "0"))
def _update_dist_lookup_table_vars(self, param_list, grad_list, role = os.getenv("PADDLE_TRAINING_ROLE")
params_grads):
# TODO(wuyi): put find a way to put dist lookup table stuff all together. t = distribute_transpiler.DistributeTranspiler()
# update self.table_param_grad and self.trainer_side_table_grad_list t.transpile(
program = self.origin_program trainer_id, pservers=pserver_endpoints, trainers=trainers)
if self.has_distributed_lookup_table: if role == "PSERVER":
param_list = [ pserver_program = t.get_pserver_program(current_endpoint)
param for param in param_list if param.name != self.table_name pserver_startup_program = t.get_startup_program(current_endpoint,
] pserver_program)
grad_list = [ elif role == "TRAINER":
grad for grad in grad_list trainer_program = t.get_trainer_program()
if grad.name != grad_var_name(self.table_name) """
]
self.table_param_grad = [
param_grad for param_grad in params_grads
if param_grad[0].name == self.table_name
][0]
table_grad_var = self.table_param_grad[1]
if self.sync_mode:
self.trainer_side_table_grad_list = [
program.global_block().create_var(
name="%s.trainer_%d.pserver_%d" %
(table_grad_var.name, self.trainer_id, index),
type=table_grad_var.type,
shape=table_grad_var.shape,
dtype=table_grad_var.dtype)
for index in range(len(self.pserver_endpoints))
]
else:
self.trainer_side_table_grad_list = [
program.global_block().create_var(
name="%s.pserver_%d" % (table_grad_var.name, index),
type=table_grad_var.type,
shape=table_grad_var.shape,
dtype=table_grad_var.dtype)
for index in range(len(self.pserver_endpoints))
]
return param_list, grad_list
def _init_splited_vars(self, slice_var_up):
# update these mappings for further transpile:
# 1. param_var_mapping: param var name -> [splited params vars]
# 2. grad_var_mapping: grad var name -> [splited grads vars]
# 3. grad_param_mapping: grad.blockx -> param.blockx
# 4. param_grad_ep_mapping: ep -> {"params": [], "grads": []}
param_list = []
grad_list = []
param_grad_set = set()
for p, g in self.params_grads:
# skip parameter marked not trainable
if type(p) == Parameter and p.trainable == False:
continue
if p.name not in param_grad_set:
param_list.append(p)
param_grad_set.add(p.name)
if g.name not in param_grad_set:
grad_list.append(g)
param_grad_set.add(g.name)
param_list, grad_list = self._update_dist_lookup_table_vars(
param_list, grad_list, self.params_grads)
if slice_var_up:
# when we slice var up into blocks, we will slice the var according to
# pserver services' count. A pserver may have two or more listening ports.
grad_blocks = slice_variable(grad_list, len(self.pserver_endpoints))
param_blocks = slice_variable(param_list,
len(self.pserver_endpoints))
else:
# when we do NOT slice var up into blocks, we will always slice params
# grads into one block.
grad_blocks = slice_variable(grad_list, 1)
param_blocks = slice_variable(param_list, 1)
assert (len(grad_blocks) == len(param_blocks))
# origin_varname -> [splited_var]
self.param_var_mapping = self._create_vars_from_blocklist(
self.origin_program, param_blocks)
self.grad_var_mapping = self._create_vars_from_blocklist(
self.origin_program,
grad_blocks,
add_trainer_suffix=self.trainer_num > 1)
self.grad_param_mapping = dict()
for g, p in zip(grad_blocks, param_blocks):
g_name, g_bid, _ = g.split(":")
p_name, p_bid, _ = p.split(":")
self.grad_param_mapping[self.grad_var_mapping[g_name][int(g_bid)]] = \
self.param_var_mapping[p_name][int(p_bid)]
# create mapping of endpoint -> split var to create pserver side program
self.param_grad_ep_mapping = dict()
[
self.param_grad_ep_mapping.update({
ep: {
"params": [],
"grads": []
}
}) for ep in self.pserver_endpoints
]
def transpile(self, def transpile(self,
trainer_id, trainer_id,
...@@ -250,20 +154,20 @@ class DistributeTranspiler: ...@@ -250,20 +154,20 @@ class DistributeTranspiler:
split_method=RoundRobin, split_method=RoundRobin,
sync_mode=True): sync_mode=True):
""" """
:param trainer_id: one unique id for each trainer in a job. Run the transpiler.
:type trainer_id: int
:param program: program to transpile, default is default_main_program Args:
:type program: Program trainer_id (int): id for current trainer worker, if you have
:param pservers: parameter server endpoints like "m1:6174,m2:6174" n workers, the id may range from 0 ~ n-1
:type pservers: string program (Program|None): program to transpile,
:param trainers: total number of workers/trainers in the job default is fluid.default_main_program().
:type trainers: int pservers (str): comma separated ip:port string for the pserver
:param split_method: A function to determin how to split variables list.
to different servers equally. trainers (int): number of trainers in the distributed job.
:type split_method: function slice_var_up (bool): Do Tensor slice for pservers, default is True.
:type sync_mode: boolean default True split_method (PSDispatcher): RoundRobin or HashName can be used
:param sync_mode: if sync_mode is set True, it means that dist transpiler try to choose the best method to balance loads for pservers.
will transpile the program into sync_mode pserver and trainer program. sync_mode (bool): Do sync training or not, default is True.
""" """
assert (split_method.__bases__[0] == PSDispatcher) assert (split_method.__bases__[0] == PSDispatcher)
if program is None: if program is None:
...@@ -390,6 +294,12 @@ class DistributeTranspiler: ...@@ -390,6 +294,12 @@ class DistributeTranspiler:
self._split_table_grad_and_add_send_vars(program, pserver_endpoints) self._split_table_grad_and_add_send_vars(program, pserver_endpoints)
def get_trainer_program(self): def get_trainer_program(self):
"""
Get transpiled trainer side program.
Returns:
Program: trainer side program.
"""
# remove optimize ops and add a send op to main_program # remove optimize ops and add a send op to main_program
delete_ops(self.origin_program.global_block(), self.optimize_ops) delete_ops(self.origin_program.global_block(), self.optimize_ops)
# FIXME(typhoonzero): serialize once will fix error occurs when clone. # FIXME(typhoonzero): serialize once will fix error occurs when clone.
...@@ -398,12 +308,19 @@ class DistributeTranspiler: ...@@ -398,12 +308,19 @@ class DistributeTranspiler:
def get_pserver_program(self, endpoint): def get_pserver_program(self, endpoint):
""" """
Get pserver side program using the endpoint. Get parameter server side program.
TODO(panyx0718): Revisit this assumption. what if #blocks > #pservers.
NOTE: assume blocks of the same variable is not distributed Args:
on the same pserver, only change param/grad varnames for endpoint (str): current parameter server endpoint.
trainers to fetch.
Returns:
Program: the program for current parameter server to run.
""" """
# TODO(panyx0718): Revisit this assumption. what if #blocks > #pservers.
# NOTE: assume blocks of the same variable is not distributed
# on the same pserver, only change param/grad varnames for
# trainers to fetch.
# step1 # step1
pserver_program = Program() pserver_program = Program()
# step2: Create vars to receive vars at parameter servers. # step2: Create vars to receive vars at parameter servers.
...@@ -556,6 +473,14 @@ class DistributeTranspiler: ...@@ -556,6 +473,14 @@ class DistributeTranspiler:
Get startup program for current parameter server. Get startup program for current parameter server.
Modify operator input variables if there are variables that Modify operator input variables if there are variables that
were split to several blocks. were split to several blocks.
Args:
endpoint (str): current pserver endpoint.
pserver_program (Program): call get_pserver_program first and
pass the result here.
Returns:
Program: parameter server side startup program.
""" """
s_prog = Program() s_prog = Program()
orig_s_prog = default_startup_program() orig_s_prog = default_startup_program()
...@@ -607,6 +532,129 @@ class DistributeTranspiler: ...@@ -607,6 +532,129 @@ class DistributeTranspiler:
# ====================== private transpiler functions ===================== # ====================== private transpiler functions =====================
def _has_distributed_lookup_table(self):
# process lookup_table_op
# 1. check all lookup_table_op is distributed
# 2. check all lookup_table_op share the same table.
distributed_lookup_table_ops = []
# support only one distributed_lookup_table now
self.table_name = None
for op in self.origin_program.global_block().ops:
if op.type == LOOKUP_TABLE_TYPE:
if op.attrs['is_distributed'] is True:
if self.table_name is None:
self.table_name = op.input("W")[0]
if self.table_name != op.input("W")[0]:
raise RuntimeError("all distributed lookup_table_ops"
" should have only one table")
distributed_lookup_table_ops.append(op)
else:
if self.table_name is not None:
assert op.input("W")[0] != self.table_name
return len(distributed_lookup_table_ops) > 0
def _update_dist_lookup_table_vars(self, param_list, grad_list,
params_grads):
# TODO(wuyi): put find a way to put dist lookup table stuff all together.
# update self.table_param_grad and self.trainer_side_table_grad_list
program = self.origin_program
if self.has_distributed_lookup_table:
param_list = [
param for param in param_list if param.name != self.table_name
]
grad_list = [
grad for grad in grad_list
if grad.name != grad_var_name(self.table_name)
]
self.table_param_grad = [
param_grad for param_grad in params_grads
if param_grad[0].name == self.table_name
][0]
table_grad_var = self.table_param_grad[1]
if self.sync_mode:
self.trainer_side_table_grad_list = [
program.global_block().create_var(
name="%s.trainer_%d.pserver_%d" %
(table_grad_var.name, self.trainer_id, index),
type=table_grad_var.type,
shape=table_grad_var.shape,
dtype=table_grad_var.dtype)
for index in range(len(self.pserver_endpoints))
]
else:
self.trainer_side_table_grad_list = [
program.global_block().create_var(
name="%s.pserver_%d" % (table_grad_var.name, index),
type=table_grad_var.type,
shape=table_grad_var.shape,
dtype=table_grad_var.dtype)
for index in range(len(self.pserver_endpoints))
]
return param_list, grad_list
def _init_splited_vars(self, slice_var_up):
# update these mappings for further transpile:
# 1. param_var_mapping: param var name -> [splited params vars]
# 2. grad_var_mapping: grad var name -> [splited grads vars]
# 3. grad_param_mapping: grad.blockx -> param.blockx
# 4. param_grad_ep_mapping: ep -> {"params": [], "grads": []}
param_list = []
grad_list = []
param_grad_set = set()
for p, g in self.params_grads:
# skip parameter marked not trainable
if type(p) == Parameter and p.trainable == False:
continue
if p.name not in param_grad_set:
param_list.append(p)
param_grad_set.add(p.name)
if g.name not in param_grad_set:
grad_list.append(g)
param_grad_set.add(g.name)
param_list, grad_list = self._update_dist_lookup_table_vars(
param_list, grad_list, self.params_grads)
if slice_var_up:
# when we slice var up into blocks, we will slice the var according to
# pserver services' count. A pserver may have two or more listening ports.
grad_blocks = slice_variable(grad_list, len(self.pserver_endpoints))
param_blocks = slice_variable(param_list,
len(self.pserver_endpoints))
else:
# when we do NOT slice var up into blocks, we will always slice params
# grads into one block.
grad_blocks = slice_variable(grad_list, 1)
param_blocks = slice_variable(param_list, 1)
assert (len(grad_blocks) == len(param_blocks))
# origin_varname -> [splited_var]
self.param_var_mapping = self._create_vars_from_blocklist(
self.origin_program, param_blocks)
self.grad_var_mapping = self._create_vars_from_blocklist(
self.origin_program,
grad_blocks,
add_trainer_suffix=self.trainer_num > 1)
self.grad_param_mapping = dict()
for g, p in zip(grad_blocks, param_blocks):
g_name, g_bid, _ = g.split(":")
p_name, p_bid, _ = p.split(":")
self.grad_param_mapping[self.grad_var_mapping[g_name][int(g_bid)]] = \
self.param_var_mapping[p_name][int(p_bid)]
# create mapping of endpoint -> split var to create pserver side program
self.param_grad_ep_mapping = dict()
[
self.param_grad_ep_mapping.update({
ep: {
"params": [],
"grads": []
}
}) for ep in self.pserver_endpoints
]
# transpiler function for dis lookup_table # transpiler function for dis lookup_table
def _replace_lookup_table_op_with_prefetch(self, program, def _replace_lookup_table_op_with_prefetch(self, program,
pserver_endpoints): pserver_endpoints):
......
...@@ -41,7 +41,11 @@ class PSDispatcher(object): ...@@ -41,7 +41,11 @@ class PSDispatcher(object):
class HashName(PSDispatcher): class HashName(PSDispatcher):
""" """
Hash variable names to several endpoints Hash variable names to several endpoints using python
"hash()" function.
Args:
pserver_endpoints (list): list of endpoint(ip:port).
""" """
def __init__(self, pserver_endpoints): def __init__(self, pserver_endpoints):
...@@ -61,7 +65,11 @@ class HashName(PSDispatcher): ...@@ -61,7 +65,11 @@ class HashName(PSDispatcher):
class RoundRobin(PSDispatcher): class RoundRobin(PSDispatcher):
""" """
Distribute variables to serveral endpoints. Distribute variables to serveral endpoints using
RondRobin<https://en.wikipedia.org/wiki/Round-robin_scheduling> method.
Args:
pserver_endpoints (list): list of endpoint(ip:port).
""" """
def __init__(self, pserver_endpoints): def __init__(self, pserver_endpoints):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册