提交 91f72dce 编写于 作者: W weishengyu

dbg label

上级 ce43150f
...@@ -53,7 +53,7 @@ class RecModel(nn.Layer): ...@@ -53,7 +53,7 @@ class RecModel(nn.Layer):
else: else:
self.head = None self.head = None
def forward(self, x, label): def forward(self, x, label=None):
x = self.backbone(x) x = self.backbone(x)
if self.neck is not None: if self.neck is not None:
x = self.neck(x) x = self.neck(x)
......
...@@ -39,7 +39,7 @@ class ArcMargin(nn.Layer): ...@@ -39,7 +39,7 @@ class ArcMargin(nn.Layer):
weight_attr=weight_attr, weight_attr=weight_attr,
bias_attr=False) bias_attr=False)
def forward(self, input, label): def forward(self, input, label=None):
input_norm = paddle.sqrt( input_norm = paddle.sqrt(
paddle.sum(paddle.square(input), axis=1, keepdim=True)) paddle.sum(paddle.square(input), axis=1, keepdim=True))
input = paddle.divide(input, input_norm) input = paddle.divide(input, input_norm)
...@@ -50,7 +50,7 @@ class ArcMargin(nn.Layer): ...@@ -50,7 +50,7 @@ class ArcMargin(nn.Layer):
weight = paddle.divide(weight, weight_norm) weight = paddle.divide(weight, weight_norm)
cos = paddle.matmul(input, weight) cos = paddle.matmul(input, weight)
if not self.training: if not self.training or label is None:
return cos return cos
sin = paddle.sqrt(1.0 - paddle.square(cos) + 1e-6) sin = paddle.sqrt(1.0 - paddle.square(cos) + 1e-6)
cos_m = math.cos(self.margin) cos_m = math.cos(self.margin)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册