diff --git a/ppcls/loss/deephashloss.py b/ppcls/loss/deephashloss.py index 93a7e2e5221bbd9d8a5eeb0881f964bfa7300d7a..c9a58dc78db48a0eef83a5aec2efb4b99d44ea91 100644 --- a/ppcls/loss/deephashloss.py +++ b/ppcls/loss/deephashloss.py @@ -23,38 +23,42 @@ class DSHSDLoss(nn.Layer): # [DSHSD] epoch:250, bit:48, dataset:nuswide_21, MAP:0.809, Best MAP: 0.815 # [DSHSD] epoch:135, bit:48, dataset:imagenet, MAP:0.647, Best MAP: 0.647 """ - def __init__(self, n_class, bit, alpha, multi_label=False): - super(DSHSDLoss, self).__init__() + def __init__(self, alpha, multi_label=False): + super(DSHSDLoss, self).__init__() self.alpha = alpha self.multi_label = multi_label - - def forward(self, input, label): + + def forward(self, input, label): feature = input["features"] - logits = input["logits"] - - dist = paddle.sum( - paddle.square((paddle.unsqueeze(feature, 1) - paddle.unsqueeze(feature, 0))), - axis=2) - + logits = input["logits"] + + dist = paddle.sum(paddle.square( + (paddle.unsqueeze(feature, 1) - paddle.unsqueeze(feature, 0))), + axis=2) + # label to ont-hot label = paddle.flatten(label) n_class = logits.shape[1] - label = paddle.nn.functional.one_hot(label, n_class).astype("float32") + label = paddle.nn.functional.one_hot(label, n_class).astype("float32") - s = (paddle.matmul(label, label, transpose_y=True) == 0).astype("float32") + s = (paddle.matmul( + label, label, transpose_y=True) == 0).astype("float32") margin = 2 * feature.shape[1] Ld = (1 - s) / 2 * dist + s / 2 * (margin - dist).clip(min=0) Ld = Ld.mean() - + if self.multi_label: # multiple labels classification loss - Lc = (logits - label * logits + ((1 + (-logits).exp()).log())).sum(axis=1).mean() + Lc = (logits - label * logits + ( + (1 + (-logits).exp()).log())).sum(axis=1).mean() else: # single labels classification loss - Lc = (-paddle.nn.functional.softmax(logits).log() * label).sum(axis=1).mean() + Lc = (-paddle.nn.functional.softmax(logits).log() * label).sum( + axis=1).mean() return {"dshsdloss": Lc + Ld * self.alpha} + class LCDSHLoss(nn.Layer): """ # paper [Locality-Constrained Deep Supervised Hashing for Image Retrieval](https://www.ijcai.org/Proceedings/2017/0499.pdf)