提交 de298b1b 编写于 作者: W weishengyu

dbg

上级 4dc175c9
......@@ -26,9 +26,13 @@ from ppcls.data.dataloader.common_dataset import create_operators
from ppcls.data.dataloader.vehicle_dataset import CompCars, VeriWild
from ppcls.data.dataloader.logo_dataset import LogoDataset
from ppcls.data.dataloader.icartoon_dataset import ICartoonDataset
from ppcls.data.dataloader.mix_dataset import MixDataset
from ppcls.data.dataloader.writer_hard_dataset import WriterHardDataset
# sampler
from ppcls.data.dataloader.DistributedRandomIdentitySampler import DistributedRandomIdentitySampler
from ppcls.data.dataloader.writer_hard_sampler import WriterHardSampler
from ppcls.data.dataloader.mix_sampler import MixSampler
from ppcls.data import preprocess
from ppcls.data.preprocess import transform
......
......@@ -34,10 +34,11 @@ class MixSampler(DistributedBatchSampler):
batch_size_left = self.batch_size
self.iter_list = []
for i, config_i in enumerate(sample_configs):
self.start_list.append(dataset_list[i][1])
sample_method = config_i.pop("name")
ratio_i = config_i.pop("ratio")
if i < len(sample_configs) - 1:
batch_size_i = self.batch_size * ratio_i
batch_size_i = int(self.batch_size * ratio_i)
batch_size_left -= batch_size_i
else:
batch_size_i = batch_size_left
......
......@@ -33,10 +33,11 @@ class WriterHardSampler(DistributedBatchSampler):
- batch_size (int): number of examples in a batch.
"""
def __init__(self, dataset, batch_size, **args):
def __init__(self, dataset, batch_size, shuffle=True, **args):
super(WriterHardSampler, self).__init__(dataset, batch_size)
self.dataset = dataset
self.batch_size = batch_size
self.shuffle = shuffle
assert not self.batch_size % 4, "bs of WriterHardSampler should be 3*N"
assert isinstance(dataset, WriterHardDataset), "WriterHardSampler only support WriterHardDataset"
self.num_pids_per_batch = self.batch_size // 4
......@@ -45,7 +46,7 @@ class WriterHardSampler(DistributedBatchSampler):
self.text_id_map = {}
anno_list = dataset.anno_list
for i, anno_i in enumerate(anno_list):
_, person_id, text_id = anno_i.split(" ")
_, person_id, text_id = anno_i.strip().split(" ")
if text_id != "-1":
if random.random() < 0.5:
self.anchor_list.append([i, person_id, text_id])
......@@ -59,11 +60,11 @@ class WriterHardSampler(DistributedBatchSampler):
self.person_id_map[person_id].append(i)
else:
self.person_id_map[person_id] = [i]
assert len(self.anchor_list) < self.batch_size, "anchor should be larger than batch_size"
assert len(self.anchor_list) > self.batch_size, "anchor should be larger than batch_size"
def __iter__(self):
random.shuffle(self.anchor_list)
if self.shuffle:
random.shuffle(self.anchor_list)
for i in range(len(self)):
batch_indices = []
for j in range(self.batch_size // 4):
......@@ -79,4 +80,4 @@ class WriterHardSampler(DistributedBatchSampler):
yield batch_indices
def __len__(self):
len(self.anchor_list) * 4 // self.batch_size
return len(self.anchor_list) * 4 // self.batch_size
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册