From c0e98bd4510be50e30b4e9d2b93944566bf6e4e2 Mon Sep 17 00:00:00 2001 From: ShawnXuan Date: Thu, 2 Apr 2020 20:46:28 +0800 Subject: [PATCH] modify --- cnn_e2e/ofrecord_util.py | 14 +++++--------- 1 file changed, 5 insertions(+), 9 deletions(-) diff --git a/cnn_e2e/ofrecord_util.py b/cnn_e2e/ofrecord_util.py index 90fbdd1..2587562 100644 --- a/cnn_e2e/ofrecord_util.py +++ b/cnn_e2e/ofrecord_util.py @@ -34,8 +34,8 @@ def load_imagenet(args, batch_size, data_dir, data_part_num, codec): batch_size=batch_size, data_part_num=data_part_num, part_name_suffix_length=5, - shuffle = True, - buffer_size=32768, + #shuffle = True, + #buffer_size=32768, name="decode", ) @@ -45,7 +45,7 @@ def load_imagenet_for_training(args): train_batch_size = total_device_num * args.batch_size_per_device codec=flow.data.ImageCodec([ #flow.data.ImagePreprocessor('bgr2rgb'), - flow.data.ImageCropWithRandomSizePreprocessor(area=(0.08, 1)), + #flow.data.ImageCropWithRandomSizePreprocessor(area=(0.08, 1)), flow.data.ImageResizePreprocessor(args.image_size, args.image_size), flow.data.ImagePreprocessor('mirror'), ]) @@ -101,13 +101,10 @@ def load_imagenet_for_training2(args): rsz = flow.image.Resize(image, resize_x=float(args.image_size), resize_y=float(args.image_size), color_space=color_space) - print(rsz.shape) - print(label.shape) rng = flow.image.CoinFlip(batch_size=train_batch_size, seed=seed) normal = flow.image.CropMirrorNormalize(rsz, mirror_blob=rng, color_space=color_space, mean=args.rgb_mean, std=args.rgb_std, output_dtype = flow.float) - print(normal.shape) return label, normal @@ -132,8 +129,7 @@ if __name__ == "__main__": else: print("Loading synthetic data.") (labels, images) = load_synthetic(args) - predictions = labels - outputs = {"predictions":predictions, "labels": labels} + outputs = {"images":images, "labels": labels} return outputs total_device_num = args.num_nodes * args.gpu_num_per_node @@ -141,6 +137,6 @@ if __name__ == "__main__": summary = Summary(args.log_dir, args, filename='io_test.csv') metric = Metric(desc='io_test', calculate_batches=args.loss_print_every_n_iter, summary=summary, save_summary_steps=args.loss_print_every_n_iter, - batch_size=train_batch_size) + batch_size=train_batch_size, prediction_key=None) for i in range(1000): IOTest().async_get(metric.metric_cb(0, i)) -- GitLab