diff --git a/ernie/modeling_ernie.py b/ernie/modeling_ernie.py index c7655c280184536e4b637a5f437da9fdc4c1e44b..b3b3bbe21c5b81899a9843beca96019262a38271 100644 --- a/ernie/modeling_ernie.py +++ b/ernie/modeling_ernie.py @@ -407,7 +407,7 @@ class ErnieModelForTokenClassification(ErnieModel): self.dropout = lambda i: L.dropout(i, dropout_prob=prob, dropout_implementation="upscale_in_train",) if self.training else i @add_docstring(ErnieModel.forward.__doc__) - def forward(self, *args, ignore_index=-100, labels=None, loss_weights=None, **kwargs, ): + def forward(self, *args, **kwargs): """ Args: labels (optional, `Variable` of shape [batch_size, seq_len]): @@ -418,8 +418,14 @@ class ErnieModelForTokenClassification(ErnieModel): if labels not set, returns None logits (`Variable` of shape [batch_size, seq_len, hidden_size]): output logits of classifier + loss_weights (`Variable` of shape [batch_size, seq_len]): + weigths of loss for each tokens. + ignore_index (int): + when label == `ignore_index`, this token will not contribute to loss """ - + ignore_index = kwargs.pop('ignore_index', -100) + labels = kwargs.pop('labels', None) + loss_weights = kwargs.pop('loss_weights', None) pooled, encoded = super(ErnieModelForTokenClassification, self).forward(*args, **kwargs) hidden = self.dropout(encoded) # maybe not? logits = self.classifier(hidden)