提交 9beb154b 编写于 作者: G gaotingquan 提交者: Wei Shengyu

support ShiTu

上级 a41b201e
...@@ -88,14 +88,15 @@ def worker_init_fn(worker_id: int, num_workers: int, rank: int, seed: int): ...@@ -88,14 +88,15 @@ def worker_init_fn(worker_id: int, num_workers: int, rank: int, seed: int):
random.seed(worker_seed) random.seed(worker_seed)
def build_dataloader(config, mode, seed=None): def build_dataloader(config, *mode, seed=None):
assert mode in [ dataloader_config = config["DataLoader"]
'Train', 'Eval', 'Test', 'Gallery', 'Query', 'UnLabelTrain' for m in mode:
], "Dataset mode should be Train, Eval, Test, Gallery, Query, UnLabelTrain" assert m in [
assert mode in config["DataLoader"].keys(), "{} config not in yaml".format( 'Train', 'Eval', 'Test', 'Gallery', 'Query', 'UnLabelTrain'
mode) ], "Dataset mode should be Train, Eval, Test, Gallery, Query, UnLabelTrain"
assert m in dataloader_config.keys(), "{} config not in yaml".format(m)
dataloader_config = config["DataLoader"][mode] dataloader_config = dataloader_config[m]
class_num = config["Arch"].get("class_num", None) class_num = config["Arch"].get("class_num", None)
epochs = config["Global"]["epochs"] epochs = config["Global"]["epochs"]
use_dali = config["Global"].get("use_dali", False) use_dali = config["Global"].get("use_dali", False)
......
...@@ -22,6 +22,7 @@ from paddle import nn ...@@ -22,6 +22,7 @@ from paddle import nn
import numpy as np import numpy as np
import random import random
from ..utils.amp import AMPForwardDecorator
from ppcls.utils import logger from ppcls.utils import logger
from ppcls.utils.logger import init_logger from ppcls.utils.logger import init_logger
from ppcls.utils.config import print_config from ppcls.utils.config import print_config
......
...@@ -13,17 +13,17 @@ ...@@ -13,17 +13,17 @@
# limitations under the License. # limitations under the License.
from .classification import ClassEval from .classification import ClassEval
from .retrieval import retrieval_eval from .retrieval import RetrievalEval
from .adaface import adaface_eval from .adaface import adaface_eval
def build_eval_func(config, mode, model): def build_eval_func(config, mode, model):
if mode not in ["eval", "train"]: if mode not in ["eval", "train"]:
return None return None
eval_mode = config["Global"].get("eval_mode", None) task = config["Global"].get("task", "classification")
if eval_mode is None: if task == "classification":
config["Global"]["eval_mode"] = "classification"
return ClassEval(config, mode, model) return ClassEval(config, mode, model)
elif task == "retrieval":
return RetrievalEval(config, mode, model)
else: else:
return getattr(sys.modules[__name__], eval_mode + "_eval")(config, raise Exception()
mode, model)
...@@ -21,182 +21,202 @@ import numpy as np ...@@ -21,182 +21,202 @@ import numpy as np
import paddle import paddle
import scipy import scipy
from ppcls.utils import all_gather, logger from ...utils.misc import AverageMeter
from ...utils import all_gather, logger
from ...data import build_dataloader
def retrieval_eval(engine, epoch_id=0): from ...loss import build_loss
engine.model.eval() from ...metric import build_metrics
# step1. prepare query and gallery features
if engine.gallery_query_dataloader is not None:
gallery_feat, gallery_label, gallery_camera = compute_feature( class RetrievalEval(object):
engine, "gallery_query") def __init__(self, config, mode, model):
query_feat, query_label, query_camera = gallery_feat, gallery_label, gallery_camera self.config = config
else: self.model = model
gallery_feat, gallery_label, gallery_camera = compute_feature( self.print_batch_step = self.config["Global"]["print_batch_step"]
engine, "gallery") self.use_dali = self.config["Global"].get("use_dali", False)
query_feat, query_label, query_camera = compute_feature(engine, self.eval_metric_func = build_metrics(self.config, "Eval")
"query") self.eval_loss_func = build_loss(self.config, "Eval")
self.output_info = dict()
# step2. split features into feature blocks for saving memory
num_query = len(query_feat) self.gallery_query_dataloader = None
block_size = engine.config["Global"].get("sim_block_size", 64) if len(self.config["DataLoader"]["Eval"].keys()) == 1:
sections = [block_size] * (num_query // block_size) self.gallery_query_dataloader = build_dataloader(self.config,
if num_query % block_size > 0: "Eval")
sections.append(num_query % block_size) else:
self.gallery_dataloader = build_dataloader(self.config, "Eval",
query_feat_blocks = paddle.split(query_feat, sections) "Gallery")
query_label_blocks = paddle.split(query_label, sections) self.query_dataloader = build_dataloader(self.config, "Eval",
query_camera_blocks = paddle.split( "Query")
query_camera, sections) if query_camera is not None else None
metric_key = None def __call__(self, epoch_id=0):
self.model.eval()
# step3. compute metric
if engine.eval_loss_func is None: # step1. prepare query and gallery features
metric_dict = {metric_key: 0.0} if self.gallery_query_dataloader is not None:
else: gallery_feat, gallery_label, gallery_camera = self.compute_feature(
use_reranking = engine.config["Global"].get("re_ranking", False) "gallery_query")
logger.info(f"re_ranking={use_reranking}") query_feat, query_label, query_camera = gallery_feat, gallery_label, gallery_camera
if use_reranking:
# compute distance matrix
distmat = compute_re_ranking_dist(
query_feat, gallery_feat, engine.config["Global"].get(
"feature_normalize", True), 20, 6, 0.3)
# exclude illegal distance
if query_camera is not None:
camera_mask = query_camera != gallery_camera.t()
label_mask = query_label != gallery_label.t()
keep_mask = label_mask | camera_mask
distmat = keep_mask.astype(query_feat.dtype) * distmat + (
~keep_mask).astype(query_feat.dtype) * (distmat.max() + 1)
else:
keep_mask = None
# compute metric with all samples
metric_dict = engine.eval_metric_func(-distmat, query_label,
gallery_label, keep_mask)
else: else:
metric_dict = defaultdict(float) gallery_feat, gallery_label, gallery_camera = self.compute_feature(
for block_idx, block_feat in enumerate(query_feat_blocks): "gallery")
query_feat, query_label, query_camera = self.compute_feature(
"query")
# step2. split features into feature blocks for saving memory
num_query = len(query_feat)
block_size = self.config["Global"].get("sim_block_size", 64)
sections = [block_size] * (num_query // block_size)
if num_query % block_size > 0:
sections.append(num_query % block_size)
query_feat_blocks = paddle.split(query_feat, sections)
query_label_blocks = paddle.split(query_label, sections)
query_camera_blocks = paddle.split(
query_camera, sections) if query_camera is not None else None
metric_key = None
# step3. compute metric
if self.eval_loss_func is None:
metric_dict = {metric_key: 0.0}
else:
use_reranking = self.config["Global"].get("re_ranking", False)
logger.info(f"re_ranking={use_reranking}")
if use_reranking:
# compute distance matrix # compute distance matrix
distmat = paddle.matmul( distmat = compute_re_ranking_dist(
block_feat, gallery_feat, transpose_y=True) query_feat, gallery_feat, self.config["Global"].get(
"feature_normalize", True), 20, 6, 0.3)
# exclude illegal distance # exclude illegal distance
if query_camera is not None: if query_camera is not None:
camera_mask = query_camera_blocks[ camera_mask = query_camera != gallery_camera.t()
block_idx] != gallery_camera.t() label_mask = query_label != gallery_label.t()
label_mask = query_label_blocks[
block_idx] != gallery_label.t()
keep_mask = label_mask | camera_mask keep_mask = label_mask | camera_mask
distmat = keep_mask.astype(query_feat.dtype) * distmat distmat = keep_mask.astype(query_feat.dtype) * distmat + (
~keep_mask).astype(query_feat.dtype) * (distmat.max() +
1)
else: else:
keep_mask = None keep_mask = None
# compute metric by block # compute metric with all samples
metric_block = engine.eval_metric_func( metric_dict = self.eval_metric_func(-distmat, query_label,
distmat, query_label_blocks[block_idx], gallery_label, gallery_label, keep_mask)
keep_mask) else:
# accumulate metric metric_dict = defaultdict(float)
for key in metric_block: for block_idx, block_feat in enumerate(query_feat_blocks):
metric_dict[key] += metric_block[key] * block_feat.shape[ # compute distance matrix
0] / num_query distmat = paddle.matmul(
block_feat, gallery_feat, transpose_y=True)
metric_info_list = [] # exclude illegal distance
for key, value in metric_dict.items(): if query_camera is not None:
metric_info_list.append(f"{key}: {value:.5f}") camera_mask = query_camera_blocks[
if metric_key is None: block_idx] != gallery_camera.t()
metric_key = key label_mask = query_label_blocks[
metric_msg = ", ".join(metric_info_list) block_idx] != gallery_label.t()
logger.info(f"[Eval][Epoch {epoch_id}][Avg]{metric_msg}") keep_mask = label_mask | camera_mask
distmat = keep_mask.astype(query_feat.dtype) * distmat
return metric_dict[metric_key] else:
keep_mask = None
# compute metric by block
def compute_feature(engine, name="gallery"): metric_block = self.eval_metric_func(
if name == "gallery": distmat, query_label_blocks[block_idx], gallery_label,
dataloader = engine.gallery_dataloader keep_mask)
elif name == "query": # accumulate metric
dataloader = engine.query_dataloader for key in metric_block:
elif name == "gallery_query": metric_dict[key] += metric_block[
dataloader = engine.gallery_query_dataloader key] * block_feat.shape[0] / num_query
else:
raise ValueError( metric_info_list = []
f"Only support gallery or query or gallery_query dataset, but got {name}" for key, value in metric_dict.items():
) metric_info_list.append(f"{key}: {value:.5f}")
if metric_key is None:
all_feat = [] metric_key = key
all_label = [] metric_msg = ", ".join(metric_info_list)
all_camera = [] logger.info(f"[Eval][Epoch {epoch_id}][Avg]{metric_msg}")
has_camera = False
for idx, batch in enumerate(dataloader): # load is very time-consuming return metric_dict[metric_key]
if idx % engine.config["Global"]["print_batch_step"] == 0:
logger.info( def compute_feature(self, name="gallery"):
f"{name} feature calculation process: [{idx}/{len(dataloader)}]" if name == "gallery":
dataloader = self.gallery_dataloader
elif name == "query":
dataloader = self.query_dataloader
elif name == "gallery_query":
dataloader = self.gallery_query_dataloader
else:
raise ValueError(
f"Only support gallery or query or gallery_query dataset, but got {name}"
) )
batch = [paddle.to_tensor(x) for x in batch] all_feat = []
batch[1] = batch[1].reshape([-1, 1]).astype("int64") all_label = []
if len(batch) >= 3: all_camera = []
has_camera = True has_camera = False
batch[2] = batch[2].reshape([-1, 1]).astype("int64") for idx, batch in enumerate(dataloader): # load is very time-consuming
if engine.amp and engine.amp_eval: if idx % self.print_batch_step == 0:
with paddle.amp.auto_cast( logger.info(
custom_black_list={ f"{name} feature calculation process: [{idx}/{len(dataloader)}]"
"flatten_contiguous_range", "greater_than" )
},
level=engine.amp_level): batch = [paddle.to_tensor(x) for x in batch]
out = engine.model(batch[0]) batch[1] = batch[1].reshape([-1, 1]).astype("int64")
else: if len(batch) >= 3:
out = engine.model(batch[0]) has_camera = True
if "Student" in out: batch[2] = batch[2].reshape([-1, 1]).astype("int64")
out = out["Student"]
out = self.model(batch)
# get features
if engine.config["Global"].get("retrieval_feature_from", if "Student" in out:
"features") == "features": out = out["Student"]
# use output from neck as feature
batch_feat = out["features"] # get features
else: if self.config["Global"].get("retrieval_feature_from",
# use output from backbone as feature "features") == "features":
batch_feat = out["backbone"] # use output from neck as feature
batch_feat = out["features"]
# do norm(optional) else:
if engine.config["Global"].get("feature_normalize", True): # use output from backbone as feature
batch_feat = paddle.nn.functional.normalize(batch_feat, p=2) batch_feat = out["backbone"]
# do binarize(optional) # do norm(optional)
if engine.config["Global"].get("feature_binarize") == "round": if self.config["Global"].get("feature_normalize", True):
batch_feat = paddle.round(batch_feat).astype("float32") * 2.0 - 1.0 batch_feat = paddle.nn.functional.normalize(batch_feat, p=2)
elif engine.config["Global"].get("feature_binarize") == "sign":
batch_feat = paddle.sign(batch_feat).astype("float32") # do binarize(optional)
if self.config["Global"].get("feature_binarize") == "round":
if paddle.distributed.get_world_size() > 1: batch_feat = paddle.round(batch_feat).astype(
all_feat.append(all_gather(batch_feat)) "float32") * 2.0 - 1.0
all_label.append(all_gather(batch[1])) elif self.config["Global"].get("feature_binarize") == "sign":
if has_camera: batch_feat = paddle.sign(batch_feat).astype("float32")
all_camera.append(all_gather(batch[2]))
if paddle.distributed.get_world_size() > 1:
all_feat.append(all_gather(batch_feat))
all_label.append(all_gather(batch[1]))
if has_camera:
all_camera.append(all_gather(batch[2]))
else:
all_feat.append(batch_feat)
all_label.append(batch[1])
if has_camera:
all_camera.append(batch[2])
if self.use_dali:
dataloader.reset()
all_feat = paddle.concat(all_feat)
all_label = paddle.concat(all_label)
if has_camera:
all_camera = paddle.concat(all_camera)
else: else:
all_feat.append(batch_feat) all_camera = None
all_label.append(batch[1]) # discard redundant padding sample(s) at the end
if has_camera: total_samples = dataloader.size if self.use_dali else len(
all_camera.append(batch[2]) dataloader.dataset)
all_feat = all_feat[:total_samples]
if engine.use_dali: all_label = all_label[:total_samples]
dataloader.reset() if has_camera:
all_camera = all_camera[:total_samples]
all_feat = paddle.concat(all_feat)
all_label = paddle.concat(all_label) logger.info(f"Build {name} done, all feat shape: {all_feat.shape}")
if has_camera: return all_feat, all_label, all_camera
all_camera = paddle.concat(all_camera)
else:
all_camera = None
# discard redundant padding sample(s) at the end
total_samples = dataloader.size if engine.use_dali else len(
dataloader.dataset)
all_feat = all_feat[:total_samples]
all_label = all_label[:total_samples]
if has_camera:
all_camera = all_camera[:total_samples]
logger.info(f"Build {name} done, all feat shape: {all_feat.shape}")
return all_feat, all_label, all_camera
def k_reciprocal_neighbor(rank: np.ndarray, p: int, k: int) -> np.ndarray: def k_reciprocal_neighbor(rank: np.ndarray, p: int, k: int) -> np.ndarray:
......
...@@ -22,9 +22,8 @@ from .train_progressive import train_epoch_progressive ...@@ -22,9 +22,8 @@ from .train_progressive import train_epoch_progressive
def build_train_func(config, mode, model, eval_func): def build_train_func(config, mode, model, eval_func):
if mode != "train": if mode != "train":
return None return None
train_mode = config["Global"].get("task", None) task = config["Global"].get("task", "classification")
if train_mode is None: if task == "classification" or task == "retrieval":
config["Global"]["task"] = "classification"
return ClassTrainer(config, model, eval_func) return ClassTrainer(config, model, eval_func)
else: else:
return getattr(sys.modules[__name__], "train_epoch_" + train_mode)( return getattr(sys.modules[__name__], "train_epoch_" + train_mode)(
......
...@@ -15,7 +15,7 @@ from __future__ import absolute_import, division, print_function ...@@ -15,7 +15,7 @@ from __future__ import absolute_import, division, print_function
from ppcls.data import build_dataloader from ppcls.data import build_dataloader
from ppcls.utils import logger, type_name from ppcls.utils import logger, type_name
from .regular_train_epoch import regular_train_epoch from .classification import ClassTrainer
def train_epoch_progressive(engine, epoch_id, print_batch_step): def train_epoch_progressive(engine, epoch_id, print_batch_step):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册