提交 ca44df94 编写于 作者: D dengkaipeng

change loss smooth to history loss mean

上级 4d9eeb5d
......@@ -110,7 +110,7 @@ def train():
def train_loop_pyreader():
py_reader.start()
smoothed_loss = SmoothedValue(cfg.log_window)
smoothed_loss = SmoothedValue()
try:
start_time = time.time()
prev_start_time = start_time
......@@ -127,7 +127,7 @@ def train():
.get_tensor())
print("Iter {:d}, lr {:.6f}, loss {:.6f}, time {:.5f}".format(
iter_id, lr[0],
smoothed_loss.get_median_value(), start_time - prev_start_time))
smoothed_loss.get_mean_value(), start_time - prev_start_time))
sys.stdout.flush()
if (iter_id + 1) % cfg.snapshot_iter == 0:
save_model("model_iter{}".format(iter_id))
......@@ -143,7 +143,7 @@ def train():
start_time = time.time()
prev_start_time = start_time
start = start_time
smoothed_loss = SmoothedValue(cfg.log_window)
smoothed_loss = SmoothedValue()
snapshot_loss = 0
snapshot_time = 0
for iter_id, data in enumerate(train_reader()):
......@@ -158,7 +158,7 @@ def train():
lr = np.array(fluid.global_scope().find_var('learning_rate')
.get_tensor())
print("Iter {:d}, lr: {:.6f}, loss: {:.4f}, time {:.5f}".format(
iter_id, lr[0], smoothed_loss.get_median_value(), start_time - prev_start_time))
iter_id, lr[0], smoothed_loss.get_mean_value(), start_time - prev_start_time))
sys.stdout.flush()
if (iter_id + 1) % cfg.snapshot_iter == 0:
......
......@@ -75,14 +75,16 @@ class SmoothedValue(object):
window or the global series average.
"""
def __init__(self, window_size):
self.deque = deque(maxlen=window_size)
def __init__(self):
self.loss_sum = 0.0
self.iter_cnt = 0
def add_value(self, value):
self.deque.append(value)
self.loss_sum += np.mean(value)
self.iter_cnt += 1
def get_median_value(self):
return np.median(self.deque)
def get_mean_value(self):
return self.loss_sum / self.iter_cnt
def parse_args():
......@@ -109,7 +111,7 @@ def parse_args():
add_arg('learning_rate', float, 0.001, "Learning rate.")
add_arg('max_iter', int, 500200, "Iter number.")
add_arg('snapshot_iter', int, 2000, "Save model every snapshot stride.")
add_arg('log_window', int, 20, "Log smooth window, set 1 for debug, set 20 for train.")
# add_arg('log_window', int, 20, "Log smooth window, set 1 for debug, set 20 for train.")
# TRAIN TEST INFER
add_arg('input_size', int, 608, "Image input size of YOLOv3.")
add_arg('random_shape', bool, True, "Resize to random shape for train reader.")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册