test_learning_rate_scheduler.py 39.1 KB
Newer Older
Q
Qiao Longfei 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15 16
from __future__ import print_function

17
import copy
18
import math
19
import numpy as np
20
import unittest
21

22
import paddle
23
import paddle.fluid as fluid
24
import paddle.fluid.layers as layers
25
import paddle.fluid.framework as framework
Q
QI JUN 已提交
26
import paddle.fluid.core as core
Q
Qiao Longfei 已提交
27 28 29 30 31 32 33


def exponential_decay(learning_rate,
                      global_step,
                      decay_steps,
                      decay_rate,
                      staircase=False):
Y
Yu Yang 已提交
34
    exponent = global_step / decay_steps
Q
Qiao Longfei 已提交
35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61
    if staircase:
        exponent = math.floor(exponent)
    return learning_rate * decay_rate**exponent


def natural_exp_decay(learning_rate,
                      global_step,
                      decay_steps,
                      decay_rate,
                      staircase=False):
    exponent = float(global_step) / float(decay_steps)
    if staircase:
        exponent = math.floor(exponent)
    return learning_rate * math.exp(-1 * decay_rate * exponent)


def inverse_time_decay(learning_rate,
                       global_step,
                       decay_steps,
                       decay_rate,
                       staircase=False):
    temp = float(global_step) / float(decay_steps)
    if staircase:
        temp = math.floor(temp)
    return learning_rate / (1 + decay_rate * temp)


62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84
def polynomial_decay(learning_rate,
                     global_step,
                     decay_steps,
                     end_learning_rate=0.0001,
                     power=1.0,
                     cycle=False):
    if cycle:
        div = math.ceil(global_step / float(decay_steps))
        if div == 0:
            div = 1
        decay_steps = decay_steps * div
    else:
        global_step = min(global_step, decay_steps)
    return (learning_rate - end_learning_rate) * \
           ((1 - float(global_step) / float(decay_steps)) ** power) + end_learning_rate


def piecewise_decay(global_step, boundaries, values):
    assert len(boundaries) + 1 == len(values)
    for i in range(len(boundaries)):
        if global_step < boundaries[i]:
            return values[i]
    return values[len(values) - 1]
Q
Qiao Longfei 已提交
85

86

S
shippingwang 已提交
87 88 89 90 91 92 93
def cosine_decay(global_step, learning_rate, step_each_epoch, epochs):
    cur_epoch = math.floor(global_step / step_each_epoch)
    decayed_lr = learning_rate * 0.5 * (
        math.cos(cur_epoch * math.pi / epochs) + 1)
    return decayed_lr


94 95 96 97 98 99 100 101
def noam_decay(global_step, d_model, warmup_steps, learning_rate=1.0):
    a = math.pow(global_step, -0.5)
    b = math.pow(warmup_steps, -1.5) * global_step
    decayed_lr = learning_rate * math.pow(d_model, -0.5) * min(a, b)

    return decayed_lr


102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119
def linear_lr_warmup(global_step, warmup_steps, start_lr, end_lr):
    linear_step = end_lr - start_lr
    decayed_lr = start_lr + linear_step * (global_step / warmup_steps)
    return decayed_lr


def multi_step_decay(global_step, learning_rate, milestones, decay_rate=0.1):
    for i in range(len(milestones)):
        if global_step < milestones[i]:
            return learning_rate * math.pow(decay_rate, i)

    return learning_rate * math.pow(decay_rate, len(milestones))


