From be99e10f184b7c2099e29da55335ee3ae5653a8d Mon Sep 17 00:00:00 2001 From: ShawnXuan Date: Tue, 24 Mar 2020 10:26:25 +0800 Subject: [PATCH] fix --- cnn_e2e/dali_cnn_val.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/cnn_e2e/dali_cnn_val.py b/cnn_e2e/dali_cnn_val.py index 54c49a6..1768310 100755 --- a/cnn_e2e/dali_cnn_val.py +++ b/cnn_e2e/dali_cnn_val.py @@ -57,13 +57,14 @@ def main(): train_data_iter, val_data_iter = get_rec_iter(args, True) for epoch in range(args.num_epochs): - model_load_dir = os.path.join(args.model_load_dir, 'snapshot_epoch_{}'.format(epoch+1)) + model_load_dir = os.path.join(args.model_load_dir, 'snapshot_epoch_{}'.format(epoch)) snapshot = Snapshot(args.model_save_dir, model_load_dir) metric = Metric(desc='validation', calculate_batches=num_val_steps, summary=summary, save_summary_steps=num_val_steps, batch_size=val_batch_size) val_data_iter.reset() for i, batches in enumerate(val_data_iter): images, labels = batches + images = images[:,:,:,::-1] InferenceNet(images, labels).async_get(metric.metric_cb(epoch, i)) summary.save() -- GitLab