提交 de298b1b 编写于 作者: W weishengyu

dbg

上级 4dc175c9
...@@ -26,9 +26,13 @@ from ppcls.data.dataloader.common_dataset import create_operators ...@@ -26,9 +26,13 @@ from ppcls.data.dataloader.common_dataset import create_operators
from ppcls.data.dataloader.vehicle_dataset import CompCars, VeriWild from ppcls.data.dataloader.vehicle_dataset import CompCars, VeriWild
from ppcls.data.dataloader.logo_dataset import LogoDataset from ppcls.data.dataloader.logo_dataset import LogoDataset
from ppcls.data.dataloader.icartoon_dataset import ICartoonDataset from ppcls.data.dataloader.icartoon_dataset import ICartoonDataset
from ppcls.data.dataloader.mix_dataset import MixDataset
from ppcls.data.dataloader.writer_hard_dataset import WriterHardDataset
# sampler # sampler
from ppcls.data.dataloader.DistributedRandomIdentitySampler import DistributedRandomIdentitySampler from ppcls.data.dataloader.DistributedRandomIdentitySampler import DistributedRandomIdentitySampler
from ppcls.data.dataloader.writer_hard_sampler import WriterHardSampler
from ppcls.data.dataloader.mix_sampler import MixSampler
from ppcls.data import preprocess from ppcls.data import preprocess
from ppcls.data.preprocess import transform from ppcls.data.preprocess import transform
......
...@@ -34,10 +34,11 @@ class MixSampler(DistributedBatchSampler): ...@@ -34,10 +34,11 @@ class MixSampler(DistributedBatchSampler):
batch_size_left = self.batch_size batch_size_left = self.batch_size
self.iter_list = [] self.iter_list = []
for i, config_i in enumerate(sample_configs): for i, config_i in enumerate(sample_configs):
self.start_list.append(dataset_list[i][1])
sample_method = config_i.pop("name") sample_method = config_i.pop("name")
ratio_i = config_i.pop("ratio") ratio_i = config_i.pop("ratio")
if i < len(sample_configs) - 1: if i < len(sample_configs) - 1:
batch_size_i = self.batch_size * ratio_i batch_size_i = int(self.batch_size * ratio_i)
batch_size_left -= batch_size_i batch_size_left -= batch_size_i
else: else:
batch_size_i = batch_size_left batch_size_i = batch_size_left
......
...@@ -33,10 +33,11 @@ class WriterHardSampler(DistributedBatchSampler): ...@@ -33,10 +33,11 @@ class WriterHardSampler(DistributedBatchSampler):
- batch_size (int): number of examples in a batch. - batch_size (int): number of examples in a batch.
""" """
def __init__(self, dataset, batch_size, **args): def __init__(self, dataset, batch_size, shuffle=True, **args):
super(WriterHardSampler, self).__init__(dataset, batch_size) super(WriterHardSampler, self).__init__(dataset, batch_size)
self.dataset = dataset self.dataset = dataset
self.batch_size = batch_size self.batch_size = batch_size
self.shuffle = shuffle
assert not self.batch_size % 4, "bs of WriterHardSampler should be 3*N" assert not self.batch_size % 4, "bs of WriterHardSampler should be 3*N"
assert isinstance(dataset, WriterHardDataset), "WriterHardSampler only support WriterHardDataset" assert isinstance(dataset, WriterHardDataset), "WriterHardSampler only support WriterHardDataset"
self.num_pids_per_batch = self.batch_size // 4 self.num_pids_per_batch = self.batch_size // 4
...@@ -45,7 +46,7 @@ class WriterHardSampler(DistributedBatchSampler): ...@@ -45,7 +46,7 @@ class WriterHardSampler(DistributedBatchSampler):
self.text_id_map = {} self.text_id_map = {}
anno_list = dataset.anno_list anno_list = dataset.anno_list
for i, anno_i in enumerate(anno_list): for i, anno_i in enumerate(anno_list):
_, person_id, text_id = anno_i.split(" ") _, person_id, text_id = anno_i.strip().split(" ")
if text_id != "-1": if text_id != "-1":
if random.random() < 0.5: if random.random() < 0.5:
self.anchor_list.append([i, person_id, text_id]) self.anchor_list.append([i, person_id, text_id])
...@@ -59,10 +60,10 @@ class WriterHardSampler(DistributedBatchSampler): ...@@ -59,10 +60,10 @@ class WriterHardSampler(DistributedBatchSampler):
self.person_id_map[person_id].append(i) self.person_id_map[person_id].append(i)
else: else:
self.person_id_map[person_id] = [i] self.person_id_map[person_id] = [i]
assert len(self.anchor_list) > self.batch_size, "anchor should be larger than batch_size"
assert len(self.anchor_list) < self.batch_size, "anchor should be larger than batch_size"
def __iter__(self): def __iter__(self):
if self.shuffle:
random.shuffle(self.anchor_list) random.shuffle(self.anchor_list)
for i in range(len(self)): for i in range(len(self)):
batch_indices = [] batch_indices = []
...@@ -79,4 +80,4 @@ class WriterHardSampler(DistributedBatchSampler): ...@@ -79,4 +80,4 @@ class WriterHardSampler(DistributedBatchSampler):
yield batch_indices yield batch_indices
def __len__(self): def __len__(self):
len(self.anchor_list) * 4 // self.batch_size return len(self.anchor_list) * 4 // self.batch_size
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册