未验证 提交 69712ef2 编写于 作者: Y Yancey 提交者: GitHub

Merge pull request #8316 from Yancey1989/optimize_block

create optimize block in pserver program
......@@ -347,7 +347,8 @@ class DistributeTranspiler:
j -= 1
return False
def _append_pserver_ops(self, program, pserver_program, opt_op, endpoint):
def _append_pserver_ops(self, optimize_block, opt_op, endpoint):
program = optimize_block.program
new_inputs = dict()
# update param/grad shape first, then other inputs like
# moment can use the updated shape
......@@ -371,11 +372,11 @@ class DistributeTranspiler:
if self.trainers > 1:
vars2merge = self._create_var_for_trainers(
program.global_block(), grad_block, self.trainers)
program.global_block().append_op(
optimize_block.append_op(
type="sum",
inputs={"X": vars2merge},
outputs={"Out": merged_var})
program.global_block().append_op(
optimize_block.append_op(
type="scale",
inputs={"X": merged_var},
outputs={"Out": merged_var},
......@@ -412,25 +413,18 @@ class DistributeTranspiler:
dtype=var.dtype,
shape=new_shape)
new_inputs[key] = tmpvar
# create var in pserver program global block.
# TODO(typhoonzero): put blocks in one program to avoid create two
# variables.
pserver_program.global_block().create_var(
name=var.name,
persistable=var.persistable,
dtype=var.dtype,
shape=new_shape)
# change output's ParamOut variable
outputs = self._get_output_map_from_op(program.global_block(), opt_op)
outputs["ParamOut"] = new_inputs["Param"]
program.global_block().append_op(
optimize_block.append_op(
type=opt_op.type,
inputs=new_inputs,
outputs=outputs,
attrs=opt_op.attrs)
def _append_pserver_non_opt_ops(self, program, pserver_program, opt_op):
def _append_pserver_non_opt_ops(self, optimize_block, opt_op):
program = optimize_block.program
# Append the ops for parameters that do not need to be optimized/updated
inputs = self._get_input_map_from_op(self.program.global_block().vars,
opt_op)
......@@ -440,23 +434,17 @@ class DistributeTranspiler:
else:
varlist = [var]
for var in varlist:
# TODO(typhoonzero): will remove below line later.
if not program.global_block().vars.has_key(var.name):
program.global_block().create_var(
name=var.name,
persistable=var.persistable,
dtype=var.dtype,
shape=var.shape)
if not pserver_program.global_block().vars.has_key(var.name):
pserver_program.global_block().create_var(
name=var.name,
persistable=var.persistable,
dtype=var.dtype,
shape=var.shape)
outputs = self._get_output_map_from_op(self.program.global_block().vars,
opt_op)
program.global_block().append_op(
optimize_block.append_op(
type=opt_op.type,
inputs=inputs,
outputs=outputs,
......@@ -489,7 +477,7 @@ class DistributeTranspiler:
dtype=v.dtype,
shape=v.shape)
# step6
optimize_sub_program = Program()
optimize_block = pserver_program.create_block(0)
# Iterate through the ops and append ops as needed
for idx, opt_op in enumerate(self.optimize_ops):
is_op_on_pserver = self._is_op_on_pserver(endpoint,
......@@ -497,18 +485,17 @@ class DistributeTranspiler:
if not is_op_on_pserver:
continue
if "Grad" in opt_op.desc.input_arg_names():
self._append_pserver_ops(optimize_sub_program, pserver_program,
opt_op, endpoint)
self._append_pserver_ops(optimize_block, opt_op, endpoint)
else:
self._append_pserver_non_opt_ops(optimize_sub_program,
pserver_program, opt_op)
self._append_pserver_non_opt_ops(optimize_block, opt_op)
# Append the listen_and_serv op
pserver_program.global_block().append_op(
type="listen_and_serv",
inputs={},
outputs={},
attrs={
"OptimizeBlock": optimize_sub_program.global_block(),
"OptimizeBlock": optimize_block,
"endpoint": endpoint,
"ParamList": [
p.name
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册