module.py 5.1 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
K
KP 已提交
14
from typing import Dict
15
import os
K
KP 已提交
16
import math
17 18 19 20 21

import paddle
import paddle.nn as nn
import paddle.nn.functional as F

K
KP 已提交
22 23 24 25 26
from paddlenlp.transformers.bert.modeling import BertForSequenceClassification, BertModel, BertForTokenClassification
from paddlenlp.transformers.bert.tokenizer import BertTokenizer
from paddlenlp.metrics import ChunkEvaluator
from paddlehub.module.module import moduleinfo
from paddlehub.module.nlp_module import TransformerModule
27 28 29 30 31
from paddlehub.utils.log import logger


@moduleinfo(
    name="bert-base-uncased",
K
KP 已提交
32
    version="2.0.1",
33 34 35 36
    summary=
    "bert_uncased_L-12_H-768_A-12, 12-layer, 768-hidden, 12-heads, 110M parameters. The module is executed as paddle.dygraph.",
    author="paddlepaddle",
    author_email="",
K
KP 已提交
37 38
    type="nlp/semantic_model",
    meta=TransformerModule)
39 40 41 42 43 44 45
class Bert(nn.Layer):
    """
    BERT model
    """

    def __init__(
            self,
K
KP 已提交
46 47 48 49 50
            task: str = None,
            load_checkpoint: str = None,
            label_map: Dict = None,
            num_classes: int = 2,
            **kwargs,
51 52
    ):
        super(Bert, self).__init__()
K
KP 已提交
53 54 55 56 57 58
        if label_map:
            self.label_map = label_map
            self.num_classes = len(label_map)
        else:
            self.num_classes = num_classes

59
        if task == 'sequence_classification':
K
KP 已提交
60 61 62 63 64 65 66
            task = 'seq-cls'
            logger.warning(
                "current task name 'sequence_classification' was renamed to 'seq-cls', "
                "'sequence_classification' has been deprecated and will be removed in the future.",
            )
        if task == 'seq-cls':
            self.model = BertForSequenceClassification.from_pretrained(pretrained_model_name_or_path='bert-base-uncased', num_classes=self.num_classes, **kwargs)
67
            self.criterion = paddle.nn.loss.CrossEntropyLoss()
K
KP 已提交
68 69 70 71 72 73 74
            self.metric = paddle.metric.Accuracy()
        elif task == 'token-cls':
            self.model = BertForTokenClassification.from_pretrained(pretrained_model_name_or_path='bert-base-uncased', num_classes=self.num_classes, **kwargs)
            self.criterion = paddle.nn.loss.CrossEntropyLoss()
            self.metric = ChunkEvaluator(
                label_list=[self.label_map[i] for i in sorted(self.label_map.keys())]
            )
75
        elif task is None:
K
KP 已提交
76
            self.model = BertModel.from_pretrained(pretrained_model_name_or_path='bert-base-uncased', **kwargs)
77
        else:
K
KP 已提交
78 79
            raise RuntimeError("Unknown task {}, task should be one in {}".format(
                task, self._tasks_supported))
80 81 82 83 84 85 86 87

        self.task = task

        if load_checkpoint is not None and os.path.isfile(load_checkpoint):
            state_dict = paddle.load(load_checkpoint)
            self.set_state_dict(state_dict)
            logger.info('Loaded parameters from %s' % os.path.abspath(load_checkpoint))

K
KP 已提交
88
    def forward(self, input_ids, token_type_ids=None, position_ids=None, attention_mask=None, seq_lengths=None, labels=None):
89
        result = self.model(input_ids, token_type_ids, position_ids, attention_mask)
K
KP 已提交
90
        if self.task == 'seq-cls':
91 92 93 94 95 96
            logits = result
            probs = F.softmax(logits, axis=1)
            if labels is not None:
                loss = self.criterion(logits, labels)
                correct = self.metric.compute(probs, labels)
                acc = self.metric.update(correct)
K
KP 已提交
97
                return probs, loss, {'acc': acc}
98
            return probs
K
KP 已提交
99 100 101 102 103 104 105 106 107 108 109 110 111
        elif self.task == 'token-cls':
            logits = result
            token_level_probs = F.softmax(logits, axis=-1)
            preds = token_level_probs.argmax(axis=-1)
            if labels is not None:
                loss = self.criterion(logits, labels.unsqueeze(-1))
                num_infer_chunks, num_label_chunks, num_correct_chunks = \
                    self.metric.compute(None, seq_lengths, preds, labels)
                self.metric.update(
                    num_infer_chunks.numpy(), num_label_chunks.numpy(), num_correct_chunks.numpy())
                _, _, f1_score = map(float, self.metric.accumulate())
                return token_level_probs, loss, {'f1_score': f1_score}
            return token_level_probs
112 113 114 115
        else:
            sequence_output, pooled_output = result
            return sequence_output, pooled_output

K
KP 已提交
116 117
    @staticmethod
    def get_tokenizer(*args, **kwargs):
118 119 120
        """
        Gets the tokenizer that is customized for this module.
        """
K
KP 已提交
121 122
        return BertTokenizer.from_pretrained(
            pretrained_model_name_or_path='bert-base-uncased', *args, **kwargs)