未验证 提交 0627ee83 编写于 作者: W Wu Yi 提交者: GitHub

Merge pull request #14314 from typhoonzero/fix_pserver_weight_decay_multi_inputs

fix pserver weight decay multi inputs
...@@ -1706,11 +1706,25 @@ to transpile() call.") ...@@ -1706,11 +1706,25 @@ to transpile() call.")
outputs=outputs, outputs=outputs,
attrs=opt_op.all_attrs()) attrs=opt_op.all_attrs())
def _is_splited_grad_var(self, var, var_dict): def _get_pserver_grad_param_var(self, var, var_dict):
"""
Return pserver side grad/param variable, return None
if the variable is not grad/param, e.g.
a@GRAD -> a@GRAD.block0
a@GRAD -> a@GRAD (a is not splited)
fc_0.w_0 -> fc_0.w_0.block_0
fc_0.w_0 -> fc_0.w_0 (weight is not splited)
_generated_var_123 -> None
"""
grad_block = None grad_block = None
for _, g in six.iteritems(var_dict): for _, g in six.iteritems(var_dict):
if self._orig_varname(g.name) == self._orig_varname(var.name): if self._orig_varname(g.name) == self._orig_varname(var.name):
# skip per trainer vars
if g.name.find(".trainer_") == -1: if g.name.find(".trainer_") == -1:
# only param or grads have splited blocks
if self._orig_varname(g.name) in self.grad_name_to_param_name or\
self._orig_varname(g.name) in self.param_name_to_grad_name:
grad_block = g grad_block = g
break break
return grad_block return grad_block
...@@ -1745,32 +1759,38 @@ to transpile() call.") ...@@ -1745,32 +1759,38 @@ to transpile() call.")
for key, varlist in six.iteritems(inputs): for key, varlist in six.iteritems(inputs):
if not isinstance(varlist, list): if not isinstance(varlist, list):
varlist = [varlist] varlist = [varlist]
for var in varlist: for i in range(len(varlist)):
# for ops like clipping and weight decay, get the splited var var = varlist[i]
# for ops like clipping and weight decay, get the splited var (xxx.block0)
# for inputs/outputs # for inputs/outputs
grad_block = self._is_splited_grad_var( grad_block = self._get_pserver_grad_param_var(
var, program.global_block().vars) var, program.global_block().vars)
if grad_block: if grad_block:
inputs[key] = grad_block varlist[i] = grad_block
elif var.name not in program.global_block().vars: elif var.name not in program.global_block().vars:
program.global_block().create_var( tmpvar = program.global_block()._clone_variable(var)
name=var.name, varlist[i] = tmpvar
persistable=var.persistable, else:
dtype=var.dtype, varlist[i] = program.global_block().vars[var.name]
shape=var.shape) inputs[key] = varlist
outputs = self._get_output_map_from_op( outputs = self._get_output_map_from_op(
self.origin_program.global_block().vars, opt_op) self.origin_program.global_block().vars, opt_op)
for key, varlist in six.iteritems(outputs): for key, varlist in six.iteritems(outputs):
if not isinstance(varlist, list): if not isinstance(varlist, list):
varlist = [varlist] varlist = [varlist]
for var in varlist: for i in range(len(varlist)):
grad_block = self._is_splited_grad_var( var = varlist[i]
grad_block = self._get_pserver_grad_param_var(
var, program.global_block().vars) var, program.global_block().vars)
if grad_block: if grad_block:
outputs[key] = grad_block varlist[i] = grad_block
elif var.name not in program.global_block().vars: elif var.name not in program.global_block().vars:
program.global_block()._clone_variable(var) tmpvar = program.global_block()._clone_variable(var)
varlist[i] = tmpvar
else:
varlist[i] = program.global_block().vars[var.name]
outputs[key] = varlist
return optimize_block.append_op( return optimize_block.append_op(
type=opt_op.type, type=opt_op.type,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册