提交 34f28185 编写于 作者: Q qiaolongfei

distribute transpiler support async config

上级 42a15a43
......@@ -358,7 +358,7 @@ class DistributeTranspiler:
type=v.type,
dtype=v.dtype,
shape=v.shape)
if self.trainer_num > 1:
if self.sync_mode and self.trainer_num > 1:
for trainer_id in xrange(self.trainer_num):
var = pserver_program.global_block().create_var(
name="%s.trainer_%d" % (orig_var_name, trainer_id),
......@@ -688,17 +688,6 @@ class DistributeTranspiler:
self.table_name)],
persistable=False)
# create grad vars in pserver program
table_grad_var = self.table_param_grad[1]
table_grad_list = [
pserver_program.global_block().create_var(
name="%s.trainer_%d.pserver_%d" %
(table_grad_var.name, index, pserver_index),
type=table_grad_var.type,
shape=table_grad_var.shape,
dtype=table_grad_var.dtype) for index in range(self.trainer_num)
]
# create table optimize block in pserver program
table_opt_op = [
op for op in self.optimize_ops
......@@ -708,11 +697,24 @@ class DistributeTranspiler:
# only support sgd now
assert table_opt_op.type == "sgd"
# append sum op for table_grad_list
table_opt_block.append_op(
type="sum",
inputs={"X": table_grad_list},
outputs={"Out": [grad_var]})
if self.sync_mode:
# create grad vars in pserver program
table_grad_var = self.table_param_grad[1]
table_grad_list = [
pserver_program.global_block().create_var(
name="%s.trainer_%d.pserver_%d" %
(table_grad_var.name, index, pserver_index),
type=table_grad_var.type,
shape=table_grad_var.shape,
dtype=table_grad_var.dtype)
for index in range(self.trainer_num)
]
# append sum op for table_grad_list
table_opt_block.append_op(
type="sum",
inputs={"X": table_grad_list},
outputs={"Out": [grad_var]})
lr_var = pserver_program.global_block().vars[table_opt_op.input(
"LearningRate")[0]]
......@@ -751,7 +753,7 @@ class DistributeTranspiler:
for varname, splited in block_map.iteritems():
orig_var = program.global_block().var(varname)
if len(splited) == 1:
if add_trainer_suffix:
if self.sync_mode and add_trainer_suffix:
new_var_name = "%s.trainer_%d" % \
(orig_var.name, self.trainer_id)
program.global_block().rename_var(varname, new_var_name)
......@@ -775,7 +777,7 @@ class DistributeTranspiler:
if len(orig_shape) >= 2:
splited_shape.extend(orig_shape[1:])
new_var_name = ""
if add_trainer_suffix:
if self.sync_mode and add_trainer_suffix:
new_var_name = "%s.block%d.trainer_%d" % \
(varname, i, self.trainer_id)
else:
......@@ -907,7 +909,7 @@ class DistributeTranspiler:
pserver_block.vars[self._orig_varname(grad_block.name)]
grad_to_block_id.append(merged_var.name + ":" + str(
optimize_block.idx))
if self.trainer_num > 1:
if self.sync_mode and self.trainer_num > 1:
vars2merge = []
for i in xrange(self.trainer_num):
per_trainer_name = "%s.trainer_%d" % \
......@@ -925,6 +927,7 @@ class DistributeTranspiler:
inputs={"X": merged_var},
outputs={"Out": merged_var},
attrs={"scale": 1.0 / float(self.trainer_num)})
new_inputs[key] = merged_var
elif key == "Param":
# param is already created on global program
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册