module.py 5.2 KB
Newer Older
K
KP 已提交
1
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
W
wuzewu 已提交
2
#
K
KP 已提交
3
# Licensed under the Apache License, Version 2.0 (the "License");
W
wuzewu 已提交
4 5 6 7 8 9 10 11 12 13
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
K
KP 已提交
14
from typing import Dict
W
wuzewu 已提交
15 16
import os

K
KP 已提交
17 18 19
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
W
wuzewu 已提交
20

K
KP 已提交
21 22 23 24 25 26
from paddlenlp.transformers.electra.modeling import ElectraForSequenceClassification, ElectraForTokenClassification, ElectraModel
from paddlenlp.transformers.electra.tokenizer import ElectraTokenizer
from paddlenlp.metrics import ChunkEvaluator
from paddlehub.module.module import moduleinfo
from paddlehub.module.nlp_module import TransformerModule
from paddlehub.utils.log import logger
W
wuzewu 已提交
27 28 29 30


@moduleinfo(
    name="chinese-electra-base",
K
KP 已提交
31 32 33
    version="2.0.0",
    summary=
    "chinese-electra-base, 12-layer, 768-hidden, 12-heads, 102M parameters. The module is executed as paddle.dygraph.",
W
wuzewu 已提交
34 35 36
    author="ymcui",
    author_email="ymcui@ir.hit.edu.cn",
    type="nlp/semantic_model",
K
KP 已提交
37
    meta=TransformerModule,
W
wuzewu 已提交
38
)
K
KP 已提交
39 40 41 42
class Electra(nn.Layer):
    """
    Electra model
    """
W
wuzewu 已提交
43

K
KP 已提交
44 45 46 47 48 49 50 51 52 53 54 55 56 57
    def __init__(
            self,
            task: str = None,
            load_checkpoint: str = None,
            label_map: Dict = None,
            num_classes: int = 2,
            **kwargs,
    ):
        super(Electra, self).__init__()
        if label_map:
            self.label_map = label_map
            self.num_classes = len(label_map)
        else:
            self.num_classes = num_classes
W
wuzewu 已提交
58

K
KP 已提交
59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87
        if task == 'sequence_classification':
            task = 'seq-cls'
            logger.warning(
                "current task name 'sequence_classification' was renamed to 'seq-cls', "
                "'sequence_classification' has been deprecated and will be removed in the future.",
            )
        if task == 'seq-cls':
            self.model = ElectraForSequenceClassification.from_pretrained(
                pretrained_model_name_or_path='chinese-electra-base',
                num_classes=self.num_classes,
                **kwargs
            )
            self.criterion = paddle.nn.loss.CrossEntropyLoss()
            self.metric = paddle.metric.Accuracy()
        elif task == 'token-cls':
            self.model = ElectraForTokenClassification.from_pretrained(
                pretrained_model_name_or_path='chinese-electra-base',
                num_classes=self.num_classes,
                **kwargs
            )
            self.criterion = paddle.nn.loss.CrossEntropyLoss()
            self.metric = ChunkEvaluator(
                label_list=[self.label_map[i] for i in sorted(self.label_map.keys())]
            )
        elif task is None:
            self.model = ElectraModel.from_pretrained(pretrained_model_name_or_path='chinese-electra-base', **kwargs)
        else:
            raise RuntimeError("Unknown task {}, task should be one in {}".format(
                task, self._tasks_supported))
W
wuzewu 已提交
88

K
KP 已提交
89
        self.task = task
W
wuzewu 已提交
90

K
KP 已提交
91 92 93 94
        if load_checkpoint is not None and os.path.isfile(load_checkpoint):
            state_dict = paddle.load(load_checkpoint)
            self.set_state_dict(state_dict)
            logger.info('Loaded parameters from %s' % os.path.abspath(load_checkpoint))
W
wuzewu 已提交
95

K
KP 已提交
96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122
    def forward(self, input_ids, token_type_ids=None, position_ids=None, attention_mask=None, seq_lengths=None, labels=None):
        result = self.model(input_ids, token_type_ids, position_ids, attention_mask)
        if self.task == 'seq-cls':
            logits = result
            probs = F.softmax(logits, axis=1)
            if labels is not None:
                loss = self.criterion(logits, labels)
                correct = self.metric.compute(probs, labels)
                acc = self.metric.update(correct)
                return probs, loss, {'acc': acc}
            return probs
        elif self.task == 'token-cls':
            logits = result
            token_level_probs = F.softmax(logits, axis=-1)
            preds = token_level_probs.argmax(axis=-1)
            if labels is not None:
                loss = self.criterion(logits, labels.unsqueeze(-1))
                num_infer_chunks, num_label_chunks, num_correct_chunks = \
                    self.metric.compute(None, seq_lengths, preds, labels)
                self.metric.update(
                    num_infer_chunks.numpy(), num_label_chunks.numpy(), num_correct_chunks.numpy())
                _, _, f1_score = map(float, self.metric.accumulate())
                return token_level_probs, loss, {'f1_score': f1_score}
            return token_level_probs
        else:
            sequence_output, pooled_output = result
            return sequence_output, pooled_output
W
wuzewu 已提交
123

K
KP 已提交
124 125 126 127 128 129 130
    @staticmethod
    def get_tokenizer(*args, **kwargs):
        """
        Gets the tokenizer that is customized for this module.
        """
        return ElectraTokenizer.from_pretrained(
            pretrained_model_name_or_path='chinese-electra-base', *args, **kwargs)