diff --git a/ppcls/arch/head/arcmargin.py b/ppcls/arch/head/arcmargin.py index 4f27acbdd684d5634e8af3d94a429a36b6986953..c7a79a1fb89bac90f2b80f6703504592eb6754a3 100644 --- a/ppcls/arch/head/arcmargin.py +++ b/ppcls/arch/head/arcmargin.py @@ -36,7 +36,7 @@ class ArcMargin(nn.Layer): input_norm = paddle.sqrt(paddle.sum(paddle.square(input), axis=1, keepdim=True)) input = paddle.divide(input, input_norm) - weight = self.fc0.weight + weight = self.fc.weight weight_norm = paddle.sqrt(paddle.sum(paddle.square(weight), axis=0, keepdim=True)) weight = paddle.divide(weight, weight_norm)