diff --git a/ppcls/data/preprocess/ops/random_erasing.py b/ppcls/data/preprocess/ops/random_erasing.py index d96ceda6e895520e0e4bd3036804276f864b769d..b395d5205bc7aab7aba7098f832e4c470638913d 100644 --- a/ppcls/data/preprocess/ops/random_erasing.py +++ b/ppcls/data/preprocess/ops/random_erasing.py @@ -42,9 +42,9 @@ class RandomErasing(object): h = int(round(math.sqrt(target_area * aspect_ratio))) w = int(round(math.sqrt(target_area / aspect_ratio))) - if w < img.shape[2] and h < img.shape[1]: - x1 = random.randint(0, img.shape[1] - h) - y1 = random.randint(0, img.shape[2] - w) + if w < img.shape[1] and h < img.shape[0]: + x1 = random.randint(0, img.shape[0] - h) + y1 = random.randint(0, img.shape[1] - w) if img.shape[0] == 3: img[x1:x1 + h, y1:y1 + w, 0] = self.mean[0] img[x1:x1 + h, y1:y1 + w, 1] = self.mean[1]