From 81379888c9a0d36b8a0ca3963330b48517ebd4f3 Mon Sep 17 00:00:00 2001 From: mir-of Date: Mon, 8 Jun 2020 14:02:28 +0800 Subject: [PATCH] fix data loader transpose logic --- cnn_benchmark/model_util.py | 2 +- cnn_e2e/of_cnn_train_val.py | 5 +++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/cnn_benchmark/model_util.py b/cnn_benchmark/model_util.py index 773a61b..8cc4c91 100644 --- a/cnn_benchmark/model_util.py +++ b/cnn_benchmark/model_util.py @@ -17,7 +17,7 @@ def conv2d_layer( weight_initializer=flow.random_uniform_initializer(), bias_initializer=flow.constant_initializer(), ): - weight_shape = (filters, input.static_shape[1], kernel_size, kernel_size) + weight_shape = (filters, input.shape[1], kernel_size, kernel_size) weight = flow.get_variable( name + "-weight", shape=weight_shape, diff --git a/cnn_e2e/of_cnn_train_val.py b/cnn_e2e/of_cnn_train_val.py index 5af722b..b59888d 100755 --- a/cnn_e2e/of_cnn_train_val.py +++ b/cnn_e2e/of_cnn_train_val.py @@ -60,8 +60,9 @@ def TrainNet(): print("Loading synthetic data.") (labels, images) = ofrecord_util.load_synthetic(args) + logits = model_dict[args.model]( - images, need_transpose=not args.use_new_dataloader) + images, need_transpose=False if (args.use_new_dataloader and args.train_data_dir) else True) loss = flow.nn.sparse_softmax_cross_entropy_with_logits( labels, logits, name="softmax_loss") loss = flow.math.reduce_mean(loss) @@ -85,7 +86,7 @@ def InferenceNet(): (labels, images) = ofrecord_util.load_synthetic(args) logits = model_dict[args.model]( - images, need_transpose=not args.use_new_dataloader) + images, need_transpose=False if (args.use_new_dataloader and args.train_data_dir) else True) predictions = flow.nn.softmax(logits) outputs = {"predictions": predictions, "labels": labels} return outputs -- GitLab