提交 be99e10f 编写于 作者: S ShawnXuan

fix

上级 e20f16db
......@@ -57,13 +57,14 @@ def main():
train_data_iter, val_data_iter = get_rec_iter(args, True)
for epoch in range(args.num_epochs):
model_load_dir = os.path.join(args.model_load_dir, 'snapshot_epoch_{}'.format(epoch+1))
model_load_dir = os.path.join(args.model_load_dir, 'snapshot_epoch_{}'.format(epoch))
snapshot = Snapshot(args.model_save_dir, model_load_dir)
metric = Metric(desc='validation', calculate_batches=num_val_steps, summary=summary,
save_summary_steps=num_val_steps, batch_size=val_batch_size)
val_data_iter.reset()
for i, batches in enumerate(val_data_iter):
images, labels = batches
images = images[:,:,:,::-1]
InferenceNet(images, labels).async_get(metric.metric_cb(epoch, i))
summary.save()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册