From 686d999b804f93c86c2e128dd92fd31dbb51a609 Mon Sep 17 00:00:00 2001 From: dengkaipeng Date: Wed, 22 Sep 2021 11:19:33 +0000 Subject: [PATCH] fix undefined var in test_batch_sampler. test=develop --- python/paddle/fluid/tests/unittests/test_batch_sampler.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/python/paddle/fluid/tests/unittests/test_batch_sampler.py b/python/paddle/fluid/tests/unittests/test_batch_sampler.py index 4c323a2511f..4c5338314af 100644 --- a/python/paddle/fluid/tests/unittests/test_batch_sampler.py +++ b/python/paddle/fluid/tests/unittests/test_batch_sampler.py @@ -22,6 +22,8 @@ from paddle.io import BatchSampler, Dataset, Sampler, SequenceSampler, \ RandomSampler, WeightedRandomSampler from paddle.io import DistributedBatchSampler +IMAGE_SIZE = 32 + class RandomDataset(Dataset): def __init__(self, sample_num, class_num): @@ -31,7 +33,7 @@ class RandomDataset(Dataset): def __getitem__(self, idx): np.random.seed(idx) image = np.random.random([IMAGE_SIZE]).astype('float32') - label = np.random.randint(0, CLASS_NUM - 1, (1, )).astype('int64') + label = np.random.randint(0, self.class_num - 1, (1, )).astype('int64') return image, label def __len__(self): -- GitLab