From 560d960b2709c38922e913e7bd00e5362d75e1f2 Mon Sep 17 00:00:00 2001 From: xuwei06 Date: Fri, 6 Apr 2018 15:42:52 -0700 Subject: [PATCH] Fix a minor bug for distributed_spliter.round_robin Also fixed typo and comments. --- python/paddle/fluid/distribute_transpiler.py | 12 ++++++------ ...ributed_spliter.py => distributed_splitter.py} | 15 +++++++++++---- 2 files changed, 17 insertions(+), 10 deletions(-) rename python/paddle/fluid/{distributed_spliter.py => distributed_splitter.py} (78%) diff --git a/python/paddle/fluid/distribute_transpiler.py b/python/paddle/fluid/distribute_transpiler.py index 31bedb592f1..7a2a81be9f2 100644 --- a/python/paddle/fluid/distribute_transpiler.py +++ b/python/paddle/fluid/distribute_transpiler.py @@ -17,7 +17,7 @@ import framework from framework import Program, default_main_program, default_startup_program, Parameter, Variable import optimizer from layer_helper import LayerHelper -from distributed_spliter import * +import distributed_splitter as splitter import math from . import core import debuger @@ -36,7 +36,7 @@ class VarBlock: class UnionFind(object): """ Union-find data struct. - + Union-find is a data struct that keeps track of a set of elements partitioned into a number of disjoint (non-overlapping) subsets. @@ -138,7 +138,7 @@ class DistributeTranspiler: program=None, pservers="127.0.0.1:6174", trainers=1, - split_method=round_robin): + split_method=splitter.round_robin): """ Transpile the program to distributed data-parallelism programs. The main_program will be transformed to use a remote parameter server @@ -303,7 +303,7 @@ class DistributeTranspiler: # If two ops are connected, we could add these two ops # into one set. ufind = self._create_ufind(self.optimize_ops) - # step 4.2 + # step 4.2 # Iterate through the ops and append optimize op which # located on current pserver opt_op_on_pserver = [] @@ -312,7 +312,7 @@ class DistributeTranspiler: opt_op_on_pserver.append(op) # step 4.3 # Iterate through the ops, and if an op and the optimize ops - # which located on current pserver are in one set, then + # which located on current pserver are in one set, then # append it into the sub program. # We try to put optimization program run parallelly, assume @@ -752,7 +752,7 @@ class DistributeTranspiler: def _is_opt_op(self, op): # NOTE: It's a HACK implement. - # optimize op: SGDOptimize, MomentumOptimizer, AdamOptimizer and etc... + # optimize op: SGDOptimize, MomentumOptimizer, AdamOptimizer and etc... if "Param" in op.input_names and \ "LearningRate" in op.input_names: return True diff --git a/python/paddle/fluid/distributed_spliter.py b/python/paddle/fluid/distributed_splitter.py similarity index 78% rename from python/paddle/fluid/distributed_spliter.py rename to python/paddle/fluid/distributed_splitter.py index d288b27ba00..060c1df8ad2 100644 --- a/python/paddle/fluid/distributed_spliter.py +++ b/python/paddle/fluid/distributed_splitter.py @@ -17,8 +17,10 @@ def hash_name(varlist, pserver_endpoints): """ hash variable names to several endpoints. - :param varlist: a list of Variables - :return: a map of pserver endpoint -> varname + Args: + varlist(list): a list of Variables + + Returns(dict): a map of pserver endpoint -> varname """ def _hash_block(block_str, total): @@ -34,9 +36,14 @@ def hash_name(varlist, pserver_endpoints): def round_robin(varlist, pserver_endpoints): """ - distribute variables to several endpoints. + Distribute variables to several endpoints. + Args: + varlist(list): a list of variables + pserver_endpoints(list): a list of pserver endpoints + + Returns(list[int]): the endpoint for each variable """ - assert (len(varlist) > len(pserver_endpoints)) + assert (len(varlist) >= len(pserver_endpoints)) eplist = [] pserver_idx = 0 -- GitLab