提交 3ef99930 编写于 作者: D dongshuilong

fix ReID trainer bugs

上级 10c0fc4b
...@@ -17,6 +17,7 @@ Global: ...@@ -17,6 +17,7 @@ Global:
image_shape: [3, 224, 224] image_shape: [3, 224, 224]
save_inference_dir: "./inference" save_inference_dir: "./inference"
num_split: 1 num_split: 1
feature_normalize: True
# model architecture # model architecture
......
...@@ -27,8 +27,7 @@ from ppcls.data.dataloader.common_dataset import create_operators ...@@ -27,8 +27,7 @@ from ppcls.data.dataloader.common_dataset import create_operators
from ppcls.data.dataloader.vehicle_dataset import CompCars, VeriWild from ppcls.data.dataloader.vehicle_dataset import CompCars, VeriWild
# sampler # sampler
from ppcls.data.dataloader import DistributedRandomIdentitySampler from ppcls.data.dataloader.DistributedRandomIdentitySampler import DistributedRandomIdentitySampler
from ppcls.data.preprocess import transform from ppcls.data.preprocess import transform
......
...@@ -169,9 +169,11 @@ class TrainerReID(Trainer): ...@@ -169,9 +169,11 @@ class TrainerReID(Trainer):
batch_feas = out["features"] batch_feas = out["features"]
# do norm # do norm
feas_norm = paddle.sqrt( if self.config["Global"].get("feature_normalize", True):
paddle.sum(paddle.square(batch_feas), axis=1, keepdim=True)) feas_norm = paddle.sqrt(
batch_feas = paddle.divide(batch_feas, feas_norm) paddle.sum(paddle.square(batch_feas), axis=1,
keepdim=True))
batch_feas = paddle.divide(batch_feas, feas_norm)
batch_feas = batch_feas batch_feas = batch_feas
batch_image_labels = batch[1] batch_image_labels = batch[1]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册