module.py 2.5 KB
Newer Older
W
wuzewu 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75
# coding:utf-8
# Copyright (c) 2019  PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os

from paddlehub import TransformerModule
from paddlehub.module.module import moduleinfo

from rbtl3.model.bert import BertConfig, BertModel


@moduleinfo(
    name="rbtl3",
    version="1.0.0",
    summary="rbtl3, 3-layer, 1024-hidden, 16-heads, 61M parameters ",
    author="ymcui",
    author_email="ymcui@ir.hit.edu.cn",
    type="nlp/semantic_model",
)
class BertWwm(TransformerModule):
    def _initialize(self):
        self.MAX_SEQ_LEN = 512
        self.params_path = os.path.join(self.directory, "assets", "params")
        self.vocab_path = os.path.join(self.directory, "assets", "vocab.txt")

        bert_config_path = os.path.join(self.directory, "assets",
                                        "bert_config_rbtl3.json")
        self.bert_config = BertConfig(bert_config_path)

    def net(self, input_ids, position_ids, segment_ids, input_mask):
        """
        create neural network.

        Args:
            input_ids (tensor): the word ids.
            position_ids (tensor): the position ids.
            segment_ids (tensor): the segment ids.
            input_mask (tensor): the padding mask.

        Returns:
            pooled_output (tensor):  sentence-level output for classification task.
            sequence_output (tensor): token-level output for sequence task.
        """
        bert = BertModel(
            src_ids=input_ids,
            position_ids=position_ids,
            sentence_ids=segment_ids,
            input_mask=input_mask,
            config=self.bert_config,
            use_fp16=False)
        pooled_output = bert.get_pooled_output()
        sequence_output = bert.get_sequence_output()
        return pooled_output, sequence_output


if __name__ == '__main__':
    test_module = BertWwm()