未验证 提交 9e39d746 编写于 作者: Y Yuang Liu 提交者: GitHub

Set the lr var's dtype to fp32 when create a fp16 lr var in optimizer if user...

Set the lr var's dtype to fp32 when create a fp16 lr var in optimizer if user not mean to use global fp16. (#44840)
上级 9a17f05f
......@@ -385,19 +385,23 @@ class Optimizer(object):
return self._opti_name_list
def _create_global_learning_rate(self):
# lr var can't be float16, for pure fp16 training, should extra handle the dtype for lr
_lr_dtype = paddle.get_default_dtype(
) if self._dtype is None else self._dtype
_lr_dtype = paddle.float32 if (
paddle.get_default_dtype() != "float16"
and _lr_dtype == paddle.float16) else _lr_dtype
if isinstance(self._learning_rate, LRScheduler):
lr_var = self._global_learning_rate()
# only create global lr_var once
if not isinstance(lr_var, framework.Variable):
lr_name = unique_name.generate('learning_rate')
self._learning_rate._var_name = lr_name
lr_var = self.helper.create_global_variable(
name=lr_name,
shape=[1],
persistable=True,
stop_gradient=True,
dtype=paddle.get_default_dtype()
if self._dtype is None else self._dtype)
lr_var = self.helper.create_global_variable(name=lr_name,
shape=[1],
persistable=True,
stop_gradient=True,
dtype=_lr_dtype)
main_prog = framework.default_main_program()
main_prog.lr_sheduler = self._learning_rate
main_prog.lr_var = lr_var
......@@ -419,8 +423,7 @@ class Optimizer(object):
name=unique_name.generate("learning_rate"),
shape=[1],
value=float(self._learning_rate),
dtype=paddle.get_default_dtype()
if self._dtype is None else self._dtype,
dtype=_lr_dtype,
persistable=True)
@framework.dygraph_only
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册