未验证 提交 274c375c 编写于 作者: O Olatunji Ruwase 提交者: GitHub

Support Callable type for client optimizer and lr_scheduler (#1316)

* Callable option for optimizer and scheduler

* Add unit test

* Formatting

* Disable debug prints

* Use base optimizer to construct lr scheduler

* Formatting

* Remove dead import
上级 aa121291
......@@ -3,13 +3,16 @@ Copyright 2020 The Microsoft DeepSpeed Team
'''
import sys
import types
from typing import Optional, Union
import torch
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler
from packaging import version as pkg_version
from . import ops
from . import module_inject
from .runtime.engine import DeepSpeedEngine
from .runtime.engine import DeepSpeedEngine, DeepSpeedOptimizerCallable, DeepSpeedSchedulerCallable
from .runtime.engine import ADAM_OPTIMIZER, LAMB_OPTIMIZER
from .runtime.pipe.engine import PipelineEngine
from .inference.engine import InferenceEngine
......@@ -56,13 +59,15 @@ sys.modules['deepspeed.pt.loss_scaler'] = deepspeed.runtime.fp16.loss_scaler
def initialize(args=None,
model=None,
optimizer=None,
model_parameters=None,
training_data=None,
lr_scheduler=None,
model: torch.nn.Module = None,
optimizer: Optional[Union[Optimizer,
DeepSpeedOptimizerCallable]] = None,
model_parameters: Optional[torch.nn.Module] = None,
training_data: Optional[torch.utils.data.Dataset] = None,
lr_scheduler: Optional[Union[_LRScheduler,
DeepSpeedSchedulerCallable]] = None,
mpu=None,
dist_init_required=None,
dist_init_required: Optional[bool] = None,
collate_fn=None,
config=None,
config_params=None):
......@@ -74,16 +79,16 @@ def initialize(args=None,
model: Required: nn.module class before apply any wrappers
optimizer: Optional: a user defined optimizer, this is typically used instead of defining
an optimizer in the DeepSpeed json config.
optimizer: Optional: a user defined Optimizer or Callable that returns an Optimizer object.
This overrides any optimizer definition in the DeepSpeed json config.
model_parameters: Optional: An iterable of torch.Tensors or dicts.
Specifies what Tensors should be optimized.
training_data: Optional: Dataset of type torch.utils.data.Dataset
lr_scheduler: Optional: Learning Rate Scheduler Object. It should define a get_lr(),
step(), state_dict(), and load_state_dict() methods
lr_scheduler: Optional: Learning Rate Scheduler Object or a Callable that takes an Optimizer and returns a Scheduler object.
The scheduler object should define a get_lr(), step(), state_dict(), and load_state_dict() methods
mpu: Optional: A model parallelism unit object that implements
get_{model,data}_parallel_{rank,group,world_size}()
......
......@@ -13,9 +13,14 @@ from collections import defaultdict, OrderedDict
from shutil import copyfile
from torch.nn.modules import Module
from torch.nn.parameter import Parameter
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler
from torch.distributed.distributed_c10d import _get_global_rank
from tensorboardX import SummaryWriter
from typing import Callable, Dict, Optional, Union, Iterable
from deepspeed.runtime.utils import see_memory_usage
from deepspeed.runtime.zero.stage2 import FP16_DeepSpeedZeroOptimizer
from deepspeed.runtime.zero.stage1 import FP16_DeepSpeedZeroOptimizer_Stage1
......@@ -57,6 +62,10 @@ from deepspeed.profiling.flops_profiler.profiler import FlopsProfiler
MEMORY_OPT_ALLREDUCE_SIZE = 500000000
DeepSpeedOptimizerCallable = \
Callable[[Union[Iterable[Parameter], Dict[str, Iterable]]], Optimizer]
DeepSpeedSchedulerCallable = Callable[[Optimizer], _LRScheduler]
try:
import apex
from apex import amp
......@@ -198,6 +207,7 @@ class DeepSpeedEngine(Module):
# Configure optimizer and scheduler
self.optimizer = None
self.basic_optimizer = None
self.lr_scheduler = None
if model_parameters or optimizer:
self._configure_optimizer(optimizer, model_parameters)
......@@ -536,9 +546,15 @@ class DeepSpeedEngine(Module):
f'DeepSpeed using configured LR scheduler = {self.scheduler_name()}')
self.lr_scheduler = lr_scheduler
else:
if self.global_rank == 0:
logger.info('DeepSpeed using client LR scheduler')
self.lr_scheduler = client_lr_scheduler
if isinstance(client_lr_scheduler, _LRScheduler):
if self.global_rank == 0:
logger.info('DeepSpeed using client LR scheduler')
self.lr_scheduler = client_lr_scheduler
elif isinstance(client_lr_scheduler, Callable):
if self.global_rank == 0:
logger.info('DeepSpeed using client callable to create LR scheduler')
self.lr_scheduler = client_lr_scheduler(self.basic_optimizer)
log_dist(f'DeepSpeed LR Scheduler = {self.lr_scheduler}', ranks=[0])
def _configure_checkpointing(self, dist_init_required):
......@@ -644,6 +660,9 @@ class DeepSpeedEngine(Module):
# Validate configuration based on command line arguments
def _do_sanity_check(self):
assert isinstance(self.client_optimizer, (type(None), Optimizer, Callable)), \
f'Client Optimizer is of unexpected type {type(self.client_optimizer)}'
if not self.client_optimizer:
if self.optimizer_name() is not None:
assert self._is_supported_optimizer(self.optimizer_name()), \
......@@ -654,6 +673,14 @@ class DeepSpeedEngine(Module):
assert self.dynamic_loss_scale(), \
'DeepSpeed {} optimizer requires dynamic loss scaling'.format(self.optimizer_name())
assert isinstance(self.client_lr_scheduler, (type(None), _LRScheduler, Callable)), \
f'Client LR Scheduler is of unexpected type {type(self.client_lr_scheduler)}'
# Detect invalid combinations of client optimizer and client scheduler
if isinstance(self.client_lr_scheduler, _LRScheduler):
assert isinstance(self.client_optimizer, Optimizer), \
f'Client Optimizer (type = {type(self.client_optimizer)} is not instantiated but Client LR Scheduler is instantiated'
def _broadcast_model(self):
def is_replicated(p):
if hasattr(p, 'ds_status') and p.ds_status is not ZeroParamStatus.AVAILABLE:
......@@ -771,18 +798,23 @@ class DeepSpeedEngine(Module):
# Configure optimizer
def _configure_optimizer(self, client_optimizer, model_parameters):
if client_optimizer is not None:
client_optimizer.param_groups[:] = [
pg for pg in client_optimizer.param_groups if len(pg["params"]) != 0
]
if self.global_rank == 0:
logger.info(
"Removing param_group that has no 'params' in the client Optimizer")
if isinstance(client_optimizer, Optimizer):
client_optimizer.param_groups[:] = [
pg for pg in client_optimizer.param_groups if len(pg["params"]) != 0
]
if self.global_rank == 0:
logger.info(
"Removing param_group that has no 'params' in the client Optimizer"
)
basic_optimizer = client_optimizer
if self.global_rank == 0:
logger.info('Using client Optimizer as basic optimizer')
basic_optimizer = client_optimizer
if self.global_rank == 0:
logger.info('Using client Optimizer as basic optimizer')
else:
basic_optimizer = client_optimizer(model_parameters)
if self.global_rank == 0:
logger.info('Using client callable to create basic optimizer')
else:
basic_optimizer = self._configure_basic_optimizer(model_parameters)
if self.global_rank == 0:
......@@ -792,6 +824,7 @@ class DeepSpeedEngine(Module):
self._check_for_duplicates(basic_optimizer)
self.basic_optimizer = basic_optimizer
if self.global_rank == 0:
logger.info('DeepSpeed Basic Optimizer = {}'.format(
basic_optimizer.__class__.__name__))
......@@ -832,6 +865,8 @@ class DeepSpeedEngine(Module):
def _configure_basic_optimizer(self, model_parameters):
optimizer_parameters = self.optimizer_params()
if optimizer_parameters is None:
optimizer_parameters = {}
# print(optimizer_parameters.keys())
if 'max_grad_norm' in optimizer_parameters.keys():
raise ValueError(
......
......@@ -261,7 +261,7 @@ class FP16_DeepSpeedZeroOptimizer(object):
see_memory_usage(f"Before moving param group {i} to CPU")
# move all the parameters to cpu to free up GPU space for creating flat buffer
move_to_cpu(self.fp16_groups[i])
see_memory_usage(f"After moving param group {i} to CPU", force=True)
see_memory_usage(f"After moving param group {i} to CPU", force=False)
# Reorder group parameters for load balancing of gradient partitioning during backward among ranks.
# This ensures that gradients are reduced in a fashion such that ownership round robins among the ranks.
......@@ -286,12 +286,12 @@ class FP16_DeepSpeedZeroOptimizer(object):
dist.get_world_size(group=self.real_dp_process_group[i])).cuda(
torch.cuda.current_device()))
see_memory_usage(f"After flattening and moving param group {i} to GPU",
force=True)
force=False)
if dist.get_rank(group=self.real_dp_process_group[i]) == 0:
see_memory_usage(
f"After Flattening and after emptying param group {i} cache",
force=True)
force=False)
# set model fp16 weight to slices of flattened buffer
self._update_model_fp16_weights(i)
......
import pytest
from typing import Callable
import torch
from torch.optim import Optimizer, Adam, AdamW
from torch.optim.lr_scheduler import _LRScheduler, LambdaLR
from simple_model import args_from_dict, SimpleModel
from common import distributed_test
import deepspeed
from deepspeed.ops.adam import FusedAdam
from deepspeed.runtime.lr_schedules import WARMUP_LR, WarmupLR
from deepspeed.runtime.config import ADAM_OPTIMIZER
@pytest.mark.parametrize('optimizer_type', [None, Optimizer, Callable])
def test_client_optimizer(tmpdir, optimizer_type):
def _optimizer_callable(params) -> Optimizer:
return AdamW(params=params)
hidden_dim = 10
model = SimpleModel(hidden_dim)
config_dict = {'train_batch_size': 1}
if optimizer_type is None:
client_optimizer = None
config_dict['optimizer'] = {'type': ADAM_OPTIMIZER}
elif optimizer_type is Optimizer:
client_optimizer = Adam(model.parameters())
else:
client_optimizer = _optimizer_callable
args = args_from_dict(tmpdir, config_dict)
@distributed_test(world_size=[1])
def _test_client_optimizer(args, model, client_optimizer):
_, ds_optimizer, _, _ = deepspeed.initialize(args=args,
model=model,
model_parameters=list(model.parameters()),
optimizer=client_optimizer)
if client_optimizer is None:
assert isinstance(ds_optimizer, FusedAdam)
elif isinstance(client_optimizer, Optimizer):
assert ds_optimizer == client_optimizer
else:
assert isinstance(ds_optimizer, AdamW)
_test_client_optimizer(args=args, model=model, client_optimizer=client_optimizer)
@pytest.mark.parametrize('scheduler_type, optimizer_type',
[(None,
None),
(None,
Optimizer),
(None,
Callable),
(_LRScheduler,
None),
(_LRScheduler,
Optimizer),
(_LRScheduler,
Callable),
(Callable,
None),
(Callable,
Optimizer),
(Callable,
Callable)])
def test_client_lr_scheduler(tmpdir, scheduler_type, optimizer_type):
def _my_lambda(epoch):
return epoch // 10
def _optimizer_callable(params) -> Optimizer:
return torch.optim.AdamW(params=params)
def _lr_scheduler_callable(optimizer) -> _LRScheduler:
return LambdaLR(optimizer, _my_lambda)
hidden_dim = 10
model = SimpleModel(hidden_dim)
config_dict = {'train_batch_size': 1}
client_optimizer = None
client_scheduler = None
if optimizer_type is None:
config_dict['optimizer'] = {'type': ADAM_OPTIMIZER}
elif optimizer_type is Optimizer:
client_optimizer = torch.optim.Adam(model.parameters())
else:
client_optimizer = _optimizer_callable
if scheduler_type is None:
config_dict['scheduler'] = {'type': WARMUP_LR, 'params': {}}
elif scheduler_type == _LRScheduler:
if isinstance(client_optimizer, Optimizer):
client_scheduler = LambdaLR(client_optimizer, _my_lambda)
else:
# Verify invalid combination is correctly handled
client_scheduler = LambdaLR(torch.optim.Adam(model.parameters()), _my_lambda)
else:
client_scheduler = _lr_scheduler_callable
args = args_from_dict(tmpdir, config_dict)
@distributed_test(world_size=[1])
def _test_client_lr_scheduler(args, model, optimizer, lr_scheduler):
if isinstance(lr_scheduler,
_LRScheduler) and not isinstance(optimizer,
Optimizer):
with pytest.raises(AssertionError):
_, _, _, _ = deepspeed.initialize(args=args,
model=model,
model_parameters=list(model.parameters()),
optimizer=optimizer,
lr_scheduler=lr_scheduler)
else:
_, _, _, ds_lr_scheduler = deepspeed.initialize(args=args,
model=model,
model_parameters=list(model.parameters()),
optimizer=optimizer,
lr_scheduler=lr_scheduler)
if lr_scheduler is None:
assert isinstance(ds_lr_scheduler, WarmupLR)
elif isinstance(lr_scheduler, _LRScheduler):
assert ds_lr_scheduler == lr_scheduler
else:
assert isinstance(ds_lr_scheduler, LambdaLR)
_test_client_lr_scheduler(args=args,
model=model,
optimizer=client_optimizer,
lr_scheduler=client_scheduler)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册