提交 9beb154b 编写于 作者: G gaotingquan 提交者: Wei Shengyu

support ShiTu

上级 a41b201e
......@@ -88,14 +88,15 @@ def worker_init_fn(worker_id: int, num_workers: int, rank: int, seed: int):
random.seed(worker_seed)
def build_dataloader(config, mode, seed=None):
assert mode in [
'Train', 'Eval', 'Test', 'Gallery', 'Query', 'UnLabelTrain'
], "Dataset mode should be Train, Eval, Test, Gallery, Query, UnLabelTrain"
assert mode in config["DataLoader"].keys(), "{} config not in yaml".format(
mode)
dataloader_config = config["DataLoader"][mode]
def build_dataloader(config, *mode, seed=None):
dataloader_config = config["DataLoader"]
for m in mode:
assert m in [
'Train', 'Eval', 'Test', 'Gallery', 'Query', 'UnLabelTrain'
], "Dataset mode should be Train, Eval, Test, Gallery, Query, UnLabelTrain"
assert m in dataloader_config.keys(), "{} config not in yaml".format(m)
dataloader_config = dataloader_config[m]
class_num = config["Arch"].get("class_num", None)
epochs = config["Global"]["epochs"]
use_dali = config["Global"].get("use_dali", False)
......
......@@ -22,6 +22,7 @@ from paddle import nn
import numpy as np
import random
from ..utils.amp import AMPForwardDecorator
from ppcls.utils import logger
from ppcls.utils.logger import init_logger
from ppcls.utils.config import print_config
......
......@@ -13,17 +13,17 @@
# limitations under the License.
from .classification import ClassEval
from .retrieval import retrieval_eval
from .retrieval import RetrievalEval
from .adaface import adaface_eval
def build_eval_func(config, mode, model):
if mode not in ["eval", "train"]:
return None
eval_mode = config["Global"].get("eval_mode", None)
if eval_mode is None:
config["Global"]["eval_mode"] = "classification"
task = config["Global"].get("task", "classification")
if task == "classification":
return ClassEval(config, mode, model)
elif task == "retrieval":
return RetrievalEval(config, mode, model)
else:
return getattr(sys.modules[__name__], eval_mode + "_eval")(config,
mode, model)
raise Exception()
......@@ -21,182 +21,202 @@ import numpy as np
import paddle
import scipy
from ppcls.utils import all_gather, logger
def retrieval_eval(engine, epoch_id=0):
engine.model.eval()
# step1. prepare query and gallery features
if engine.gallery_query_dataloader is not None:
gallery_feat, gallery_label, gallery_camera = compute_feature(
engine, "gallery_query")
query_feat, query_label, query_camera = gallery_feat, gallery_label, gallery_camera
else:
gallery_feat, gallery_label, gallery_camera = compute_feature(
engine, "gallery")
query_feat, query_label, query_camera = compute_feature(engine,
"query")
# step2. split features into feature blocks for saving memory
num_query = len(query_feat)
block_size = engine.config["Global"].get("sim_block_size", 64)
sections = [block_size] * (num_query // block_size)
if num_query % block_size > 0:
sections.append(num_query % block_size)
query_feat_blocks = paddle.split(query_feat, sections)
query_label_blocks = paddle.split(query_label, sections)
query_camera_blocks = paddle.split(
query_camera, sections) if query_camera is not None else None
metric_key = None
# step3. compute metric
if engine.eval_loss_func is None:
metric_dict = {metric_key: 0.0}
else:
use_reranking = engine.config["Global"].get("re_ranking", False)
logger.info(f"re_ranking={use_reranking}")
if use_reranking:
# compute distance matrix
distmat = compute_re_ranking_dist(
query_feat, gallery_feat, engine.config["Global"].get(
"feature_normalize", True), 20, 6, 0.3)
# exclude illegal distance
if query_camera is not None:
camera_mask = query_camera != gallery_camera.t()
label_mask = query_label != gallery_label.t()
keep_mask = label_mask | camera_mask
distmat = keep_mask.astype(query_feat.dtype) * distmat + (
~keep_mask).astype(query_feat.dtype) * (distmat.max() + 1)
else:
keep_mask = None
# compute metric with all samples
metric_dict = engine.eval_metric_func(-distmat, query_label,
gallery_label, keep_mask)
from ...utils.misc import AverageMeter
from ...utils import all_gather, logger
from ...data import build_dataloader
from ...loss import build_loss
from ...metric import build_metrics
class RetrievalEval(object):
def __init__(self, config, mode, model):
self.config = config
self.model = model
self.print_batch_step = self.config["Global"]["print_batch_step"]
self.use_dali = self.config["Global"].get("use_dali", False)
self.eval_metric_func = build_metrics(self.config, "Eval")
self.eval_loss_func = build_loss(self.config, "Eval")
self.output_info = dict()
self.gallery_query_dataloader = None
if len(self.config["DataLoader"]["Eval"].keys()) == 1:
self.gallery_query_dataloader = build_dataloader(self.config,
"Eval")
else:
self.gallery_dataloader = build_dataloader(self.config, "Eval",
"Gallery")
self.query_dataloader = build_dataloader(self.config, "Eval",
"Query")
def __call__(self, epoch_id=0):
self.model.eval()
# step1. prepare query and gallery features
if self.gallery_query_dataloader is not None:
gallery_feat, gallery_label, gallery_camera = self.compute_feature(
"gallery_query")
query_feat, query_label, query_camera = gallery_feat, gallery_label, gallery_camera
else:
metric_dict = defaultdict(float)
for block_idx, block_feat in enumerate(query_feat_blocks):
gallery_feat, gallery_label, gallery_camera = self.compute_feature(
"gallery")
query_feat, query_label, query_camera = self.compute_feature(
"query")
# step2. split features into feature blocks for saving memory
num_query = len(query_feat)
block_size = self.config["Global"].get("sim_block_size", 64)
sections = [block_size] * (num_query // block_size)
if num_query % block_size > 0:
sections.append(num_query % block_size)
query_feat_blocks = paddle.split(query_feat, sections)
query_label_blocks = paddle.split(query_label, sections)
query_camera_blocks = paddle.split(
query_camera, sections) if query_camera is not None else None
metric_key = None
# step3. compute metric
if self.eval_loss_func is None:
metric_dict = {metric_key: 0.0}
else:
use_reranking = self.config["Global"].get("re_ranking", False)
logger.info(f"re_ranking={use_reranking}")
if use_reranking:
# compute distance matrix
distmat = paddle.matmul(
block_feat, gallery_feat, transpose_y=True)
distmat = compute_re_ranking_dist(
query_feat, gallery_feat, self.config["Global"].get(
"feature_normalize", True), 20, 6, 0.3)
# exclude illegal distance
if query_camera is not None:
camera_mask = query_camera_blocks[
block_idx] != gallery_camera.t()
label_mask = query_label_blocks[
block_idx] != gallery_label.t()
camera_mask = query_camera != gallery_camera.t()
label_mask = query_label != gallery_label.t()
keep_mask = label_mask | camera_mask
distmat = keep_mask.astype(query_feat.dtype) * distmat
distmat = keep_mask.astype(query_feat.dtype) * distmat + (
~keep_mask).astype(query_feat.dtype) * (distmat.max() +
1)
else:
keep_mask = None
# compute metric by block
metric_block = engine.eval_metric_func(
distmat, query_label_blocks[block_idx], gallery_label,
keep_mask)
# accumulate metric
for key in metric_block:
metric_dict[key] += metric_block[key] * block_feat.shape[
0] / num_query
metric_info_list = []
for key, value in metric_dict.items():
metric_info_list.append(f"{key}: {value:.5f}")
if metric_key is None:
metric_key = key
metric_msg = ", ".join(metric_info_list)
logger.info(f"[Eval][Epoch {epoch_id}][Avg]{metric_msg}")
return metric_dict[metric_key]
def compute_feature(engine, name="gallery"):
if name == "gallery":
dataloader = engine.gallery_dataloader
elif name == "query":
dataloader = engine.query_dataloader
elif name == "gallery_query":
dataloader = engine.gallery_query_dataloader
else:
raise ValueError(
f"Only support gallery or query or gallery_query dataset, but got {name}"
)
all_feat = []
all_label = []
all_camera = []
has_camera = False
for idx, batch in enumerate(dataloader): # load is very time-consuming
if idx % engine.config["Global"]["print_batch_step"] == 0:
logger.info(
f"{name} feature calculation process: [{idx}/{len(dataloader)}]"
# compute metric with all samples
metric_dict = self.eval_metric_func(-distmat, query_label,
gallery_label, keep_mask)
else:
metric_dict = defaultdict(float)
for block_idx, block_feat in enumerate(query_feat_blocks):
# compute distance matrix
distmat = paddle.matmul(
block_feat, gallery_feat, transpose_y=True)
# exclude illegal distance
if query_camera is not None:
camera_mask = query_camera_blocks[
block_idx] != gallery_camera.t()
label_mask = query_label_blocks[
block_idx] != gallery_label.t()
keep_mask = label_mask | camera_mask
distmat = keep_mask.astype(query_feat.dtype) * distmat
else:
keep_mask = None
# compute metric by block
metric_block = self.eval_metric_func(
distmat, query_label_blocks[block_idx], gallery_label,
keep_mask)
# accumulate metric
for key in metric_block:
metric_dict[key] += metric_block[
key] * block_feat.shape[0] / num_query
metric_info_list = []
for key, value in metric_dict.items():
metric_info_list.append(f"{key}: {value:.5f}")
if metric_key is None:
metric_key = key
metric_msg = ", ".join(metric_info_list)
logger.info(f"[Eval][Epoch {epoch_id}][Avg]{metric_msg}")
return metric_dict[metric_key]
def compute_feature(self, name="gallery"):
if name == "gallery":
dataloader = self.gallery_dataloader
elif name == "query":
dataloader = self.query_dataloader
elif name == "gallery_query":
dataloader = self.gallery_query_dataloader
else:
raise ValueError(
f"Only support gallery or query or gallery_query dataset, but got {name}"
)
batch = [paddle.to_tensor(x) for x in batch]
batch[1] = batch[1].reshape([-1, 1]).astype("int64")
if len(batch) >= 3:
has_camera = True
batch[2] = batch[2].reshape([-1, 1]).astype("int64")
if engine.amp and engine.amp_eval:
with paddle.amp.auto_cast(
custom_black_list={
"flatten_contiguous_range", "greater_than"
},
level=engine.amp_level):
out = engine.model(batch[0])
else:
out = engine.model(batch[0])
if "Student" in out:
out = out["Student"]
# get features
if engine.config["Global"].get("retrieval_feature_from",
"features") == "features":
# use output from neck as feature
batch_feat = out["features"]
else:
# use output from backbone as feature
batch_feat = out["backbone"]
# do norm(optional)
if engine.config["Global"].get("feature_normalize", True):
batch_feat = paddle.nn.functional.normalize(batch_feat, p=2)
# do binarize(optional)
if engine.config["Global"].get("feature_binarize") == "round":
batch_feat = paddle.round(batch_feat).astype("float32") * 2.0 - 1.0
elif engine.config["Global"].get("feature_binarize") == "sign":
batch_feat = paddle.sign(batch_feat).astype("float32")
if paddle.distributed.get_world_size() > 1:
all_feat.append(all_gather(batch_feat))
all_label.append(all_gather(batch[1]))
if has_camera:
all_camera.append(all_gather(batch[2]))
all_feat = []
all_label = []
all_camera = []
has_camera = False
for idx, batch in enumerate(dataloader): # load is very time-consuming
if idx % self.print_batch_step == 0:
logger.info(
f"{name} feature calculation process: [{idx}/{len(dataloader)}]"
)
batch = [paddle.to_tensor(x) for x in batch]
batch[1] = batch[1].reshape([-1, 1]).astype("int64")
if len(batch) >= 3:
has_camera = True
batch[2] = batch[2].reshape([-1, 1]).astype("int64")
out = self.model(batch)
if "Student" in out:
out = out["Student"]
# get features
if self.config["Global"].get("retrieval_feature_from",
"features") == "features":
# use output from neck as feature
batch_feat = out["features"]
else:
# use output from backbone as feature
batch_feat = out["backbone"]
# do norm(optional)
if self.config["Global"].get("feature_normalize", True):
batch_feat = paddle.nn.functional.normalize(batch_feat, p=2)
# do binarize(optional)
if self.config["Global"].get("feature_binarize") == "round":
batch_feat = paddle.round(batch_feat).astype(
"float32") * 2.0 - 1.0
elif self.config["Global"].get("feature_binarize") == "sign":
batch_feat = paddle.sign(batch_feat).astype("float32")
if paddle.distributed.get_world_size() > 1:
all_feat.append(all_gather(batch_feat))
all_label.append(all_gather(batch[1]))
if has_camera:
all_camera.append(all_gather(batch[2]))
else:
all_feat.append(batch_feat)
all_label.append(batch[1])
if has_camera:
all_camera.append(batch[2])
if self.use_dali:
dataloader.reset()
all_feat = paddle.concat(all_feat)
all_label = paddle.concat(all_label)
if has_camera:
all_camera = paddle.concat(all_camera)
else:
all_feat.append(batch_feat)
all_label.append(batch[1])
if has_camera:
all_camera.append(batch[2])
if engine.use_dali:
dataloader.reset()
all_feat = paddle.concat(all_feat)
all_label = paddle.concat(all_label)
if has_camera:
all_camera = paddle.concat(all_camera)
else:
all_camera = None
# discard redundant padding sample(s) at the end
total_samples = dataloader.size if engine.use_dali else len(
dataloader.dataset)
all_feat = all_feat[:total_samples]
all_label = all_label[:total_samples]
if has_camera:
all_camera = all_camera[:total_samples]
logger.info(f"Build {name} done, all feat shape: {all_feat.shape}")
return all_feat, all_label, all_camera
all_camera = None
# discard redundant padding sample(s) at the end
total_samples = dataloader.size if self.use_dali else len(
dataloader.dataset)
all_feat = all_feat[:total_samples]
all_label = all_label[:total_samples]
if has_camera:
all_camera = all_camera[:total_samples]
logger.info(f"Build {name} done, all feat shape: {all_feat.shape}")
return all_feat, all_label, all_camera
def k_reciprocal_neighbor(rank: np.ndarray, p: int, k: int) -> np.ndarray:
......
......@@ -22,9 +22,8 @@ from .train_progressive import train_epoch_progressive
def build_train_func(config, mode, model, eval_func):
if mode != "train":
return None
train_mode = config["Global"].get("task", None)
if train_mode is None:
config["Global"]["task"] = "classification"
task = config["Global"].get("task", "classification")
if task == "classification" or task == "retrieval":
return ClassTrainer(config, model, eval_func)
else:
return getattr(sys.modules[__name__], "train_epoch_" + train_mode)(
......
......@@ -15,7 +15,7 @@ from __future__ import absolute_import, division, print_function
from ppcls.data import build_dataloader
from ppcls.utils import logger, type_name
from .regular_train_epoch import regular_train_epoch
from .classification import ClassTrainer
def train_epoch_progressive(engine, epoch_id, print_batch_step):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册