提交 fff44af8 编写于 作者: M minqiyang

Support simple optimizer

test=develop
上级 68e9b841
......@@ -321,6 +321,9 @@ class Optimizer(object):
import sys
sys.stdout.flush()
params_grads.append((param, grad_var))
optimize_ops = self._create_optimization_pass(params_grads, loss,
startup_program)
else:
params_grads = append_backward(loss, parameter_list, no_grad_set,
[error_clip_callback])
......@@ -336,11 +339,12 @@ class Optimizer(object):
params_grads = append_regularization_ops(params_grads,
self.regularization)
optimize_ops = self._create_optimization_pass(params_grads, loss,
startup_program)
if table_optimize_op is not None:
optimize_ops.append(table_optimize_op)
params_grads.append(table_param_and_grad)
optimize_ops = self._create_optimization_pass(params_grads, loss,
startup_program)
if table_optimize_op is not None:
optimize_ops.append(table_optimize_op)
params_grads.append(table_param_and_grad)
return optimize_ops, params_grads
......@@ -389,7 +393,8 @@ class SGDOptimizer(Optimizer):
"Grad": param_and_grad[1],
"LearningRate": self._create_param_lr(param_and_grad)
},
outputs={"ParamOut": param_and_grad[0]})
outputs={"ParamOut": param_and_grad[0]},
stop_gradient=True)
return sgd_op
......@@ -473,7 +478,8 @@ class MomentumOptimizer(Optimizer):
"VelocityOut": velocity_acc
},
attrs={"mu": self._momentum,
"use_nesterov": self._use_nesterov})
"use_nesterov": self._use_nesterov},
stop_gradient=True)
return momentum_op
......@@ -558,7 +564,8 @@ class LarsMomentumOptimizer(Optimizer):
"mu": self._momentum,
"lars_coeff": self._lars_coeff,
"lars_weight_decay": self._lars_weight_decay
})
},
stop_gradient=True)
return momentum_op
......@@ -633,7 +640,8 @@ class AdagradOptimizer(Optimizer):
},
outputs={"ParamOut": param_and_grad[0],
"MomentOut": moment_acc},
attrs={"epsilon": self._epsilon})
attrs={"epsilon": self._epsilon},
stop_gradient=True)
return adagrad_op
......@@ -763,7 +771,8 @@ class AdamOptimizer(Optimizer):
"beta2": self._beta2,
"epsilon": self._epsilon,
"lazy_mode": self._lazy_mode
})
},
stop_gradient=True)
return adam_op
......@@ -785,13 +794,15 @@ class AdamOptimizer(Optimizer):
type="scale",
inputs={"X": beta1_pow_acc},
outputs={"Out": beta1_pow_acc},
attrs={"scale": self._beta1})
attrs={"scale": self._beta1},
stop_gradient=True)
main_block.append_op(
type="scale",
inputs={"X": beta2_pow_acc},
outputs={"Out": beta2_pow_acc},
attrs={"scale": self._beta2})
attrs={"scale": self._beta2},
stop_gradient=True)
class AdamaxOptimizer(Optimizer):
......@@ -902,7 +913,8 @@ class AdamaxOptimizer(Optimizer):
"beta1": self._beta1,
"beta2": self._beta2,
"epsilon": self._epsilon
})
},
stop_gradient=True)
return adamax_op
......@@ -922,7 +934,8 @@ class AdamaxOptimizer(Optimizer):
type="scale",
inputs={"X": beta1_pow_acc},
outputs={"Out": beta1_pow_acc},
attrs={"scale": self._beta1})
attrs={"scale": self._beta1},
stop_gradient=True)
class DecayedAdagradOptimizer(Optimizer):
......@@ -1004,7 +1017,8 @@ class DecayedAdagradOptimizer(Optimizer):
},
outputs={"ParamOut": param_and_grad[0],
"MomentOut": moment_acc},
attrs={"epsilon": self._epsilon})
attrs={"epsilon": self._epsilon},
stop_gradient=True)
return decayed_adagrad_op
......@@ -1100,7 +1114,8 @@ class AdadeltaOptimizer(Optimizer):
"AvgSquaredUpdateOut": avg_squared_update_acc
},
attrs={"epsilon": self._epsilon,
"rho": self._rho})
"rho": self._rho},
stop_gradient=True)
return adadelta_op
......@@ -1249,7 +1264,8 @@ class RMSPropOptimizer(Optimizer):
"decay": self._rho,
"momentum": self._momentum,
"centered": self._centered
})
},
stop_gradient=True)
return rmsprop_op
......@@ -1370,7 +1386,8 @@ class FtrlOptimizer(Optimizer):
},
attrs={"l1": self._l1,
"l2": self._l1,
"lr_power": self._lr_power})
"lr_power": self._lr_power},
stop_gradient=True)
return ftrl_op
......@@ -1534,7 +1551,8 @@ class ModelAverage(Optimizer):
"average_window": self.average_window,
"min_average_window": self.min_average_window,
"max_average_window": self.max_average_window,
})
},
stop_gradient=True)
@contextmanager
def apply(self, executor, need_restore=True):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册