From 974d1ba7fec0d1c4c9d4455d068a49cae2caedfc Mon Sep 17 00:00:00 2001 From: ShawnXuan Date: Sat, 8 Feb 2020 19:47:18 +0800 Subject: [PATCH] fix --- cnn_benchmark/of_cnn_train_val.py | 11 ++--------- run.sh | 6 +++--- test.sh | 3 ++- 3 files changed, 7 insertions(+), 13 deletions(-) diff --git a/cnn_benchmark/of_cnn_train_val.py b/cnn_benchmark/of_cnn_train_val.py index 8fc7167..e64ac20 100755 --- a/cnn_benchmark/of_cnn_train_val.py +++ b/cnn_benchmark/of_cnn_train_val.py @@ -138,7 +138,6 @@ def train_callback(epoch, step): summary.scalar('train_accuracy', accuracy, step) main.correct = 0.0 main.total = 0.0 - #exit() return callback @@ -149,7 +148,8 @@ def do_predictions(epoch, predict_step, predictions): summary.scalar('top1_accuracy', main.correct/main.total, epoch) #summary.scalar('top1_correct', main.correct, epoch) #summary.scalar('total_val_images', main.total, epoch) - print("epoch {}, top 1 accuracy: {:.6f}".format(epoch, main.correct/main.total)) + print("epoch {}, top 1 accuracy: {:.6f}, time: {:.2f}".format(epoch, + main.correct/main.total, timer.split())) def predict_callback(epoch, predict_step): @@ -190,13 +190,6 @@ def main(): InferenceNet(images, labels.astype(np.int32)).async_get(predict_callback(epoch, i)) #acc_acc(i, InferenceNet(images, labels.astype(np.int32)).get()) - assert main.total > 0 - top1_accuracy = main.correct/main.total - summary.scalar('top1_accuracy', top1_accuracy, epoch) - print("epoch {}, top 1 accuracy: {:.6f}, val_time: {:.2f}".format(epoch, top1_accuracy, - time.time()-tic)) - - snapshot.save('epoch_{}'.format(epoch+1)) summary.save() diff --git a/run.sh b/run.sh index 47c96e8..eb3d14c 100755 --- a/run.sh +++ b/run.sh @@ -1,8 +1,8 @@ rm -rf core.* #gdb --args \ -#DATA_ROOT=/mnt/13_nfs/ImageNet -DATA_ROOT=/dataset/imagenet-mxnet -#nvprof -of resnet.nvvp \ +DATA_ROOT=/mnt/13_nfs/xuan/ImageNet/mxnet +#DATA_ROOT=/dataset/imagenet-mxnet +#nvprof -f -o resnet.nvvp \ python3 cnn_benchmark/of_cnn_train_val.py \ --data_train=$DATA_ROOT/train.rec \ --data_train_idx=$DATA_ROOT/train.idx \ diff --git a/test.sh b/test.sh index ddcfbe7..d72d2ef 100755 --- a/test.sh +++ b/test.sh @@ -1,7 +1,8 @@ rm -rf core.* #gdb --args \ #DATA_ROOT=/mnt/13_nfs/xuan/ImageNet -DATA_ROOT=/dataset/imagenet-mxnet +DATA_ROOT=/mnt/13_nfs/xuan/ImageNet/mxnet +#DATA_ROOT=/dataset/imagenet-mxnet python cnn_benchmark/dali.py \ --data_train=$DATA_ROOT/train.rec \ --data_train_idx=$DATA_ROOT/train.idx \ -- GitLab