From 6f2959af3aeee4b089a0e1d82f74b82686000d45 Mon Sep 17 00:00:00 2001 From: Bin Lu Date: Mon, 31 May 2021 20:45:30 +0800 Subject: [PATCH] Update arcmargin.py --- ppcls/arch/head/arcmargin.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ppcls/arch/head/arcmargin.py b/ppcls/arch/head/arcmargin.py index 82da7f09..4f27acbd 100644 --- a/ppcls/arch/head/arcmargin.py +++ b/ppcls/arch/head/arcmargin.py @@ -30,7 +30,7 @@ class ArcMargin(nn.Layer): self.easy_margin = easy_margin weight_attr = paddle.ParamAttr(initializer = paddle.nn.initializer.XavierNormal()) - self.fc = nn.Linear(self.embedding_size, self.class_dim, weight_attr=weight_attr, bias_attr=False) + self.fc = nn.Linear(self.embedding_size, self.class_num, weight_attr=weight_attr, bias_attr=False) def forward(self, input, label): input_norm = paddle.sqrt(paddle.sum(paddle.square(input), axis=1, keepdim=True)) @@ -53,7 +53,7 @@ class ArcMargin(nn.Layer): else: phi = self._paddle_where_more_than(cos, th, phi, cos - mm) - one_hot = paddle.nn.functional.one_hot(label, self.class_dim) + one_hot = paddle.nn.functional.one_hot(label, self.class_num) one_hot = paddle.squeeze(one_hot, axis=[1]) output = paddle.multiply(one_hot, phi) + paddle.multiply((1.0 - one_hot), cos) output = output * self.scale -- GitLab