提交 84870b77 编写于 作者: H HydrogenSulfate

add assert for DistributedRandomIdentitySampler and refine docstring for...

add assert for DistributedRandomIdentitySampler and refine docstring for DistributedRandomIdentitySampler&Pksampler
上级 da70ef9c
......@@ -14,24 +14,27 @@
from __future__ import absolute_import
from __future__ import division
from collections import defaultdict
import numpy as np
import copy
import random
from collections import defaultdict
import numpy as np
from paddle.io import DistributedBatchSampler, Sampler
class DistributedRandomIdentitySampler(DistributedBatchSampler):
"""
Randomly sample N identities, then for each identity,
randomly sample K instances, therefore batch size is N*K.
"""Randomly sample N identities, then for each identity,
randomly sample K instances, therefore batch size equals to N * K.
Args:
- data_source (list): list of (img_path, pid, camid).
- num_instances (int): number of instances per identity in a batch.
- batch_size (int): number of examples in a batch.
dataset(Dataset): Dataset which contains list of (img_path, pid, camid))
batch_size (int): batch size
num_instances (int): number of instance(s) within an class
drop_last (bool): whether to discard the data at the end
"""
def __init__(self, dataset, batch_size, num_instances, drop_last, **args):
assert batch_size % num_instances == 0, \
f"batch_size({batch_size}) must be divisible by num_instances({num_instances}) when using DistributedRandomIdentitySampler"
self.dataset = dataset
self.batch_size = batch_size
self.num_instances = num_instances
......
......@@ -14,27 +14,27 @@
from __future__ import absolute_import
from __future__ import division
from collections import defaultdict
import numpy as np
import random
from paddle.io import DistributedBatchSampler
from ppcls.utils import logger
class PKSampler(DistributedBatchSampler):
"""
First, randomly sample P identities.
Then for each identity randomly sample K instances.
Therefore batch size is P*K, and the sampler called PKSampler.
"""First, randomly sample P identities.
Then for each identity randomly sample K instances.
Therefore batch size equals to P * K, and the sampler called PKSampler.
Args:
dataset (paddle.io.Dataset): list of (img_path, pid, cam_id).
sample_per_id(int): number of instances per identity in a batch.
batch_size (int): number of examples in a batch.
shuffle(bool): whether to shuffle indices order before generating
batch indices. Default False.
dataset (Dataset): Dataset which contains list of (img_path, pid, camid))
batch_size (_type_): batch size
sample_per_id (_type_): number of instance(s) within an class
shuffle (bool, optional): _description_. Defaults to True.
drop_last (bool, optional): whether to discard the data at the end. Defaults to True.
sample_method (str, optional): sample method when generating prob_list. Defaults to "sample_avg_prob".
"""
def __init__(self,
dataset,
batch_size,
......@@ -42,10 +42,9 @@ class PKSampler(DistributedBatchSampler):
shuffle=True,
drop_last=True,
sample_method="sample_avg_prob"):
super().__init__(
dataset, batch_size, shuffle=shuffle, drop_last=drop_last)
super().__init__(dataset, batch_size, shuffle=shuffle, drop_last=drop_last)
assert batch_size % sample_per_id == 0, \
"PKSampler configs error, Sample_per_id must be a divisor of batch_size."
f"PKSampler configs error, sample_per_id({sample_per_id}) must be a divisor of batch_size({batch_size})."
assert hasattr(self.dataset,
"labels"), "Dataset must have labels attribute."
self.sample_per_label = sample_per_id
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册