提交 fff44af8 编写于 作者: M minqiyang

Support simple optimizer

test=develop
上级 68e9b841
...@@ -321,6 +321,9 @@ class Optimizer(object): ...@@ -321,6 +321,9 @@ class Optimizer(object):
import sys import sys
sys.stdout.flush() sys.stdout.flush()
params_grads.append((param, grad_var)) params_grads.append((param, grad_var))
optimize_ops = self._create_optimization_pass(params_grads, loss,
startup_program)
else: else:
params_grads = append_backward(loss, parameter_list, no_grad_set, params_grads = append_backward(loss, parameter_list, no_grad_set,
[error_clip_callback]) [error_clip_callback])
...@@ -336,11 +339,12 @@ class Optimizer(object): ...@@ -336,11 +339,12 @@ class Optimizer(object):
params_grads = append_regularization_ops(params_grads, params_grads = append_regularization_ops(params_grads,
self.regularization) self.regularization)
optimize_ops = self._create_optimization_pass(params_grads, loss, optimize_ops = self._create_optimization_pass(params_grads, loss,
startup_program) startup_program)
if table_optimize_op is not None: if table_optimize_op is not None:
optimize_ops.append(table_optimize_op) optimize_ops.append(table_optimize_op)
params_grads.append(table_param_and_grad) params_grads.append(table_param_and_grad)
return optimize_ops, params_grads return optimize_ops, params_grads
...@@ -389,7 +393,8 @@ class SGDOptimizer(Optimizer): ...@@ -389,7 +393,8 @@ class SGDOptimizer(Optimizer):
"Grad": param_and_grad[1], "Grad": param_and_grad[1],
"LearningRate": self._create_param_lr(param_and_grad) "LearningRate": self._create_param_lr(param_and_grad)
}, },
outputs={"ParamOut": param_and_grad[0]}) outputs={"ParamOut": param_and_grad[0]},
stop_gradient=True)
return sgd_op return sgd_op
...@@ -473,7 +478,8 @@ class MomentumOptimizer(Optimizer): ...@@ -473,7 +478,8 @@ class MomentumOptimizer(Optimizer):
"VelocityOut": velocity_acc "VelocityOut": velocity_acc
}, },
attrs={"mu": self._momentum, attrs={"mu": self._momentum,
"use_nesterov": self._use_nesterov}) "use_nesterov": self._use_nesterov},
stop_gradient=True)
return momentum_op return momentum_op
...@@ -558,7 +564,8 @@ class LarsMomentumOptimizer(Optimizer): ...@@ -558,7 +564,8 @@ class LarsMomentumOptimizer(Optimizer):
"mu": self._momentum, "mu": self._momentum,
"lars_coeff": self._lars_coeff, "lars_coeff": self._lars_coeff,
"lars_weight_decay": self._lars_weight_decay "lars_weight_decay": self._lars_weight_decay
}) },
stop_gradient=True)
return momentum_op return momentum_op
...@@ -633,7 +640,8 @@ class AdagradOptimizer(Optimizer): ...@@ -633,7 +640,8 @@ class AdagradOptimizer(Optimizer):
}, },
outputs={"ParamOut": param_and_grad[0], outputs={"ParamOut": param_and_grad[0],
"MomentOut": moment_acc}, "MomentOut": moment_acc},
attrs={"epsilon": self._epsilon}) attrs={"epsilon": self._epsilon},
stop_gradient=True)
return adagrad_op return adagrad_op
...@@ -763,7 +771,8 @@ class AdamOptimizer(Optimizer): ...@@ -763,7 +771,8 @@ class AdamOptimizer(Optimizer):
"beta2": self._beta2, "beta2": self._beta2,
"epsilon": self._epsilon, "epsilon": self._epsilon,
"lazy_mode": self._lazy_mode "lazy_mode": self._lazy_mode
}) },
stop_gradient=True)
return adam_op return adam_op
...@@ -785,13 +794,15 @@ class AdamOptimizer(Optimizer): ...@@ -785,13 +794,15 @@ class AdamOptimizer(Optimizer):
type="scale", type="scale",
inputs={"X": beta1_pow_acc}, inputs={"X": beta1_pow_acc},
outputs={"Out": beta1_pow_acc}, outputs={"Out": beta1_pow_acc},
attrs={"scale": self._beta1}) attrs={"scale": self._beta1},
stop_gradient=True)
main_block.append_op( main_block.append_op(
type="scale", type="scale",
inputs={"X": beta2_pow_acc}, inputs={"X": beta2_pow_acc},
outputs={"Out": beta2_pow_acc}, outputs={"Out": beta2_pow_acc},
attrs={"scale": self._beta2}) attrs={"scale": self._beta2},
stop_gradient=True)
class AdamaxOptimizer(Optimizer): class AdamaxOptimizer(Optimizer):
...@@ -902,7 +913,8 @@ class AdamaxOptimizer(Optimizer): ...@@ -902,7 +913,8 @@ class AdamaxOptimizer(Optimizer):
"beta1": self._beta1, "beta1": self._beta1,
"beta2": self._beta2, "beta2": self._beta2,
"epsilon": self._epsilon "epsilon": self._epsilon
}) },
stop_gradient=True)
return adamax_op return adamax_op
...@@ -922,7 +934,8 @@ class AdamaxOptimizer(Optimizer): ...@@ -922,7 +934,8 @@ class AdamaxOptimizer(Optimizer):
type="scale", type="scale",
inputs={"X": beta1_pow_acc}, inputs={"X": beta1_pow_acc},
outputs={"Out": beta1_pow_acc}, outputs={"Out": beta1_pow_acc},
attrs={"scale": self._beta1}) attrs={"scale": self._beta1},
stop_gradient=True)
class DecayedAdagradOptimizer(Optimizer): class DecayedAdagradOptimizer(Optimizer):
...@@ -1004,7 +1017,8 @@ class DecayedAdagradOptimizer(Optimizer): ...@@ -1004,7 +1017,8 @@ class DecayedAdagradOptimizer(Optimizer):
}, },
outputs={"ParamOut": param_and_grad[0], outputs={"ParamOut": param_and_grad[0],
"MomentOut": moment_acc}, "MomentOut": moment_acc},
attrs={"epsilon": self._epsilon}) attrs={"epsilon": self._epsilon},
stop_gradient=True)
return decayed_adagrad_op return decayed_adagrad_op
...@@ -1100,7 +1114,8 @@ class AdadeltaOptimizer(Optimizer): ...@@ -1100,7 +1114,8 @@ class AdadeltaOptimizer(Optimizer):
"AvgSquaredUpdateOut": avg_squared_update_acc "AvgSquaredUpdateOut": avg_squared_update_acc
}, },
attrs={"epsilon": self._epsilon, attrs={"epsilon": self._epsilon,
"rho": self._rho}) "rho": self._rho},
stop_gradient=True)
return adadelta_op return adadelta_op
...@@ -1249,7 +1264,8 @@ class RMSPropOptimizer(Optimizer): ...@@ -1249,7 +1264,8 @@ class RMSPropOptimizer(Optimizer):
"decay": self._rho, "decay": self._rho,
"momentum": self._momentum, "momentum": self._momentum,
"centered": self._centered "centered": self._centered
}) },
stop_gradient=True)
return rmsprop_op return rmsprop_op
...@@ -1370,7 +1386,8 @@ class FtrlOptimizer(Optimizer): ...@@ -1370,7 +1386,8 @@ class FtrlOptimizer(Optimizer):
}, },
attrs={"l1": self._l1, attrs={"l1": self._l1,
"l2": self._l1, "l2": self._l1,
"lr_power": self._lr_power}) "lr_power": self._lr_power},
stop_gradient=True)
return ftrl_op return ftrl_op
...@@ -1534,7 +1551,8 @@ class ModelAverage(Optimizer): ...@@ -1534,7 +1551,8 @@ class ModelAverage(Optimizer):
"average_window": self.average_window, "average_window": self.average_window,
"min_average_window": self.min_average_window, "min_average_window": self.min_average_window,
"max_average_window": self.max_average_window, "max_average_window": self.max_average_window,
}) },
stop_gradient=True)
@contextmanager @contextmanager
def apply(self, executor, need_restore=True): def apply(self, executor, need_restore=True):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册