diff --git a/python/paddle/optimizer/optimizer.py b/python/paddle/optimizer/optimizer.py index 4534c39b0082af5e952c4cdd92e9e6361b39b077..dccbb21f5d2a2117c7ecef52a9a66cd1556c7d58 100644 --- a/python/paddle/optimizer/optimizer.py +++ b/python/paddle/optimizer/optimizer.py @@ -385,19 +385,23 @@ class Optimizer(object): return self._opti_name_list def _create_global_learning_rate(self): + # lr var can't be float16, for pure fp16 training, should extra handle the dtype for lr + _lr_dtype = paddle.get_default_dtype( + ) if self._dtype is None else self._dtype + _lr_dtype = paddle.float32 if ( + paddle.get_default_dtype() != "float16" + and _lr_dtype == paddle.float16) else _lr_dtype if isinstance(self._learning_rate, LRScheduler): lr_var = self._global_learning_rate() # only create global lr_var once if not isinstance(lr_var, framework.Variable): lr_name = unique_name.generate('learning_rate') self._learning_rate._var_name = lr_name - lr_var = self.helper.create_global_variable( - name=lr_name, - shape=[1], - persistable=True, - stop_gradient=True, - dtype=paddle.get_default_dtype() - if self._dtype is None else self._dtype) + lr_var = self.helper.create_global_variable(name=lr_name, + shape=[1], + persistable=True, + stop_gradient=True, + dtype=_lr_dtype) main_prog = framework.default_main_program() main_prog.lr_sheduler = self._learning_rate main_prog.lr_var = lr_var @@ -419,8 +423,7 @@ class Optimizer(object): name=unique_name.generate("learning_rate"), shape=[1], value=float(self._learning_rate), - dtype=paddle.get_default_dtype() - if self._dtype is None else self._dtype, + dtype=_lr_dtype, persistable=True) @framework.dygraph_only