def step_decay(global_step, learning_rate, step_size, decay_rate=0.1):
    return learning_rate * math.pow(decay_rate, global_step // step_size)


120 121 122 123
def lambda_decay(global_step, learning_rate, lr_lambda):
    return learning_rate * lr_lambda(global_step)


124
class TestLearningRateDecayDygraph(unittest.TestCase):
125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222
    def test_LR_state_dict(self):
        with fluid.dygraph.guard():
            x = np.random.uniform(-1, 1, [3, 10]).astype("float32")
            linear = fluid.dygraph.Linear(10, 10)
            input = fluid.dygraph.to_variable(x)

            Exponential_scheduler = fluid.dygraph.ExponentialDecay(
                learning_rate=0.1,
                decay_steps=10000,
                decay_rate=0.5,
                staircase=True)
            Step_scheduler = fluid.dygraph.StepDecay(0.5, step_size=3)
            Reducelr_scheduler = fluid.dygraph.ReduceLROnPlateau(
                learning_rate=1.0, decay_rate=0.5, patience=5, cooldown=3)

            adam1 = fluid.optimizer.Adam(
                learning_rate=Exponential_scheduler,
                parameter_list=linear.parameters())
            adam2 = fluid.optimizer.Adam(
                learning_rate=Step_scheduler,
                parameter_list=linear.parameters())
            adam3 = fluid.optimizer.Adam(
                learning_rate=Reducelr_scheduler,
                parameter_list=linear.parameters())
            print(adam3.state_dict())

            for epoch in range(10):
                out = linear(input)
                loss = fluid.layers.reduce_mean(out)
                loss.backward()
                adam1.minimize(loss)
                adam2.minimize(loss)
                adam3.minimize(loss)
                linear.clear_gradients()

                Step_scheduler.epoch()
                Reducelr_scheduler.step(loss)

            fluid.dygraph.save_dygraph(linear.state_dict(), "save_path")

            Exponential_scheduler_test = fluid.dygraph.ExponentialDecay(
                learning_rate=0.1,
                decay_steps=10000,
                decay_rate=0.5,
                staircase=True)
            Step_scheduler_test = fluid.dygraph.StepDecay(0.5, step_size=3)
            Reducelr_scheduler_test = fluid.dygraph.ReduceLROnPlateau(
                learning_rate=1.0, decay_rate=0.5, patience=5, cooldown=3)

            fluid.dygraph.save_dygraph(adam1.state_dict(), "save_path")
            _, opt_state = fluid.dygraph.load_dygraph("save_path")
            adam_test = fluid.optimizer.Adam(
                learning_rate=Exponential_scheduler_test,
                parameter_list=linear.parameters())
            adam_test.set_dict(opt_state)
            self.assertEqual(adam_test._learning_rate.step_num,
                             adam1._learning_rate.step_num,
                             "epoch_num is different before and after set_dict")

            fluid.dygraph.save_dygraph(adam2.state_dict(), "save_path")
            _, opt_state = fluid.dygraph.load_dygraph("save_path")
            adam_test = fluid.optimizer.Adam(
                learning_rate=Step_scheduler_test,
                parameter_list=linear.parameters())
            adam_test.set_dict(opt_state)
            self.assertEqual(adam_test._learning_rate.epoch_num,
                             adam2._learning_rate.epoch_num,
                             "epoch_num is different before and after set_dict")
            self.assertEqual(
                adam_test._learning_rate(),
                adam2._learning_rate(),
                "current learning rate is different before and after set_dict")

            fluid.dygraph.save_dygraph(adam3.state_dict(), "save_path")
            _, opt_state = fluid.dygraph.load_dygraph("save_path")
            adam_test = fluid.optimizer.Adam(
                learning_rate=Reducelr_scheduler_test,
                parameter_list=linear.parameters())
            adam_test.set_dict(opt_state)
            self.assertEqual(adam_test._learning_rate.best_loss,
                             adam3._learning_rate.best_loss.numpy()[0],
                             "best_loss is different before and after set_dict")
            self.assertEqual(
                adam_test._learning_rate.cooldown_counter,
                adam3._learning_rate.cooldown_counter,
                "cooldown_counter is different before and after set_dict")
            self.assertEqual(
                adam_test._learning_rate.num_bad_epochs,
                adam3._learning_rate.num_bad_epochs,
                "num_bad_epochs is different before and after set_dict")
            self.assertEqual(adam_test._learning_rate.epoch_num,
                             adam3._learning_rate.epoch_num,
                             "epoch is different before and after set_dict")
            self.assertEqual(
                adam_test._learning_rate(),
                adam3._learning_rate(),
                "current learning rate is different before and after set_dict")

223
    def test_NoamDecay(self):
224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240
        with fluid.dygraph.guard():
            d_model = 0.01
            warmup_steps = 200
            learning_rate = 2.0
            lr = fluid.layers.noam_decay(d_model, warmup_steps, learning_rate)
            for step in range(5):
                step += 1
                right_result = noam_decay(step, d_model, warmup_steps,
                                          learning_rate)
                fluid_result = lr()

                self.assertAlmostEqual(
                    right_result,
                    fluid_result[0],
                    msg='Failed lr scheduler in step {0}, Python result is {1}, Fluid result is {2}'.
                    format(step, right_result, fluid_result[0]))

241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270
    def test_LinearLrWarmup(self):
        with fluid.dygraph.guard():
            lr = fluid.layers.polynomial_decay(
                learning_rate=1.0,
                decay_steps=10,
                end_learning_rate=0.0,
                power=1.0)
            lr = fluid.layers.linear_lr_warmup(
                learning_rate=lr, warmup_steps=2, start_lr=0.0, end_lr=1.0)

            right_result = [0.5, 0.9, 0.8, 0.7, 0.6]
            for i in range(5):

                t = lr()

                self.assertTrue(
                    np.allclose((t.numpy())[0].item(), right_result[i]))

            with self.assertRaises(TypeError):
                lr = fluid.layers.linear_lr_warmup(
                    learning_rate="fake_lr",
                    warmup_steps=2,
                    start_lr=0.0,
                    end_lr=1.0)

    def test_MultiStepDecay(self):
        with fluid.dygraph.guard():
            learning_rate = 0.5
            milestones = [2, 4, 8]
            decay_rate = 0.2
271 272
            linear = fluid.dygraph.Linear(10, 10)

273 274
            scheduler = fluid.dygraph.MultiStepDecay(learning_rate, milestones,
                                                     decay_rate)
275 276 277

            adam = fluid.optimizer.AdamOptimizer(
                learning_rate=scheduler, parameter_list=linear.parameters())
278 279 280
            for epoch in range(10):
                right_result = multi_step_decay(epoch, learning_rate,
                                                milestones, decay_rate)
281
                fluid_result = adam.current_step_lr()
282 283 284 285
                scheduler.epoch()
                self.assertAlmostEqual(
                    right_result,
                    fluid_result,
286
                    msg='Failed lr scheduler in epoch {0}, Python result is {1}, Fluid result is {2}'.
287 288 289 290 291 292 293 294 295 296 297 298 299 300
                    format(epoch, right_result, fluid_result))

            with self.assertRaises(ValueError):
                lr = fluid.dygraph.MultiStepDecay(learning_rate, [30, 50, 20],
                                                  0.1)

            with self.assertRaises(ValueError):
                lr = fluid.dygraph.MultiStepDecay(learning_rate, [20, 30, 50],
                                                  1)

            with self.assertRaises(TypeError):
                lr = fluid.dygraph.MultiStepDecay("test", [20, 30, 50])

            with self.assertRaises(ValueError):
301
                lr = fluid.dygraph.MultiStepDecay(-1, [20, 30, 50])
302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317

    def test_StepDecay(self):
        with fluid.dygraph.guard():
            learning_rate = 0.5
            step_size = 3
            decay_rate = 0.2
            scheduler = fluid.dygraph.StepDecay(learning_rate, step_size,
                                                decay_rate)
            for epoch in range(10):
                right_result = step_decay(epoch, learning_rate, step_size,
                                          decay_rate)
                fluid_result = scheduler().numpy()[0]
                scheduler.epoch()
                self.assertAlmostEqual(
                    right_result,
                    fluid_result,
318
                    msg='Failed lr scheduler in epoch {0}, Python result is {1}, Fluid result is {2}'.
319 320 321
                    format(epoch, right_result, fluid_result))

            with self.assertRaises(TypeError):
322
                lr = fluid.dygraph.StepDecay(learning_rate, "test", 0.1)
323 324

            with self.assertRaises(ValueError):
325
                lr = fluid.dygraph.StepDecay(learning_rate, 20, 2)
326

327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349
    def test_LambdaDecay(self):
        with fluid.dygraph.guard():
            learning_rate = 0.5
            lr_lambda = lambda x: 0.95**x
            scheduler = fluid.dygraph.LambdaDecay(learning_rate, lr_lambda)

            linear = fluid.dygraph.nn.Linear(10, 10)
            adam = fluid.optimizer.Adam(
                scheduler, parameter_list=linear.parameters())

            for epoch in range(30):
                right_result = lambda_decay(epoch, learning_rate, lr_lambda)
                fluid_result = scheduler().numpy()[0]
                scheduler.epoch()
                self.assertAlmostEqual(
                    right_result,
                    fluid_result,
                    msg='Failed lr scheduler in epoch {0}, Python result is {1}, Fluid result is {2}'.
                    format(epoch, right_result, fluid_result))

            with self.assertRaises(TypeError):
                lr = fluid.dygraph.LambdaDecay(learning_rate, "test")

350

351 352
class TestLearningRateDecay(unittest.TestCase):
    def check_decay(self, python_decay_fn, fluid_decay_fn, kwargs):
Q
QI JUN 已提交
353 354 355 356 357 358 359 360 361
        places = [fluid.CPUPlace()]
        if core.is_compiled_with_cuda():
            places.append(fluid.CUDAPlace(0))
        for place in places:
            self.check_decay_with_place(place, python_decay_fn, fluid_decay_fn,
                                        kwargs)

    def check_decay_with_place(self, place, python_decay_fn, fluid_decay_fn,
                               kwargs):
362 363
        main_prog = fluid.Program()
        startup_prog = fluid.Program()
Q
QI JUN 已提交
364

365
        with fluid.program_guard(main_prog, startup_prog):
366
            decayed_lr = fluid_decay_fn(**kwargs)
Q
Qiao Longfei 已提交
367 368 369 370

        place = fluid.CPUPlace()
        exe = fluid.Executor(place)

371
        exe.run(startup_prog)
372

Q
Qiao Longfei 已提交
373
        for step in range(10):
374 375 376
            # Step of NoamDecay starts from 1.
            if python_decay_fn.__name__ == 'noam_decay':
                step += 1
377
            lr_val, = exe.run(main_prog, feed={}, fetch_list=[decayed_lr])
Y
Yu Yang 已提交
378 379 380 381 382
            python_decayed_lr = python_decay_fn(
                global_step=float(step), **kwargs)
            self.assertAlmostEqual(
                python_decayed_lr,
                lr_val[0],
383
                msg='Failed lr scheduler is {0}, step {1}, Python result is {2}, Fluid result is {3}'.
Y
Yu Yang 已提交
384
                format(python_decay_fn.__name__,
385
                       str(step), str(python_decayed_lr), str(lr_val[0])))
Q
Qiao Longfei 已提交
386 387

    def test_decay(self):
388 389 390 391 392 393 394 395 396
        common_kwargs_true = {
            "learning_rate": 1.0,
            "decay_steps": 5,
            "decay_rate": 0.5,
            "staircase": True
        }
        common_kwargs_false = copy.deepcopy(common_kwargs_true)
        common_kwargs_false["staircase"] = False

Q
Qiao Longfei 已提交
397
        decay_fns = [
398 399 400 401 402 403
            (exponential_decay, layers.exponential_decay, common_kwargs_true),
            (exponential_decay, layers.exponential_decay, common_kwargs_false),
            (natural_exp_decay, layers.natural_exp_decay, common_kwargs_true),
            (natural_exp_decay, layers.natural_exp_decay, common_kwargs_false),
            (inverse_time_decay, layers.inverse_time_decay, common_kwargs_true),
            (inverse_time_decay, layers.inverse_time_decay,
404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423
             common_kwargs_false), (polynomial_decay, layers.polynomial_decay, {
                 "learning_rate": 1.0,
                 "decay_steps": 5,
                 "cycle": True
             }), (polynomial_decay, layers.polynomial_decay, {
                 "learning_rate": 1.0,
                 "decay_steps": 5,
                 "cycle": False
             }), (piecewise_decay, layers.piecewise_decay, {
                 "boundaries": [3, 6, 9],
                 "values": [0.1, 0.2, 0.3, 0.4]
             }), (cosine_decay, layers.cosine_decay, {
                 "learning_rate": 0.1,
                 "step_each_epoch": 100,
                 "epochs": 120
             }), (noam_decay, layers.noam_decay, {
                 "d_model": 0.01,
                 "warmup_steps": 200,
                 "learning_rate": 2.0
             })
Q
Qiao Longfei 已提交
424 425
        ]

426
        for py_decay_fn, fluid_decay_fn, kwargs in decay_fns:
427
            print("class=" + self.__class__.__name__ + " decay_fn=" +
428
                  py_decay_fn.__name__ + " kwargs=" + str(kwargs))
Q
Qiao Longfei 已提交
429 430 431
            main_program = framework.Program()
            startup_program = framework.Program()
            with framework.program_guard(main_program, startup_program):
432
                self.check_decay(py_decay_fn, fluid_decay_fn, kwargs)
Q
Qiao Longfei 已提交
433 434


435
class TestLinearWamrupLearningRateDecay(unittest.TestCase):
436 437 438 439 440 441
    def check_decay_with_place(self, place, python_decay_fn, fluid_decay_fn,
                               kwargs):
        main_prog = fluid.Program()
        startup_prog = fluid.Program()

        warmup_steps = 10
Q
qingqing01 已提交
442
        start_lr = 0.1 / 3.
443 444 445 446 447 448 449 450 451 452 453
        end_lr = 0.1

        with fluid.program_guard(main_prog, startup_prog):
            decayed_lr = layers.linear_lr_warmup(
                fluid_decay_fn(**kwargs), warmup_steps, start_lr, end_lr)

        place = fluid.CPUPlace()
        exe = fluid.Executor(place)
        exe.run(startup_prog)

        for step in range(20):
454 455 456
            # Step of NoamDecay starts from 1.
            if fluid_decay_fn.__name__ == 'noam_decay':
                step += 1
457 458 459 460 461 462 463 464 465 466 467 468 469 470 471
            lr_val, = exe.run(main_prog, feed={}, fetch_list=[decayed_lr])
            if step < warmup_steps:
                python_decayed_lr = linear_lr_warmup(
                    float(step), warmup_steps, start_lr, end_lr)
            else:
                python_decayed_lr = python_decay_fn(
                    global_step=float(step), **kwargs)
            self.assertAlmostEqual(
                python_decayed_lr,
                lr_val[0],
                msg='Test {0} Failed, step {1}, Python result is {2}, Fluid result is {3}'.
                format(python_decay_fn.__name__,
                       str(step), str(python_decayed_lr), str(lr_val[0])))


Q
qingqing01 已提交
472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525
class TestLinearWamrupLearningRateDecayWithScalarInput(unittest.TestCase):
    def run_scalar_lr(self, place, lr, start_lr, end_lr):
        main_prog = fluid.Program()
        startup_prog = fluid.Program()

        warmup_steps = 10

        with fluid.program_guard(main_prog, startup_prog):
            decayed_lr = layers.linear_lr_warmup(lr, warmup_steps, start_lr,
                                                 end_lr)

        exe = fluid.Executor(place)
        exe.run(startup_prog)

        for step in range(20):
            lr_val, = exe.run(main_prog, feed={}, fetch_list=[decayed_lr])
            if step < warmup_steps:
                expected_lr = linear_lr_warmup(
                    float(step), warmup_steps, start_lr, end_lr)
            else:
                expected_lr = lr
            self.assertAlmostEqual(
                expected_lr,
                lr_val[0],
                msg='Test failed, step {0}, expected {1}, but got {2}'.format(
                    step, expected_lr, lr_val[0]))

    def test_scalar_lr(self):
        def run_places(lr, start_lr, end_lr):
            places = [fluid.CPUPlace()]
            if core.is_compiled_with_cuda():
                places.append(fluid.CUDAPlace(0))
            for p in places:
                self.run_scalar_lr(p, lr, start_lr, end_lr)

        # float
        lr = 0.2
        start_lr = 0.1 / 3.
        end_lr = 0.2
        run_places(lr, start_lr, end_lr)

        # int end_lr
        lr = 2.
        start_lr = 0.1 / 3.
        end_lr = 1
        run_places(lr, start_lr, end_lr)

        # int
        lr = 1
        start_lr = 0
        end_lr = 1
        run_places(lr, start_lr, end_lr)


526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556
def reduce_lr_on_plateau(decay_rate, threshold, cooldown, patience, m, n, loss,
                         var_list):
    def is_better(current, best, m, n):
        if m == 'min' and n == 'rel':
            return current < best - best * threshold
        elif m == 'min' and n == 'abs':
            return current < best - threshold
        elif m == 'max' and n == 'rel':
            return current > best + best * threshold
        else:  # mode == 'max' and epsilon_mode == 'abs':
            return current > best + threshold

    if var_list[2] > 0:
        var_list[2] -= 1
        return var_list[1]

    if is_better(loss, var_list[0], m, n):
        var_list[0] = loss
        var_list[3] = 0
    else:
        var_list[3] += 1
        if var_list[3] > patience:
            var_list[2] = cooldown
            var_list[3] = 0
            new_lr = var_list[1] * decay_rate
            var_list[1] = new_lr if var_list[1] - new_lr > 1e-8 else var_list[1]

    return var_list[1]


class TestReduceLROnPlateauDecay(unittest.TestCase):
557 558 559 560 561 562 563 564 565 566 567 568 569 570 571
    def test_ReduceLR(self):
        # the decay rate must be less than 1.0
        with self.assertRaises(ValueError):
            paddle.optimizer.ReduceLROnPlateau(learning_rate=1.0, factor=2.0)
        # the mode must be "min" or "max"
        with self.assertRaises(ValueError):
            paddle.optimizer.ReduceLROnPlateau(learning_rate=1.0, mode="test")
        # the threshold_mode must be "rel" or "abs"
        with self.assertRaises(ValueError):
            paddle.optimizer.ReduceLROnPlateau(
                learning_rate=1.0, threshold_mode="test")
        with self.assertRaises(TypeError):
            paddle.optimizer.ReduceLROnPlateau(learning_rate="test")
        with self.assertRaises(TypeError):
            paddle.optimizer.ReduceLROnPlateau(learning_rate=0.5).step("test")
572

573 574 575 576 577
        places = [fluid.CPUPlace()]
        if core.is_compiled_with_cuda():
            places.append(fluid.CUDAPlace(0))

        for place in places:
578 579 580
            for m, n in zip(['min', 'max', 'min', 'max'],
                            ['rel', 'rel', 'abs', 'abs']):
                kwargs = {
581
                    'learning_rate': 1.0,
582
                    'mode': m,
583 584 585
                    'factor': 0.5,
                    'patience': 3,
                    'threshold': 1e-4,
586
                    'threshold_mode': n,
587 588 589 590
                    'cooldown': 1,
                    'min_lr': 0,
                    'epsilon': 1e-8,
                    'verbose': False,
591
                }
592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962 963 964 965 966 967 968 969 970 971 972 973 974 975 976 977 978 979 980 981 982 983 984 985 986 987 988 989 990 991 992 993 994 995 996 997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009
                paddle.enable_static()
                self._test_static(place, kwargs)
                paddle.disable_static(place)
                self._test_dygraph(place, kwargs)
                paddle.enable_static()

    def _test_static(self, place, kwargs):
        paddle.enable_static()

        best = float("-10000") if kwargs['mode'] == "max" else float("10000")
        current_lr = 1.0
        cooldown_counter = 0
        num_bad_epochs = 0
        var_list = [best, current_lr, cooldown_counter, num_bad_epochs]

        main_prog = fluid.Program()
        start_prog = fluid.Program()
        with fluid.program_guard(main_prog, start_prog):
            x = fluid.layers.create_global_var(
                [1], 1, 'float32', persistable=True)
            paddle.increment(x)
            loss = paddle.sin(x)
            scheduler = paddle.optimizer.ReduceLROnPlateau(**kwargs)
            adam = fluid.optimizer.Adam(learning_rate=scheduler)
            adam.minimize(loss)
            lr_var = adam._global_learning_rate()
            test_prog = main_prog.clone()

        exe = fluid.Executor(place)
        exe.run(start_prog)

        for epoch in range(20):
            for batch_id in range(1):
                out, actual_lr = exe.run(main_prog,
                                         fetch_list=[loss.name, lr_var.name])
                expected_lr = reduce_lr_on_plateau(
                    kwargs['factor'], kwargs['threshold'], kwargs['cooldown'],
                    kwargs['patience'], kwargs['mode'],
                    kwargs['threshold_mode'], out[0], var_list)

            scheduler.step(out[0])
            actual_lr = scheduler()
            self.assertEqual(actual_lr, np.array(expected_lr))

        for epoch in range(10):
            for batch_id in range(1):
                out, actual_lr = exe.run(test_prog,
                                         fetch_list=[loss.name, lr_var.name])
                expected_lr = reduce_lr_on_plateau(
                    kwargs['factor'], kwargs['threshold'], kwargs['cooldown'],
                    kwargs['patience'], kwargs['mode'],
                    kwargs['threshold_mode'], out[0], var_list)
            scheduler.step(out[0])
            actual_lr = scheduler()
            self.assertEqual(actual_lr, np.array(expected_lr))

    def _test_dygraph(self, place, kwargs):
        paddle.disable_static(place)

        best = float("-10000") if kwargs['mode'] == "max" else float("10000")
        current_lr = 1.0
        cooldown_counter = 0
        num_bad_epochs = 0
        var_list = [best, current_lr, cooldown_counter, num_bad_epochs]

        linear = paddle.nn.Linear(10, 10)
        scheduler = paddle.optimizer.ReduceLROnPlateau(**kwargs)
        sgd = paddle.optimizer.SGD(learning_rate=scheduler,
                                   parameter_list=linear.parameters())

        for epoch in range(20):
            for batch_id in range(1):
                x = paddle.to_tensor(epoch).astype('float32')
                loss = paddle.sin(x)
                loss.backward()
                sgd.minimize(loss)

            scheduler.step(loss)
            # get lr from paddle
            current_lr = scheduler()
            # get lr form python
            expected_lr = reduce_lr_on_plateau(
                kwargs['factor'], kwargs['threshold'], kwargs['cooldown'],
                kwargs['patience'], kwargs['mode'], kwargs['threshold_mode'],
                loss, var_list)
            self.assertEqual(current_lr, expected_lr)
        state_dict = sgd.state_dict()
        scheduler1 = paddle.optimizer.ReduceLROnPlateau(**kwargs)
        sgd1 = paddle.optimizer.SGD(learning_rate=scheduler1,
                                    parameter_list=linear.parameters())
        sgd1.set_dict(state_dict)
        self.assertEqual(scheduler.cooldown_counter,
                         scheduler1.cooldown_counter)
        self.assertEqual(scheduler.best.numpy()[0], scheduler1.best)
        self.assertEqual(scheduler.num_bad_epochs, scheduler1.num_bad_epochs)
        self.assertEqual(scheduler.last_epoch, scheduler1.last_epoch)
        self.assertEqual(scheduler.last_lr, scheduler1.last_lr)


def noam_lr(epoch_num, d_model, warmup_steps, learning_rate=1.0, verbose=False):
    if epoch_num == 0:
        a = 1
    else:
        a = math.pow(epoch_num, -0.5)
    b = math.pow(warmup_steps, -1.5) * epoch_num
    return learning_rate * math.pow(d_model, -0.5) * min(a, b)


def lambda_lr(epoch_num, learning_rate, lr_lambda, verbose=False):
    return learning_rate * lr_lambda(epoch_num)


def piecewise_lr(epoch_num, boundaries, values, verbose=False):
    assert len(boundaries) + 1 == len(values)
    for i in range(len(boundaries)):
        if epoch_num < boundaries[i]:
            return values[i]
    return values[len(values) - 1]


def exponential_lr(epoch_num, learning_rate, gamma, verbose=False):
    return learning_rate * gamma**epoch_num


def natural_exp_lr(epoch_num, learning_rate, gamma, verbose=False):
    return learning_rate * math.exp(-1 * gamma * epoch_num)


def inverse_time_lr(epoch_num, learning_rate, gamma, verbose=False):
    return learning_rate / (1 + gamma * epoch_num)


def polynomial_lr(epoch_num,
                  learning_rate,
                  decay_steps,
                  end_lr=0.0001,
                  power=1.0,
                  cycle=False,
                  verbose=False):

    if cycle:
        div = math.ceil(epoch_num / float(decay_steps))
        if epoch_num == 0:
            div = 1
        decay_steps = decay_steps * div
    else:
        epoch_num = min(epoch_num, decay_steps)
    return (learning_rate - end_lr) * (
        (1 - float(epoch_num) / float(decay_steps))**power) + end_lr

    def get_lr(self):
        if self.last_epoch == 0:
            return self.base_lr
        elif (self.last_epoch - 1 - self.T_max) % (2 * self.T_max) == 0:
            return self.last_lr + (self.base_lr - self.eta_min) * (1 - math.cos(
                math.pi / self.T_max)) / 2

        return (1 + math.cos(math.pi * self.last_epoch / self.T_max)) / (
            1 + math.cos(math.pi * (self.last_epoch - 1) / self.T_max)) * (
                self.last_lr - self.eta_min) + self.eta_min


cosine_annealing_lr_current = None


def cosine_annealing_lr(epoch_num,
                        learning_rate,
                        T_max,
                        eta_min=0,
                        verbose=False):
    global cosine_annealing_lr_current
    if epoch_num == 0:
        cosine_annealing_lr_current = learning_rate
    elif (epoch_num - 1 - T_max) % (2 * T_max) == 0:
        cosine_annealing_lr_current = cosine_annealing_lr_current + (
            learning_rate - eta_min) * (1 - math.cos(math.pi / float(T_max))
                                        ) / 2
    else:
        cosine_annealing_lr_current = (1 + math.cos(
            math.pi * epoch_num / float(T_max))) / (1 + math.cos(math.pi * (
                epoch_num - 1) / float(T_max))) * (cosine_annealing_lr_current -
                                                   eta_min) + eta_min
    return cosine_annealing_lr_current


def linear_warmup_lr(epoch_num,
                     learning_rate,
                     warmup_steps,
                     start_lr,
                     end_lr,
                     verbose=False):
    if epoch_num < warmup_steps:
        return start_lr + (end_lr - start_lr) * (float(epoch_num) /
                                                 float(warmup_steps))
    else:
        return learning_rate


def multi_step_lr(epoch_num,
                  learning_rate,
                  milestones,
                  gamma=0.1,
                  verbose=False):
    for i in range(len(milestones)):
        if epoch_num < milestones[i]:
            return learning_rate * (gamma**i)
    return learning_rate * (gamma**len(milestones))


def step_lr(epoch_num, learning_rate, step_size, gamma=0.1, verbose=False):
    return learning_rate * math.pow(gamma, epoch_num // step_size)


class TestLRScheduler(unittest.TestCase):
    def _test_static(self, python_func, paddle_api, kwarg, place):
        main_prog = fluid.Program()
        start_prog = fluid.Program()
        with fluid.program_guard(main_prog, start_prog):
            x = fluid.data(name='x', shape=[3, 4, 5])
            y = fluid.data(name='y', shape=[3, 4, 5])
            z = fluid.layers.fc(x, 100)
            loss = fluid.layers.mean(z)
            scheduler = paddle_api(**kwarg)
            adam = fluid.optimizer.Adam(learning_rate=scheduler)
            adam.minimize(loss)
            lr_var = adam._global_learning_rate()
            test_prog = main_prog.clone()

        num = 0
        exe = fluid.Executor(place)
        exe.run(start_prog)
        for epoch in range(5):
            for batch_id in range(2):
                out = exe.run(
                    main_prog,
                    feed={
                        'x': np.random.randn(3, 4, 5).astype('float32'),
                        'y': np.random.randn(3, 4, 5).astype('float32')
                    },
                    fetch_list=lr_var.name)
            self.assertEqual(out, np.array(python_func(num, **kwarg)))
            scheduler.step()
            num += 1

        for epoch in range(5):
            for batch_id in range(2):
                out = exe.run(
                    test_prog,
                    feed={
                        'x': np.random.randn(3, 4, 5).astype('float32'),
                        'y': np.random.randn(3, 4, 5).astype('float32')
                    },
                    fetch_list=lr_var.name)
            self.assertEqual(out, np.array(python_func(num, **kwarg)))
            scheduler.step()
            num += 1

        if isinstance(place, fluid.CPUPlace):
            compiled_train_prog = fluid.CompiledProgram(
                main_prog).with_data_parallel(
                    loss_name=loss.name, places=fluid.cpu_places(4))
            for epoch in range(5):
                python_result = python_func(num, **kwarg)
                for batch_id in range(2):
                    _ = exe.run(
                        compiled_train_prog,
                        feed={
                            'x': np.random.randn(12, 4, 5).astype('float32'),
                            'y': np.random.randn(12, 4, 5).astype('float32')
                        },
                        fetch_list=lr_var.name)
                scopes = compiled_train_prog._executor.local_scopes()
                out = np.array(scopes[0].var(lr_var.name).get_tensor())
                self.assertEqual(out, np.array(python_result))
                out = np.array(scopes[1].var(lr_var.name).get_tensor())
                self.assertEqual(out, np.array(python_result))
                out = np.array(scopes[2].var(lr_var.name).get_tensor())
                self.assertEqual(out, np.array(python_result))
                out = np.array(scopes[3].var(lr_var.name).get_tensor())
                self.assertEqual(out, np.array(python_result))
                scheduler.step()
                num += 1

            compiled_test_prog = fluid.CompiledProgram(
                test_prog).with_data_parallel(
                    loss_name=loss.name,
                    share_vars_from=compiled_train_prog,
                    places=fluid.cpu_places(4))
            for epoch in range(5):
                python_result = python_func(num, **kwarg)
                for batch_id in range(2):
                    _ = exe.run(
                        compiled_test_prog,
                        feed={
                            'x': np.random.randn(12, 4, 5).astype('float32'),
                            'y': np.random.randn(12, 4, 5).astype('float32')
                        },
                        fetch_list=lr_var.name)
                scopes = compiled_test_prog._executor.local_scopes()
                out = np.array(scopes[0].var(lr_var.name).get_tensor())
                self.assertEqual(out, np.array(python_result))
                out = np.array(scopes[1].var(lr_var.name).get_tensor())
                self.assertEqual(out, np.array(python_result))
                out = np.array(scopes[2].var(lr_var.name).get_tensor())
                self.assertEqual(out, np.array(python_result))
                out = np.array(scopes[3].var(lr_var.name).get_tensor())
                self.assertEqual(out, np.array(python_result))
                scheduler.step()
                num += 1

    def _test_dygraph(self, python_func, paddle_api, kwarg, place):
        x = np.random.uniform(-1, 1, [10, 10]).astype("float32")
        linear = paddle.nn.Linear(10, 10)
        scheduler = paddle_api(**kwarg)
        sgd = paddle.optimizer.SGD(learning_rate=scheduler,
                                   parameter_list=linear.parameters())
        for epoch in range(20):
            for batch_id in range(2):
                x = paddle.to_tensor(x)
                out = linear(x)
                loss = paddle.reduce_mean(out)
                out.backward()
                sgd.minimize(loss)
                linear.clear_gradients()

            self.assertAlmostEqual(sgd.current_step_lr(),
                                   python_func(epoch, **kwarg))
            if paddle_api.__name__ != "CosineAnnealingLR":
                scheduler.step()
            else:
                scheduler.step(epoch + 1)

    def test_scheduler(self):
        with self.assertRaises(NotImplementedError):
            paddle.optimizer.lr_scheduler._LRScheduler().step()
        with self.assertRaises(TypeError):
            paddle.optimizer.MultiStepLR(
                learning_rate="test", milestones=[1, 2, 3])
        with self.assertRaises(TypeError):
            paddle.optimizer.MultiStepLR(learning_rate=0.5, milestones='test')
        with self.assertRaises(ValueError):
            paddle.optimizer.MultiStepLR(
                learning_rate=0.5, milestones=[3, 2, 1])
        with self.assertRaises(ValueError):
            paddle.optimizer.MultiStepLR(
                learning_rate=0.5, milestones=[1, 2, 3], gamma=2)

        func_api_kwargs = [(noam_lr, paddle.optimizer.NoamLR, {
            "d_model": 0.01,
            "warmup_steps": 100,
            "verbose": False
        }), (piecewise_lr, paddle.optimizer.PiecewiseLR, {
            "boundaries": [3, 6, 9, 15, 20],
            "values": [0.1, 0.2, 0.3, 0.4, 0.5, 0.6],
            "verbose": False
        }), (natural_exp_lr, paddle.optimizer.NaturalExpLR, {
            "learning_rate": 0.5,
            "gamma": 0.1,
            "verbose": False
        }), (inverse_time_lr, paddle.optimizer.InverseTimeLR, {
            "learning_rate": 0.5,
            "gamma": 0.1,
            "verbose": True
        }), (polynomial_lr, paddle.optimizer.PolynomialLR, {
            "learning_rate": 0.5,
            "decay_steps": 20,
            "end_lr": 0,
            "power": 1.0,
            "cycle": False,
            "verbose": False
        }), (polynomial_lr, paddle.optimizer.PolynomialLR, {
            "learning_rate": 0.5,
            "decay_steps": 20,
            "end_lr": 0,
            "power": 1.0,
            "cycle": True,
            "verbose": False
        }), (linear_warmup_lr, paddle.optimizer.LinearLrWarmup, {
            'learning_rate': 0.5,
            'warmup_steps': 20,
            'start_lr': 0,
            'end_lr': 0.5,
            "verbose": False
        }), (exponential_lr, paddle.optimizer.ExponentialLR, {
            "learning_rate": 0.5,
            "gamma": 0.9,
            "verbose": False
        }), (multi_step_lr, paddle.optimizer.MultiStepLR, {
            "learning_rate": 0.5,
            "milestones": [3, 6, 9, 15, 20],
            "gamma": 0.8,
            "verbose": True
        }), (step_lr, paddle.optimizer.StepLR, {
            "learning_rate": 0.5,
            "step_size": 2,
            "gamma": 0.8,
            "verbose": False
        }), (lambda_lr, paddle.optimizer.LambdaLR, {
            "learning_rate": 0.5,
            "lr_lambda": lambda x: 0.95**x,
            "verbose": False
        }), (cosine_annealing_lr, paddle.optimizer.CosineAnnealingLR, {
            "learning_rate": 0.5,
            "T_max": 10,
            "verbose": True
        })]

        for python_func, paddle_api, kwarg in func_api_kwargs:
            places = [fluid.CPUPlace()]
            if core.is_compiled_with_cuda():
                places.append(fluid.CUDAPlace(0))

            for place in places:
                paddle.enable_static()
                self._test_static(python_func, paddle_api, kwarg, place)
                paddle.disable_static(place)
                self._test_dygraph(python_func, paddle_api, kwarg, place)
                paddle.enable_static()
1010 1011


Q
Qiao Longfei 已提交
1012 1013
if __name__ == '__main__':
    unittest.main()