未验证 提交 9e39d746 编写于 作者: Y Yuang Liu 提交者: GitHub

Set the lr var's dtype to fp32 when create a fp16 lr var in optimizer if user...

Set the lr var's dtype to fp32 when create a fp16 lr var in optimizer if user not mean to use global fp16. (#44840)
上级 9a17f05f
...@@ -385,19 +385,23 @@ class Optimizer(object): ...@@ -385,19 +385,23 @@ class Optimizer(object):
return self._opti_name_list return self._opti_name_list
def _create_global_learning_rate(self): def _create_global_learning_rate(self):
# lr var can't be float16, for pure fp16 training, should extra handle the dtype for lr
_lr_dtype = paddle.get_default_dtype(
) if self._dtype is None else self._dtype
_lr_dtype = paddle.float32 if (
paddle.get_default_dtype() != "float16"
and _lr_dtype == paddle.float16) else _lr_dtype
if isinstance(self._learning_rate, LRScheduler): if isinstance(self._learning_rate, LRScheduler):
lr_var = self._global_learning_rate() lr_var = self._global_learning_rate()
# only create global lr_var once # only create global lr_var once
if not isinstance(lr_var, framework.Variable): if not isinstance(lr_var, framework.Variable):
lr_name = unique_name.generate('learning_rate') lr_name = unique_name.generate('learning_rate')
self._learning_rate._var_name = lr_name self._learning_rate._var_name = lr_name
lr_var = self.helper.create_global_variable( lr_var = self.helper.create_global_variable(name=lr_name,
name=lr_name, shape=[1],
shape=[1], persistable=True,
persistable=True, stop_gradient=True,
stop_gradient=True, dtype=_lr_dtype)
dtype=paddle.get_default_dtype()
if self._dtype is None else self._dtype)
main_prog = framework.default_main_program() main_prog = framework.default_main_program()
main_prog.lr_sheduler = self._learning_rate main_prog.lr_sheduler = self._learning_rate
main_prog.lr_var = lr_var main_prog.lr_var = lr_var
...@@ -419,8 +423,7 @@ class Optimizer(object): ...@@ -419,8 +423,7 @@ class Optimizer(object):
name=unique_name.generate("learning_rate"), name=unique_name.generate("learning_rate"),
shape=[1], shape=[1],
value=float(self._learning_rate), value=float(self._learning_rate),
dtype=paddle.get_default_dtype() dtype=_lr_dtype,
if self._dtype is None else self._dtype,
persistable=True) persistable=True)
@framework.dygraph_only @framework.dygraph_only
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册