diff --git a/ppcls/arch/head/arcmargin.py b/ppcls/arch/head/arcmargin.py index 82da7f093fde2949d1d5c15195318737550d4cac..4f27acbdd684d5634e8af3d94a429a36b6986953 100644 --- a/ppcls/arch/head/arcmargin.py +++ b/ppcls/arch/head/arcmargin.py @@ -30,7 +30,7 @@ class ArcMargin(nn.Layer): self.easy_margin = easy_margin weight_attr = paddle.ParamAttr(initializer = paddle.nn.initializer.XavierNormal()) - self.fc = nn.Linear(self.embedding_size, self.class_dim, weight_attr=weight_attr, bias_attr=False) + self.fc = nn.Linear(self.embedding_size, self.class_num, weight_attr=weight_attr, bias_attr=False) def forward(self, input, label): input_norm = paddle.sqrt(paddle.sum(paddle.square(input), axis=1, keepdim=True)) @@ -53,7 +53,7 @@ class ArcMargin(nn.Layer): else: phi = self._paddle_where_more_than(cos, th, phi, cos - mm) - one_hot = paddle.nn.functional.one_hot(label, self.class_dim) + one_hot = paddle.nn.functional.one_hot(label, self.class_num) one_hot = paddle.squeeze(one_hot, axis=[1]) output = paddle.multiply(one_hot, phi) + paddle.multiply((1.0 - one_hot), cos) output = output * self.scale