提交 91f72dce 编写于 作者: W weishengyu

dbg label

上级 ce43150f
......@@ -53,7 +53,7 @@ class RecModel(nn.Layer):
else:
self.head = None
def forward(self, x, label):
def forward(self, x, label=None):
x = self.backbone(x)
if self.neck is not None:
x = self.neck(x)
......
......@@ -39,7 +39,7 @@ class ArcMargin(nn.Layer):
weight_attr=weight_attr,
bias_attr=False)
def forward(self, input, label):
def forward(self, input, label=None):
input_norm = paddle.sqrt(
paddle.sum(paddle.square(input), axis=1, keepdim=True))
input = paddle.divide(input, input_norm)
......@@ -50,7 +50,7 @@ class ArcMargin(nn.Layer):
weight = paddle.divide(weight, weight_norm)
cos = paddle.matmul(input, weight)
if not self.training:
if not self.training or label is None:
return cos
sin = paddle.sqrt(1.0 - paddle.square(cos) + 1e-6)
cos_m = math.cos(self.margin)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册