未验证 提交 0939de92 编写于 作者: W Wei Shengyu 提交者: GitHub

Merge pull request #1268 from weisy11/develop

fix bug of pk sampler
...@@ -68,18 +68,20 @@ class PKSampler(DistributedBatchSampler): ...@@ -68,18 +68,20 @@ class PKSampler(DistributedBatchSampler):
logger.error( logger.error(
"PKSampler only support id_avg_prob and sample_avg_prob sample method, " "PKSampler only support id_avg_prob and sample_avg_prob sample method, "
"but receive {}.".format(self.sample_method)) "but receive {}.".format(self.sample_method))
if sum(np.abs(self.prob_list - 1) > 0.00000001): diff = np.abs(sum(self.prob_list) - 1)
if diff > 0.00000001:
self.prob_list[-1] = 1 - sum(self.prob_list[:-1]) self.prob_list[-1] = 1 - sum(self.prob_list[:-1])
if self.prob_list[-1] > 1 or self.prob_list[-1] < 0: if self.prob_list[-1] > 1 or self.prob_list[-1] < 0:
logger.error("PKSampler prob list error") logger.error("PKSampler prob list error")
else: else:
logger.info( logger.info(
"PKSampler: sum of prob list not equal to 1, change the last prob" "PKSampler: sum of prob list not equal to 1, diff is {}, change the last prob".format(diff)
) )
def __iter__(self): def __iter__(self):
label_per_batch = self.batch_size // self.sample_per_label label_per_batch = self.batch_size // self.sample_per_label
if self.shuffle: if self.shuffle:
# It's not accurate literally, but it helps in some dataset.
np.random.RandomState(self.epoch).shuffle(self.label_list) np.random.RandomState(self.epoch).shuffle(self.label_list)
for i in range(len(self)): for i in range(len(self)):
batch_index = [] batch_index = []
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册