diff --git a/python/paddle/fluid/tests/unittests/test_batch_sampler.py b/python/paddle/fluid/tests/unittests/test_batch_sampler.py index 4c323a2511f5b61fc0f82971b1b09248bb2dd367..4c5338314afb1a41eb05336c80bfa9e8f20f34f0 100644 --- a/python/paddle/fluid/tests/unittests/test_batch_sampler.py +++ b/python/paddle/fluid/tests/unittests/test_batch_sampler.py @@ -22,6 +22,8 @@ from paddle.io import BatchSampler, Dataset, Sampler, SequenceSampler, \ RandomSampler, WeightedRandomSampler from paddle.io import DistributedBatchSampler +IMAGE_SIZE = 32 + class RandomDataset(Dataset): def __init__(self, sample_num, class_num): @@ -31,7 +33,7 @@ class RandomDataset(Dataset): def __getitem__(self, idx): np.random.seed(idx) image = np.random.random([IMAGE_SIZE]).astype('float32') - label = np.random.randint(0, CLASS_NUM - 1, (1, )).astype('int64') + label = np.random.randint(0, self.class_num - 1, (1, )).astype('int64') return image, label def __len__(self):