提交 b542416d 编写于 作者: H HydrogenSulfate

add random shuffle in PKSampler

上级 d8f28bf3
...@@ -18,7 +18,9 @@ from __future__ import division ...@@ -18,7 +18,9 @@ from __future__ import division
from collections import defaultdict from collections import defaultdict
import numpy as np import numpy as np
import paddle.distributed as dist
from paddle.io import DistributedBatchSampler from paddle.io import DistributedBatchSampler
from ppcls.utils import logger from ppcls.utils import logger
...@@ -94,6 +96,12 @@ class PKSampler(DistributedBatchSampler): ...@@ -94,6 +96,12 @@ class PKSampler(DistributedBatchSampler):
format(diff)) format(diff))
def __iter__(self): def __iter__(self):
# shuffle manually, same as DistributedBatchSampler.__iter__
if self.shuffle:
np.random.RandomState(self.epoch + dist.get_rank()).shuffle(
self.label_list)
self.epoch += 1
label_per_batch = self.batch_size // self.sample_per_id label_per_batch = self.batch_size // self.sample_per_id
for _ in range(len(self)): for _ in range(len(self)):
batch_index = [] batch_index = []
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册