提交 b33ea7be 编写于 作者: M minqiyang

1. change the variable name from align_var_to_block to slice_var_up

2. replace split_method with slice_var_up in func init_splited_variables
上级 9d92dcea
...@@ -72,7 +72,7 @@ class TestSimpleDistTranspiler(TranspilerTest): ...@@ -72,7 +72,7 @@ class TestSimpleDistTranspiler(TranspilerTest):
program=main, program=main,
pservers=self.pserver_eps, pservers=self.pserver_eps,
trainers=self.trainers, trainers=self.trainers,
align_var_to_block=False) slice_var_up=False)
return t return t
......
...@@ -71,7 +71,7 @@ def same_or_split_var(p_name, var_name): ...@@ -71,7 +71,7 @@ def same_or_split_var(p_name, var_name):
return p_name == var_name or p_name.startswith(var_name + ".block") return p_name == var_name or p_name.startswith(var_name + ".block")
def split_variable(var_list, service_count, min_block_size=8192): def slice_variable(var_list, slice_count, min_block_size=8192):
""" """
We may need to split dense tensor to one or more blocks and put We may need to split dense tensor to one or more blocks and put
them equally onto parameter server. One block is a sub-tensor them equally onto parameter server. One block is a sub-tensor
...@@ -83,8 +83,8 @@ def split_variable(var_list, service_count, min_block_size=8192): ...@@ -83,8 +83,8 @@ def split_variable(var_list, service_count, min_block_size=8192):
Args: Args:
var_list (list): List of variables. var_list (list): List of variables.
service_count (int): Numel of pserver services. A pserver may have two slice_count (int): Numel of count that variables will be sliced, which
or more listening ports. could be the pserver services' count.
min_block_size (int): Minimum splitted block size. min_block_size (int): Minimum splitted block size.
Returns: Returns:
blocks (list[(varname, block_id, current_block_size)]): A list blocks (list[(varname, block_id, current_block_size)]): A list
...@@ -92,12 +92,12 @@ def split_variable(var_list, service_count, min_block_size=8192): ...@@ -92,12 +92,12 @@ def split_variable(var_list, service_count, min_block_size=8192):
""" """
blocks = [] blocks = []
for var in var_list: for var in var_list:
split_count = service_count split_count = slice_count
var_numel = reduce(lambda x, y: x * y, var.shape) var_numel = reduce(lambda x, y: x * y, var.shape)
max_pserver_count = int(math.floor(var_numel / float(min_block_size))) max_pserver_count = int(math.floor(var_numel / float(min_block_size)))
if max_pserver_count == 0: if max_pserver_count == 0:
max_pserver_count = 1 max_pserver_count = 1
if max_pserver_count < service_count: if max_pserver_count < slice_count:
split_count = max_pserver_count split_count = max_pserver_count
block_size = int(math.ceil(var_numel / float(split_count))) block_size = int(math.ceil(var_numel / float(split_count)))
...@@ -178,7 +178,7 @@ class DistributeTranspiler: ...@@ -178,7 +178,7 @@ class DistributeTranspiler:
for index in range(len(self.pserver_endpoints)) for index in range(len(self.pserver_endpoints))
] ]
def _init_splited_vars(self, split_method, align_var_to_block=True): def _init_splited_vars(self, slice_var_up):
# update these mappings for further transpile: # update these mappings for further transpile:
# 1. param_var_mapping: param var name -> [splited params vars] # 1. param_var_mapping: param var name -> [splited params vars]
# 2. grad_var_mapping: grad var name -> [splited grads vars] # 2. grad_var_mapping: grad var name -> [splited grads vars]
...@@ -197,16 +197,19 @@ class DistributeTranspiler: ...@@ -197,16 +197,19 @@ class DistributeTranspiler:
self._update_dist_lookup_table_vars(param_list, grad_list, self._update_dist_lookup_table_vars(param_list, grad_list,
self.params_grads) self.params_grads)
if align_var_to_block: if slice_var_up:
grad_blocks = split_variable(grad_list, len(self.pserver_endpoints)) # when we slice var up into blocks, we will slice the var according to
param_blocks = split_variable(param_list, # pserver services' count. A pserver may have two or more listening ports.
grad_blocks = slice_variable(grad_list, len(self.pserver_endpoints))
param_blocks = slice_variable(param_list,
len(self.pserver_endpoints)) len(self.pserver_endpoints))
else: else:
# when we do NOT align var to block, we will always split params # when we do NOT slice var up into blocks, we will always slice params
# grads into one block. # grads into one block.
grad_blocks = split_variable(grad_list, 1) grad_blocks = slice_variable(grad_list, 1)
param_blocks = split_variable(param_list, 1) param_blocks = slice_variable(param_list, 1)
assert (len(grad_blocks) == len(param_blocks)) assert (len(grad_blocks) == len(param_blocks))
# origin_varname -> [splited_var] # origin_varname -> [splited_var]
self.param_var_mapping = self._create_vars_from_blocklist( self.param_var_mapping = self._create_vars_from_blocklist(
self.origin_program, param_blocks) self.origin_program, param_blocks)
...@@ -237,7 +240,7 @@ class DistributeTranspiler: ...@@ -237,7 +240,7 @@ class DistributeTranspiler:
program=None, program=None,
pservers="127.0.0.1:6174", pservers="127.0.0.1:6174",
trainers=1, trainers=1,
align_var_to_block=True, slice_var_up=True,
split_method=RoundRobin, split_method=RoundRobin,
sync_mode=True): sync_mode=True):
""" """
...@@ -271,7 +274,7 @@ class DistributeTranspiler: ...@@ -271,7 +274,7 @@ class DistributeTranspiler:
self.has_distributed_lookup_table = self._has_distributed_lookup_table() self.has_distributed_lookup_table = self._has_distributed_lookup_table()
# split and create vars, then put splited vars in dicts for later use. # split and create vars, then put splited vars in dicts for later use.
self._init_splited_vars(split_method, align_var_to_block) self._init_splited_vars(slice_var_up)
# step 3.1: insert send op to send gradient vars to parameter servers # step 3.1: insert send op to send gradient vars to parameter servers
ps_dispatcher.reset() ps_dispatcher.reset()
...@@ -283,13 +286,13 @@ class DistributeTranspiler: ...@@ -283,13 +286,13 @@ class DistributeTranspiler:
# fc_b@GRAD_trainer_0, fc_b@GRAD_trainer_1 --> pserver2 # fc_b@GRAD_trainer_0, fc_b@GRAD_trainer_1 --> pserver2
# shuffle the map will avoid the uneven distribution above # shuffle the map will avoid the uneven distribution above
grad_var_mapping_items = self.grad_var_mapping.items() grad_var_mapping_items = self.grad_var_mapping.items()
if not align_var_to_block: if not slice_var_up:
np.random.shuffle(grad_var_mapping_items) np.random.shuffle(grad_var_mapping_items)
for orig_varname, splited_vars in grad_var_mapping_items: for orig_varname, splited_vars in grad_var_mapping_items:
eplist = ps_dispatcher.dispatch(splited_vars) eplist = ps_dispatcher.dispatch(splited_vars)
if not align_var_to_block: if not slice_var_up:
assert (len(splited_vars) == 1) assert (len(splited_vars) == 1)
if len(splited_vars) == 1: if len(splited_vars) == 1:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册