提交 9cd59990 编写于 作者: M minqiyang

Fix dist transpiler unordered dict issue

上级 be6ecec4
...@@ -46,6 +46,7 @@ class TranspilerTest(unittest.TestCase): ...@@ -46,6 +46,7 @@ class TranspilerTest(unittest.TestCase):
def get_main_program(self): def get_main_program(self):
main = fluid.Program() main = fluid.Program()
main.random_seed = 1
with fluid.program_guard(main): with fluid.program_guard(main):
self.net_conf() self.net_conf()
self.origin_prog = main.clone() self.origin_prog = main.clone()
......
...@@ -31,6 +31,7 @@ Steps to transpile pserver: ...@@ -31,6 +31,7 @@ Steps to transpile pserver:
import math import math
import random import random
import numpy as np import numpy as np
import collections
from .ps_dispatcher import RoundRobin, HashName, PSDispatcher from .ps_dispatcher import RoundRobin, HashName, PSDispatcher
from .. import core, framework from .. import core, framework
...@@ -218,8 +219,9 @@ class DistributeTranspiler(object): ...@@ -218,8 +219,9 @@ class DistributeTranspiler(object):
# fc_b@GRAD_trainer_0, fc_b@GRAD_trainer_1 --> pserver2 # fc_b@GRAD_trainer_0, fc_b@GRAD_trainer_1 --> pserver2
# shuffle the map will avoid the uneven distribution above # shuffle the map will avoid the uneven distribution above
grad_var_mapping_items = list(self.grad_var_mapping.items()) grad_var_mapping_items = list(self.grad_var_mapping.items())
if not self.config.slice_var_up: if not self.config.slice_var_up:
random.seed(self.trainer_num) random.seed(self.origin_program.random_seed)
random.shuffle(grad_var_mapping_items) random.shuffle(grad_var_mapping_items)
for orig_varname, splited_vars in grad_var_mapping_items: for orig_varname, splited_vars in grad_var_mapping_items:
...@@ -557,14 +559,14 @@ class DistributeTranspiler(object): ...@@ -557,14 +559,14 @@ class DistributeTranspiler(object):
# 1. create vars in pserver program to startup program # 1. create vars in pserver program to startup program
pserver_vars = pserver_program.global_block().vars pserver_vars = pserver_program.global_block().vars
created_var_map = dict() created_var_map = collections.OrderedDict()
for _, var in list(pserver_vars.items()): for _, var in list(pserver_vars.items()):
tmpvar = s_prog.global_block()._clone_variable(var) tmpvar = s_prog.global_block()._clone_variable(var)
created_var_map[var.name] = tmpvar created_var_map[var.name] = tmpvar
# 2. rename op outputs # 2. rename op outputs
for op in orig_s_prog.global_block().ops: for op in orig_s_prog.global_block().ops:
new_outputs = dict() new_outputs = collections.OrderedDict()
# do not append startup op if var is not on this pserver # do not append startup op if var is not on this pserver
op_on_pserver = False op_on_pserver = False
for key in op.output_names: for key in op.output_names:
...@@ -703,7 +705,7 @@ class DistributeTranspiler(object): ...@@ -703,7 +705,7 @@ class DistributeTranspiler(object):
self.origin_program, self.origin_program,
grad_blocks, grad_blocks,
add_trainer_suffix=self.trainer_num > 1) add_trainer_suffix=self.trainer_num > 1)
self.grad_param_mapping = dict() self.grad_param_mapping = collections.OrderedDict()
for g, p in zip(grad_blocks, param_blocks): for g, p in zip(grad_blocks, param_blocks):
g_name, g_bid, _ = g.split(":") g_name, g_bid, _ = g.split(":")
p_name, p_bid, _ = p.split(":") p_name, p_bid, _ = p.split(":")
...@@ -711,7 +713,7 @@ class DistributeTranspiler(object): ...@@ -711,7 +713,7 @@ class DistributeTranspiler(object):
self.param_var_mapping[p_name][int(p_bid)] self.param_var_mapping[p_name][int(p_bid)]
# create mapping of endpoint -> split var to create pserver side program # create mapping of endpoint -> split var to create pserver side program
self.param_grad_ep_mapping = dict() self.param_grad_ep_mapping = collections.OrderedDict()
[ [
self.param_grad_ep_mapping.update({ self.param_grad_ep_mapping.update({
ep: { ep: {
...@@ -981,14 +983,14 @@ class DistributeTranspiler(object): ...@@ -981,14 +983,14 @@ class DistributeTranspiler(object):
block_list (list[(varname, block_id, block_size)]): List of gradient blocks. block_list (list[(varname, block_id, block_size)]): List of gradient blocks.
add_trainer_suffix (Bool): Add trainer suffix to new variable's name if set True. add_trainer_suffix (Bool): Add trainer suffix to new variable's name if set True.
Returns: Returns:
var_mapping (dict(varname->[new_varname_variable])):A dict mapping var_mapping (collections.OrderedDict(varname->[new_varname_variable])):A dict mapping
from original var name to each var split. from original var name to each var split.
""" """
# varname->[(block_id, current_block_size)] # varname->[(block_id, current_block_size)]
block_map = dict() block_map = collections.OrderedDict()
var_mapping = dict() var_mapping = collections.OrderedDict()
for block_str in block_list: for block_str in block_list:
varname, offset, size = block_str.split(":") varname, offset, size = block_str.split(":")
if varname not in block_map: if varname not in block_map:
...@@ -1181,7 +1183,7 @@ class DistributeTranspiler(object): ...@@ -1181,7 +1183,7 @@ class DistributeTranspiler(object):
grad_to_block_id, origin_program, merged_var): grad_to_block_id, origin_program, merged_var):
program = optimize_block.program program = optimize_block.program
pserver_block = program.global_block() pserver_block = program.global_block()
new_inputs = dict() new_inputs = collections.OrderedDict()
# update param/grad shape first, then other inputs like # update param/grad shape first, then other inputs like
# moment can use the updated shape # moment can use the updated shape
for key in opt_op.input_names: for key in opt_op.input_names:
...@@ -1359,7 +1361,7 @@ class DistributeTranspiler(object): ...@@ -1359,7 +1361,7 @@ class DistributeTranspiler(object):
def _get_input_map_from_op(self, varmap, op): def _get_input_map_from_op(self, varmap, op):
"""Returns a dict from op input name to the vars in varmap.""" """Returns a dict from op input name to the vars in varmap."""
iomap = dict() iomap = collections.OrderedDict()
for key in op.input_names: for key in op.input_names:
vars = [] vars = []
for varname in op.input(key): for varname in op.input(key):
...@@ -1372,7 +1374,7 @@ class DistributeTranspiler(object): ...@@ -1372,7 +1374,7 @@ class DistributeTranspiler(object):
def _get_output_map_from_op(self, varmap, op): def _get_output_map_from_op(self, varmap, op):
"""Returns a dict from op output name to the vars in varmap.""" """Returns a dict from op output name to the vars in varmap."""
iomap = dict() iomap = collections.OrderedDict()
for key in op.output_names: for key in op.output_names:
vars = [] vars = []
for varname in op.output(key): for varname in op.output(key):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册