提交 b33ea7be 编写于 作者: M minqiyang

1. change the variable name from align_var_to_block to slice_var_up

2. replace split_method with slice_var_up in func init_splited_variables
上级 9d92dcea
......@@ -72,7 +72,7 @@ class TestSimpleDistTranspiler(TranspilerTest):
program=main,
pservers=self.pserver_eps,
trainers=self.trainers,
align_var_to_block=False)
slice_var_up=False)
return t
......
......@@ -71,7 +71,7 @@ def same_or_split_var(p_name, var_name):
return p_name == var_name or p_name.startswith(var_name + ".block")
def split_variable(var_list, service_count, min_block_size=8192):
def slice_variable(var_list, slice_count, min_block_size=8192):
"""
We may need to split dense tensor to one or more blocks and put
them equally onto parameter server. One block is a sub-tensor
......@@ -83,8 +83,8 @@ def split_variable(var_list, service_count, min_block_size=8192):
Args:
var_list (list): List of variables.
service_count (int): Numel of pserver services. A pserver may have two
or more listening ports.
slice_count (int): Numel of count that variables will be sliced, which
could be the pserver services' count.
min_block_size (int): Minimum splitted block size.
Returns:
blocks (list[(varname, block_id, current_block_size)]): A list
......@@ -92,12 +92,12 @@ def split_variable(var_list, service_count, min_block_size=8192):
"""
blocks = []
for var in var_list:
split_count = service_count
split_count = slice_count
var_numel = reduce(lambda x, y: x * y, var.shape)
max_pserver_count = int(math.floor(var_numel / float(min_block_size)))
if max_pserver_count == 0:
max_pserver_count = 1
if max_pserver_count < service_count:
if max_pserver_count < slice_count:
split_count = max_pserver_count
block_size = int(math.ceil(var_numel / float(split_count)))
......@@ -178,7 +178,7 @@ class DistributeTranspiler:
for index in range(len(self.pserver_endpoints))
]
def _init_splited_vars(self, split_method, align_var_to_block=True):
def _init_splited_vars(self, slice_var_up):
# update these mappings for further transpile:
# 1. param_var_mapping: param var name -> [splited params vars]
# 2. grad_var_mapping: grad var name -> [splited grads vars]
......@@ -197,16 +197,19 @@ class DistributeTranspiler:
self._update_dist_lookup_table_vars(param_list, grad_list,
self.params_grads)
if align_var_to_block:
grad_blocks = split_variable(grad_list, len(self.pserver_endpoints))
param_blocks = split_variable(param_list,
if slice_var_up:
# when we slice var up into blocks, we will slice the var according to
# pserver services' count. A pserver may have two or more listening ports.
grad_blocks = slice_variable(grad_list, len(self.pserver_endpoints))
param_blocks = slice_variable(param_list,
len(self.pserver_endpoints))
else:
# when we do NOT align var to block, we will always split params
# when we do NOT slice var up into blocks, we will always slice params
# grads into one block.
grad_blocks = split_variable(grad_list, 1)
param_blocks = split_variable(param_list, 1)
grad_blocks = slice_variable(grad_list, 1)
param_blocks = slice_variable(param_list, 1)
assert (len(grad_blocks) == len(param_blocks))
# origin_varname -> [splited_var]
self.param_var_mapping = self._create_vars_from_blocklist(
self.origin_program, param_blocks)
......@@ -237,7 +240,7 @@ class DistributeTranspiler:
program=None,
pservers="127.0.0.1:6174",
trainers=1,
align_var_to_block=True,
slice_var_up=True,
split_method=RoundRobin,
sync_mode=True):
"""
......@@ -271,7 +274,7 @@ class DistributeTranspiler:
self.has_distributed_lookup_table = self._has_distributed_lookup_table()
# split and create vars, then put splited vars in dicts for later use.
self._init_splited_vars(split_method, align_var_to_block)
self._init_splited_vars(slice_var_up)
# step 3.1: insert send op to send gradient vars to parameter servers
ps_dispatcher.reset()
......@@ -283,13 +286,13 @@ class DistributeTranspiler:
# fc_b@GRAD_trainer_0, fc_b@GRAD_trainer_1 --> pserver2
# shuffle the map will avoid the uneven distribution above
grad_var_mapping_items = self.grad_var_mapping.items()
if not align_var_to_block:
if not slice_var_up:
np.random.shuffle(grad_var_mapping_items)
for orig_varname, splited_vars in grad_var_mapping_items:
eplist = ps_dispatcher.dispatch(splited_vars)
if not align_var_to_block:
if not slice_var_up:
assert (len(splited_vars) == 1)
if len(splited_vars) == 1:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册