提交 b542416d 编写于 作者: H HydrogenSulfate

add random shuffle in PKSampler

上级 d8f28bf3
......@@ -18,7 +18,9 @@ from __future__ import division
from collections import defaultdict
import numpy as np
import paddle.distributed as dist
from paddle.io import DistributedBatchSampler
from ppcls.utils import logger
......@@ -94,6 +96,12 @@ class PKSampler(DistributedBatchSampler):
format(diff))
def __iter__(self):
# shuffle manually, same as DistributedBatchSampler.__iter__
if self.shuffle:
np.random.RandomState(self.epoch + dist.get_rank()).shuffle(
self.label_list)
self.epoch += 1
label_per_batch = self.batch_size // self.sample_per_id
for _ in range(len(self)):
batch_index = []
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册