From 9e39d7463c285121819454ab36d204259a240ff8 Mon Sep 17 00:00:00 2001 From: Yuang Liu Date: Thu, 4 Aug 2022 08:09:00 +0800 Subject: [PATCH] Set the lr var's dtype to fp32 when create a fp16 lr var in optimizer if user not mean to use global fp16. (#44840) --- python/paddle/optimizer/optimizer.py | 21 ++++++++++++--------- 1 file changed, 12 insertions(+), 9 deletions(-) diff --git a/python/paddle/optimizer/optimizer.py b/python/paddle/optimizer/optimizer.py index 4534c39b00..dccbb21f5d 100644 --- a/python/paddle/optimizer/optimizer.py +++ b/python/paddle/optimizer/optimizer.py @@ -385,19 +385,23 @@ class Optimizer(object): return self._opti_name_list def _create_global_learning_rate(self): + # lr var can't be float16, for pure fp16 training, should extra handle the dtype for lr + _lr_dtype = paddle.get_default_dtype( + ) if self._dtype is None else self._dtype + _lr_dtype = paddle.float32 if ( + paddle.get_default_dtype() != "float16" + and _lr_dtype == paddle.float16) else _lr_dtype if isinstance(self._learning_rate, LRScheduler): lr_var = self._global_learning_rate() # only create global lr_var once if not isinstance(lr_var, framework.Variable): lr_name = unique_name.generate('learning_rate') self._learning_rate._var_name = lr_name - lr_var = self.helper.create_global_variable( - name=lr_name, - shape=[1], - persistable=True, - stop_gradient=True, - dtype=paddle.get_default_dtype() - if self._dtype is None else self._dtype) + lr_var = self.helper.create_global_variable(name=lr_name, + shape=[1], + persistable=True, + stop_gradient=True, + dtype=_lr_dtype) main_prog = framework.default_main_program() main_prog.lr_sheduler = self._learning_rate main_prog.lr_var = lr_var @@ -419,8 +423,7 @@ class Optimizer(object): name=unique_name.generate("learning_rate"), shape=[1], value=float(self._learning_rate), - dtype=paddle.get_default_dtype() - if self._dtype is None else self._dtype, + dtype=_lr_dtype, persistable=True) @framework.dygraph_only -- GitLab