未验证 提交 f61a64aa 编写于 作者: B Bin Lu 提交者: GitHub

Update arcmargin.py

上级 77a2c457
...@@ -36,7 +36,7 @@ class ArcMargin(nn.Layer): ...@@ -36,7 +36,7 @@ class ArcMargin(nn.Layer):
input_norm = paddle.sqrt(paddle.sum(paddle.square(input), axis=1, keepdim=True)) input_norm = paddle.sqrt(paddle.sum(paddle.square(input), axis=1, keepdim=True))
input = paddle.divide(input, input_norm) input = paddle.divide(input, input_norm)
weight = self.fc0.weight weight = self.fc.weight
weight_norm = paddle.sqrt(paddle.sum(paddle.square(weight), axis=0, keepdim=True)) weight_norm = paddle.sqrt(paddle.sum(paddle.square(weight), axis=0, keepdim=True))
weight = paddle.divide(weight, weight_norm) weight = paddle.divide(weight, weight_norm)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册