From be801d6c056c3435922e345d9d2ea105120b812d Mon Sep 17 00:00:00 2001 From: Qiao Longfei Date: Wed, 31 Jan 2018 10:37:09 +0800 Subject: [PATCH] Add learning rate decay (#7892) * add basic interface for learning rate decay * add exponential_decay * add natural_exp_decay * add inverse_time_decay --- paddle/operators/elementwise_pow_op.cc | 37 ++++++ paddle/operators/elementwise_pow_op.cu | 20 +++ paddle/operators/elementwise_pow_op.h | 37 ++++++ python/paddle/v2/fluid/__init__.py | 2 + .../paddle/v2/fluid/layers/math_op_patch.py | 4 +- python/paddle/v2/fluid/layers/ops.py | 1 + python/paddle/v2/fluid/layers/tensor.py | 13 +- python/paddle/v2/fluid/learning_rate_decay.py | 125 ++++++++++++++++++ python/paddle/v2/fluid/optimizer.py | 66 +++++---- .../tests/book/test_label_semantic_roles.py | 2 +- .../v2/fluid/tests/test_elementwise_pow_op.py | 43 ++++++ .../fluid/tests/test_learning_rate_decay.py | 110 +++++++++++++++ 12 files changed, 432 insertions(+), 28 deletions(-) create mode 100644 paddle/operators/elementwise_pow_op.cc create mode 100644 paddle/operators/elementwise_pow_op.cu create mode 100644 paddle/operators/elementwise_pow_op.h create mode 100644 python/paddle/v2/fluid/learning_rate_decay.py create mode 100644 python/paddle/v2/fluid/tests/test_elementwise_pow_op.py create mode 100644 python/paddle/v2/fluid/tests/test_learning_rate_decay.py diff --git a/paddle/operators/elementwise_pow_op.cc b/paddle/operators/elementwise_pow_op.cc new file mode 100644 index 00000000000..5293cc7dd34 --- /dev/null +++ b/paddle/operators/elementwise_pow_op.cc @@ -0,0 +1,37 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/elementwise_pow_op.h" +#include "paddle/operators/elementwise_op.h" + +namespace paddle { +namespace operators { +class ElementwisePowOpMaker : public ElementwiseOpMaker { + public: + ElementwisePowOpMaker(OpProto* proto, OpAttrChecker* op_checker) + : ElementwiseOpMaker(proto, op_checker) { + SetComment("Pow", "Out = X ^ Y"); + AddComment(comment_); + } +}; +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_WITHOUT_GRADIENT(elementwise_pow, ops::ElementwiseOp, + ops::ElementwisePowOpMaker); +REGISTER_OP_CPU_KERNEL( + elementwise_pow, + ops::ElementwisePowKernel, + ops::ElementwisePowKernel); diff --git a/paddle/operators/elementwise_pow_op.cu b/paddle/operators/elementwise_pow_op.cu new file mode 100644 index 00000000000..643c978e635 --- /dev/null +++ b/paddle/operators/elementwise_pow_op.cu @@ -0,0 +1,20 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#define EIGEN_USE_GPU +#include "paddle/operators/elementwise_pow_op.h" + +namespace ops = paddle::operators; + +REGISTER_OP_CUDA_KERNEL( + elementwise_pow, + ops::ElementwisePowKernel, + ops::ElementwisePowKernel); diff --git a/paddle/operators/elementwise_pow_op.h b/paddle/operators/elementwise_pow_op.h new file mode 100644 index 00000000000..6019e709e0d --- /dev/null +++ b/paddle/operators/elementwise_pow_op.h @@ -0,0 +1,37 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include "paddle/operators/elementwise_op_function.h" + +namespace paddle { +namespace operators { + +template +struct PowFunctor { + inline HOSTDEVICE T operator()(T a, T b) const { return std::pow(a, b); } +}; + +template +class ElementwisePowKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + ElementwiseComputeEx, DeviceContext, T>(ctx); + } +}; + +} // namespace operators +} // namespace paddle diff --git a/python/paddle/v2/fluid/__init__.py b/python/paddle/v2/fluid/__init__.py index a542e3dbabf..18c8343d098 100644 --- a/python/paddle/v2/fluid/__init__.py +++ b/python/paddle/v2/fluid/__init__.py @@ -26,6 +26,7 @@ import initializer import layers import nets import optimizer +import learning_rate_decay import backward import regularizer from param_attr import ParamAttr @@ -44,6 +45,7 @@ __all__ = framework.__all__ + executor.__all__ + [ 'layers', 'nets', 'optimizer', + 'learning_rate_decay', 'backward', 'regularizer', 'LoDTensor', diff --git a/python/paddle/v2/fluid/layers/math_op_patch.py b/python/paddle/v2/fluid/layers/math_op_patch.py index f359e70126f..79a130a3eb1 100644 --- a/python/paddle/v2/fluid/layers/math_op_patch.py +++ b/python/paddle/v2/fluid/layers/math_op_patch.py @@ -145,7 +145,9 @@ def monkey_patch_variable(): # a*b == b*a. Do not need to reverse explicitly ("__rmul__", "elementwise_mul", False), ("__div__", "elementwise_div", False), - ("__rdiv__", "elementwise_div", True)): + ("__rdiv__", "elementwise_div", True), + ("__pow__", "elementwise_pow", False), + ("__rpow__", "elementwise_pow", True)): setattr(Variable, method_name, _elemwise_method_creator_(method_name, op_type, reverse)) diff --git a/python/paddle/v2/fluid/layers/ops.py b/python/paddle/v2/fluid/layers/ops.py index 022a94cad44..ee3172c7b8d 100644 --- a/python/paddle/v2/fluid/layers/ops.py +++ b/python/paddle/v2/fluid/layers/ops.py @@ -56,6 +56,7 @@ __all__ = [ 'elementwise_mul', 'elementwise_max', 'elementwise_min', + 'elementwise_pow', 'clip', 'clip_by_norm', 'sequence_softmax', diff --git a/python/paddle/v2/fluid/layers/tensor.py b/python/paddle/v2/fluid/layers/tensor.py index 6e7d09459c0..c435c5206d1 100644 --- a/python/paddle/v2/fluid/layers/tensor.py +++ b/python/paddle/v2/fluid/layers/tensor.py @@ -16,12 +16,14 @@ from ..layer_helper import LayerHelper from ..param_attr import ParamAttr from ..framework import convert_np_dtype_to_dtype_ from ..framework import Variable +from ..initializer import Constant from ..core import DataType import numpy __all__ = [ 'create_tensor', 'create_parameter', + 'create_global_var', 'cast', 'concat', 'sums', @@ -58,13 +60,22 @@ def create_parameter(shape, Returns: Parameter: the created parameter """ - helper = LayerHelper("create_parameter") + helper = LayerHelper("create_parameter", **locals()) if attr is None: attr = ParamAttr() return helper.create_parameter(attr, shape, dtype, is_bias, default_initializer) +def create_global_var(shape, value, dtype, persistable=False, name=None): + helper = LayerHelper("global_var", **locals()) + var = helper.create_global_variable( + dtype=dtype, shape=shape, persistable=persistable, name=name) + helper.set_variable_initializer( + var, initializer=Constant(value=float(value))) + return var + + def cast(x, dtype): """ This function takes in the input with input_dtype diff --git a/python/paddle/v2/fluid/learning_rate_decay.py b/python/paddle/v2/fluid/learning_rate_decay.py new file mode 100644 index 00000000000..96b3e9a0d73 --- /dev/null +++ b/python/paddle/v2/fluid/learning_rate_decay.py @@ -0,0 +1,125 @@ +# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import layers +from framework import Variable + +__all__ = ['exponential_decay', 'natural_exp_decay', 'inverse_time_decay'] +""" +When training a model, it's often useful to decay the +learning rate during training process, this is called +learning_rate_decay. There are many strategies to do +this, this module will provide some classical method. +User can also implement their own learning_rate_decay +strategy according to this module. +""" + + +def exponential_decay(learning_rate, + global_step, + decay_steps, + decay_rate, + staircase=False): + """Applies exponential decay to the learning rate. + + ```python + decayed_learning_rate = learning_rate * + decay_rate ^ (global_step / decay_steps) + ``` + Args: + learning_rate: A scalar float32 value or a Variable. This + will be the initial learning rate during training + global_step: A Variable that record the training step. + decay_steps: A Python `int32` number. + decay_rate: A Python `float` number. + staircase: Boolean. If set true, decay the learning rate every decay_steps. + + Returns: + The decayed learning rate + """ + if not isinstance(global_step, Variable): + raise ValueError("global_step is required for exponential_decay.") + + # update learning_rate + div_res = global_step / decay_steps + if staircase: + div_res = layers.floor(x=div_res) + return learning_rate * (decay_rate**div_res) + + +def natural_exp_decay(learning_rate, + global_step, + decay_steps, + decay_rate, + staircase=False): + """Applies natural exponential decay to the initial learning rate. + + ```python + if not staircase: + decayed_learning_rate = learning_rate * exp(- decay_rate * (global_step / decay_steps)) + else: + decayed_learning_rate = learning_rate * exp(- decay_rate * (global_step / decay_steps)) + ``` + Args: + learning_rate: A scalar float32 value or a Variable. This + will be the initial learning rate during training + global_step: A Variable that record the training step. + decay_steps: A Python `int32` number. + decay_rate: A Python `float` number. + staircase: Boolean. If set true, decay the learning rate every decay_steps. + + Returns: + The decayed learning rate + """ + if not isinstance(global_step, Variable): + raise ValueError("global_step is required for natural_exp_decay.") + + div_res = global_step / decay_steps + if staircase: + div_res = layers.floor(x=div_res) + return learning_rate * layers.exp(x=(-1 * decay_rate * div_res)) + + +def inverse_time_decay(learning_rate, + global_step, + decay_steps, + decay_rate, + staircase=False): + """Applies inverse time decay to the initial learning rate. + + ```python + if staircase: + decayed_learning_rate = learning_rate / (1 + decay_rate * floor(global_step / decay_step)) + else + decayed_learning_rate = learning_rate / (1 + decay_rate * global_step / decay_step) + ``` + Args: + learning_rate: A scalar float32 value or a Variable. This + will be the initial learning rate during training + global_step: A Variable that record the training step. + decay_steps: A Python `int32` number. + decay_rate: A Python `float` number. + staircase: Boolean. If set true, decay the learning rate every decay_steps. + + Returns: + The decayed learning rate + """ + if not isinstance(global_step, Variable): + raise ValueError("global_step is required for inverse_time_decay.") + + div_res = global_step / decay_steps + if staircase: + div_res = layers.floor(x=div_res) + + return learning_rate / (1 + decay_rate * div_res) diff --git a/python/paddle/v2/fluid/optimizer.py b/python/paddle/v2/fluid/optimizer.py index 0c3533b8921..7844a4e2df1 100644 --- a/python/paddle/v2/fluid/optimizer.py +++ b/python/paddle/v2/fluid/optimizer.py @@ -15,6 +15,7 @@ from collections import defaultdict import framework +import layers from backward import append_backward from framework import unique_name, program_guard from initializer import Constant @@ -33,9 +34,11 @@ class Optimizer(object): but need to use one of it's implementation. """ - def __init__(self, global_step=None, regularization=None): + def __init__(self, learning_rate, global_step=None, regularization=None): + assert learning_rate is not None self._global_step = global_step self.regularization = regularization + self._global_learning_rate = learning_rate # Dictionary of accumulators. Some optimizer subclasses need to # allocate and manage extra variables associated with the parameters # to train. These variables are called accumulators. @@ -43,6 +46,28 @@ class Optimizer(object): self._accumulators = defaultdict(lambda: dict()) self.helper = None + def _create_global_learning_rate(self): + if isinstance(self._global_learning_rate, float): + self._global_learning_rate = layers.create_global_var( + name=unique_name("learning_rate"), + shape=[1], + value=float(self._global_learning_rate), + dtype='float32', + persistable=True) + + if not isinstance(self._global_learning_rate, framework.Variable): + raise ValueError("learning rate should be a Variable, " + "actual type is %s", + type(self._global_learning_rate)) + + @property + def global_learning_rate(self): + """ + get global decayed learning rate + :return: + """ + return self._global_learning_rate + def _append_optimize_op(self, block, param_and_grad): """ append optimize operator to block and return all the added optimize_op """ @@ -52,17 +77,7 @@ class Optimizer(object): # create learning rate variable for every parameter param = param_and_grad[0] param_lr = param.optimize_attr['learning_rate'] - param_lr_shape = [1] - param_lr_var = self.helper.create_global_variable( - name=unique_name("learning_rate"), - dtype='float32', - shape=param_lr_shape, - lod_level=1, - persistable=True) - param_lr = param_lr * self._learning_rate - self.helper.set_variable_initializer( - var=param_lr_var, initializer=Constant(param_lr)) - return param_lr_var + return self._global_learning_rate * param_lr def _create_accumulators(self, block, parameters): """Create all accumulators needed by the parameters @@ -163,7 +178,7 @@ class Optimizer(object): optimization. This will include parameter update ops, global step update ops and any other custom ops required by subclasses to manage their internal state. - :param startup_program: + :param startup_program: """ # This is a default implementation of create_optimization_pass that # can be shared by most optimizers. This implementation assumes that @@ -178,6 +193,7 @@ class Optimizer(object): self.helper = LayerHelper(self.__class__.__name__) self._create_accumulators(loss.block, [p[0] for p in parameters_and_grads]) + self._create_global_learning_rate() optimize_ops = [] for param_and_grad in parameters_and_grads: @@ -231,9 +247,9 @@ class SGDOptimizer(Optimizer): def __init__(self, learning_rate, **kwargs): assert learning_rate is not None - super(SGDOptimizer, self).__init__(**kwargs) + super(SGDOptimizer, self).__init__( + learning_rate=learning_rate, **kwargs) self.type = "sgd" - self._learning_rate = learning_rate def _append_optimize_op(self, block, param_and_grad): assert isinstance(block, framework.Block) @@ -259,9 +275,9 @@ class MomentumOptimizer(Optimizer): def __init__(self, learning_rate, momentum, use_nesterov=False, **kwargs): assert learning_rate is not None assert momentum is not None - super(MomentumOptimizer, self).__init__(**kwargs) + super(MomentumOptimizer, self).__init__( + learning_rate=learning_rate, **kwargs) self.type = "momentum" - self._learning_rate = learning_rate self._momentum = momentum self._use_nesterov = bool(use_nesterov) @@ -303,9 +319,9 @@ class AdagradOptimizer(Optimizer): def __init__(self, learning_rate, epsilon=1.0e-6, **kwargs): assert learning_rate is not None assert epsilon is not None - super(AdagradOptimizer, self).__init__(**kwargs) + super(AdagradOptimizer, self).__init__( + learning_rate=learning_rate, **kwargs) self.type = "adagrad" - self._learning_rate = learning_rate self._epsilon = epsilon def _create_accumulators(self, block, parameters): @@ -352,9 +368,9 @@ class AdamOptimizer(Optimizer): assert beta1 is not None assert beta2 is not None assert epsilon is not None - super(AdamOptimizer, self).__init__(**kwargs) + super(AdamOptimizer, self).__init__( + learning_rate=learning_rate, **kwargs) self.type = "adam" - self._learning_rate = learning_rate self._beta1 = beta1 self._beta2 = beta2 self._epsilon = epsilon @@ -457,9 +473,9 @@ class AdamaxOptimizer(Optimizer): assert beta1 is not None assert beta2 is not None assert epsilon is not None - super(AdamaxOptimizer, self).__init__(**kwargs) + super(AdamaxOptimizer, self).__init__( + learning_rate=learning_rate, **kwargs) self.type = "adamax" - self._learning_rate = learning_rate self._beta1 = beta1 self._beta2 = beta2 self._epsilon = epsilon @@ -535,9 +551,9 @@ class DecayedAdagradOptimizer(Optimizer): assert decay is not None assert epsilon is not None - super(DecayedAdagradOptimizer, self).__init__(**kwargs) + super(DecayedAdagradOptimizer, self).__init__( + learning_rate=learning_rate, **kwargs) self.type = "decayed_adagrad" - self._learning_rate = learning_rate self._decay = decay self._epsilon = epsilon diff --git a/python/paddle/v2/fluid/tests/book/test_label_semantic_roles.py b/python/paddle/v2/fluid/tests/book/test_label_semantic_roles.py index 1a342bf1fbb..f85768de99a 100644 --- a/python/paddle/v2/fluid/tests/book/test_label_semantic_roles.py +++ b/python/paddle/v2/fluid/tests/book/test_label_semantic_roles.py @@ -175,7 +175,7 @@ def main(): paddle.reader.shuffle( paddle.dataset.conll05.test(), buf_size=8192), batch_size=BATCH_SIZE) - #place = fluid.CPUPlace() + # place = fluid.CPUPlace() place = fluid.CUDAPlace(0) feeder = fluid.DataFeeder( feed_list=[ diff --git a/python/paddle/v2/fluid/tests/test_elementwise_pow_op.py b/python/paddle/v2/fluid/tests/test_elementwise_pow_op.py new file mode 100644 index 00000000000..e31749df9ba --- /dev/null +++ b/python/paddle/v2/fluid/tests/test_elementwise_pow_op.py @@ -0,0 +1,43 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import unittest +import numpy as np +from op_test import OpTest + + +class TestElementwisePowOp(OpTest): + def setUp(self): + self.op_type = "elementwise_pow" + self.inputs = { + 'X': np.random.uniform(0.1, 1, [13, 17]).astype("float32"), + 'Y': np.random.uniform(0.1, 1, [13, 17]).astype("float32") + } + self.outputs = {'Out': np.power(self.inputs['X'], self.inputs['Y'])} + + def test_check_output(self): + self.check_output() + + +class TestElementwisePowOp_scalar(TestElementwisePowOp): + def setUp(self): + self.op_type = "elementwise_pow" + self.inputs = { + 'X': np.random.rand(2, 3, 4).astype('float32'), + 'Y': np.random.rand(1).astype('float32') + } + self.outputs = {'Out': np.power(self.inputs['X'], self.inputs['Y'])} + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/v2/fluid/tests/test_learning_rate_decay.py b/python/paddle/v2/fluid/tests/test_learning_rate_decay.py new file mode 100644 index 00000000000..dc348cf2d21 --- /dev/null +++ b/python/paddle/v2/fluid/tests/test_learning_rate_decay.py @@ -0,0 +1,110 @@ +# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import math +import paddle.v2.fluid.framework as framework +import paddle.v2.fluid as fluid +import paddle.v2.fluid.layers as layers +import paddle.v2.fluid.learning_rate_decay as lr_decay + + +def exponential_decay(learning_rate, + global_step, + decay_steps, + decay_rate, + staircase=False): + exponent = float(global_step) / float(decay_steps) + if staircase: + exponent = math.floor(exponent) + return learning_rate * decay_rate**exponent + + +def natural_exp_decay(learning_rate, + global_step, + decay_steps, + decay_rate, + staircase=False): + exponent = float(global_step) / float(decay_steps) + if staircase: + exponent = math.floor(exponent) + return learning_rate * math.exp(-1 * decay_rate * exponent) + + +def inverse_time_decay(learning_rate, + global_step, + decay_steps, + decay_rate, + staircase=False): + temp = float(global_step) / float(decay_steps) + if staircase: + temp = math.floor(temp) + return learning_rate / (1 + decay_rate * temp) + + +class TestLearningRateDecay(unittest.TestCase): + def check_decay(self, python_decay_fn, fluid_decay_fn, staircase): + init_lr = 1.0 + decay_steps = 5 + decay_rate = 0.5 + + global_step = layers.create_global_var( + shape=[1], value=0.0, dtype='float32', persistable=True) + + decayed_lr = fluid_decay_fn( + learning_rate=init_lr, + global_step=global_step, + decay_steps=decay_steps, + decay_rate=decay_rate, + staircase=staircase) + layers.increment(global_step, 1.0) + + place = fluid.CPUPlace() + exe = fluid.Executor(place) + + exe.run(fluid.default_startup_program()) + for step in range(10): + step_val, lr_val = exe.run(fluid.default_main_program(), + feed=[], + fetch_list=[global_step, decayed_lr]) + python_decayed_lr = python_decay_fn( + learning_rate=init_lr, + global_step=step, + decay_steps=decay_steps, + decay_rate=decay_rate, + staircase=staircase) + self.assertAlmostEqual(python_decayed_lr, lr_val[0]) + + def test_decay(self): + decay_fns = [ + (exponential_decay, lr_decay.exponential_decay, True), + (exponential_decay, lr_decay.exponential_decay, False), + (natural_exp_decay, lr_decay.natural_exp_decay, True), + (natural_exp_decay, lr_decay.natural_exp_decay, False), + (inverse_time_decay, lr_decay.inverse_time_decay, True), + (inverse_time_decay, lr_decay.inverse_time_decay, False), + ] + + for py_decay_fn, fluid_decay_fn, staircase in decay_fns: + print("decay_fn=" + str(py_decay_fn) + " staircase=" + str( + staircase)) + main_program = framework.Program() + startup_program = framework.Program() + with framework.program_guard(main_program, startup_program): + self.check_decay(py_decay_fn, fluid_decay_fn, staircase) + + +if __name__ == '__main__': + unittest.main() -- GitLab