diff --git a/plsc/utils/jpeg_reader.py b/plsc/utils/jpeg_reader.py index 554c6e0540bbb528d75cfbbe2a8236c1e9964a8f..dfca8eb61fc9626359fa8f6dbd9ad344370c2a9c 100644 --- a/plsc/utils/jpeg_reader.py +++ b/plsc/utils/jpeg_reader.py @@ -252,7 +252,11 @@ def load_bin(path, image_size, data_format ='NCHW'): bins, issame_list = pickle.load(open(path, 'rb'), encoding='bytes') data_list = [] for flip in [0, 1]: - data = np.empty((len(issame_list) * 2, 3, image_size[0], image_size[1])) + if data_format == 'NCHW': + data = np.empty((len(issame_list) * 2, 3, image_size[0], image_size[1])) + else: + #NHWC + data = np.empty((len(issame_list) * 2, image_size[0], image_size[1], 3)) data_list.append(data) for i in range(len(issame_list) * 2): _bin = bins[i]