提交 560d960b 编写于 作者: X xuwei06

Fix a minor bug for distributed_spliter.round_robin

Also fixed typo and comments.
上级 b2a1c9e8
...@@ -17,7 +17,7 @@ import framework ...@@ -17,7 +17,7 @@ import framework
from framework import Program, default_main_program, default_startup_program, Parameter, Variable from framework import Program, default_main_program, default_startup_program, Parameter, Variable
import optimizer import optimizer
from layer_helper import LayerHelper from layer_helper import LayerHelper
from distributed_spliter import * import distributed_splitter as splitter
import math import math
from . import core from . import core
import debuger import debuger
...@@ -36,7 +36,7 @@ class VarBlock: ...@@ -36,7 +36,7 @@ class VarBlock:
class UnionFind(object): class UnionFind(object):
""" Union-find data struct. """ Union-find data struct.
Union-find is a data struct that keeps track of a set of elements partitioned Union-find is a data struct that keeps track of a set of elements partitioned
into a number of disjoint (non-overlapping) subsets. into a number of disjoint (non-overlapping) subsets.
...@@ -138,7 +138,7 @@ class DistributeTranspiler: ...@@ -138,7 +138,7 @@ class DistributeTranspiler:
program=None, program=None,
pservers="127.0.0.1:6174", pservers="127.0.0.1:6174",
trainers=1, trainers=1,
split_method=round_robin): split_method=splitter.round_robin):
""" """
Transpile the program to distributed data-parallelism programs. Transpile the program to distributed data-parallelism programs.
The main_program will be transformed to use a remote parameter server The main_program will be transformed to use a remote parameter server
...@@ -303,7 +303,7 @@ class DistributeTranspiler: ...@@ -303,7 +303,7 @@ class DistributeTranspiler:
# If two ops are connected, we could add these two ops # If two ops are connected, we could add these two ops
# into one set. # into one set.
ufind = self._create_ufind(self.optimize_ops) ufind = self._create_ufind(self.optimize_ops)
# step 4.2 # step 4.2
# Iterate through the ops and append optimize op which # Iterate through the ops and append optimize op which
# located on current pserver # located on current pserver
opt_op_on_pserver = [] opt_op_on_pserver = []
...@@ -312,7 +312,7 @@ class DistributeTranspiler: ...@@ -312,7 +312,7 @@ class DistributeTranspiler:
opt_op_on_pserver.append(op) opt_op_on_pserver.append(op)
# step 4.3 # step 4.3
# Iterate through the ops, and if an op and the optimize ops # Iterate through the ops, and if an op and the optimize ops
# which located on current pserver are in one set, then # which located on current pserver are in one set, then
# append it into the sub program. # append it into the sub program.
# We try to put optimization program run parallelly, assume # We try to put optimization program run parallelly, assume
...@@ -752,7 +752,7 @@ class DistributeTranspiler: ...@@ -752,7 +752,7 @@ class DistributeTranspiler:
def _is_opt_op(self, op): def _is_opt_op(self, op):
# NOTE: It's a HACK implement. # NOTE: It's a HACK implement.
# optimize op: SGDOptimize, MomentumOptimizer, AdamOptimizer and etc... # optimize op: SGDOptimize, MomentumOptimizer, AdamOptimizer and etc...
if "Param" in op.input_names and \ if "Param" in op.input_names and \
"LearningRate" in op.input_names: "LearningRate" in op.input_names:
return True return True
......
...@@ -17,8 +17,10 @@ def hash_name(varlist, pserver_endpoints): ...@@ -17,8 +17,10 @@ def hash_name(varlist, pserver_endpoints):
""" """
hash variable names to several endpoints. hash variable names to several endpoints.
:param varlist: a list of Variables Args:
:return: a map of pserver endpoint -> varname varlist(list): a list of Variables
Returns(dict): a map of pserver endpoint -> varname
""" """
def _hash_block(block_str, total): def _hash_block(block_str, total):
...@@ -34,9 +36,14 @@ def hash_name(varlist, pserver_endpoints): ...@@ -34,9 +36,14 @@ def hash_name(varlist, pserver_endpoints):
def round_robin(varlist, pserver_endpoints): def round_robin(varlist, pserver_endpoints):
""" """
distribute variables to several endpoints. Distribute variables to several endpoints.
Args:
varlist(list): a list of variables
pserver_endpoints(list): a list of pserver endpoints
Returns(list[int]): the endpoint for each variable
""" """
assert (len(varlist) > len(pserver_endpoints)) assert (len(varlist) >= len(pserver_endpoints))
eplist = [] eplist = []
pserver_idx = 0 pserver_idx = 0
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册