From beef5eff9928e3d19dbbe38e11dc4897cbce58f4 Mon Sep 17 00:00:00 2001 From: Meiyim Date: Mon, 1 Jun 2020 12:09:24 +0800 Subject: [PATCH] ErnieForTokenClassficifation: py2 compatibility (#480) --- ernie/modeling_ernie.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/ernie/modeling_ernie.py b/ernie/modeling_ernie.py index c7655c2..b3b3bbe 100644 --- a/ernie/modeling_ernie.py +++ b/ernie/modeling_ernie.py @@ -407,7 +407,7 @@ class ErnieModelForTokenClassification(ErnieModel): self.dropout = lambda i: L.dropout(i, dropout_prob=prob, dropout_implementation="upscale_in_train",) if self.training else i @add_docstring(ErnieModel.forward.__doc__) - def forward(self, *args, ignore_index=-100, labels=None, loss_weights=None, **kwargs, ): + def forward(self, *args, **kwargs): """ Args: labels (optional, `Variable` of shape [batch_size, seq_len]): @@ -418,8 +418,14 @@ class ErnieModelForTokenClassification(ErnieModel): if labels not set, returns None logits (`Variable` of shape [batch_size, seq_len, hidden_size]): output logits of classifier + loss_weights (`Variable` of shape [batch_size, seq_len]): + weigths of loss for each tokens. + ignore_index (int): + when label == `ignore_index`, this token will not contribute to loss """ - + ignore_index = kwargs.pop('ignore_index', -100) + labels = kwargs.pop('labels', None) + loss_weights = kwargs.pop('loss_weights', None) pooled, encoded = super(ErnieModelForTokenClassification, self).forward(*args, **kwargs) hidden = self.dropout(encoded) # maybe not? logits = self.classifier(hidden) -- GitLab