提交 3ef99930 编写于 作者: D dongshuilong

fix ReID trainer bugs

上级 10c0fc4b
......@@ -17,6 +17,7 @@ Global:
image_shape: [3, 224, 224]
save_inference_dir: "./inference"
num_split: 1
feature_normalize: True
# model architecture
......
......@@ -27,8 +27,7 @@ from ppcls.data.dataloader.common_dataset import create_operators
from ppcls.data.dataloader.vehicle_dataset import CompCars, VeriWild
# sampler
from ppcls.data.dataloader import DistributedRandomIdentitySampler
from ppcls.data.dataloader.DistributedRandomIdentitySampler import DistributedRandomIdentitySampler
from ppcls.data.preprocess import transform
......
......@@ -169,9 +169,11 @@ class TrainerReID(Trainer):
batch_feas = out["features"]
# do norm
feas_norm = paddle.sqrt(
paddle.sum(paddle.square(batch_feas), axis=1, keepdim=True))
batch_feas = paddle.divide(batch_feas, feas_norm)
if self.config["Global"].get("feature_normalize", True):
feas_norm = paddle.sqrt(
paddle.sum(paddle.square(batch_feas), axis=1,
keepdim=True))
batch_feas = paddle.divide(batch_feas, feas_norm)
batch_feas = batch_feas
batch_image_labels = batch[1]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册