module.py 2.7 KB
Newer Older
W
wuzewu 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79
# coding:utf-8
# Copyright (c) 2019  PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os

from paddlehub import TransformerModule
from paddlehub.module.module import moduleinfo

from ernie.model.ernie import ErnieModel, ErnieConfig


@moduleinfo(
    name="ernie",
    version="1.2.0",
    summary=
    "Baidu's ERNIE, Enhanced Representation through kNowledge IntEgration, max_seq_len=512 when predtrained",
    author="baidu-nlp",
    author_email="",
    type="nlp/semantic_model",
)
class Ernie(TransformerModule):
    def _initialize(self):
        ernie_config_path = os.path.join(self.directory, "assets",
                                         "ernie_config.json")
        self.ernie_config = ErnieConfig(ernie_config_path)
        self.MAX_SEQ_LEN = 512
        self.params_path = os.path.join(self.directory, "assets", "params")
        self.vocab_path = os.path.join(self.directory, "assets", "vocab.txt")\

    def net(self, input_ids, position_ids, segment_ids, input_mask):
        """
        create neural network.

        Args:
            input_ids (tensor): the word ids.
            position_ids (tensor): the position ids.
            segment_ids (tensor): the segment ids.
            input_mask (tensor): the padding mask.

        Returns:
            pooled_output (tensor):  sentence-level output for classification task.
            sequence_output (tensor): token-level output for sequence task.
        """
        self.ernie_config._config_dict['use_task_id'] = False
        ernie = ErnieModel(
            src_ids=input_ids,
            position_ids=position_ids,
            sentence_ids=segment_ids,
            input_mask=input_mask,
            config=self.ernie_config,
            use_fp16=False)
        pooled_output = ernie.get_pooled_output()
        sequence_output = ernie.get_sequence_output()
        return pooled_output, sequence_output

    def param_prefix(self):
        return "@HUB_ernie-stable@"


if __name__ == '__main__':
    test_module = Ernie()