未验证 提交 8e301b6e 编写于 作者: O Olatunji Ruwase 提交者: GitHub

Support client lr schedulers that are not subclass of torch _LRScheduler (#1337)

上级 c0b27fb0
......@@ -546,14 +546,14 @@ class DeepSpeedEngine(Module):
f'DeepSpeed using configured LR scheduler = {self.scheduler_name()}')
self.lr_scheduler = lr_scheduler
else:
if isinstance(client_lr_scheduler, _LRScheduler):
if self.global_rank == 0:
logger.info('DeepSpeed using client LR scheduler')
self.lr_scheduler = client_lr_scheduler
elif isinstance(client_lr_scheduler, Callable):
if isinstance(client_lr_scheduler, Callable):
if self.global_rank == 0:
logger.info('DeepSpeed using client callable to create LR scheduler')
self.lr_scheduler = client_lr_scheduler(self.basic_optimizer)
else:
if self.global_rank == 0:
logger.info('DeepSpeed using client LR scheduler')
self.lr_scheduler = client_lr_scheduler
log_dist(f'DeepSpeed LR Scheduler = {self.lr_scheduler}', ranks=[0])
......@@ -673,9 +673,6 @@ class DeepSpeedEngine(Module):
assert self.dynamic_loss_scale(), \
'DeepSpeed {} optimizer requires dynamic loss scaling'.format(self.optimizer_name())
assert isinstance(self.client_lr_scheduler, (type(None), _LRScheduler, Callable)), \
f'Client LR Scheduler is of unexpected type {type(self.client_lr_scheduler)}'
# Detect invalid combinations of client optimizer and client scheduler
if isinstance(self.client_lr_scheduler, _LRScheduler):
assert isinstance(self.client_optimizer, Optimizer), \
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册