提交 ade3bbc0 编写于 作者: O Olatunji Ruwase

Support legacy optimizer fusion as config option

上级 4f7d016d
...@@ -157,6 +157,14 @@ def get_optimizer_gradient_clipping(param_dict): ...@@ -157,6 +157,14 @@ def get_optimizer_gradient_clipping(param_dict):
return None return None
def get_optimizer_legacy_fusion(param_dict):
if OPTIMIZER in param_dict.keys() and \
LEGACY_FUSION in param_dict[OPTIMIZER].keys():
return param_dict[OPTIMIZER][LEGACY_FUSION]
else:
return LEGACY_FUSION_DEFAULT
def get_scheduler_name(param_dict): def get_scheduler_name(param_dict):
if SCHEDULER in param_dict.keys() and \ if SCHEDULER in param_dict.keys() and \
TYPE in param_dict[SCHEDULER].keys(): TYPE in param_dict[SCHEDULER].keys():
...@@ -261,6 +269,7 @@ class DeepSpeedConfig(object): ...@@ -261,6 +269,7 @@ class DeepSpeedConfig(object):
self.optimizer_name = self.optimizer_name.lower() self.optimizer_name = self.optimizer_name.lower()
self.optimizer_params = get_optimizer_params(param_dict) self.optimizer_params = get_optimizer_params(param_dict)
self.optimizer_legacy_fusion = get_optimizer_legacy_fusion(param_dict)
self.scheduler_name = get_scheduler_name(param_dict) self.scheduler_name = get_scheduler_name(param_dict)
self.scheduler_params = get_scheduler_params(param_dict) self.scheduler_params = get_scheduler_params(param_dict)
......
...@@ -24,6 +24,8 @@ OPTIMIZER = "optimizer" ...@@ -24,6 +24,8 @@ OPTIMIZER = "optimizer"
OPTIMIZER_TYPE_DEFAULT = None OPTIMIZER_TYPE_DEFAULT = None
OPTIMIZER_PARAMS = "params" OPTIMIZER_PARAMS = "params"
TYPE = "type" TYPE = "type"
LEGACY_FUSION = "legacy_fusion"
LEGACY_FUSION_DEFAULT = True
SCHEDULER = "scheduler" SCHEDULER = "scheduler"
SCHEDULER_TYPE_DEFAULT = None SCHEDULER_TYPE_DEFAULT = None
SCHEDULER_PARAMS = "params" SCHEDULER_PARAMS = "params"
......
...@@ -211,6 +211,9 @@ class DeepSpeedLight(Module): ...@@ -211,6 +211,9 @@ class DeepSpeedLight(Module):
def optimizer_params(self): def optimizer_params(self):
return self._config.optimizer_params return self._config.optimizer_params
def optimizer_legacy_fusion(self):
return self._config.optimizer_legacy_fusion
def scheduler_name(self): def scheduler_name(self):
return self._config.scheduler_name return self._config.scheduler_name
...@@ -411,21 +414,23 @@ class DeepSpeedLight(Module): ...@@ -411,21 +414,23 @@ class DeepSpeedLight(Module):
if self.optimizer_name() == ADAM_OPTIMIZER: if self.optimizer_name() == ADAM_OPTIMIZER:
if self.dynamic_loss_scale(): if self.dynamic_loss_scale():
logging.info('Creating fp16 optimizer with dynamic loss scale') logging.info('Creating fp16 optimizer with dynamic loss scale')
optimizer = FP16_Optimizer(optimizer, optimizer = FP16_Optimizer(
dynamic_loss_scale=True, optimizer,
initial_dynamic_scale=initial_dynamic_scale, dynamic_loss_scale=True,
dynamic_loss_args=dynamic_loss_args, initial_dynamic_scale=initial_dynamic_scale,
mpu=self.mpu, dynamic_loss_args=dynamic_loss_args,
clip_grad=clip_grad, mpu=self.mpu,
fused_adam_legacy=True) clip_grad=clip_grad,
fused_adam_legacy=self.optimizer_legacy_fusion())
else: else:
logging.info('Creating fp16 optimizer with static loss scale: {}'.format( logging.info('Creating fp16 optimizer with static loss scale: {}'.format(
self.loss_scale())) self.loss_scale()))
optimizer = FP16_Optimizer(optimizer, optimizer = FP16_Optimizer(
static_loss_scale=self.loss_scale(), optimizer,
mpu=self.mpu, static_loss_scale=self.loss_scale(),
clip_grad=clip_grad, mpu=self.mpu,
fused_adam_legacy=True) clip_grad=clip_grad,
fused_adam_legacy=self.optimizer_legacy_fusion())
else: else:
logging.info('Creating fp16 unfused optimizer with dynamic loss scale') logging.info('Creating fp16 unfused optimizer with dynamic loss scale')
optimizer = FP16_UnfusedOptimizer( optimizer = FP16_UnfusedOptimizer(
...@@ -434,7 +439,7 @@ class DeepSpeedLight(Module): ...@@ -434,7 +439,7 @@ class DeepSpeedLight(Module):
dynamic_loss_args=dynamic_loss_args, dynamic_loss_args=dynamic_loss_args,
mpu=self.mpu, mpu=self.mpu,
clip_grad=clip_grad, clip_grad=clip_grad,
fused_lamb_legacy=True fused_lamb_legacy=self.optimizer_legacy_fusion()
if self.optimizer_name() == LAMB_OPTIMIZER else False) if self.optimizer_name() == LAMB_OPTIMIZER else False)
return optimizer return optimizer
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册