diff --git a/ppcls/data/preprocess/ops/operators.py b/ppcls/data/preprocess/ops/operators.py index 1971e5ea25e01fa40f9f87922fe877b80048b36c..9cdc58b2b5f6c4987c898a6549f5652797d1c845 100644 --- a/ppcls/data/preprocess/ops/operators.py +++ b/ppcls/data/preprocess/ops/operators.py @@ -352,7 +352,7 @@ class AugMix(object): np.random.beta(self.aug_prob_coeff, self.aug_prob_coeff)) # image = Image.fromarray(image) - mix = np.zeros([image.shape[1], image.shape[0], 3]) + mix = np.zeros(image.shape) for i in range(self.mixture_width): image_aug = image.copy() image_aug = Image.fromarray(image_aug)