提交 b6e38605 编写于 作者: S ShawnXuan

add reduce mean for loss

上级 a741e132
......@@ -51,7 +51,7 @@ def TrainNet():
logits = model_dict[args.model](images)
loss = flow.nn.sparse_softmax_cross_entropy_with_logits(labels, logits, name="softmax_loss")
#loss = flow.math.reduce_mean(loss)
loss = flow.math.reduce_mean(loss)
flow.losses.add_loss(loss)
predictions = flow.nn.softmax(logits)
outputs = {"loss": loss, "predictions":predictions, "labels": labels}
......@@ -84,8 +84,8 @@ def main():
snapshot = Snapshot(args.model_save_dir, args.model_load_dir)
for epoch in range(args.num_epochs):
metric = Metric(desc='train', calculate_batches=args.loss_print_every_n_iter,
summary=summary, save_summary_steps=epoch_size,
metric = Metric(desc='train', calculate_batches=args.loss_print_every_n_iter,
summary=summary, save_summary_steps=epoch_size,
batch_size=train_batch_size, loss_key='loss')
for i in range(epoch_size):
TrainNet().async_get(metric.metric_cb(epoch, i))
......@@ -93,7 +93,7 @@ def main():
# break
#break
if args.val_data_dir:
metric = Metric(desc='validataion', calculate_batches=num_val_steps, summary=summary,
metric = Metric(desc='validataion', calculate_batches=num_val_steps, summary=summary,
save_summary_steps=num_val_steps, batch_size=val_batch_size)
for i in range(num_val_steps):
InferenceNet().async_get(metric.metric_cb(epoch, i))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册