diff --git a/python/paddle/fluid/layers/learning_rate_scheduler.py b/python/paddle/fluid/layers/learning_rate_scheduler.py
index 069ade544587ed43906debe03d6d13aa6c2e0447..9c642712d2acda2b61b815ef0ece26223cb9a988 100644
--- a/python/paddle/fluid/layers/learning_rate_scheduler.py
+++ b/python/paddle/fluid/layers/learning_rate_scheduler.py
@@ -350,7 +350,7 @@ def cosine_decay(learning_rate, step_each_epoch, epochs):
     following cosine decay strategy.
 
     decayed_lr = learning_rate * 0.5 * (math.cos(epoch * math.pi / epochs) + 1)
-
+    
     Args:
         learning_rate(Variable|float): The initial learning rate.
         step_each_epoch(int): the number of steps in an epoch.
diff --git a/python/paddle/fluid/optimizer.py b/python/paddle/fluid/optimizer.py
index cea182db036e6baae572a3e38ca56c0b46da00b0..8fdc7f33ab16df8aca1e9ff0b334aace7b710d62 100644
--- a/python/paddle/fluid/optimizer.py
+++ b/python/paddle/fluid/optimizer.py
@@ -94,13 +94,18 @@ class Optimizer(object):
         if imperative_base.enabled():
             # create learning rate Variable
             if isinstance(self._learning_rate, float):
-                self._learning_rate_map[framework.default_main_program(
-                )] = layers.create_global_var(
-                    name=unique_name.generate("learning_rate"),
-                    shape=[1],
-                    value=float(self._learning_rate),
-                    dtype='float32' if self._dtype is None else self._dtype,
-                    persistable=True)
+                lr = self._global_learning_rate()
+
+                if isinstance(lr, framework.Variable):
+                    return
+                else:
+                    self._learning_rate_map[framework.default_main_program(
+                    )] = layers.create_global_var(
+                        name=unique_name.generate("learning_rate"),
+                        shape=[1],
+                        value=float(self._learning_rate),
+                        dtype='float32' if self._dtype is None else self._dtype,
+                        persistable=True)
             # get learning rate Variable from LearningRateDecay
             elif isinstance(self._learning_rate, LearningRateDecay):
                 self._learning_rate_map[framework.default_main_program(
@@ -114,11 +119,12 @@ class Optimizer(object):
 
             if isinstance(lr, framework.Variable):
                 return
-
-            if not isinstance(self._learning_rate, float):
-                raise TypeError(
-                    "learning rate variable is create outside optimizer,"
-                    "can not create new learning rate variable for new program")
+            else:
+                if not isinstance(self._learning_rate, float):
+                    raise TypeError(
+                        "learning rate variable is create outside optimizer,"
+                        "can not create new learning rate variable for new program"
+                    )
 
             # create learning rate in the current main program
             self._learning_rate_map[framework.default_main_program(