提交 9beb154b 编写于 作者: G gaotingquan 提交者: Wei Shengyu

support ShiTu

上级 a41b201e
...@@ -88,14 +88,15 @@ def worker_init_fn(worker_id: int, num_workers: int, rank: int, seed: int): ...@@ -88,14 +88,15 @@ def worker_init_fn(worker_id: int, num_workers: int, rank: int, seed: int):
random.seed(worker_seed) random.seed(worker_seed)
def build_dataloader(config, mode, seed=None): def build_dataloader(config, *mode, seed=None):
assert mode in [ dataloader_config = config["DataLoader"]
for m in mode:
assert m in [
'Train', 'Eval', 'Test', 'Gallery', 'Query', 'UnLabelTrain' 'Train', 'Eval', 'Test', 'Gallery', 'Query', 'UnLabelTrain'
], "Dataset mode should be Train, Eval, Test, Gallery, Query, UnLabelTrain" ], "Dataset mode should be Train, Eval, Test, Gallery, Query, UnLabelTrain"
assert mode in config["DataLoader"].keys(), "{} config not in yaml".format( assert m in dataloader_config.keys(), "{} config not in yaml".format(m)
mode) dataloader_config = dataloader_config[m]
dataloader_config = config["DataLoader"][mode]
class_num = config["Arch"].get("class_num", None) class_num = config["Arch"].get("class_num", None)
epochs = config["Global"]["epochs"] epochs = config["Global"]["epochs"]
use_dali = config["Global"].get("use_dali", False) use_dali = config["Global"].get("use_dali", False)
......
...@@ -22,6 +22,7 @@ from paddle import nn ...@@ -22,6 +22,7 @@ from paddle import nn
import numpy as np import numpy as np
import random import random
from ..utils.amp import AMPForwardDecorator
from ppcls.utils import logger from ppcls.utils import logger
from ppcls.utils.logger import init_logger from ppcls.utils.logger import init_logger
from ppcls.utils.config import print_config from ppcls.utils.config import print_config
......
...@@ -13,17 +13,17 @@ ...@@ -13,17 +13,17 @@
# limitations under the License. # limitations under the License.
from .classification import ClassEval from .classification import ClassEval
from .retrieval import retrieval_eval from .retrieval import RetrievalEval
from .adaface import adaface_eval from .adaface import adaface_eval
def build_eval_func(config, mode, model): def build_eval_func(config, mode, model):
if mode not in ["eval", "train"]: if mode not in ["eval", "train"]:
return None return None
eval_mode = config["Global"].get("eval_mode", None) task = config["Global"].get("task", "classification")
if eval_mode is None: if task == "classification":
config["Global"]["eval_mode"] = "classification"
return ClassEval(config, mode, model) return ClassEval(config, mode, model)
elif task == "retrieval":
return RetrievalEval(config, mode, model)
else: else:
return getattr(sys.modules[__name__], eval_mode + "_eval")(config, raise Exception()
mode, model)
...@@ -21,25 +21,50 @@ import numpy as np ...@@ -21,25 +21,50 @@ import numpy as np
import paddle import paddle
import scipy import scipy
from ppcls.utils import all_gather, logger from ...utils.misc import AverageMeter
from ...utils import all_gather, logger
from ...data import build_dataloader
from ...loss import build_loss
from ...metric import build_metrics
class RetrievalEval(object):
def __init__(self, config, mode, model):
self.config = config
self.model = model
self.print_batch_step = self.config["Global"]["print_batch_step"]
self.use_dali = self.config["Global"].get("use_dali", False)
self.eval_metric_func = build_metrics(self.config, "Eval")
self.eval_loss_func = build_loss(self.config, "Eval")
self.output_info = dict()
self.gallery_query_dataloader = None
if len(self.config["DataLoader"]["Eval"].keys()) == 1:
self.gallery_query_dataloader = build_dataloader(self.config,
"Eval")
else:
self.gallery_dataloader = build_dataloader(self.config, "Eval",
"Gallery")
self.query_dataloader = build_dataloader(self.config, "Eval",
"Query")
def __call__(self, epoch_id=0):
self.model.eval()
def retrieval_eval(engine, epoch_id=0):
engine.model.eval()
# step1. prepare query and gallery features # step1. prepare query and gallery features
if engine.gallery_query_dataloader is not None: if self.gallery_query_dataloader is not None:
gallery_feat, gallery_label, gallery_camera = compute_feature( gallery_feat, gallery_label, gallery_camera = self.compute_feature(
engine, "gallery_query") "gallery_query")
query_feat, query_label, query_camera = gallery_feat, gallery_label, gallery_camera query_feat, query_label, query_camera = gallery_feat, gallery_label, gallery_camera
else: else:
gallery_feat, gallery_label, gallery_camera = compute_feature( gallery_feat, gallery_label, gallery_camera = self.compute_feature(
engine, "gallery") "gallery")
query_feat, query_label, query_camera = compute_feature(engine, query_feat, query_label, query_camera = self.compute_feature(
"query") "query")
# step2. split features into feature blocks for saving memory # step2. split features into feature blocks for saving memory
num_query = len(query_feat) num_query = len(query_feat)
block_size = engine.config["Global"].get("sim_block_size", 64) block_size = self.config["Global"].get("sim_block_size", 64)
sections = [block_size] * (num_query // block_size) sections = [block_size] * (num_query // block_size)
if num_query % block_size > 0: if num_query % block_size > 0:
sections.append(num_query % block_size) sections.append(num_query % block_size)
...@@ -51,15 +76,15 @@ def retrieval_eval(engine, epoch_id=0): ...@@ -51,15 +76,15 @@ def retrieval_eval(engine, epoch_id=0):
metric_key = None metric_key = None
# step3. compute metric # step3. compute metric
if engine.eval_loss_func is None: if self.eval_loss_func is None:
metric_dict = {metric_key: 0.0} metric_dict = {metric_key: 0.0}
else: else:
use_reranking = engine.config["Global"].get("re_ranking", False) use_reranking = self.config["Global"].get("re_ranking", False)
logger.info(f"re_ranking={use_reranking}") logger.info(f"re_ranking={use_reranking}")
if use_reranking: if use_reranking:
# compute distance matrix # compute distance matrix
distmat = compute_re_ranking_dist( distmat = compute_re_ranking_dist(
query_feat, gallery_feat, engine.config["Global"].get( query_feat, gallery_feat, self.config["Global"].get(
"feature_normalize", True), 20, 6, 0.3) "feature_normalize", True), 20, 6, 0.3)
# exclude illegal distance # exclude illegal distance
if query_camera is not None: if query_camera is not None:
...@@ -67,11 +92,12 @@ def retrieval_eval(engine, epoch_id=0): ...@@ -67,11 +92,12 @@ def retrieval_eval(engine, epoch_id=0):
label_mask = query_label != gallery_label.t() label_mask = query_label != gallery_label.t()
keep_mask = label_mask | camera_mask keep_mask = label_mask | camera_mask
distmat = keep_mask.astype(query_feat.dtype) * distmat + ( distmat = keep_mask.astype(query_feat.dtype) * distmat + (
~keep_mask).astype(query_feat.dtype) * (distmat.max() + 1) ~keep_mask).astype(query_feat.dtype) * (distmat.max() +
1)
else: else:
keep_mask = None keep_mask = None
# compute metric with all samples # compute metric with all samples
metric_dict = engine.eval_metric_func(-distmat, query_label, metric_dict = self.eval_metric_func(-distmat, query_label,
gallery_label, keep_mask) gallery_label, keep_mask)
else: else:
metric_dict = defaultdict(float) metric_dict = defaultdict(float)
...@@ -90,13 +116,13 @@ def retrieval_eval(engine, epoch_id=0): ...@@ -90,13 +116,13 @@ def retrieval_eval(engine, epoch_id=0):
else: else:
keep_mask = None keep_mask = None
# compute metric by block # compute metric by block
metric_block = engine.eval_metric_func( metric_block = self.eval_metric_func(
distmat, query_label_blocks[block_idx], gallery_label, distmat, query_label_blocks[block_idx], gallery_label,
keep_mask) keep_mask)
# accumulate metric # accumulate metric
for key in metric_block: for key in metric_block:
metric_dict[key] += metric_block[key] * block_feat.shape[ metric_dict[key] += metric_block[
0] / num_query key] * block_feat.shape[0] / num_query
metric_info_list = [] metric_info_list = []
for key, value in metric_dict.items(): for key, value in metric_dict.items():
...@@ -108,14 +134,13 @@ def retrieval_eval(engine, epoch_id=0): ...@@ -108,14 +134,13 @@ def retrieval_eval(engine, epoch_id=0):
return metric_dict[metric_key] return metric_dict[metric_key]
def compute_feature(self, name="gallery"):
def compute_feature(engine, name="gallery"):
if name == "gallery": if name == "gallery":
dataloader = engine.gallery_dataloader dataloader = self.gallery_dataloader
elif name == "query": elif name == "query":
dataloader = engine.query_dataloader dataloader = self.query_dataloader
elif name == "gallery_query": elif name == "gallery_query":
dataloader = engine.gallery_query_dataloader dataloader = self.gallery_query_dataloader
else: else:
raise ValueError( raise ValueError(
f"Only support gallery or query or gallery_query dataset, but got {name}" f"Only support gallery or query or gallery_query dataset, but got {name}"
...@@ -126,7 +151,7 @@ def compute_feature(engine, name="gallery"): ...@@ -126,7 +151,7 @@ def compute_feature(engine, name="gallery"):
all_camera = [] all_camera = []
has_camera = False has_camera = False
for idx, batch in enumerate(dataloader): # load is very time-consuming for idx, batch in enumerate(dataloader): # load is very time-consuming
if idx % engine.config["Global"]["print_batch_step"] == 0: if idx % self.print_batch_step == 0:
logger.info( logger.info(
f"{name} feature calculation process: [{idx}/{len(dataloader)}]" f"{name} feature calculation process: [{idx}/{len(dataloader)}]"
) )
...@@ -136,20 +161,14 @@ def compute_feature(engine, name="gallery"): ...@@ -136,20 +161,14 @@ def compute_feature(engine, name="gallery"):
if len(batch) >= 3: if len(batch) >= 3:
has_camera = True has_camera = True
batch[2] = batch[2].reshape([-1, 1]).astype("int64") batch[2] = batch[2].reshape([-1, 1]).astype("int64")
if engine.amp and engine.amp_eval:
with paddle.amp.auto_cast( out = self.model(batch)
custom_black_list={
"flatten_contiguous_range", "greater_than"
},
level=engine.amp_level):
out = engine.model(batch[0])
else:
out = engine.model(batch[0])
if "Student" in out: if "Student" in out:
out = out["Student"] out = out["Student"]
# get features # get features
if engine.config["Global"].get("retrieval_feature_from", if self.config["Global"].get("retrieval_feature_from",
"features") == "features": "features") == "features":
# use output from neck as feature # use output from neck as feature
batch_feat = out["features"] batch_feat = out["features"]
...@@ -158,13 +177,14 @@ def compute_feature(engine, name="gallery"): ...@@ -158,13 +177,14 @@ def compute_feature(engine, name="gallery"):
batch_feat = out["backbone"] batch_feat = out["backbone"]
# do norm(optional) # do norm(optional)
if engine.config["Global"].get("feature_normalize", True): if self.config["Global"].get("feature_normalize", True):
batch_feat = paddle.nn.functional.normalize(batch_feat, p=2) batch_feat = paddle.nn.functional.normalize(batch_feat, p=2)
# do binarize(optional) # do binarize(optional)
if engine.config["Global"].get("feature_binarize") == "round": if self.config["Global"].get("feature_binarize") == "round":
batch_feat = paddle.round(batch_feat).astype("float32") * 2.0 - 1.0 batch_feat = paddle.round(batch_feat).astype(
elif engine.config["Global"].get("feature_binarize") == "sign": "float32") * 2.0 - 1.0
elif self.config["Global"].get("feature_binarize") == "sign":
batch_feat = paddle.sign(batch_feat).astype("float32") batch_feat = paddle.sign(batch_feat).astype("float32")
if paddle.distributed.get_world_size() > 1: if paddle.distributed.get_world_size() > 1:
...@@ -178,7 +198,7 @@ def compute_feature(engine, name="gallery"): ...@@ -178,7 +198,7 @@ def compute_feature(engine, name="gallery"):
if has_camera: if has_camera:
all_camera.append(batch[2]) all_camera.append(batch[2])
if engine.use_dali: if self.use_dali:
dataloader.reset() dataloader.reset()
all_feat = paddle.concat(all_feat) all_feat = paddle.concat(all_feat)
...@@ -188,7 +208,7 @@ def compute_feature(engine, name="gallery"): ...@@ -188,7 +208,7 @@ def compute_feature(engine, name="gallery"):
else: else:
all_camera = None all_camera = None
# discard redundant padding sample(s) at the end # discard redundant padding sample(s) at the end
total_samples = dataloader.size if engine.use_dali else len( total_samples = dataloader.size if self.use_dali else len(
dataloader.dataset) dataloader.dataset)
all_feat = all_feat[:total_samples] all_feat = all_feat[:total_samples]
all_label = all_label[:total_samples] all_label = all_label[:total_samples]
......
...@@ -22,9 +22,8 @@ from .train_progressive import train_epoch_progressive ...@@ -22,9 +22,8 @@ from .train_progressive import train_epoch_progressive
def build_train_func(config, mode, model, eval_func): def build_train_func(config, mode, model, eval_func):
if mode != "train": if mode != "train":
return None return None
train_mode = config["Global"].get("task", None) task = config["Global"].get("task", "classification")
if train_mode is None: if task == "classification" or task == "retrieval":
config["Global"]["task"] = "classification"
return ClassTrainer(config, model, eval_func) return ClassTrainer(config, model, eval_func)
else: else:
return getattr(sys.modules[__name__], "train_epoch_" + train_mode)( return getattr(sys.modules[__name__], "train_epoch_" + train_mode)(
......
...@@ -15,7 +15,7 @@ from __future__ import absolute_import, division, print_function ...@@ -15,7 +15,7 @@ from __future__ import absolute_import, division, print_function
from ppcls.data import build_dataloader from ppcls.data import build_dataloader
from ppcls.utils import logger, type_name from ppcls.utils import logger, type_name
from .regular_train_epoch import regular_train_epoch from .classification import ClassTrainer
def train_epoch_progressive(engine, epoch_id, print_batch_step): def train_epoch_progressive(engine, epoch_id, print_batch_step):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册