未验证 提交 3ded7ccf 编写于 作者: C cuicheng01 提交者: GitHub

Merge pull request #792 from weisy11/develop_reg

export model support rec model
......@@ -53,11 +53,12 @@ class RecModel(nn.Layer):
else:
self.head = None
def forward(self, x, label):
def forward(self, x, label=None):
x = self.backbone(x)
if self.neck is not None:
x = self.neck(x)
y = x
if self.head is not None:
y = self.head(x, label)
else:
y = None
return {"features": x, "logits": y}
......@@ -39,7 +39,7 @@ class ArcMargin(nn.Layer):
weight_attr=weight_attr,
bias_attr=False)
def forward(self, input, label):
def forward(self, input, label=None):
input_norm = paddle.sqrt(
paddle.sum(paddle.square(input), axis=1, keepdim=True))
input = paddle.divide(input, input_norm)
......@@ -50,7 +50,7 @@ class ArcMargin(nn.Layer):
weight = paddle.divide(weight, weight_norm)
cos = paddle.matmul(input, weight)
if not self.training:
if not self.training or label is None:
return cos
sin = paddle.sqrt(1.0 - paddle.square(cos) + 1e-6)
cos_m = math.cos(self.margin)
......
......@@ -329,22 +329,22 @@ class Trainer(object):
self.model.eval()
cum_similarity_matrix = None
# step1. build gallery
gallery_feas, gallery_img_id, gallery_camera_id = self._cal_feature(
gallery_feas, gallery_img_id, gallery_unique_id = self._cal_feature(
name='gallery')
query_feas, query_img_id, query_camera_id = self._cal_feature(
query_feas, query_img_id, query_query_id = self._cal_feature(
name='query')
gallery_img_id = gallery_img_id
# if gallery_camera_id is not None:
# gallery_camera_id = gallery_camera_id
# if gallery_unique_id is not None:
# gallery_unique_id = gallery_unique_id
# step2. do evaluation
sim_block_size = self.config["Global"].get("sim_block_size", 64)
sections = [sim_block_size] * (len(query_feas) // sim_block_size)
if len(query_feas) % sim_block_size:
sections.append(len(query_feas) % sim_block_size)
fea_blocks = paddle.split(query_feas, num_or_sections=sections)
if query_camera_id is not None:
camera_id_blocks = paddle.split(
query_camera_id, num_or_sections=sections)
if query_query_id is not None:
query_id_blocks = paddle.split(
query_query_id, num_or_sections=sections)
image_id_blocks = paddle.split(
query_img_id, num_or_sections=sections)
metric_key = None
......@@ -352,14 +352,14 @@ class Trainer(object):
for block_idx, block_fea in enumerate(fea_blocks):
similarity_matrix = paddle.matmul(
block_fea, gallery_feas, transpose_y=True)
if query_camera_id is not None:
camera_id_block = camera_id_blocks[block_idx]
camera_id_mask = (camera_id_block != gallery_camera_id.t())
if query_query_id is not None:
query_id_block = query_id_blocks[block_idx]
query_id_mask = (query_id_block != gallery_unique_id.t())
image_id_block = image_id_blocks[block_idx]
image_id_mask = (image_id_block != gallery_img_id.t())
keep_mask = paddle.logical_or(camera_id_mask, image_id_mask)
keep_mask = paddle.logical_or(query_id_mask, image_id_mask)
similarity_matrix = similarity_matrix * keep_mask.astype(
"float32")
if cum_similarity_matrix is None:
......@@ -388,7 +388,7 @@ class Trainer(object):
def _cal_feature(self, name='gallery'):
all_feas = None
all_image_id = None
all_camera_id = None
all_unique_id = None
if name == 'gallery':
dataloader = self.gallery_dataloader
elif name == 'query':
......@@ -396,13 +396,13 @@ class Trainer(object):
else:
raise RuntimeError("Only support gallery or query dataset")
has_cam_id = False
has_unique_id = False
for idx, batch in enumerate(dataloader(
)): # load is very time-consuming
batch = [paddle.to_tensor(x) for x in batch]
batch[1] = batch[1].reshape([-1, 1])
if len(batch) == 3:
has_cam_id = True
has_unique_id = True
batch[2] = batch[2].reshape([-1, 1])
out = self.model(batch[0], batch[1])
batch_feas = out["features"]
......@@ -416,30 +416,30 @@ class Trainer(object):
if all_feas is None:
all_feas = batch_feas
if has_cam_id:
all_camera_id = batch[2]
if has_unique_id:
all_unique_id = batch[2]
all_image_id = batch[1]
else:
all_feas = paddle.concat([all_feas, batch_feas])
all_image_id = paddle.concat([all_image_id, batch[1]])
if has_cam_id:
all_camera_id = paddle.concat([all_camera_id, batch[2]])
if has_unique_id:
all_unique_id = paddle.concat([all_unique_id, batch[2]])
if paddle.distributed.get_world_size() > 1:
feat_list = []
img_id_list = []
cam_id_list = []
unique_id_list = []
paddle.distributed.all_gather(feat_list, all_feas)
paddle.distributed.all_gather(img_id_list, all_image_id)
all_feas = paddle.concat(feat_list, axis=0)
all_image_id = paddle.concat(img_id_list, axis=0)
if has_cam_id:
paddle.distributed.all_gather(cam_id_list, all_camera_id)
all_camera_id = paddle.concat(cam_id_list, axis=0)
if has_unique_id:
paddle.distributed.all_gather(unique_id_list, all_unique_id)
all_unique_id = paddle.concat(unique_id_list, axis=0)
logger.info("Build {} done, all feat shape: {}, begin to eval..".
format(name, all_feas.shape))
return all_feas, all_image_id, all_camera_id
return all_feas, all_image_id, all_unique_id
@paddle.no_grad()
def infer(self, ):
......
......@@ -29,7 +29,7 @@ from ppcls.arch import build_model
from ppcls.utils.save_load import load_dygraph_pretrain
class ClasModel(nn.Layer):
class ExportModel(nn.Layer):
"""
ClasModel: add softmax onto the model
"""
......@@ -37,7 +37,11 @@ class ClasModel(nn.Layer):
def __init__(self, config):
super().__init__()
self.base_model = build_model(config)
self.softmax = nn.Softmax(axis=-1)
self.infer_output_key = config.get("infer_output_key")
if config.get("infer_add_softmax", True):
self.softmax = nn.Softmax(axis=-1)
else:
self.softmax = None
def eval(self):
self.training = False
......@@ -47,7 +51,10 @@ class ClasModel(nn.Layer):
def forward(self, x):
x = self.base_model(x)
x = self.softmax(x)
if self.infer_output_key is not None:
x = x[self.infer_output_key]
if self.softmax is not None:
x = self.softmax(x)
return x
......@@ -57,8 +64,7 @@ if __name__ == "__main__":
# set device
assert config["Global"]["device"] in ["cpu", "gpu", "xpu"]
device = paddle.set_device(config["Global"]["device"])
model = ClasModel(config["Arch"])
model = ExportModel(config["Arch"])
if config["Global"]["pretrained_model"] is not None:
load_dygraph_pretrain(model.base_model,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册