未验证 提交 0627ee83 编写于 作者: W Wu Yi 提交者: GitHub

Merge pull request #14314 from typhoonzero/fix_pserver_weight_decay_multi_inputs

fix pserver weight decay multi inputs
......@@ -1706,13 +1706,27 @@ to transpile() call.")
outputs=outputs,
attrs=opt_op.all_attrs())
def _is_splited_grad_var(self, var, var_dict):
def _get_pserver_grad_param_var(self, var, var_dict):
"""
Return pserver side grad/param variable, return None
if the variable is not grad/param, e.g.
a@GRAD -> a@GRAD.block0
a@GRAD -> a@GRAD (a is not splited)
fc_0.w_0 -> fc_0.w_0.block_0
fc_0.w_0 -> fc_0.w_0 (weight is not splited)
_generated_var_123 -> None
"""
grad_block = None
for _, g in six.iteritems(var_dict):
if self._orig_varname(g.name) == self._orig_varname(var.name):
# skip per trainer vars
if g.name.find(".trainer_") == -1:
grad_block = g
break
# only param or grads have splited blocks
if self._orig_varname(g.name) in self.grad_name_to_param_name or\
self._orig_varname(g.name) in self.param_name_to_grad_name:
grad_block = g
break
return grad_block
def _clone_lr_op(self, program, block, op):
......@@ -1745,32 +1759,38 @@ to transpile() call.")
for key, varlist in six.iteritems(inputs):
if not isinstance(varlist, list):
varlist = [varlist]
for var in varlist:
# for ops like clipping and weight decay, get the splited var
for i in range(len(varlist)):
var = varlist[i]
# for ops like clipping and weight decay, get the splited var (xxx.block0)
# for inputs/outputs
grad_block = self._is_splited_grad_var(
grad_block = self._get_pserver_grad_param_var(
var, program.global_block().vars)
if grad_block:
inputs[key] = grad_block
varlist[i] = grad_block
elif var.name not in program.global_block().vars:
program.global_block().create_var(
name=var.name,
persistable=var.persistable,
dtype=var.dtype,
shape=var.shape)
tmpvar = program.global_block()._clone_variable(var)
varlist[i] = tmpvar
else:
varlist[i] = program.global_block().vars[var.name]
inputs[key] = varlist
outputs = self._get_output_map_from_op(
self.origin_program.global_block().vars, opt_op)
for key, varlist in six.iteritems(outputs):
if not isinstance(varlist, list):
varlist = [varlist]
for var in varlist:
grad_block = self._is_splited_grad_var(
for i in range(len(varlist)):
var = varlist[i]
grad_block = self._get_pserver_grad_param_var(
var, program.global_block().vars)
if grad_block:
outputs[key] = grad_block
varlist[i] = grad_block
elif var.name not in program.global_block().vars:
program.global_block()._clone_variable(var)
tmpvar = program.global_block()._clone_variable(var)
varlist[i] = tmpvar
else:
varlist[i] = program.global_block().vars[var.name]
outputs[key] = varlist
return optimize_block.append_op(
type=opt_op.type,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册