未验证 提交 76c3f1c6 编写于 作者: H HydrogenSulfate 提交者: GitHub

Merge pull request #2213 from HydrogenSulfate/refine_pksampler

add assertion for DistributedRandomIdentitySampler
...@@ -14,24 +14,27 @@ ...@@ -14,24 +14,27 @@
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
from collections import defaultdict
import numpy as np
import copy import copy
import random import random
from collections import defaultdict
import numpy as np
from paddle.io import DistributedBatchSampler, Sampler from paddle.io import DistributedBatchSampler, Sampler
class DistributedRandomIdentitySampler(DistributedBatchSampler): class DistributedRandomIdentitySampler(DistributedBatchSampler):
""" """Randomly sample N identities, then for each identity,
Randomly sample N identities, then for each identity, randomly sample K instances, therefore batch size equals to N * K.
randomly sample K instances, therefore batch size is N*K.
Args: Args:
- data_source (list): list of (img_path, pid, camid). dataset(Dataset): Dataset which contains list of (img_path, pid, camid))
- num_instances (int): number of instances per identity in a batch. batch_size (int): batch size
- batch_size (int): number of examples in a batch. num_instances (int): number of instance(s) within an class
drop_last (bool): whether to discard the data at the end
""" """
def __init__(self, dataset, batch_size, num_instances, drop_last, **args): def __init__(self, dataset, batch_size, num_instances, drop_last, **args):
assert batch_size % num_instances == 0, \
f"batch_size({batch_size}) must be divisible by num_instances({num_instances}) when using DistributedRandomIdentitySampler"
self.dataset = dataset self.dataset = dataset
self.batch_size = batch_size self.batch_size = batch_size
self.num_instances = num_instances self.num_instances = num_instances
......
...@@ -14,27 +14,27 @@ ...@@ -14,27 +14,27 @@
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
from collections import defaultdict from collections import defaultdict
import numpy as np import numpy as np
import random
from paddle.io import DistributedBatchSampler from paddle.io import DistributedBatchSampler
from ppcls.utils import logger from ppcls.utils import logger
class PKSampler(DistributedBatchSampler): class PKSampler(DistributedBatchSampler):
""" """First, randomly sample P identities.
First, randomly sample P identities. Then for each identity randomly sample K instances.
Then for each identity randomly sample K instances. Therefore batch size equals to P * K, and the sampler called PKSampler.
Therefore batch size is P*K, and the sampler called PKSampler.
Args: Args:
dataset (paddle.io.Dataset): list of (img_path, pid, cam_id). dataset (Dataset): Dataset which contains list of (img_path, pid, camid))
sample_per_id(int): number of instances per identity in a batch. batch_size (int): batch size
batch_size (int): number of examples in a batch. sample_per_id (int): number of instance(s) within an class
shuffle(bool): whether to shuffle indices order before generating shuffle (bool, optional): _description_. Defaults to True.
batch indices. Default False. drop_last (bool, optional): whether to discard the data at the end. Defaults to True.
sample_method (str, optional): sample method when generating prob_list. Defaults to "sample_avg_prob".
""" """
def __init__(self, def __init__(self,
dataset, dataset,
batch_size, batch_size,
...@@ -42,10 +42,9 @@ class PKSampler(DistributedBatchSampler): ...@@ -42,10 +42,9 @@ class PKSampler(DistributedBatchSampler):
shuffle=True, shuffle=True,
drop_last=True, drop_last=True,
sample_method="sample_avg_prob"): sample_method="sample_avg_prob"):
super().__init__( super().__init__(dataset, batch_size, shuffle=shuffle, drop_last=drop_last)
dataset, batch_size, shuffle=shuffle, drop_last=drop_last)
assert batch_size % sample_per_id == 0, \ assert batch_size % sample_per_id == 0, \
"PKSampler configs error, Sample_per_id must be a divisor of batch_size." f"PKSampler configs error, sample_per_id({sample_per_id}) must be a divisor of batch_size({batch_size})."
assert hasattr(self.dataset, assert hasattr(self.dataset,
"labels"), "Dataset must have labels attribute." "labels"), "Dataset must have labels attribute."
self.sample_per_label = sample_per_id self.sample_per_label = sample_per_id
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册