diff --git a/ppcls/arch/gears/cosmargin.py b/ppcls/arch/gears/cosmargin.py index 51db550868352e2ef20c3accd1fa4dc92d64321d..378e102a215664e33bb8d79016673778ed8a4221 100644 --- a/ppcls/arch/gears/cosmargin.py +++ b/ppcls/arch/gears/cosmargin.py @@ -46,6 +46,9 @@ class CosMargin(paddle.nn.Layer): weight = paddle.divide(weight, weight_norm) cos = paddle.matmul(input, weight) + if not self.training or label is None: + return cos + cos_m = cos - self.margin one_hot = paddle.nn.functional.one_hot(label, self.class_num)