diff --git a/python/paddle/fluid/dygraph/checkpoint.py b/python/paddle/fluid/dygraph/checkpoint.py index 82018132cc8b8600958e5cd52df5844e3d37638e..f4d68a798efa26d43702aa1c555f6046f0e6a6a5 100644 --- a/python/paddle/fluid/dygraph/checkpoint.py +++ b/python/paddle/fluid/dygraph/checkpoint.py @@ -207,6 +207,7 @@ def load_dygraph(model_path, keep_name_table=False): # NOTE: `jit.save` doesn't save optimizer state else: # Load state dict by `save_dygraph` save format + para_dict = {} if os.path.exists(params_file_path): with open(params_file_path, 'rb') as f: para_dict = pickle.load(f) if six.PY2 else pickle.load( diff --git a/python/paddle/fluid/tests/unittests/test_adam_op.py b/python/paddle/fluid/tests/unittests/test_adam_op.py index 6b94eab92c321c5eb451010bb3243eba11caae28..14e83fccd655527d8f3012365e4757d23236a445 100644 --- a/python/paddle/fluid/tests/unittests/test_adam_op.py +++ b/python/paddle/fluid/tests/unittests/test_adam_op.py @@ -504,6 +504,19 @@ class TestAdamOpV2(unittest.TestCase): shape=[1], value=lr, dtype='float32') adam.set_lr(lr_var) + def test_adam_op_invalid_input(self): + paddle.disable_static() + linear = paddle.nn.Linear(10, 10) + with self.assertRaises(ValueError): + adam = paddle.optimizer.Adam( + 0.1, beta1=-1, parameters=linear.parameters()) + with self.assertRaises(ValueError): + adam = paddle.optimizer.Adam( + 0.1, beta2=-1, parameters=linear.parameters()) + with self.assertRaises(ValueError): + adam = paddle.optimizer.Adam( + 0.1, epsilon=-1, parameters=linear.parameters()) + if __name__ == "__main__": unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_adamax_op.py b/python/paddle/fluid/tests/unittests/test_adamax_op.py index a6d1be7616c73019cd8f66dcf0c108cd58ec600b..8ce7656acfae77987b284e29cd85b35d264b20e2 100644 --- a/python/paddle/fluid/tests/unittests/test_adamax_op.py +++ b/python/paddle/fluid/tests/unittests/test_adamax_op.py @@ -184,5 +184,21 @@ def adamax_step(inputs, attributes): return param_out, moment_out, inf_norm_out +class TestAdamaxOpV2(unittest.TestCase): + def test_adamax_op_invalid_input(self): + import paddle + paddle.disable_static() + linear = paddle.nn.Linear(10, 10) + with self.assertRaises(ValueError): + adam = paddle.optimizer.Adamax( + 0.1, beta1=-1, parameters=linear.parameters()) + with self.assertRaises(ValueError): + adam = paddle.optimizer.Adamax( + 0.1, beta2=-1, parameters=linear.parameters()) + with self.assertRaises(ValueError): + adam = paddle.optimizer.Adamax( + 0.1, epsilon=-1, parameters=linear.parameters()) + + if __name__ == "__main__": unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_adamw_op.py b/python/paddle/fluid/tests/unittests/test_adamw_op.py index 0a7cf54e2e0f15e51ba1b6f7526837f53c7cc2e0..cce24b57d2ca50e96e3ae0cf6d8912a8aea79a31 100644 --- a/python/paddle/fluid/tests/unittests/test_adamw_op.py +++ b/python/paddle/fluid/tests/unittests/test_adamw_op.py @@ -76,6 +76,19 @@ class TestAdamWOp(unittest.TestCase): rets = exe.run(train_prog, feed={"data": data_np}, fetch_list=[loss]) assert rets[0] is not None + def test_adamw_op_invalid_input(self): + paddle.disable_static() + linear = paddle.nn.Linear(10, 10) + with self.assertRaises(ValueError): + adam = paddle.optimizer.AdamW( + 0.1, beta1=-1, parameters=linear.parameters()) + with self.assertRaises(ValueError): + adam = paddle.optimizer.AdamW( + 0.1, beta2=-1, parameters=linear.parameters()) + with self.assertRaises(ValueError): + adam = paddle.optimizer.AdamW( + 0.1, epsilon=-1, parameters=linear.parameters()) + if __name__ == "__main__": unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_imperative_optimizer_v2.py b/python/paddle/fluid/tests/unittests/test_imperative_optimizer_v2.py index 619e9e8e90783365b5f0d718783a14468520c8d4..887e50f07c55cc991d7816609253039ce0d48d7d 100644 --- a/python/paddle/fluid/tests/unittests/test_imperative_optimizer_v2.py +++ b/python/paddle/fluid/tests/unittests/test_imperative_optimizer_v2.py @@ -401,9 +401,7 @@ class TestOptimizerLearningRate(unittest.TestCase): a = np.random.uniform(-0.1, 0.1, [10, 10]).astype("float32") linear = fluid.dygraph.nn.Linear(10, 10) - a = fluid.dygraph.to_variable(a) - b = linear(a) loss = fluid.layers.reduce_mean(b) diff --git a/python/paddle/fluid/tests/unittests/test_rmsprop_op.py b/python/paddle/fluid/tests/unittests/test_rmsprop_op.py index f7b9d4214d36a422a3ec94dc410e58c6c827ef4c..ddac7f6b98b19d204d20ccdff75c6d4fcae50d4d 100644 --- a/python/paddle/fluid/tests/unittests/test_rmsprop_op.py +++ b/python/paddle/fluid/tests/unittests/test_rmsprop_op.py @@ -276,6 +276,19 @@ class TestRMSPropV2(unittest.TestCase): learning_rate=0.1, momentum=None) + def test_rmsprop_op_invalid_input(self): + paddle.disable_static() + linear = paddle.nn.Linear(10, 10) + with self.assertRaises(ValueError): + adam = paddle.optimizer.RMSProp( + 0.1, epsilon=-1, parameters=linear.parameters()) + with self.assertRaises(ValueError): + adam = paddle.optimizer.RMSProp( + 0.1, momentum=-1, parameters=linear.parameters()) + with self.assertRaises(ValueError): + adam = paddle.optimizer.RMSProp( + 0.1, rho=-1, parameters=linear.parameters()) + if __name__ == "__main__": unittest.main() diff --git a/python/paddle/optimizer/adam.py b/python/paddle/optimizer/adam.py index 0da8053fe8a3495f5d3188a737638531347de648..3150b8c2d0363274dfb6fd3465110c89339cd4c9 100644 --- a/python/paddle/optimizer/adam.py +++ b/python/paddle/optimizer/adam.py @@ -45,8 +45,8 @@ class Adam(Optimizer): Related paper: `Adam: A Method for Stochastic Optimization `_ Args: - learning_rate (float|LearningRateDecay, optional): The learning rate used to update ``Parameter``. - It can be a float value or a LearningRateDecay. The default value is 0.001. + learning_rate (float|_LRScheduler, optional): The learning rate used to update ``Parameter``. + It can be a float value or a _LRScheduler. The default value is 0.001. beta1 (float|Tensor, optional): The exponential decay rate for the 1st moment estimates. It should be a float number or a Tensor with shape [1] and data type as float32. The default value is 0.9. @@ -55,7 +55,7 @@ class Adam(Optimizer): The default value is 0.999. epsilon (float, optional): A small float value for numerical stability. The default value is 1e-08. - parameters (list, optional): List of ``Tensor`` names to update to minimize ``loss``. \ + parameters (list, optional): List of ``Tensor`` to update to minimize ``loss``. \ This parameter is required in dygraph mode. \ The default value is None in static mode, at this time all parameters will be updated. weight_decay (float|WeightDecayRegularizer, optional): The strategy of regularization. \ @@ -143,6 +143,12 @@ class Adam(Optimizer): assert beta1 is not None assert beta2 is not None assert epsilon is not None + if not 0 <= beta1 < 1: + raise ValueError("Invaild value of beta1, expect beta1 in [0,1).") + if not 0 <= beta2 < 1: + raise ValueError("Invaild value of beta2, expect beta2 in [0,1).") + if not 0 <= epsilon: + raise ValueError("Invaild value of epsilon, expect epsilon >= 0.") super(Adam, self).__init__( learning_rate=learning_rate, parameters=parameters, diff --git a/python/paddle/optimizer/adamax.py b/python/paddle/optimizer/adamax.py index 73a78b17cbba55c1ee90a2708f6c163940158a51..cca120efd450768520d9cf027f6a36aaad121d9e 100644 --- a/python/paddle/optimizer/adamax.py +++ b/python/paddle/optimizer/adamax.py @@ -47,15 +47,15 @@ class Adamax(Optimizer): it is added here for numerical stability to prevent the division by 0 error. Args: - learning_rate (float|LearningRateDecay, optional): The learning rate used to update ``Parameter``. - It can be a float value or a LearningRateDecay. The default value is 0.001. + learning_rate (float|_LRScheduler, optional): The learning rate used to update ``Parameter``. + It can be a float value or a _LRScheduler. The default value is 0.001. beta1 (float, optional): The exponential decay rate for the 1st moment estimates. The default value is 0.9. beta2 (float, optional): The exponential decay rate for the 2nd moment estimates. The default value is 0.999. epsilon (float, optional): A small float value for numerical stability. The default value is 1e-08. - parameters (list, optional): List of ``Tensor`` names to update to minimize ``loss``. \ + parameters (list, optional): List of ``Tensor`` to update to minimize ``loss``. \ This parameter is required in dygraph mode. \ The default value is None in static mode, at this time all parameters will be updated. weight_decay (float|WeightDecayRegularizer, optional): The strategy of regularization. \ @@ -118,6 +118,12 @@ class Adamax(Optimizer): assert beta1 is not None assert beta2 is not None assert epsilon is not None + if not 0 <= beta1 < 1: + raise ValueError("Invaild value of beta1, expect beta1 in [0,1).") + if not 0 <= beta2 < 1: + raise ValueError("Invaild value of beta2, expect beta2 in [0,1).") + if not 0 <= epsilon: + raise ValueError("Invaild value of epsilon, expect epsilon >= 0.") super(Adamax, self).__init__( learning_rate=learning_rate, parameters=parameters, diff --git a/python/paddle/optimizer/adamw.py b/python/paddle/optimizer/adamw.py index f498fcbffa24ec188b57ceb2d3c6884fc1e135d2..edaca7e8301676c8734eb3e60924844bea0121d9 100644 --- a/python/paddle/optimizer/adamw.py +++ b/python/paddle/optimizer/adamw.py @@ -19,112 +19,7 @@ import paddle __all__ = ['AdamW'] -class DecoupledWeightDecay(object): - def __init__(self, coeff=0.0, apply_decay_param_fun=None, **kwargs): - if not isinstance(coeff, float) and \ - not isinstance(coeff, framework.Variable): - raise TypeError("coeff should be float or Tensor.") - self._params_name = set() - self._apply_decay_param_fun = apply_decay_param_fun - self._coeff = coeff - super(DecoupledWeightDecay, self).__init__(**kwargs) - - def _scale_parameters(self, params_and_grads): - """ - Adds weight decay ops. - scaled_parameter = parameter * coeff - - Args: - params_and_grads: A list of (parameters, gradients) pairs, - the parameters need to decay. - Raises: - Exception: The type of coeff and parameter is not consistent. - """ - if isinstance(self._coeff, float) and self._coeff == 0.0: - return - - scaled_params = [] - for param, grad in params_and_grads: - # If no gradient then we don't need to do anything - if grad is None: - continue - if self._apply_decay_param_fun is not None \ - and not self._apply_decay_param_fun(param.name): - continue - - if isinstance(self._coeff, float): - assert param.dtype is not paddle.fluid.core.VarDesc.VarType.FP32, \ - "the type of coeff(float) and parameter(%s) is not consistent."%(self._coeff.dtype) - else: - assert self._coeff.dtype == param.dtype, \ - "the type of coeff(%s) and parameter(%s) is not consistent."%(self._coeff.dtype, param.dtype) - - with param.block.program._optimized_guard( - [param, grad]), framework.name_scope('weight decay'): - assert param.name not in self._params_name - scaled_params.append((param, grad, param * self._coeff)) - self._params_name.add(param.name) - return scaled_params - - def backward(self, **kargs): - return super(DecoupledWeightDecay, self).backward(**kargs) - - def _apply_optimize(self, **kargs): - return super(DecoupledWeightDecay, self)._apply_optimize(**kargs) - - def minimize(self, - loss, - startup_program=None, - parameters=None, - no_grad_set=None): - params_grads = self.backward( - loss=loss, - startup_program=startup_program, - parameters=parameters, - no_grad_set=no_grad_set) - scaled_params = self._scale_parameters(params_grads) - for p_grad_sgrad in scaled_params: - param, grad, scaled_param = p_grad_sgrad - with param.block.program._optimized_guard( - [param, grad]), framework.name_scope('weight decay'): - updated_param = paddle.fluid.layers.elementwise_sub( - x=param, y=scaled_param) - paddle.fluid.layers.assign(input=updated_param, output=param) - - optimize_ops = self._apply_optimize( - loss=loss, - params_grads=params_grads, - startup_program=startup_program) - return optimize_ops, params_grads - - @framework.dygraph_only - def step(self): - parameter_list = self._parameter_list - self._dtype = None - params_grads = [] - for param in self._parameter_list: - if not param.trainable: - continue - if param._grad_ivar() is not None: - grad_var = param._grad_ivar() - params_grads.append((param, grad_var)) - - scaled_params = self._scale_parameters(params_grads) - for p_grad_sgrad in scaled_params: - param, grad, scaled_param = p_grad_sgrad - with param.block.program._optimized_guard( - [param, grad]), framework.name_scope('weight decay'): - updated_param = paddle.fluid.layers.elementwise_sub( - x=param, y=scaled_param) - paddle.fluid.layers.assign(input=updated_param, output=param) - optimize_ops = self._apply_optimize( - loss=None, startup_program=None, params_grads=params_grads) - - def __str__(self): - return " ".join(["Weight Decay, params:", ",".join(self._params_name)]) - - -class AdamW(DecoupledWeightDecay, Adam): +class AdamW(Adam): """ The AdamW optimizer is implemented based on the AdamW Optimization in paper `DECOUPLED WEIGHT DECAY REGULARIZATION `_. @@ -145,8 +40,8 @@ class AdamW(DecoupledWeightDecay, Adam): Args: - learning_rate (float|LearningRateDecay, optional): The learning rate used to update ``Parameter``. - It can be a float value or a LearningRateDecay. The default value is 0.001. + learning_rate (float|_LRScheduler, optional): The learning rate used to update ``Parameter``. + It can be a float value or a _LRScheduler. The default value is 0.001. parameters (list, optional): List of ``Tensor`` names to update to minimize ``loss``. \ This parameter is required in dygraph mode. \ The default value is None in static mode, at this time all parameters will be updated. @@ -157,9 +52,9 @@ class AdamW(DecoupledWeightDecay, Adam): It should be a float number or a Tensor with shape [1] and data type as float32. The default value is 0.999. epsilon (float, optional): A small float value for numerical stability. - weight_decay (float|Tensor): The weight decay coefficient, it can be float or Tensor. The default value is 0.0. The default value is 1e-08. - apply_decay_param_fun (function|None): If it is not None, + weight_decay (float|Tensor, optional): The weight decay coefficient, it can be float or Tensor. The default value is 0.01. + apply_decay_param_fun (function|None, optional): If it is not None, only tensors that makes apply_decay_param_fun(Tensor)==True will be updated. It only works when we want to specify tensors. Default: None. @@ -208,26 +103,129 @@ class AdamW(DecoupledWeightDecay, Adam): def __init__(self, learning_rate=0.001, - parameters=None, beta1=0.9, beta2=0.999, epsilon=1e-8, - weight_decay=0.0, + parameters=None, + weight_decay=0.01, apply_decay_param_fun=None, grad_clip=None, name=None, lazy_mode=False): - args_dict = { - "learning_rate": learning_rate, - "parameters": parameters, - "beta1": beta1, - "beta2": beta2, - "epsilon": epsilon, - "grad_clip": grad_clip, - "name": name, - "lazy_mode": lazy_mode - } + assert learning_rate is not None + assert beta1 is not None + assert beta2 is not None + assert epsilon is not None + if not 0 <= beta1 < 1: + raise ValueError("Invaild value of beta1, expect beta1 in [0,1).") + if not 0 <= beta2 < 1: + raise ValueError("Invaild value of beta2, expect beta2 in [0,1).") + if not 0 <= epsilon: + raise ValueError("Invaild value of epsilon, expect epsilon >= 0.") + coeff = weight_decay + if not isinstance(coeff, float) and \ + not isinstance(coeff, framework.Variable): + raise TypeError("coeff should be float or Tensor.") + self._params_name = set() + self._apply_decay_param_fun = apply_decay_param_fun + self._coeff = coeff super(AdamW, self).__init__( - weight_decay, - apply_decay_param_fun=apply_decay_param_fun, - **args_dict) + learning_rate=learning_rate, + parameters=parameters, + beta1=beta1, + beta2=beta2, + epsilon=epsilon, + grad_clip=grad_clip, + name=name, + lazy_mode=lazy_mode) + + def _scale_parameters(self, params_and_grads): + """ + Adds weight decay ops. + scaled_parameter = parameter * coeff + + Args: + params_and_grads: A list of (parameters, gradients) pairs, + the parameters need to decay. + Raises: + Exception: The type of coeff and parameter is not consistent. + """ + + scaled_params = [] + for param, grad in params_and_grads: + # If no gradient then we don't need to do anything + if grad is None: + continue + if self._apply_decay_param_fun is not None \ + and not self._apply_decay_param_fun(param.name): + continue + + if isinstance(self._coeff, float): + assert param.dtype is not paddle.fluid.core.VarDesc.VarType.FP32, \ + "the type of coeff(float) and parameter(%s) is not consistent."%(self._coeff.dtype) + else: + assert self._coeff.dtype == param.dtype, \ + "the type of coeff(%s) and parameter(%s) is not consistent."%(self._coeff.dtype, param.dtype) + if isinstance(self._learning_rate, float): + learning_rate = self._learning_rate + else: + self._learning_rate() + with param.block.program._optimized_guard( + [param, grad]), framework.name_scope('weight decay'): + if param.name not in self._params_name: + scaled_params.append( + (param, grad, param * self._coeff * learning_rate)) + self._params_name.add(param.name) + param = param * self._coeff + return scaled_params + + def minimize(self, + loss, + startup_program=None, + parameters=None, + no_grad_set=None): + params_grads = self.backward( + loss=loss, + startup_program=startup_program, + parameters=parameters, + no_grad_set=no_grad_set) + scaled_params = self._scale_parameters(params_grads) + for p_grad_sgrad in scaled_params: + param, grad, scaled_param = p_grad_sgrad + with param.block.program._optimized_guard( + [param, grad]), framework.name_scope('weight decay'): + updated_param = paddle.fluid.layers.elementwise_sub( + x=param, y=scaled_param) + paddle.fluid.layers.assign(input=updated_param, output=param) + + optimize_ops = self._apply_optimize( + loss=loss, + params_grads=params_grads, + startup_program=startup_program) + return optimize_ops, params_grads + + @framework.dygraph_only + def step(self): + parameter_list = self._parameter_list + self._dtype = None + params_grads = [] + for param in self._parameter_list: + if not param.trainable: + continue + if param._grad_ivar() is not None: + grad_var = param._grad_ivar() + params_grads.append((param, grad_var)) + + scaled_params = self._scale_parameters(params_grads) + for p_grad_sgrad in scaled_params: + param, grad, scaled_param = p_grad_sgrad + with param.block.program._optimized_guard( + [param, grad]), framework.name_scope('weight decay'): + updated_param = paddle.fluid.layers.elementwise_sub( + x=param, y=scaled_param) + param.set_value(updated_param.numpy()) + optimize_ops = self._apply_optimize( + loss=None, startup_program=None, params_grads=params_grads) + + def __str__(self): + return " ".join(["Weight Decay, params:", ",".join(self._params_name)]) diff --git a/python/paddle/optimizer/optimizer.py b/python/paddle/optimizer/optimizer.py index 972e5a312ea7cf85694fddc027a556728652fcb0..1bd9a1f144ed4b5c69d76070eadc317e2063e25b 100644 --- a/python/paddle/optimizer/optimizer.py +++ b/python/paddle/optimizer/optimizer.py @@ -80,7 +80,6 @@ class Optimizer(object): .. code-block:: python #Take the subclass adam as an example - #Optimizer import paddle import numpy as np @@ -170,7 +169,7 @@ class Optimizer(object): import paddle paddle.disable_static() - emb = paddle.nn.Embedding([10, 10]) + emb = paddle.nn.Embedding(10, 10) adam = paddle.optimizer.Adam(0.001, parameters=emb.parameters()) state_dict = adam.state_dict() @@ -200,7 +199,7 @@ class Optimizer(object): import paddle paddle.disable_static() - emb = paddle.nn.Embedding([10, 10]) + emb = paddle.nn.Embedding(10, 10) state_dict = emb.state_dict() paddle.framework.save(state_dict, "paddle_dy") @@ -215,6 +214,8 @@ class Optimizer(object): adam.set_state_dict(opti_state_dict) ''' + if isinstance(self._learning_rate, _LRScheduler): + self._learning_rate.set_dict(state_dict["LR_Scheduler"]) if isinstance(self._learning_rate, _LRScheduler): self._learning_rate.set_state_dict(state_dict["LR_Scheduler"]) @@ -270,6 +271,7 @@ class Optimizer(object): main_prog = framework.default_main_program() main_prog.lr_sheduler = self._learning_rate main_prog.lr_var = lr_var + self._learning_rate_map[framework.default_main_program( )] = lr_var @@ -300,7 +302,7 @@ class Optimizer(object): this API cannot be invoked, because it will lead to conflict. Args: - value (float|Tensor): the value of learning rate + value (float): the value of learning rate Returns: None @@ -358,6 +360,7 @@ class Optimizer(object): Get current step learning rate. The return value is all the same When _LRScheduler is not used, otherwise return the current step learning rate. + Returns: float: The learning rate of the current step. @@ -368,7 +371,7 @@ class Optimizer(object): import paddle # example1: _LRScheduler is not used, return value is all the same paddle.disable_static() - emb = paddle.nn.Embedding([10, 10]) + emb = paddle.nn.Embedding(10, 10) adam = paddle.optimizer.Adam(0.001, parameters = emb.parameters()) lr = adam.get_lr() print(lr) # 0.001 @@ -655,7 +658,7 @@ class Optimizer(object): paddle.disable_static() value = np.arange(26).reshape(2, 13).astype("float32") a = paddle.to_tensor(value) - linear = paddle.nn.Linear(13, 5, dtype="float32") + linear = paddle.nn.Linear(13, 5) # This can be any optimizer supported by dygraph. adam = paddle.optimizer.Adam(learning_rate = 0.01, parameters = linear.parameters()) @@ -798,7 +801,7 @@ class Optimizer(object): paddle.disable_static() value = np.arange(26).reshape(2, 13).astype("float32") a = paddle.to_tensor(value) - linear = paddle.nn.Linear(13, 5, dtype="float32") + linear = paddle.nn.Linear(13, 5) # This can be any optimizer supported by dygraph. adam = paddle.optimizer.Adam(learning_rate = 0.01, parameters = linear.parameters()) @@ -836,36 +839,33 @@ class Optimizer(object): tuple: tuple (optimize_ops, params_grads), A list of operators appended by minimize and a list of (param, grad) tensor pairs, param is ``Parameter``, grad is the gradient value corresponding to the parameter. - The returned tuple can be passed to ``fetch_list`` in ``Executor.run()`` to + In static graph mode, the returned tuple can be passed to ``fetch_list`` in ``Executor.run()`` to indicate program pruning. If so, the program will be pruned by ``feed`` and ``fetch_list`` before run, see details in ``Executor``. Examples: .. code-block:: python - + import paddle - import paddle.fluid as fluid - - place = fluid.CPUPlace() - main = fluid.Program() - with fluid.program_guard(main): - x = fluid.data(name='x', shape=[None, 13], dtype='float32') - y = fluid.data(name='y', shape=[None, 1], dtype='float32') - y_predict = fluid.layers.fc(input=x, size=1, act=None) - cost = fluid.layers.square_error_cost(input=y_predict, label=y) - avg_cost = fluid.layers.mean(cost) - - adam_optimizer = paddle.optimizer.Adam(0.01) - adam_optimizer.minimize(avg_cost) - - fetch_list = [avg_cost] - train_reader = paddle.batch( - paddle.dataset.uci_housing.train(), batch_size=1) - feeder = fluid.DataFeeder(place=place, feed_list=[x, y]) - exe = fluid.Executor(place) - exe.run(fluid.default_startup_program()) - for data in train_reader(): - exe.run(main, feed=feeder.feed(data), fetch_list=fetch_list) + import numpy as np + + paddle.disable_static() + inp = np.random.uniform(-0.1, 0.1, [10, 10]).astype("float32") + linear = paddle.nn.Linear(10, 10) + inp = paddle.to_tensor(inp) + out = linear(inp) + loss = paddle.mean(out) + + beta1 = paddle.to_tensor([0.9], dtype="float32") + beta2 = paddle.to_tensor([0.99], dtype="float32") + + adam = paddle.optimizer.Adam(learning_rate=0.1, + parameters=linear.parameters(), + weight_decay=0.01) + out.backward() + adam.minimize(loss) + adam.clear_grad() + """ assert isinstance(loss, Variable), "The loss should be an Tensor." @@ -885,7 +885,7 @@ class Optimizer(object): @framework.dygraph_only def step(self): """ - Execute the optimizer once. + Execute the optimizer and update parameters once. Returns: None @@ -898,7 +898,7 @@ class Optimizer(object): paddle.disable_static() value = np.arange(26).reshape(2, 13).astype("float32") a = paddle.to_tensor(value) - linear = paddle.nn.Linear(13, 5, dtype="float32") + linear = paddle.nn.Linear(13, 5) # This can be any optimizer supported by dygraph. adam = paddle.optimizer.Adam(learning_rate = 0.01, parameters = linear.parameters()) diff --git a/python/paddle/optimizer/rmsprop.py b/python/paddle/optimizer/rmsprop.py index 0bc4c9bfd53dc15449f03d6de6c8942e977bf562..2609972d85ccdc2a867765431fefe21b9ba2de16 100644 --- a/python/paddle/optimizer/rmsprop.py +++ b/python/paddle/optimizer/rmsprop.py @@ -69,8 +69,8 @@ class RMSProp(Optimizer): Parameters: - learning_rate (float|LearningRateDecay): The learning rate used to update ``Parameter``. - It can be a float value or a LearningRateDecay. + learning_rate (float|_LRScheduler): The learning rate used to update ``Parameter``. + It can be a float value or a _LRScheduler. rho(float): rho is :math: `\\rho` in equation, default is 0.95. epsilon(float): :math: `\\epsilon` in equation is smoothing term to avoid division by zero, default is 1e-6. @@ -80,7 +80,7 @@ class RMSProp(Optimizer): the gradient; if False, by the uncentered second moment. Setting this to True may help with training, but is slightly more expensive in terms of computation and memory. Defaults to False. - parameters (list, optional): List of ``Tensor`` names to update to minimize ``loss``. \ + parameters (list, optional): List of ``Tensor`` to update to minimize ``loss``. \ This parameter is required in dygraph mode. \ The default value is None in static mode, at this time all parameters will be updated. weight_decay (float|WeightDecayRegularizer, optional): The strategy of regularization. \ @@ -147,6 +147,12 @@ class RMSProp(Optimizer): raise ValueError("epsilon is not set.") if momentum is None: raise ValueError("momentum is not set.") + if not 0.0 <= epsilon: + raise ValueError("Invalid value of epsilon, expect epsilon >= 0.") + if not 0.0 <= momentum: + raise ValueError("Invalid value of momentum, expect momentum >= 0.") + if not 0.0 <= rho: + raise ValueError("Invalid value of rho, expect rho >= 0.") super(RMSProp, self).__init__( learning_rate=learning_rate,