未验证 提交 5e725dc5 编写于 作者: Y yuyang18

Hide Optimizer methods

上级 6c83dcd6
...@@ -29,7 +29,7 @@ __all__ = [ ...@@ -29,7 +29,7 @@ __all__ = [
'SGD', 'Momentum', 'Adagrad', 'Adam', 'Adamax', 'DecayedAdagrad', 'Ftrl', 'SGD', 'Momentum', 'Adagrad', 'Adam', 'Adamax', 'DecayedAdagrad', 'Ftrl',
'SGDOptimizer', 'MomentumOptimizer', 'AdagradOptimizer', 'AdamOptimizer', 'SGDOptimizer', 'MomentumOptimizer', 'AdagradOptimizer', 'AdamOptimizer',
'AdamaxOptimizer', 'DecayedAdagradOptimizer', 'RMSPropOptimizer', 'AdamaxOptimizer', 'DecayedAdagradOptimizer', 'RMSPropOptimizer',
'FtrlOptimizer', 'Adadelta', 'ModelAverage', 'Optimizer', 'RMSPropOptimizer' 'FtrlOptimizer', 'Adadelta', 'ModelAverage', 'RMSPropOptimizer'
] ]
...@@ -67,7 +67,7 @@ class Optimizer(object): ...@@ -67,7 +67,7 @@ class Optimizer(object):
self._LARS_weight_decay = LARS_weight_decay self._LARS_weight_decay = LARS_weight_decay
def _create_global_learning_rate(self): def _create_global_learning_rate(self):
lr = self.global_learning_rate() lr = self._global_learning_rate()
if isinstance(lr, framework.Variable): if isinstance(lr, framework.Variable):
return return
...@@ -86,7 +86,7 @@ class Optimizer(object): ...@@ -86,7 +86,7 @@ class Optimizer(object):
dtype='float32' if self._dtype == None else self._dtype, dtype='float32' if self._dtype == None else self._dtype,
persistable=True) persistable=True)
def global_learning_rate(self, program=None): def _global_learning_rate(self, program=None):
""" """
get global decayed learning rate get global decayed learning rate
:return: :return:
...@@ -110,9 +110,9 @@ class Optimizer(object): ...@@ -110,9 +110,9 @@ class Optimizer(object):
return param_lr return param_lr
else: else:
if param_lr == 1.0: if param_lr == 1.0:
return self.global_learning_rate() return self._global_learning_rate()
else: else:
return self.global_learning_rate() * param_lr return self._global_learning_rate() * param_lr
def _create_accumulators(self, block, parameters): def _create_accumulators(self, block, parameters):
"""Create all accumulators needed by the parameters """Create all accumulators needed by the parameters
...@@ -185,7 +185,7 @@ class Optimizer(object): ...@@ -185,7 +185,7 @@ class Optimizer(object):
format(name, param.name)) format(name, param.name))
return self._accumulators[name][param.name] return self._accumulators[name][param.name]
def create_optimization_pass(self, def _create_optimization_pass(self,
parameters_and_grads, parameters_and_grads,
loss, loss,
startup_program=None): startup_program=None):
...@@ -221,7 +221,7 @@ class Optimizer(object): ...@@ -221,7 +221,7 @@ class Optimizer(object):
self._create_global_learning_rate() self._create_global_learning_rate()
if self._LARS_weight_decay > 0.0: if self._LARS_weight_decay > 0.0:
layers.append_LARS(parameters_and_grads, layers.append_LARS(parameters_and_grads,
self.global_learning_rate(), self._global_learning_rate(),
self._LARS_weight_decay) self._LARS_weight_decay)
optimize_ops = [] optimize_ops = []
...@@ -262,7 +262,7 @@ class Optimizer(object): ...@@ -262,7 +262,7 @@ class Optimizer(object):
params_grads = append_regularization_ops(params_grads, params_grads = append_regularization_ops(params_grads,
self.regularization) self.regularization)
optimize_ops = self.create_optimization_pass(params_grads, loss, optimize_ops = self._create_optimization_pass(params_grads, loss,
startup_program) startup_program)
return optimize_ops, params_grads return optimize_ops, params_grads
......
...@@ -97,7 +97,7 @@ class TestMomentumOptimizer(unittest.TestCase): ...@@ -97,7 +97,7 @@ class TestMomentumOptimizer(unittest.TestCase):
params_grads = append_backward(mean_out) params_grads = append_backward(mean_out)
self.assertEqual(len(params_grads), 1) self.assertEqual(len(params_grads), 1)
self.assertEqual(len(momentum_optimizer.get_accumulators()), 0) self.assertEqual(len(momentum_optimizer.get_accumulators()), 0)
opts = momentum_optimizer.create_optimization_pass( opts = momentum_optimizer._create_optimization_pass(
params_grads, mul_out, init_program) params_grads, mul_out, init_program)
self.assertEqual(len(opts), 3) self.assertEqual(len(opts), 3)
sgd_op = opts[-1] sgd_op = opts[-1]
...@@ -151,7 +151,7 @@ class TestMomentumOptimizer(unittest.TestCase): ...@@ -151,7 +151,7 @@ class TestMomentumOptimizer(unittest.TestCase):
params_grads = append_backward(mean_out) params_grads = append_backward(mean_out)
self.assertEqual(len(params_grads), 1) self.assertEqual(len(params_grads), 1)
self.assertEqual(len(momentum_optimizer.get_accumulators()), 0) self.assertEqual(len(momentum_optimizer.get_accumulators()), 0)
opts = momentum_optimizer.create_optimization_pass( opts = momentum_optimizer._create_optimization_pass(
params_grads, mul_out, init_program) params_grads, mul_out, init_program)
self.assertEqual(len(opts), 3) self.assertEqual(len(opts), 3)
sgd_op = opts[-1] sgd_op = opts[-1]
...@@ -214,8 +214,8 @@ class TestAdagradOptimizer(unittest.TestCase): ...@@ -214,8 +214,8 @@ class TestAdagradOptimizer(unittest.TestCase):
params_grads = append_backward(mean_out) params_grads = append_backward(mean_out)
self.assertEqual(len(params_grads), 1) self.assertEqual(len(params_grads), 1)
self.assertEqual(len(adagrad_optimizer.get_accumulators()), 0) self.assertEqual(len(adagrad_optimizer.get_accumulators()), 0)
opts = adagrad_optimizer.create_optimization_pass(params_grads, mul_out, opts = adagrad_optimizer._create_optimization_pass(
init_program) params_grads, mul_out, init_program)
self.assertEqual(len(opts), 3) self.assertEqual(len(opts), 3)
self.assertEqual([op.type for op in opts], self.assertEqual([op.type for op in opts],
["fill_constant", "elementwise_mul", "adagrad"]) ["fill_constant", "elementwise_mul", "adagrad"])
...@@ -278,7 +278,7 @@ class TestAdamOptimizer(unittest.TestCase): ...@@ -278,7 +278,7 @@ class TestAdamOptimizer(unittest.TestCase):
params_grads = append_backward(mean_out) params_grads = append_backward(mean_out)
self.assertEqual(len(params_grads), 1) self.assertEqual(len(params_grads), 1)
self.assertEqual(len(adam_optimizer.get_accumulators()), 0) self.assertEqual(len(adam_optimizer.get_accumulators()), 0)
opts = adam_optimizer.create_optimization_pass(params_grads, mul_out, opts = adam_optimizer._create_optimization_pass(params_grads, mul_out,
init_program) init_program)
self.assertEqual(len(opts), 5) self.assertEqual(len(opts), 5)
self.assertEqual( self.assertEqual(
...@@ -345,7 +345,7 @@ class TestAdamaxOptimizer(unittest.TestCase): ...@@ -345,7 +345,7 @@ class TestAdamaxOptimizer(unittest.TestCase):
params_grads = append_backward(mean_out) params_grads = append_backward(mean_out)
self.assertEqual(len(params_grads), 1) self.assertEqual(len(params_grads), 1)
self.assertEqual(len(adamax_optimizer.get_accumulators()), 0) self.assertEqual(len(adamax_optimizer.get_accumulators()), 0)
opts = adamax_optimizer.create_optimization_pass(params_grads, mul_out, opts = adamax_optimizer._create_optimization_pass(params_grads, mul_out,
init_program) init_program)
self.assertEqual(len(opts), 4) self.assertEqual(len(opts), 4)
self.assertEqual( self.assertEqual(
...@@ -409,7 +409,7 @@ class TestDecayedAdagradOptimizer(unittest.TestCase): ...@@ -409,7 +409,7 @@ class TestDecayedAdagradOptimizer(unittest.TestCase):
params_grads = append_backward(mean_out) params_grads = append_backward(mean_out)
self.assertEqual(len(params_grads), 1) self.assertEqual(len(params_grads), 1)
self.assertEqual(len(decayed_adagrad_optimizer.get_accumulators()), 0) self.assertEqual(len(decayed_adagrad_optimizer.get_accumulators()), 0)
opts = decayed_adagrad_optimizer.create_optimization_pass( opts = decayed_adagrad_optimizer._create_optimization_pass(
params_grads, mul_out, init_program) params_grads, mul_out, init_program)
self.assertEqual(len(opts), 3) self.assertEqual(len(opts), 3)
self.assertEqual( self.assertEqual(
...@@ -475,7 +475,7 @@ class TestFtrlOptimizer(unittest.TestCase): ...@@ -475,7 +475,7 @@ class TestFtrlOptimizer(unittest.TestCase):
params_grads = append_backward(mean_out) params_grads = append_backward(mean_out)
self.assertEqual(len(params_grads), 1) self.assertEqual(len(params_grads), 1)
self.assertEqual(len(ftrl_optimizer.get_accumulators()), 0) self.assertEqual(len(ftrl_optimizer.get_accumulators()), 0)
opts = ftrl_optimizer.create_optimization_pass(params_grads, mul_out, opts = ftrl_optimizer._create_optimization_pass(params_grads, mul_out,
init_program) init_program)
self.assertEqual(len(opts), 3) self.assertEqual(len(opts), 3)
self.assertEqual([op.type for op in opts], self.assertEqual([op.type for op in opts],
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册