未验证 提交 beef5eff 编写于 作者: M Meiyim 提交者: GitHub

ErnieForTokenClassficifation: py2 compatibility (#480)

上级 899dbfc2
...@@ -407,7 +407,7 @@ class ErnieModelForTokenClassification(ErnieModel): ...@@ -407,7 +407,7 @@ class ErnieModelForTokenClassification(ErnieModel):
self.dropout = lambda i: L.dropout(i, dropout_prob=prob, dropout_implementation="upscale_in_train",) if self.training else i self.dropout = lambda i: L.dropout(i, dropout_prob=prob, dropout_implementation="upscale_in_train",) if self.training else i
@add_docstring(ErnieModel.forward.__doc__) @add_docstring(ErnieModel.forward.__doc__)
def forward(self, *args, ignore_index=-100, labels=None, loss_weights=None, **kwargs, ): def forward(self, *args, **kwargs):
""" """
Args: Args:
labels (optional, `Variable` of shape [batch_size, seq_len]): labels (optional, `Variable` of shape [batch_size, seq_len]):
...@@ -418,8 +418,14 @@ class ErnieModelForTokenClassification(ErnieModel): ...@@ -418,8 +418,14 @@ class ErnieModelForTokenClassification(ErnieModel):
if labels not set, returns None if labels not set, returns None
logits (`Variable` of shape [batch_size, seq_len, hidden_size]): logits (`Variable` of shape [batch_size, seq_len, hidden_size]):
output logits of classifier output logits of classifier
loss_weights (`Variable` of shape [batch_size, seq_len]):
weigths of loss for each tokens.
ignore_index (int):
when label == `ignore_index`, this token will not contribute to loss
""" """
ignore_index = kwargs.pop('ignore_index', -100)
labels = kwargs.pop('labels', None)
loss_weights = kwargs.pop('loss_weights', None)
pooled, encoded = super(ErnieModelForTokenClassification, self).forward(*args, **kwargs) pooled, encoded = super(ErnieModelForTokenClassification, self).forward(*args, **kwargs)
hidden = self.dropout(encoded) # maybe not? hidden = self.dropout(encoded) # maybe not?
logits = self.classifier(hidden) logits = self.classifier(hidden)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册