提交 560d960b 编写于 作者: X xuwei06

Fix a minor bug for distributed_spliter.round_robin

Also fixed typo and comments.
上级 b2a1c9e8
......@@ -17,7 +17,7 @@ import framework
from framework import Program, default_main_program, default_startup_program, Parameter, Variable
import optimizer
from layer_helper import LayerHelper
from distributed_spliter import *
import distributed_splitter as splitter
import math
from . import core
import debuger
......@@ -138,7 +138,7 @@ class DistributeTranspiler:
program=None,
pservers="127.0.0.1:6174",
trainers=1,
split_method=round_robin):
split_method=splitter.round_robin):
"""
Transpile the program to distributed data-parallelism programs.
The main_program will be transformed to use a remote parameter server
......
......@@ -17,8 +17,10 @@ def hash_name(varlist, pserver_endpoints):
"""
hash variable names to several endpoints.
:param varlist: a list of Variables
:return: a map of pserver endpoint -> varname
Args:
varlist(list): a list of Variables
Returns(dict): a map of pserver endpoint -> varname
"""
def _hash_block(block_str, total):
......@@ -34,9 +36,14 @@ def hash_name(varlist, pserver_endpoints):
def round_robin(varlist, pserver_endpoints):
"""
distribute variables to several endpoints.
Distribute variables to several endpoints.
Args:
varlist(list): a list of variables
pserver_endpoints(list): a list of pserver endpoints
Returns(list[int]): the endpoint for each variable
"""
assert (len(varlist) > len(pserver_endpoints))
assert (len(varlist) >= len(pserver_endpoints))
eplist = []
pserver_idx = 0
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册