提交 cc28a460 编写于 作者: D dengkaipeng

spread train_loop.

上级 2d4b3459
...@@ -99,38 +99,36 @@ def train(): ...@@ -99,38 +99,36 @@ def train():
fetch_list = [loss] fetch_list = [loss]
def train_loop(): py_reader.start()
py_reader.start() smoothed_loss = SmoothedValue()
smoothed_loss = SmoothedValue() try:
try: start_time = time.time()
start_time = time.time() prev_start_time = start_time
snapshot_loss = 0
snapshot_time = 0
for iter_id in range(cfg.start_iter, cfg.max_iter):
prev_start_time = start_time prev_start_time = start_time
snapshot_loss = 0 start_time = time.time()
snapshot_time = 0 losses = exe.run(compile_program, fetch_list=[v.name for v in fetch_list])
for iter_id in range(cfg.start_iter, cfg.max_iter): smoothed_loss.add_value(np.mean(np.array(losses[0])))
prev_start_time = start_time snapshot_loss += np.mean(np.array(losses[0]))
start_time = time.time() snapshot_time += start_time - prev_start_time
losses = exe.run(compile_program, fetch_list=[v.name for v in fetch_list]) lr = np.array(fluid.global_scope().find_var('learning_rate')
smoothed_loss.add_value(np.mean(np.array(losses[0]))) .get_tensor())
snapshot_loss += np.mean(np.array(losses[0])) print("Iter {:d}, lr {:.6f}, loss {:.6f}, time {:.5f}".format(
snapshot_time += start_time - prev_start_time iter_id, lr[0],
lr = np.array(fluid.global_scope().find_var('learning_rate') smoothed_loss.get_mean_value(), start_time - prev_start_time))
.get_tensor()) sys.stdout.flush()
print("Iter {:d}, lr {:.6f}, loss {:.6f}, time {:.5f}".format( if (iter_id + 1) % cfg.snapshot_iter == 0:
iter_id, lr[0], save_model("model_iter{}".format(iter_id))
smoothed_loss.get_mean_value(), start_time - prev_start_time)) print("Snapshot {} saved, average loss: {}, average time: {}".format(
sys.stdout.flush() iter_id + 1, snapshot_loss / float(cfg.snapshot_iter),
if (iter_id + 1) % cfg.snapshot_iter == 0: snapshot_time / float(cfg.snapshot_iter)))
save_model("model_iter{}".format(iter_id)) snapshot_loss = 0
print("Snapshot {} saved, average loss: {}, average time: {}".format( snapshot_time = 0
iter_id + 1, snapshot_loss / float(cfg.snapshot_iter), except fluid.core.EOFException:
snapshot_time / float(cfg.snapshot_iter))) py_reader.reset()
snapshot_loss = 0
snapshot_time = 0
except fluid.core.EOFException:
py_reader.reset()
train_loop()
save_model('model_final') save_model('model_final')
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册