未验证 提交 9fecdbaf 编写于 作者: B Bin Lu 提交者: GitHub

Update center_loss.py

上级 1ac84b07
...@@ -27,7 +27,6 @@ class CenterLoss(nn.Layer): ...@@ -27,7 +27,6 @@ class CenterLoss(nn.Layer):
""" """
Reference: Wen et al. A Discriminative Feature Learning Approach for Deep Face Recognition. ECCV 2016. Reference: Wen et al. A Discriminative Feature Learning Approach for Deep Face Recognition. ECCV 2016.
""" """
def __init__(self, def __init__(self,
num_classes=6625, num_classes=6625,
feat_dim=96, feat_dim=96,
...@@ -37,8 +36,7 @@ class CenterLoss(nn.Layer): ...@@ -37,8 +36,7 @@ class CenterLoss(nn.Layer):
self.num_classes = num_classes self.num_classes = num_classes
self.feat_dim = feat_dim self.feat_dim = feat_dim
self.centers = paddle.randn( self.centers = paddle.randn(
shape=[self.num_classes, self.feat_dim]).astype( shape=[self.num_classes, self.feat_dim]).astype("float64")
"float64") #random center
if init_center: if init_center:
assert os.path.exists( assert os.path.exists(
...@@ -60,22 +58,23 @@ class CenterLoss(nn.Layer): ...@@ -60,22 +58,23 @@ class CenterLoss(nn.Layer):
batch_size = feats_reshape.shape[0] batch_size = feats_reshape.shape[0]
#calc feat * feat #calc l2 distance between feats and centers
dist1 = paddle.sum(paddle.square(feats_reshape), axis=1, keepdim=True) square_feat = paddle.sum(paddle.square(feats_reshape),
dist1 = paddle.expand(dist1, [batch_size, self.num_classes]) axis=1,
keepdim=True)
square_feat = paddle.expand(square_feat, [batch_size, self.num_classes])
#dist2 of centers square_center = paddle.sum(paddle.square(self.centers),
dist2 = paddle.sum(paddle.square(self.centers), axis=1, axis=1,
keepdim=True) #num_classes keepdim=True)
dist2 = paddle.expand(dist2, square_center = paddle.expand(
[self.num_classes, batch_size]).astype("float64") square_center, [self.num_classes, batch_size]).astype("float64")
dist2 = paddle.transpose(dist2, [1, 0]) square_center = paddle.transpose(square_center, [1, 0])
#first x * x + y * y distmat = paddle.add(square_feat, square_center)
distmat = paddle.add(dist1, dist2) feat_dot_center = paddle.matmul(feats_reshape,
tmp = paddle.matmul(feats_reshape, paddle.transpose(self.centers, [1, 0]))
paddle.transpose(self.centers, [1, 0])) distmat = distmat - 2.0 * feat_dot_center
distmat = distmat - 2.0 * tmp
#generate the mask #generate the mask
classes = paddle.arange(self.num_classes).astype("int64") classes = paddle.arange(self.num_classes).astype("int64")
...@@ -83,7 +82,8 @@ class CenterLoss(nn.Layer): ...@@ -83,7 +82,8 @@ class CenterLoss(nn.Layer):
paddle.unsqueeze(label, 1), (batch_size, self.num_classes)) paddle.unsqueeze(label, 1), (batch_size, self.num_classes))
mask = paddle.equal( mask = paddle.equal(
paddle.expand(classes, [batch_size, self.num_classes]), paddle.expand(classes, [batch_size, self.num_classes]),
label).astype("float64") #get mask label).astype("float64")
dist = paddle.multiply(distmat, mask) dist = paddle.multiply(distmat, mask)
loss = paddle.sum(paddle.clip(dist, min=1e-12, max=1e+12)) / batch_size loss = paddle.sum(paddle.clip(dist, min=1e-12, max=1e+12)) / batch_size
return {'loss_center': loss} return {'loss_center': loss}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册