提交 686d999b 编写于 作者: D dengkaipeng

fix undefined var in test_batch_sampler. test=develop

上级 c67cf85d
......@@ -22,6 +22,8 @@ from paddle.io import BatchSampler, Dataset, Sampler, SequenceSampler, \
RandomSampler, WeightedRandomSampler
from paddle.io import DistributedBatchSampler
IMAGE_SIZE = 32
class RandomDataset(Dataset):
def __init__(self, sample_num, class_num):
......@@ -31,7 +33,7 @@ class RandomDataset(Dataset):
def __getitem__(self, idx):
np.random.seed(idx)
image = np.random.random([IMAGE_SIZE]).astype('float32')
label = np.random.randint(0, CLASS_NUM - 1, (1, )).astype('int64')
label = np.random.randint(0, self.class_num - 1, (1, )).astype('int64')
return image, label
def __len__(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册