未验证 提交 d5ccd221 编写于 作者: K Kaipeng Deng 提交者: GitHub

fix video get learning_rate error (#2172)

* fix video get learning_rate error

* refine comment
上级 9952b253
...@@ -9,6 +9,29 @@ import shutil ...@@ -9,6 +9,29 @@ import shutil
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def log_lr_and_step():
try:
# In optimizers, if learning_rate is set as constant, lr_var
# name is 'learning_rate_0', and iteration counter is not
# recorded. If learning_rate is set as decayed values from
# learning_rate_scheduler, lr_var name is 'learning_rate',
# and iteration counter is recorded with name '@LR_DECAY_COUNTER@',
# better impliment is required here
lr_var = fluid.global_scope().find_var("learning_rate")
if not lr_var:
lr_var = fluid.global_scope().find_var("learning_rate_0")
lr = np.array(lr_var.get_tensor())
lr_count = '[-]'
lr_count_var = fluid.global_scope().find_var("@LR_DECAY_COUNTER@")
if lr_count_var:
lr_count = np.array(lr_count_var.get_tensor())
logger.info("------- learning rate {}, learning rate counter {} -----"
.format(np.array(lr), np.array(lr_count)))
except:
logger.warn("Unable to get learning_rate and LR_DECAY_COUNTER.")
def test_without_pyreader(test_exe, def test_without_pyreader(test_exe,
test_reader, test_reader,
test_feeder, test_feeder,
...@@ -61,11 +84,7 @@ def train_without_pyreader(exe, train_prog, train_exe, train_reader, train_feede ...@@ -61,11 +84,7 @@ def train_without_pyreader(exe, train_prog, train_exe, train_reader, train_feede
save_model_name = 'model', test_exe = None, test_reader = None, \ save_model_name = 'model', test_exe = None, test_reader = None, \
test_feeder = None, test_fetch_list = None, test_metrics = None): test_feeder = None, test_fetch_list = None, test_metrics = None):
for epoch in range(epochs): for epoch in range(epochs):
lr = fluid.global_scope().find_var("learning_rate").get_tensor() log_lr_and_step()
lr_count = fluid.global_scope().find_var(
"@LR_DECAY_COUNTER@").get_tensor()
logger.info("------- learning rate {}, learning rate counter {} -----"
.format(np.array(lr), np.array(lr_count)))
epoch_periods = [] epoch_periods = []
for train_iter, data in enumerate(train_reader()): for train_iter, data in enumerate(train_reader()):
cur_time = time.time() cur_time = time.time()
...@@ -101,11 +120,7 @@ def train_with_pyreader(exe, train_prog, train_exe, train_pyreader, \ ...@@ -101,11 +120,7 @@ def train_with_pyreader(exe, train_prog, train_exe, train_pyreader, \
if not train_pyreader: if not train_pyreader:
logger.error("[TRAIN] get pyreader failed.") logger.error("[TRAIN] get pyreader failed.")
for epoch in range(epochs): for epoch in range(epochs):
lr = fluid.global_scope().find_var("learning_rate").get_tensor() log_lr_and_step()
lr_count = fluid.global_scope().find_var(
"@LR_DECAY_COUNTER@").get_tensor()
logger.info("------- learning rate {}, learning rate counter {} -----"
.format(np.array(lr), np.array(lr_count)))
train_pyreader.start() train_pyreader.start()
train_metrics.reset() train_metrics.reset()
try: try:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册