diff --git a/ppocr/losses/cls_loss.py b/ppocr/losses/cls_loss.py index ecca5d2e1739631716123d4a793f5ece09d7f9ab..abc5e5b72cb055716715345105b59089f0a96edc 100755 --- a/ppocr/losses/cls_loss.py +++ b/ppocr/losses/cls_loss.py @@ -25,6 +25,6 @@ class ClsLoss(nn.Layer): self.loss_func = nn.CrossEntropyLoss(reduction='mean') def forward(self, predicts, batch): - label = batch[1] + label = batch[1].astype("int64") loss = self.loss_func(input=predicts, label=label) return {'loss': loss}