提交 560d960b 编写于 作者: X xuwei06

Fix a minor bug for distributed_spliter.round_robin

Also fixed typo and comments.
上级 b2a1c9e8
...@@ -17,7 +17,7 @@ import framework ...@@ -17,7 +17,7 @@ import framework
from framework import Program, default_main_program, default_startup_program, Parameter, Variable from framework import Program, default_main_program, default_startup_program, Parameter, Variable
import optimizer import optimizer
from layer_helper import LayerHelper from layer_helper import LayerHelper
from distributed_spliter import * import distributed_splitter as splitter
import math import math
from . import core from . import core
import debuger import debuger
...@@ -138,7 +138,7 @@ class DistributeTranspiler: ...@@ -138,7 +138,7 @@ class DistributeTranspiler:
program=None, program=None,
pservers="127.0.0.1:6174", pservers="127.0.0.1:6174",
trainers=1, trainers=1,
split_method=round_robin): split_method=splitter.round_robin):
""" """
Transpile the program to distributed data-parallelism programs. Transpile the program to distributed data-parallelism programs.
The main_program will be transformed to use a remote parameter server The main_program will be transformed to use a remote parameter server
......
...@@ -17,8 +17,10 @@ def hash_name(varlist, pserver_endpoints): ...@@ -17,8 +17,10 @@ def hash_name(varlist, pserver_endpoints):
""" """
hash variable names to several endpoints. hash variable names to several endpoints.
:param varlist: a list of Variables Args:
:return: a map of pserver endpoint -> varname varlist(list): a list of Variables
Returns(dict): a map of pserver endpoint -> varname
""" """
def _hash_block(block_str, total): def _hash_block(block_str, total):
...@@ -34,9 +36,14 @@ def hash_name(varlist, pserver_endpoints): ...@@ -34,9 +36,14 @@ def hash_name(varlist, pserver_endpoints):
def round_robin(varlist, pserver_endpoints): def round_robin(varlist, pserver_endpoints):
""" """
distribute variables to several endpoints. Distribute variables to several endpoints.
Args:
varlist(list): a list of variables
pserver_endpoints(list): a list of pserver endpoints
Returns(list[int]): the endpoint for each variable
""" """
assert (len(varlist) > len(pserver_endpoints)) assert (len(varlist) >= len(pserver_endpoints))
eplist = [] eplist = []
pserver_idx = 0 pserver_idx = 0
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册