categorical_field_reader.py 1.6 KB
Newer Older
W
webyfdt 已提交
1 2 3 4 5 6
# -*- coding: utf-8 -*
"""
:py:class:`CategoricalField`

"""

K
Kennycao123 已提交
7 8 9 10
from erniekit.common.register import RegisterSet
from erniekit.data.field_reader.custom_text_field_reader import CustomTextFieldReader
from erniekit.data.util_helper import pad_batch_data
from erniekit.utils.util_helper import convert_to_unicode
W
webyfdt 已提交
11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45


@RegisterSet.field_reader.register
class CategoricalField(CustomTextFieldReader):
    """通用文本(string)类型的field_reader, 不进行分词,直接通过词表将明文转成id
    文本处理规则是,文本类型的数据会自动添加padding和mask,并返回length. 比较特殊的一点是,这个field_reader中的length全是1
    """
    def __init__(self, field_config):
        """
        :param field_config:
        """
        CustomTextFieldReader.__init__(self, field_config=field_config)

    def convert_texts_to_ids(self, batch_text):
        """将一个batch的明文text转成id
        :param batch_text:
        :return:
        """
        src_ids = []
        for text in batch_text:
            src_id = self.tokenizer.convert_tokens_to_ids(convert_to_unicode(text))
            src_ids.append(src_id)

        return_list = []
        padded_ids, mask_ids, batch_seq_lens = pad_batch_data(src_ids,
                                                              pad_idx=self.field_config.padding_id,
                                                              return_input_mask=True,
                                                              return_seq_lens=True)
        return_list.append(padded_ids)
        return_list.append(mask_ids)
        return_list.append(batch_seq_lens)

        return return_list