提交 49cce3fd 编写于 作者: Q Qiao Longfei

fix dist sparse l2 decay

test=develop
上级 dc8eca82
...@@ -235,7 +235,6 @@ class DistSeResneXt2x2(TestDistRunnerBase): ...@@ -235,7 +235,6 @@ class DistSeResneXt2x2(TestDistRunnerBase):
bd = [step * e for e in epochs] bd = [step * e for e in epochs]
base_lr = 0.1 base_lr = 0.1
lr = []
lr = [base_lr * (0.1**i) for i in range(len(bd) + 1)] lr = [base_lr * (0.1**i) for i in range(len(bd) + 1)]
optimizer = fluid.optimizer.Momentum( optimizer = fluid.optimizer.Momentum(
......
...@@ -744,12 +744,6 @@ class DistributeTranspiler(object): ...@@ -744,12 +744,6 @@ class DistributeTranspiler(object):
elif op not in lr_ops: elif op not in lr_ops:
self._append_pserver_non_opt_ops(block, op) self._append_pserver_non_opt_ops(block, op)
def __op_have_grad_input__(op):
for varname in op.input_arg_names:
if varname.find("@GRAD") >= 0:
return varname
return ""
def __clone_lr_op_sub_block__(op, program, lr_block): def __clone_lr_op_sub_block__(op, program, lr_block):
if not op.has_attr('sub_block'): if not op.has_attr('sub_block'):
return return
...@@ -800,7 +794,7 @@ class DistributeTranspiler(object): ...@@ -800,7 +794,7 @@ class DistributeTranspiler(object):
merged_var = None merged_var = None
for _, op in enumerate(self.optimize_ops): for _, op in enumerate(self.optimize_ops):
# find the origin grad var before clipping/L2Decay, # find the origin grad var before clipping/L2Decay,
# merged_var should be the input var name of L2Decaybuil # merged_var should be the input var name of L2Decay
grad_varname_for_block = op.attr(OP_ROLE_VAR_ATTR_NAME)[1] grad_varname_for_block = op.attr(OP_ROLE_VAR_ATTR_NAME)[1]
if op.attr(OP_ROLE_VAR_ATTR_NAME)[ if op.attr(OP_ROLE_VAR_ATTR_NAME)[
0] == optimize_target_param_name: 0] == optimize_target_param_name:
...@@ -1278,9 +1272,8 @@ class DistributeTranspiler(object): ...@@ -1278,9 +1272,8 @@ class DistributeTranspiler(object):
# create table param and grad var in pserver program # create table param and grad var in pserver program
# create table optimize block in pserver program # create table optimize block in pserver program
table_opt_op = [ table_opt_op = [
op for op in self.optimize_ops op for op in self.optimize_ops if 'Param' in op.input_names and
if 'Param' in op.input_names and op.input("Param")[0] == op.input("Param")[0] == self.table_name
self.table_name
][0] ][0]
origin_param_var = self.origin_program.global_block().vars[ origin_param_var = self.origin_program.global_block().vars[
...@@ -1676,7 +1669,16 @@ class DistributeTranspiler(object): ...@@ -1676,7 +1669,16 @@ class DistributeTranspiler(object):
if self.config.enable_dc_asgd: if self.config.enable_dc_asgd:
new_inputs[key] = dc new_inputs[key] = dc
else: else:
new_inputs[key] = merged_var # Note!! This is for l2decay on sparse gradient, because it will create a new tensor for
# decayed gradient but not inplace modify the origin one
origin_grad_name = opt_op.input(key)[0]
if core.kNewGradSuffix(
) in origin_grad_name and pserver_block.has_var(
origin_grad_name):
new_grad = pserver_block.var(origin_grad_name)
new_inputs[key] = new_grad
else:
new_inputs[key] = merged_var
elif key == "Param": elif key == "Param":
param_block = _get_param_block(opt_op) param_block = _get_param_block(opt_op)
if not param_block: if not param_block:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册