未验证 提交 be801d6c 编写于 作者: Q Qiao Longfei 提交者: GitHub

Add learning rate decay (#7892)

* add basic interface for learning rate decay
* add exponential_decay
* add natural_exp_decay
* add inverse_time_decay
上级 80eff266
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/elementwise_pow_op.h"
#include "paddle/operators/elementwise_op.h"
namespace paddle {
namespace operators {
class ElementwisePowOpMaker : public ElementwiseOpMaker {
public:
ElementwisePowOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: ElementwiseOpMaker(proto, op_checker) {
SetComment("Pow", "Out = X ^ Y");
AddComment(comment_);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_WITHOUT_GRADIENT(elementwise_pow, ops::ElementwiseOp,
ops::ElementwisePowOpMaker);
REGISTER_OP_CPU_KERNEL(
elementwise_pow,
ops::ElementwisePowKernel<paddle::platform::CPUDeviceContext, float>,
ops::ElementwisePowKernel<paddle::platform::CPUDeviceContext, double>);
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#define EIGEN_USE_GPU
#include "paddle/operators/elementwise_pow_op.h"
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
elementwise_pow,
ops::ElementwisePowKernel<paddle::platform::CUDADeviceContext, float>,
ops::ElementwisePowKernel<paddle::platform::CUDADeviceContext, double>);
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <cmath>
#include "paddle/operators/elementwise_op_function.h"
namespace paddle {
namespace operators {
template <typename T>
struct PowFunctor {
inline HOSTDEVICE T operator()(T a, T b) const { return std::pow(a, b); }
};
template <typename DeviceContext, typename T>
class ElementwisePowKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
ElementwiseComputeEx<PowFunctor<T>, DeviceContext, T>(ctx);
}
};
} // namespace operators
} // namespace paddle
...@@ -26,6 +26,7 @@ import initializer ...@@ -26,6 +26,7 @@ import initializer
import layers import layers
import nets import nets
import optimizer import optimizer
import learning_rate_decay
import backward import backward
import regularizer import regularizer
from param_attr import ParamAttr from param_attr import ParamAttr
...@@ -44,6 +45,7 @@ __all__ = framework.__all__ + executor.__all__ + [ ...@@ -44,6 +45,7 @@ __all__ = framework.__all__ + executor.__all__ + [
'layers', 'layers',
'nets', 'nets',
'optimizer', 'optimizer',
'learning_rate_decay',
'backward', 'backward',
'regularizer', 'regularizer',
'LoDTensor', 'LoDTensor',
......
...@@ -145,7 +145,9 @@ def monkey_patch_variable(): ...@@ -145,7 +145,9 @@ def monkey_patch_variable():
# a*b == b*a. Do not need to reverse explicitly # a*b == b*a. Do not need to reverse explicitly
("__rmul__", "elementwise_mul", False), ("__rmul__", "elementwise_mul", False),
("__div__", "elementwise_div", False), ("__div__", "elementwise_div", False),
("__rdiv__", "elementwise_div", True)): ("__rdiv__", "elementwise_div", True),
("__pow__", "elementwise_pow", False),
("__rpow__", "elementwise_pow", True)):
setattr(Variable, method_name, setattr(Variable, method_name,
_elemwise_method_creator_(method_name, op_type, reverse)) _elemwise_method_creator_(method_name, op_type, reverse))
......
...@@ -56,6 +56,7 @@ __all__ = [ ...@@ -56,6 +56,7 @@ __all__ = [
'elementwise_mul', 'elementwise_mul',
'elementwise_max', 'elementwise_max',
'elementwise_min', 'elementwise_min',
'elementwise_pow',
'clip', 'clip',
'clip_by_norm', 'clip_by_norm',
'sequence_softmax', 'sequence_softmax',
......
...@@ -16,12 +16,14 @@ from ..layer_helper import LayerHelper ...@@ -16,12 +16,14 @@ from ..layer_helper import LayerHelper
from ..param_attr import ParamAttr from ..param_attr import ParamAttr
from ..framework import convert_np_dtype_to_dtype_ from ..framework import convert_np_dtype_to_dtype_
from ..framework import Variable from ..framework import Variable
from ..initializer import Constant
from ..core import DataType from ..core import DataType
import numpy import numpy
__all__ = [ __all__ = [
'create_tensor', 'create_tensor',
'create_parameter', 'create_parameter',
'create_global_var',
'cast', 'cast',
'concat', 'concat',
'sums', 'sums',
...@@ -58,13 +60,22 @@ def create_parameter(shape, ...@@ -58,13 +60,22 @@ def create_parameter(shape,
Returns: Returns:
Parameter: the created parameter Parameter: the created parameter
""" """
helper = LayerHelper("create_parameter") helper = LayerHelper("create_parameter", **locals())
if attr is None: if attr is None:
attr = ParamAttr() attr = ParamAttr()
return helper.create_parameter(attr, shape, dtype, is_bias, return helper.create_parameter(attr, shape, dtype, is_bias,
default_initializer) default_initializer)
def create_global_var(shape, value, dtype, persistable=False, name=None):
helper = LayerHelper("global_var", **locals())
var = helper.create_global_variable(
dtype=dtype, shape=shape, persistable=persistable, name=name)
helper.set_variable_initializer(
var, initializer=Constant(value=float(value)))
return var
def cast(x, dtype): def cast(x, dtype):
""" """
This function takes in the input with input_dtype This function takes in the input with input_dtype
......
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import layers
from framework import Variable
__all__ = ['exponential_decay', 'natural_exp_decay', 'inverse_time_decay']
"""
When training a model, it's often useful to decay the
learning rate during training process, this is called
learning_rate_decay. There are many strategies to do
this, this module will provide some classical method.
User can also implement their own learning_rate_decay
strategy according to this module.
"""
def exponential_decay(learning_rate,
global_step,
decay_steps,
decay_rate,
staircase=False):
"""Applies exponential decay to the learning rate.
```python
decayed_learning_rate = learning_rate *
decay_rate ^ (global_step / decay_steps)
```
Args:
learning_rate: A scalar float32 value or a Variable. This
will be the initial learning rate during training
global_step: A Variable that record the training step.
decay_steps: A Python `int32` number.
decay_rate: A Python `float` number.
staircase: Boolean. If set true, decay the learning rate every decay_steps.
Returns:
The decayed learning rate
"""
if not isinstance(global_step, Variable):
raise ValueError("global_step is required for exponential_decay.")
# update learning_rate
div_res = global_step / decay_steps
if staircase:
div_res = layers.floor(x=div_res)
return learning_rate * (decay_rate**div_res)
def natural_exp_decay(learning_rate,
global_step,
decay_steps,
decay_rate,
staircase=False):
"""Applies natural exponential decay to the initial learning rate.
```python
if not staircase:
decayed_learning_rate = learning_rate * exp(- decay_rate * (global_step / decay_steps))
else:
decayed_learning_rate = learning_rate * exp(- decay_rate * (global_step / decay_steps))
```
Args:
learning_rate: A scalar float32 value or a Variable. This
will be the initial learning rate during training
global_step: A Variable that record the training step.
decay_steps: A Python `int32` number.
decay_rate: A Python `float` number.
staircase: Boolean. If set true, decay the learning rate every decay_steps.
Returns:
The decayed learning rate
"""
if not isinstance(global_step, Variable):
raise ValueError("global_step is required for natural_exp_decay.")
div_res = global_step / decay_steps
if staircase:
div_res = layers.floor(x=div_res)
return learning_rate * layers.exp(x=(-1 * decay_rate * div_res))
def inverse_time_decay(learning_rate,
global_step,
decay_steps,
decay_rate,
staircase=False):
"""Applies inverse time decay to the initial learning rate.
```python
if staircase:
decayed_learning_rate = learning_rate / (1 + decay_rate * floor(global_step / decay_step))
else
decayed_learning_rate = learning_rate / (1 + decay_rate * global_step / decay_step)
```
Args:
learning_rate: A scalar float32 value or a Variable. This
will be the initial learning rate during training
global_step: A Variable that record the training step.
decay_steps: A Python `int32` number.
decay_rate: A Python `float` number.
staircase: Boolean. If set true, decay the learning rate every decay_steps.
Returns:
The decayed learning rate
"""
if not isinstance(global_step, Variable):
raise ValueError("global_step is required for inverse_time_decay.")
div_res = global_step / decay_steps
if staircase:
div_res = layers.floor(x=div_res)
return learning_rate / (1 + decay_rate * div_res)
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
from collections import defaultdict from collections import defaultdict
import framework import framework
import layers
from backward import append_backward from backward import append_backward
from framework import unique_name, program_guard from framework import unique_name, program_guard
from initializer import Constant from initializer import Constant
...@@ -33,9 +34,11 @@ class Optimizer(object): ...@@ -33,9 +34,11 @@ class Optimizer(object):
but need to use one of it's implementation. but need to use one of it's implementation.
""" """
def __init__(self, global_step=None, regularization=None): def __init__(self, learning_rate, global_step=None, regularization=None):
assert learning_rate is not None
self._global_step = global_step self._global_step = global_step
self.regularization = regularization self.regularization = regularization
self._global_learning_rate = learning_rate
# Dictionary of accumulators. Some optimizer subclasses need to # Dictionary of accumulators. Some optimizer subclasses need to
# allocate and manage extra variables associated with the parameters # allocate and manage extra variables associated with the parameters
# to train. These variables are called accumulators. # to train. These variables are called accumulators.
...@@ -43,6 +46,28 @@ class Optimizer(object): ...@@ -43,6 +46,28 @@ class Optimizer(object):
self._accumulators = defaultdict(lambda: dict()) self._accumulators = defaultdict(lambda: dict())
self.helper = None self.helper = None
def _create_global_learning_rate(self):
if isinstance(self._global_learning_rate, float):
self._global_learning_rate = layers.create_global_var(
name=unique_name("learning_rate"),
shape=[1],
value=float(self._global_learning_rate),
dtype='float32',
persistable=True)
if not isinstance(self._global_learning_rate, framework.Variable):
raise ValueError("learning rate should be a Variable, "
"actual type is %s",
type(self._global_learning_rate))
@property
def global_learning_rate(self):
"""
get global decayed learning rate
:return:
"""
return self._global_learning_rate
def _append_optimize_op(self, block, param_and_grad): def _append_optimize_op(self, block, param_and_grad):
""" append optimize operator to block and return all the added optimize_op """ append optimize operator to block and return all the added optimize_op
""" """
...@@ -52,17 +77,7 @@ class Optimizer(object): ...@@ -52,17 +77,7 @@ class Optimizer(object):
# create learning rate variable for every parameter # create learning rate variable for every parameter
param = param_and_grad[0] param = param_and_grad[0]
param_lr = param.optimize_attr['learning_rate'] param_lr = param.optimize_attr['learning_rate']
param_lr_shape = [1] return self._global_learning_rate * param_lr
param_lr_var = self.helper.create_global_variable(
name=unique_name("learning_rate"),
dtype='float32',
shape=param_lr_shape,
lod_level=1,
persistable=True)
param_lr = param_lr * self._learning_rate
self.helper.set_variable_initializer(
var=param_lr_var, initializer=Constant(param_lr))
return param_lr_var
def _create_accumulators(self, block, parameters): def _create_accumulators(self, block, parameters):
"""Create all accumulators needed by the parameters """Create all accumulators needed by the parameters
...@@ -163,7 +178,7 @@ class Optimizer(object): ...@@ -163,7 +178,7 @@ class Optimizer(object):
optimization. This will include parameter update ops, global step optimization. This will include parameter update ops, global step
update ops and any other custom ops required by subclasses to manage update ops and any other custom ops required by subclasses to manage
their internal state. their internal state.
:param startup_program: :param startup_program:
""" """
# This is a default implementation of create_optimization_pass that # This is a default implementation of create_optimization_pass that
# can be shared by most optimizers. This implementation assumes that # can be shared by most optimizers. This implementation assumes that
...@@ -178,6 +193,7 @@ class Optimizer(object): ...@@ -178,6 +193,7 @@ class Optimizer(object):
self.helper = LayerHelper(self.__class__.__name__) self.helper = LayerHelper(self.__class__.__name__)
self._create_accumulators(loss.block, self._create_accumulators(loss.block,
[p[0] for p in parameters_and_grads]) [p[0] for p in parameters_and_grads])
self._create_global_learning_rate()
optimize_ops = [] optimize_ops = []
for param_and_grad in parameters_and_grads: for param_and_grad in parameters_and_grads:
...@@ -231,9 +247,9 @@ class SGDOptimizer(Optimizer): ...@@ -231,9 +247,9 @@ class SGDOptimizer(Optimizer):
def __init__(self, learning_rate, **kwargs): def __init__(self, learning_rate, **kwargs):
assert learning_rate is not None assert learning_rate is not None
super(SGDOptimizer, self).__init__(**kwargs) super(SGDOptimizer, self).__init__(
learning_rate=learning_rate, **kwargs)
self.type = "sgd" self.type = "sgd"
self._learning_rate = learning_rate
def _append_optimize_op(self, block, param_and_grad): def _append_optimize_op(self, block, param_and_grad):
assert isinstance(block, framework.Block) assert isinstance(block, framework.Block)
...@@ -259,9 +275,9 @@ class MomentumOptimizer(Optimizer): ...@@ -259,9 +275,9 @@ class MomentumOptimizer(Optimizer):
def __init__(self, learning_rate, momentum, use_nesterov=False, **kwargs): def __init__(self, learning_rate, momentum, use_nesterov=False, **kwargs):
assert learning_rate is not None assert learning_rate is not None
assert momentum is not None assert momentum is not None
super(MomentumOptimizer, self).__init__(**kwargs) super(MomentumOptimizer, self).__init__(
learning_rate=learning_rate, **kwargs)
self.type = "momentum" self.type = "momentum"
self._learning_rate = learning_rate
self._momentum = momentum self._momentum = momentum
self._use_nesterov = bool(use_nesterov) self._use_nesterov = bool(use_nesterov)
...@@ -303,9 +319,9 @@ class AdagradOptimizer(Optimizer): ...@@ -303,9 +319,9 @@ class AdagradOptimizer(Optimizer):
def __init__(self, learning_rate, epsilon=1.0e-6, **kwargs): def __init__(self, learning_rate, epsilon=1.0e-6, **kwargs):
assert learning_rate is not None assert learning_rate is not None
assert epsilon is not None assert epsilon is not None
super(AdagradOptimizer, self).__init__(**kwargs) super(AdagradOptimizer, self).__init__(
learning_rate=learning_rate, **kwargs)
self.type = "adagrad" self.type = "adagrad"
self._learning_rate = learning_rate
self._epsilon = epsilon self._epsilon = epsilon
def _create_accumulators(self, block, parameters): def _create_accumulators(self, block, parameters):
...@@ -352,9 +368,9 @@ class AdamOptimizer(Optimizer): ...@@ -352,9 +368,9 @@ class AdamOptimizer(Optimizer):
assert beta1 is not None assert beta1 is not None
assert beta2 is not None assert beta2 is not None
assert epsilon is not None assert epsilon is not None
super(AdamOptimizer, self).__init__(**kwargs) super(AdamOptimizer, self).__init__(
learning_rate=learning_rate, **kwargs)
self.type = "adam" self.type = "adam"
self._learning_rate = learning_rate
self._beta1 = beta1 self._beta1 = beta1
self._beta2 = beta2 self._beta2 = beta2
self._epsilon = epsilon self._epsilon = epsilon
...@@ -457,9 +473,9 @@ class AdamaxOptimizer(Optimizer): ...@@ -457,9 +473,9 @@ class AdamaxOptimizer(Optimizer):
assert beta1 is not None assert beta1 is not None
assert beta2 is not None assert beta2 is not None
assert epsilon is not None assert epsilon is not None
super(AdamaxOptimizer, self).__init__(**kwargs) super(AdamaxOptimizer, self).__init__(
learning_rate=learning_rate, **kwargs)
self.type = "adamax" self.type = "adamax"
self._learning_rate = learning_rate
self._beta1 = beta1 self._beta1 = beta1
self._beta2 = beta2 self._beta2 = beta2
self._epsilon = epsilon self._epsilon = epsilon
...@@ -535,9 +551,9 @@ class DecayedAdagradOptimizer(Optimizer): ...@@ -535,9 +551,9 @@ class DecayedAdagradOptimizer(Optimizer):
assert decay is not None assert decay is not None
assert epsilon is not None assert epsilon is not None
super(DecayedAdagradOptimizer, self).__init__(**kwargs) super(DecayedAdagradOptimizer, self).__init__(
learning_rate=learning_rate, **kwargs)
self.type = "decayed_adagrad" self.type = "decayed_adagrad"
self._learning_rate = learning_rate
self._decay = decay self._decay = decay
self._epsilon = epsilon self._epsilon = epsilon
......
...@@ -175,7 +175,7 @@ def main(): ...@@ -175,7 +175,7 @@ def main():
paddle.reader.shuffle( paddle.reader.shuffle(
paddle.dataset.conll05.test(), buf_size=8192), paddle.dataset.conll05.test(), buf_size=8192),
batch_size=BATCH_SIZE) batch_size=BATCH_SIZE)
#place = fluid.CPUPlace() # place = fluid.CPUPlace()
place = fluid.CUDAPlace(0) place = fluid.CUDAPlace(0)
feeder = fluid.DataFeeder( feeder = fluid.DataFeeder(
feed_list=[ feed_list=[
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
from op_test import OpTest
class TestElementwisePowOp(OpTest):
def setUp(self):
self.op_type = "elementwise_pow"
self.inputs = {
'X': np.random.uniform(0.1, 1, [13, 17]).astype("float32"),
'Y': np.random.uniform(0.1, 1, [13, 17]).astype("float32")
}
self.outputs = {'Out': np.power(self.inputs['X'], self.inputs['Y'])}
def test_check_output(self):
self.check_output()
class TestElementwisePowOp_scalar(TestElementwisePowOp):
def setUp(self):
self.op_type = "elementwise_pow"
self.inputs = {
'X': np.random.rand(2, 3, 4).astype('float32'),
'Y': np.random.rand(1).astype('float32')
}
self.outputs = {'Out': np.power(self.inputs['X'], self.inputs['Y'])}
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import math
import paddle.v2.fluid.framework as framework
import paddle.v2.fluid as fluid
import paddle.v2.fluid.layers as layers
import paddle.v2.fluid.learning_rate_decay as lr_decay
def exponential_decay(learning_rate,
global_step,
decay_steps,
decay_rate,
staircase=False):
exponent = float(global_step) / float(decay_steps)
if staircase:
exponent = math.floor(exponent)
return learning_rate * decay_rate**exponent
def natural_exp_decay(learning_rate,
global_step,
decay_steps,
decay_rate,
staircase=False):
exponent = float(global_step) / float(decay_steps)
if staircase:
exponent = math.floor(exponent)
return learning_rate * math.exp(-1 * decay_rate * exponent)
def inverse_time_decay(learning_rate,
global_step,
decay_steps,
decay_rate,
staircase=False):
temp = float(global_step) / float(decay_steps)
if staircase:
temp = math.floor(temp)
return learning_rate / (1 + decay_rate * temp)
class TestLearningRateDecay(unittest.TestCase):
def check_decay(self, python_decay_fn, fluid_decay_fn, staircase):
init_lr = 1.0
decay_steps = 5
decay_rate = 0.5
global_step = layers.create_global_var(
shape=[1], value=0.0, dtype='float32', persistable=True)
decayed_lr = fluid_decay_fn(
learning_rate=init_lr,
global_step=global_step,
decay_steps=decay_steps,
decay_rate=decay_rate,
staircase=staircase)
layers.increment(global_step, 1.0)
place = fluid.CPUPlace()
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
for step in range(10):
step_val, lr_val = exe.run(fluid.default_main_program(),
feed=[],
fetch_list=[global_step, decayed_lr])
python_decayed_lr = python_decay_fn(
learning_rate=init_lr,
global_step=step,
decay_steps=decay_steps,
decay_rate=decay_rate,
staircase=staircase)
self.assertAlmostEqual(python_decayed_lr, lr_val[0])
def test_decay(self):
decay_fns = [
(exponential_decay, lr_decay.exponential_decay, True),
(exponential_decay, lr_decay.exponential_decay, False),
(natural_exp_decay, lr_decay.natural_exp_decay, True),
(natural_exp_decay, lr_decay.natural_exp_decay, False),
(inverse_time_decay, lr_decay.inverse_time_decay, True),
(inverse_time_decay, lr_decay.inverse_time_decay, False),
]
for py_decay_fn, fluid_decay_fn, staircase in decay_fns:
print("decay_fn=" + str(py_decay_fn) + " staircase=" + str(
staircase))
main_program = framework.Program()
startup_program = framework.Program()
with framework.program_guard(main_program, startup_program):
self.check_decay(py_decay_fn, fluid_decay_fn, staircase)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册