未验证 提交 69712ef2 编写于 作者: Y Yancey 提交者: GitHub

Merge pull request #8316 from Yancey1989/optimize_block

create optimize block in pserver program
...@@ -347,7 +347,8 @@ class DistributeTranspiler: ...@@ -347,7 +347,8 @@ class DistributeTranspiler:
j -= 1 j -= 1
return False return False
def _append_pserver_ops(self, program, pserver_program, opt_op, endpoint): def _append_pserver_ops(self, optimize_block, opt_op, endpoint):
program = optimize_block.program
new_inputs = dict() new_inputs = dict()
# update param/grad shape first, then other inputs like # update param/grad shape first, then other inputs like
# moment can use the updated shape # moment can use the updated shape
...@@ -371,11 +372,11 @@ class DistributeTranspiler: ...@@ -371,11 +372,11 @@ class DistributeTranspiler:
if self.trainers > 1: if self.trainers > 1:
vars2merge = self._create_var_for_trainers( vars2merge = self._create_var_for_trainers(
program.global_block(), grad_block, self.trainers) program.global_block(), grad_block, self.trainers)
program.global_block().append_op( optimize_block.append_op(
type="sum", type="sum",
inputs={"X": vars2merge}, inputs={"X": vars2merge},
outputs={"Out": merged_var}) outputs={"Out": merged_var})
program.global_block().append_op( optimize_block.append_op(
type="scale", type="scale",
inputs={"X": merged_var}, inputs={"X": merged_var},
outputs={"Out": merged_var}, outputs={"Out": merged_var},
...@@ -412,25 +413,18 @@ class DistributeTranspiler: ...@@ -412,25 +413,18 @@ class DistributeTranspiler:
dtype=var.dtype, dtype=var.dtype,
shape=new_shape) shape=new_shape)
new_inputs[key] = tmpvar new_inputs[key] = tmpvar
# create var in pserver program global block.
# TODO(typhoonzero): put blocks in one program to avoid create two
# variables.
pserver_program.global_block().create_var(
name=var.name,
persistable=var.persistable,
dtype=var.dtype,
shape=new_shape)
# change output's ParamOut variable # change output's ParamOut variable
outputs = self._get_output_map_from_op(program.global_block(), opt_op) outputs = self._get_output_map_from_op(program.global_block(), opt_op)
outputs["ParamOut"] = new_inputs["Param"] outputs["ParamOut"] = new_inputs["Param"]
program.global_block().append_op( optimize_block.append_op(
type=opt_op.type, type=opt_op.type,
inputs=new_inputs, inputs=new_inputs,
outputs=outputs, outputs=outputs,
attrs=opt_op.attrs) attrs=opt_op.attrs)
def _append_pserver_non_opt_ops(self, program, pserver_program, opt_op): def _append_pserver_non_opt_ops(self, optimize_block, opt_op):
program = optimize_block.program
# Append the ops for parameters that do not need to be optimized/updated # Append the ops for parameters that do not need to be optimized/updated
inputs = self._get_input_map_from_op(self.program.global_block().vars, inputs = self._get_input_map_from_op(self.program.global_block().vars,
opt_op) opt_op)
...@@ -440,23 +434,17 @@ class DistributeTranspiler: ...@@ -440,23 +434,17 @@ class DistributeTranspiler:
else: else:
varlist = [var] varlist = [var]
for var in varlist: for var in varlist:
# TODO(typhoonzero): will remove below line later. if not program.global_block().vars.has_key(var.name):
program.global_block().create_var( program.global_block().create_var(
name=var.name, name=var.name,
persistable=var.persistable, persistable=var.persistable,
dtype=var.dtype, dtype=var.dtype,
shape=var.shape) shape=var.shape)
if not pserver_program.global_block().vars.has_key(var.name):
pserver_program.global_block().create_var(
name=var.name,
persistable=var.persistable,
dtype=var.dtype,
shape=var.shape)
outputs = self._get_output_map_from_op(self.program.global_block().vars, outputs = self._get_output_map_from_op(self.program.global_block().vars,
opt_op) opt_op)
program.global_block().append_op( optimize_block.append_op(
type=opt_op.type, type=opt_op.type,
inputs=inputs, inputs=inputs,
outputs=outputs, outputs=outputs,
...@@ -489,7 +477,7 @@ class DistributeTranspiler: ...@@ -489,7 +477,7 @@ class DistributeTranspiler:
dtype=v.dtype, dtype=v.dtype,
shape=v.shape) shape=v.shape)
# step6 # step6
optimize_sub_program = Program() optimize_block = pserver_program.create_block(0)
# Iterate through the ops and append ops as needed # Iterate through the ops and append ops as needed
for idx, opt_op in enumerate(self.optimize_ops): for idx, opt_op in enumerate(self.optimize_ops):
is_op_on_pserver = self._is_op_on_pserver(endpoint, is_op_on_pserver = self._is_op_on_pserver(endpoint,
...@@ -497,18 +485,17 @@ class DistributeTranspiler: ...@@ -497,18 +485,17 @@ class DistributeTranspiler:
if not is_op_on_pserver: if not is_op_on_pserver:
continue continue
if "Grad" in opt_op.desc.input_arg_names(): if "Grad" in opt_op.desc.input_arg_names():
self._append_pserver_ops(optimize_sub_program, pserver_program, self._append_pserver_ops(optimize_block, opt_op, endpoint)
opt_op, endpoint)
else: else:
self._append_pserver_non_opt_ops(optimize_sub_program, self._append_pserver_non_opt_ops(optimize_block, opt_op)
pserver_program, opt_op)
# Append the listen_and_serv op # Append the listen_and_serv op
pserver_program.global_block().append_op( pserver_program.global_block().append_op(
type="listen_and_serv", type="listen_and_serv",
inputs={}, inputs={},
outputs={}, outputs={},
attrs={ attrs={
"OptimizeBlock": optimize_sub_program.global_block(), "OptimizeBlock": optimize_block,
"endpoint": endpoint, "endpoint": endpoint,
"ParamList": [ "ParamList": [
p.name p.name
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册