From 0e942df008a2540c729d985bc1abdce8a98c7168 Mon Sep 17 00:00:00 2001 From: Olatunji Ruwase Date: Wed, 16 Sep 2020 10:13:06 -0700 Subject: [PATCH] Add Linear warmup+decay lr schedule (#414) Update lr schedule unit tests --- deepspeed/runtime/lr_schedules.py | 56 ++++++++++- tests/unit/test_lr_schedulers.py | 162 ++++++++++++++++++++++++++++-- 2 files changed, 210 insertions(+), 8 deletions(-) mode change 100644 => 100755 tests/unit/test_lr_schedulers.py diff --git a/deepspeed/runtime/lr_schedules.py b/deepspeed/runtime/lr_schedules.py index 76c4582f..5ec106c2 100755 --- a/deepspeed/runtime/lr_schedules.py +++ b/deepspeed/runtime/lr_schedules.py @@ -19,7 +19,8 @@ LR_SCHEDULE = 'lr_schedule' LR_RANGE_TEST = 'LRRangeTest' ONE_CYCLE = 'OneCycle' WARMUP_LR = 'WarmupLR' -VALID_LR_SCHEDULES = [LR_RANGE_TEST, ONE_CYCLE, WARMUP_LR] +WARMUP_DECAY_LR = 'WarmupDecayLR' +VALID_LR_SCHEDULES = [LR_RANGE_TEST, ONE_CYCLE, WARMUP_LR, WARMUP_DECAY_LR] LR_RANGE_TEST_MIN_LR = 'lr_range_test_min_lr' LR_RANGE_TEST_STEP_RATE = 'lr_range_test_step_rate' @@ -47,6 +48,8 @@ WARMUP_MIN_LR = 'warmup_min_lr' WARMUP_MAX_LR = 'warmup_max_lr' WARMUP_NUM_STEPS = 'warmup_num_steps' +TOTAL_NUM_STEPS = 'total_num_steps' + def add_tuning_arguments(parser): group = parser.add_argument_group('Convergence Tuning', @@ -714,3 +717,54 @@ class WarmupLR(object): FileNotFoundError(param_value))) return list(param_value) return [param_value] * len(optimizer.param_groups) + + +class WarmupDecayLR(WarmupLR): + """Increase the learning rate of each parameter group from min lr to max lr + over warmup_num_steps steps, and then decay at linear rate over the remaining training steps. + + Args: + optimizer (Optimizer): Wrapped optimizer. + total_num_steps (int): total number of training steps + warmup_min_lr (float or list): minimum learning rate. Default: 0 + warmup_max_lr (float or list): maximum learning rate. Default: 0.001 + warmup_num_steps (int): number of steps to warm up from min_lr to max_lr. Default: 1000 + last_batch_iteration (int): The index of the last batch. Default: -1. + Example: + >>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9) + >>> scheduler = WarmupDecayLR(optimizer, 1000000) + >>> data_loader = torch.utils.data.DataLoader(...) + >>> for epoch in range(10): + >>> for batch in data_loader: + >>> train_batch(...) + >>> scheduler.step() + + """ + def __init__(self, + optimizer: Optimizer, + total_num_steps: int, + warmup_min_lr: float = 0.0, + warmup_max_lr: float = 0.001, + warmup_num_steps: int = 1000, + last_batch_iteration: int = -1): + + self.total_num_steps = total_num_steps + super(WarmupDecayLR, + self).__init__(optimizer, + warmup_min_lr, + warmup_max_lr, + warmup_num_steps, + last_batch_iteration) + if self.total_num_steps < self.warmup_num_steps: + logger.warning('total_num_steps {} is less than warmup_num_steps {}'.format( + total_num_steps, + warmup_num_steps)) + + def _get_gamma(self): + if self.last_batch_iteration < self.warmup_num_steps: + return self.inverse_log_warm_up * math.log(self.last_batch_iteration + 1) + return max( + 0.0, + float(self.total_num_steps - self.last_batch_iteration) / + float(max(1.0, + self.total_num_steps - self.warmup_num_steps))) diff --git a/tests/unit/test_lr_schedulers.py b/tests/unit/test_lr_schedulers.py old mode 100644 new mode 100755 index 0c388627..bf630b1c --- a/tests/unit/test_lr_schedulers.py +++ b/tests/unit/test_lr_schedulers.py @@ -6,17 +6,25 @@ import json import os from common import distributed_test from simple_model import SimpleModel, SimpleOptimizer, random_dataloader, args_from_dict +from deepspeed.runtime.lr_schedules import LR_RANGE_TEST, ONE_CYCLE, WARMUP_LR, WARMUP_DECAY_LR +from deepspeed.runtime.lr_schedules import WARMUP_MIN_LR, WARMUP_MAX_LR, WARMUP_NUM_STEPS, TOTAL_NUM_STEPS +from deepspeed.runtime.lr_schedules import CYCLE_MIN_LR, CYCLE_MAX_LR @pytest.mark.parametrize("scheduler_type,params", - [("WarmupLR", + [(WARMUP_LR, {}), - ("OneCycle", + (WARMUP_DECAY_LR, { - 'cycle_min_lr': 0, - 'cycle_max_lr': 0 + WARMUP_NUM_STEPS: 10, + TOTAL_NUM_STEPS: 20 }), - ("LRRangeTest", + (ONE_CYCLE, + { + CYCLE_MIN_LR: 0, + CYCLE_MAX_LR: 0 + }), + (LR_RANGE_TEST, {})]) def test_get_lr_before_train(tmpdir, scheduler_type, params): config_dict = { @@ -42,8 +50,8 @@ def test_get_lr_before_train(tmpdir, scheduler_type, params): @distributed_test(world_size=[1]) def _test_get_lr_before_train(args, model, hidden_dim): model, _, _, lr_scheduler = deepspeed.initialize(args=args, - model=model, - model_parameters=model.parameters()) + model=model, + model_parameters=model.parameters()) data_loader = random_dataloader(model=model, total_samples=50, hidden_dim=hidden_dim, @@ -57,3 +65,143 @@ def test_get_lr_before_train(tmpdir, scheduler_type, params): model.step() _test_get_lr_before_train(args=args, model=model, hidden_dim=hidden_dim) + + +@pytest.mark.parametrize("warmup_num_steps", [10, 15, 19, 33]) +def test_lr_warmup_schedule(tmpdir, warmup_num_steps): + config_dict = { + "train_batch_size": 2, + "steps_per_print": 1, + "optimizer": { + "type": "Adam", + "params": { + "lr": 0.00015 + }, + }, + "scheduler": { + "type": WARMUP_LR, + "params": { + WARMUP_MIN_LR: 0.1, + WARMUP_MAX_LR: 0.2, + WARMUP_NUM_STEPS: warmup_num_steps + } + }, + "gradient_clipping": 1.0 + } + + total_num_steps = 2 * warmup_num_steps + + args = args_from_dict(tmpdir, config_dict) + hidden_dim = 10 + + model = SimpleModel(hidden_dim, empty_grad=False) + + @distributed_test(world_size=[1]) + def _test_lr_warmup_schedule(args, model, hidden_dim, schedule_params, num_steps): + model, _, _, lr_scheduler = deepspeed.initialize(args=args, + model=model, + model_parameters=model.parameters()) + + data_loader = random_dataloader(model=model, + total_samples=num_steps * 2, + hidden_dim=hidden_dim, + device=model.device, + dtype=torch.float) + step_lrs = [] + for n, batch in enumerate(data_loader): + loss = model(batch[0], batch[1]) + model.backward(loss) + model.step() + step_lrs.append(lr_scheduler.get_lr()) + + # Verify initial lr + assert step_lrs[0] == [schedule_params[WARMUP_MIN_LR]] + + # Verify warmup completion + warmup_num_steps = schedule_params[WARMUP_NUM_STEPS] + warmup_max_lr = [schedule_params[WARMUP_MAX_LR]] + assert step_lrs[warmup_num_steps] == warmup_max_lr + + # Verify post-warmup completion + assert all([warmup_max_lr == lr for lr in step_lrs[warmup_num_steps:]]) + + _test_lr_warmup_schedule(args=args, + model=model, + hidden_dim=hidden_dim, + schedule_params=config_dict["scheduler"]["params"], + num_steps=total_num_steps) + + +@pytest.mark.parametrize("warmup_num_steps", [10, 15, 19, 33]) +def test_lr_warmup_decay_schedule(tmpdir, warmup_num_steps): + config_dict = { + "train_batch_size": 2, + "steps_per_print": 1, + "optimizer": { + "type": "Adam", + "params": { + "lr": 0.00015 + }, + }, + "scheduler": { + "type": WARMUP_DECAY_LR, + "params": { + WARMUP_MIN_LR: 0.1, + WARMUP_MAX_LR: 0.2, + WARMUP_NUM_STEPS: warmup_num_steps, + TOTAL_NUM_STEPS: warmup_num_steps * 2 + } + }, + "gradient_clipping": 1.0 + } + + args = args_from_dict(tmpdir, config_dict) + hidden_dim = 10 + + model = SimpleModel(hidden_dim, empty_grad=False) + + @distributed_test(world_size=[1]) + def _test_lr_warmup_decay_schedule(args, + model, + hidden_dim, + schedule_params, + num_steps): + model, _, _, lr_scheduler = deepspeed.initialize(args=args, + model=model, + model_parameters=model.parameters()) + + data_loader = random_dataloader(model=model, + total_samples=num_steps * 2, + hidden_dim=hidden_dim, + device=model.device, + dtype=torch.float) + step_lrs = [] + for n, batch in enumerate(data_loader): + loss = model(batch[0], batch[1]) + model.backward(loss) + model.step() + step_lrs.append(lr_scheduler.get_lr()) + + # Verify initial lr + assert step_lrs[0] == [schedule_params[WARMUP_MIN_LR]] + + # Verify lr at warmup completion + warmup_num_steps = schedule_params[WARMUP_NUM_STEPS] + warmup_max_lr = [schedule_params[WARMUP_MAX_LR]] + assert step_lrs[warmup_num_steps] == warmup_max_lr + + # Verify decay phase + previous_lr = warmup_max_lr + for lr in step_lrs[warmup_num_steps + 1:]: + assert lr < previous_lr + previous_lr = lr + + schedule_params = config_dict["scheduler"]["params"] + + total_num_steps = schedule_params[TOTAL_NUM_STEPS] + + _test_lr_warmup_decay_schedule(args=args, + model=model, + hidden_dim=hidden_dim, + schedule_params=schedule_params, + num_steps=total_num_steps) -- GitLab