未验证 提交 5b70d442 编写于 作者: B Bin Lu 提交者: GitHub

Update circlemargin.py

上级 c9912f0a
...@@ -28,7 +28,7 @@ class CircleMargin(nn.Layer): ...@@ -28,7 +28,7 @@ class CircleMargin(nn.Layer):
weight_attr = paddle.ParamAttr( weight_attr = paddle.ParamAttr(
initializer=paddle.nn.initializer.XavierNormal()) initializer=paddle.nn.initializer.XavierNormal())
self.fc0 = paddle.nn.Linear( self.fc = paddle.nn.Linear(
self.embedding_size, self.class_num, weight_attr=weight_attr) self.embedding_size, self.class_num, weight_attr=weight_attr)
def forward(self, input, label): def forward(self, input, label):
...@@ -36,19 +36,22 @@ class CircleMargin(nn.Layer): ...@@ -36,19 +36,22 @@ class CircleMargin(nn.Layer):
paddle.sum(paddle.square(input), axis=1, keepdim=True)) paddle.sum(paddle.square(input), axis=1, keepdim=True))
input = paddle.divide(input, feat_norm) input = paddle.divide(input, feat_norm)
weight = self.fc0.weight weight = self.fc.weight
weight_norm = paddle.sqrt( weight_norm = paddle.sqrt(
paddle.sum(paddle.square(weight), axis=0, keepdim=True)) paddle.sum(paddle.square(weight), axis=0, keepdim=True))
weight = paddle.divide(weight, weight_norm) weight = paddle.divide(weight, weight_norm)
logits = paddle.matmul(input, weight) logits = paddle.matmul(input, weight)
if not self.training or label is None:
return logits
alpha_p = paddle.clip(-logits.detach() + 1 + self.margin, min=0.) alpha_p = paddle.clip(-logits.detach() + 1 + self.margin, min=0.)
alpha_n = paddle.clip(logits.detach() + self.margin, min=0.) alpha_n = paddle.clip(logits.detach() + self.margin, min=0.)
delta_p = 1 - self.margin delta_p = 1 - self.margin
delta_n = self.margin delta_n = self.margin
index = paddle.fluid.layers.where(label != -1).reshape([-1])
m_hot = F.one_hot(label.reshape([-1]), num_classes=logits.shape[1]) m_hot = F.one_hot(label.reshape([-1]), num_classes=logits.shape[1])
logits_p = alpha_p * (logits - delta_p) logits_p = alpha_p * (logits - delta_p)
logits_n = alpha_n * (logits - delta_n) logits_n = alpha_n * (logits - delta_n)
pre_logits = logits_p * m_hot + logits_n * (1 - m_hot) pre_logits = logits_p * m_hot + logits_n * (1 - m_hot)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册