未验证 提交 beef5eff 编写于 作者: M Meiyim 提交者: GitHub

ErnieForTokenClassficifation: py2 compatibility (#480)

上级 899dbfc2
......@@ -407,7 +407,7 @@ class ErnieModelForTokenClassification(ErnieModel):
self.dropout = lambda i: L.dropout(i, dropout_prob=prob, dropout_implementation="upscale_in_train",) if self.training else i
@add_docstring(ErnieModel.forward.__doc__)
def forward(self, *args, ignore_index=-100, labels=None, loss_weights=None, **kwargs, ):
def forward(self, *args, **kwargs):
"""
Args:
labels (optional, `Variable` of shape [batch_size, seq_len]):
......@@ -418,8 +418,14 @@ class ErnieModelForTokenClassification(ErnieModel):
if labels not set, returns None
logits (`Variable` of shape [batch_size, seq_len, hidden_size]):
output logits of classifier
loss_weights (`Variable` of shape [batch_size, seq_len]):
weigths of loss for each tokens.
ignore_index (int):
when label == `ignore_index`, this token will not contribute to loss
"""
ignore_index = kwargs.pop('ignore_index', -100)
labels = kwargs.pop('labels', None)
loss_weights = kwargs.pop('loss_weights', None)
pooled, encoded = super(ErnieModelForTokenClassification, self).forward(*args, **kwargs)
hidden = self.dropout(encoded) # maybe not?
logits = self.classifier(hidden)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册