未验证 提交 49382181 编写于 作者: D Double_V 提交者: GitHub

Merge pull request #3717 from littletomatodonkey/dyg/fix_cls_type

fix cls type
...@@ -25,6 +25,6 @@ class ClsLoss(nn.Layer): ...@@ -25,6 +25,6 @@ class ClsLoss(nn.Layer):
self.loss_func = nn.CrossEntropyLoss(reduction='mean') self.loss_func = nn.CrossEntropyLoss(reduction='mean')
def forward(self, predicts, batch): def forward(self, predicts, batch):
label = batch[1] label = batch[1].astype("int64")
loss = self.loss_func(input=predicts, label=label) loss = self.loss_func(input=predicts, label=label)
return {'loss': loss} return {'loss': loss}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册